Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Commit Β·
4a5bfab
1
Parent(s): 82d2ab7
Put all the models together (expect UI)
Browse filesThis view is limited to 50 files because it contains too many changes. Β See raw diff
- .gitignore +29 -8
- Dockerfile +27 -0
- README.md +87 -6
- app.py +1 -0
- checkpoints/hybrid_focus_config.json +10 -0
- MLP/models/meta_20260224_024200.npz β checkpoints/meta_best.npz +2 -2
- MLP/models/mlp_20260224_024200.joblib β checkpoints/mlp_best.pt +2 -2
- best_eye_cnn.pth β checkpoints/model_best.joblib +2 -2
- MLP/models/scaler_20260224_024200.joblib β checkpoints/scaler_best.joblib +2 -2
- checkpoints/xgboost_face_orientation_best.json +0 -0
- {data_preparation β data}/CNN/eye_crops/val/open/.gitkeep +0 -0
- data/README.md +47 -0
- {data_preparation β data}/collected_Abdelrahman/abdelrahman_20260306_023035.npz +0 -0
- {data_preparation β data}/collected_Jarek/Jarek_20260225_012931.npz +0 -0
- {data_preparation β data}/collected_Junhao/Junhao_20260303_113554.npz +0 -0
- {data_preparation β data}/collected_Kexin/kexin2_20260305_180229.npz +0 -0
- {data_preparation β data}/collected_Kexin/kexin_20260224_151043.npz +0 -0
- {data_preparation β data}/collected_Langyuan/Langyuan_20260303_153145.npz +0 -0
- {data_preparation β data}/collected_Mohamed/session_20260224_010131.npz +0 -0
- {data_preparation β data}/collected_Yingtao/Yingtao_20260306_023937.npz +0 -0
- {data_preparation/collected_Ayten β data/collected_ayten}/ayten_session_1.npz +0 -0
- {data_preparation/collected_Saba β data/collected_saba}/saba_20260306_230710.npz +0 -0
- data_preparation/MLP/explore_collected_data.ipynb +0 -0
- data_preparation/MLP/train_mlp.ipynb +0 -0
- data_preparation/README.md +61 -27
- {models/geometric β data_preparation}/__init__.py +0 -0
- data_preparation/data_exploration.ipynb +0 -0
- data_preparation/prepare_dataset.py +232 -0
- docker-compose.yml +5 -0
- eslint.config.js +29 -0
- evaluation/README.md +45 -2
- index.html +17 -0
- main.py +964 -0
- models/README.md +51 -8
- models/{attention/__init__.py β __init__.py} +0 -0
- models/attention/classifier.py +0 -0
- models/attention/fusion.py +0 -0
- models/attention/train.py +0 -0
- models/cnn/notebooks/EyeCNN.ipynb +107 -0
- models/cnn/notebooks/EyeCNN_Train_Evaluate_new.ipynb +0 -0
- models/cnn/notebooks/EyeCNN_Training_Evaluate.ipynb +0 -0
- models/cnn/notebooks/README.md +1 -0
- models/{attention/collect_features.py β collect_features.py} +26 -19
- models/eye_classifier.py +69 -0
- models/eye_crop.py +77 -0
- models/{geometric/eye_behaviour/eye_scorer.py β eye_scorer.py} +7 -3
- models/{pretrained/face_mesh/face_mesh.py β face_mesh.py} +6 -3
- models/geometric/eye_behaviour/__init__.py +0 -0
- models/geometric/face_orientation/__init__.py +0 -1
- models/{geometric/face_orientation/head_pose.py β head_pose.py} +10 -1
.gitignore
CHANGED
|
@@ -1,4 +1,26 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
*.py[cod]
|
| 3 |
*$py.class
|
| 4 |
*.so
|
|
@@ -9,12 +31,11 @@ env/
|
|
| 9 |
.env
|
| 10 |
*.egg-info/
|
| 11 |
.eggs/
|
| 12 |
-
dist/
|
| 13 |
build/
|
| 14 |
-
.idea/
|
| 15 |
-
.vscode/
|
| 16 |
-
*.swp
|
| 17 |
-
*.swo
|
| 18 |
-
docs/
|
| 19 |
-
.DS_Store
|
| 20 |
Thumbs.db
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Logs
|
| 2 |
+
logs
|
| 3 |
+
*.log
|
| 4 |
+
npm-debug.log*
|
| 5 |
+
yarn-debug.log*
|
| 6 |
+
yarn-error.log*
|
| 7 |
+
pnpm-debug.log*
|
| 8 |
+
lerna-debug.log*
|
| 9 |
+
|
| 10 |
+
node_modules/
|
| 11 |
+
dist/
|
| 12 |
+
dist-ssr/
|
| 13 |
+
*.local
|
| 14 |
+
|
| 15 |
+
# Editor directories and files
|
| 16 |
+
.vscode/
|
| 17 |
+
.idea/
|
| 18 |
+
.DS_Store
|
| 19 |
+
*.suo
|
| 20 |
+
*.ntvs*
|
| 21 |
+
*.njsproj
|
| 22 |
+
*.sln
|
| 23 |
+
*.sw?
|
| 24 |
*.py[cod]
|
| 25 |
*$py.class
|
| 26 |
*.so
|
|
|
|
| 31 |
.env
|
| 32 |
*.egg-info/
|
| 33 |
.eggs/
|
|
|
|
| 34 |
build/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
Thumbs.db
|
| 36 |
+
|
| 37 |
+
# Project specific
|
| 38 |
+
focus_guard.db
|
| 39 |
+
static/
|
| 40 |
+
__pycache__/
|
| 41 |
+
docs/
|
Dockerfile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
RUN useradd -m -u 1000 user
|
| 4 |
+
ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
|
| 5 |
+
|
| 6 |
+
ENV PYTHONUNBUFFERED=1
|
| 7 |
+
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends libglib2.0-0 libsm6 libxrender1 libxext6 libxcb1 libgl1 libgomp1 ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev libavdevice-dev libopus-dev libvpx-dev libsrtp2-dev build-essential nodejs npm && rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
COPY requirements.txt ./
|
| 13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
+
|
| 15 |
+
COPY . .
|
| 16 |
+
|
| 17 |
+
RUN npm install && npm run build && mkdir -p /app/static && cp -R dist/* /app/static/
|
| 18 |
+
|
| 19 |
+
ENV FOCUSGUARD_CACHE_DIR=/app/.cache/focusguard
|
| 20 |
+
RUN python -c "from models.face_mesh import _ensure_model; _ensure_model()"
|
| 21 |
+
|
| 22 |
+
RUN mkdir -p /app/data && chown -R user:user /app
|
| 23 |
+
|
| 24 |
+
USER user
|
| 25 |
+
EXPOSE 7860
|
| 26 |
+
|
| 27 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--log-level", "debug"]
|
README.md
CHANGED
|
@@ -1,10 +1,91 @@
|
|
| 1 |
# FocusGuard
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
|
| 6 |
-
- **models/** β face mesh, head pose, eye scorer, YOLO classifier, MLP training, attention feature collection
|
| 7 |
-
- **evaluation/** β metrics and run logs
|
| 8 |
-
- **ui/** β live demo (geometry+YOLO or MLP-only)
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# FocusGuard
|
| 2 |
|
| 3 |
+
Real-time webcam-based focus detection system combining geometric feature extraction with machine learning classification. The pipeline extracts 17 facial features (EAR, gaze, head pose, PERCLOS, blink rate, etc.) from MediaPipe landmarks and classifies attentiveness using MLP and XGBoost models. Served via a React + FastAPI web application with live WebSocket video.
|
| 4 |
|
| 5 |
+
## 1. Project Structure
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
```
|
| 8 |
+
βββ data/ Raw collected sessions (collected_<name>/*.npz)
|
| 9 |
+
βββ data_preparation/ Data loading, cleaning, and exploration
|
| 10 |
+
βββ notebooks/ Training notebooks (MLP, XGBoost) with LOPO evaluation
|
| 11 |
+
βββ models/ Feature extraction modules and training scripts
|
| 12 |
+
βββ checkpoints/ All saved weights (mlp_best.pt, xgboost_*_best.json, GRU, scalers)
|
| 13 |
+
βββ evaluation/ Training logs and metrics (JSON)
|
| 14 |
+
βββ ui/ Live OpenCV demo and inference pipeline
|
| 15 |
+
βββ src/ React/Vite frontend source
|
| 16 |
+
βββ static/ Built frontend (served by FastAPI)
|
| 17 |
+
βββ app.py / main.py FastAPI backend (API, WebSocket, DB)
|
| 18 |
+
βββ requirements.txt Python dependencies
|
| 19 |
+
βββ package.json Frontend dependencies
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
## 2. Setup
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
python -m venv venv
|
| 26 |
+
source venv/bin/activate
|
| 27 |
+
pip install -r requirements.txt
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
Frontend (only needed if modifying the React app):
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
npm install
|
| 34 |
+
npm run build
|
| 35 |
+
cp -r dist/* static/
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## 3. Running
|
| 39 |
+
|
| 40 |
+
**Web application (API + frontend):**
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
uvicorn app:app --host 0.0.0.0 --port 7860
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Open http://localhost:7860 in a browser.
|
| 47 |
+
|
| 48 |
+
**Live camera demo (OpenCV):**
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
python ui/live_demo.py
|
| 52 |
+
python ui/live_demo.py --xgb # XGBoost mode
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
**Training:**
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
python -m models.mlp.train # MLP
|
| 59 |
+
python -m models.xgboost.train # XGBoost
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## 4. Dataset
|
| 63 |
+
|
| 64 |
+
- **9 participants**, each recorded via webcam with real-time labelling (focused / unfocused)
|
| 65 |
+
- **144,793 total samples**, 10 selected features, binary classification
|
| 66 |
+
- Collected using `python -m models.collect_features --name <name>`
|
| 67 |
+
- Stored as `.npz` files in `data/collected_<name>/`
|
| 68 |
+
|
| 69 |
+
## 5. Models
|
| 70 |
+
|
| 71 |
+
| Model | Test Accuracy | Test F1 | ROC-AUC |
|
| 72 |
+
|-------|--------------|---------|---------|
|
| 73 |
+
| XGBoost (600 trees, depth 8, lr 0.149) | 95.87% | 0.959 | 0.991 |
|
| 74 |
+
| MLP (64β32, 30 epochs, lr 1e-3) | 92.92% | 0.929 | 0.971 |
|
| 75 |
+
|
| 76 |
+
Both evaluated on a held-out 15% stratified test split. LOPO (Leave-One-Person-Out) cross-validation available in `notebooks/`.
|
| 77 |
+
|
| 78 |
+
## 6. Feature Pipeline
|
| 79 |
+
|
| 80 |
+
1. **Face mesh** β MediaPipe 478-landmark detection
|
| 81 |
+
2. **Head pose** β solvePnP β yaw, pitch, roll, face score, gaze offset, head deviation
|
| 82 |
+
3. **Eye scorer** β EAR (left/right/avg), horizontal/vertical gaze ratio, MAR
|
| 83 |
+
4. **Temporal tracking** β PERCLOS, blink rate, closure duration, yawn duration
|
| 84 |
+
5. **Classification** β 10-feature vector β MLP or XGBoost β focused / unfocused
|
| 85 |
+
|
| 86 |
+
## 7. Tech Stack
|
| 87 |
+
|
| 88 |
+
- **Backend:** Python, FastAPI, WebSocket, aiosqlite
|
| 89 |
+
- **Frontend:** React, Vite, TypeScript
|
| 90 |
+
- **ML:** PyTorch (MLP), XGBoost, scikit-learn
|
| 91 |
+
- **Vision:** MediaPipe, OpenCV
|
app.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from main import app
|
checkpoints/hybrid_focus_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"w_mlp": 0.6000000000000001,
|
| 3 |
+
"w_geo": 0.3999999999999999,
|
| 4 |
+
"threshold": 0.35,
|
| 5 |
+
"use_yawn_veto": true,
|
| 6 |
+
"geo_face_weight": 0.4,
|
| 7 |
+
"geo_eye_weight": 0.6,
|
| 8 |
+
"mar_yawn_threshold": 0.55,
|
| 9 |
+
"metric": "f1"
|
| 10 |
+
}
|
MLP/models/meta_20260224_024200.npz β checkpoints/meta_best.npz
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d78d1df5e25536a2c82c4b8f5fd0c26dd35f44b28fd59761634cbf78c7546f8
|
| 3 |
+
size 4196
|
MLP/models/mlp_20260224_024200.joblib β checkpoints/mlp_best.pt
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2f55129785b6882c304483aa5399f5bf6c9ed6e73dfec7ca6f36cd0436156c8
|
| 3 |
+
size 14497
|
best_eye_cnn.pth β checkpoints/model_best.joblib
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:183f2d4419e0eb1e58704e5a7312eb61e331523566d4dc551054a07b3aac7557
|
| 3 |
+
size 5775881
|
MLP/models/scaler_20260224_024200.joblib β checkpoints/scaler_best.joblib
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02ed6b4c0d99e0254c6a740a949da2384db58ec7d3e6df6432b9bfcd3a296c71
|
| 3 |
+
size 783
|
checkpoints/xgboost_face_orientation_best.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
{data_preparation β data}/CNN/eye_crops/val/open/.gitkeep
RENAMED
|
File without changes
|
data/README.md
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# data/
|
| 2 |
+
|
| 3 |
+
Raw collected session data used for model training and evaluation.
|
| 4 |
+
|
| 5 |
+
## 1. Contents
|
| 6 |
+
|
| 7 |
+
Each `collected_<name>/` folder contains `.npz` files for one participant:
|
| 8 |
+
|
| 9 |
+
| Folder | Participant | Samples |
|
| 10 |
+
|--------|-------------|---------|
|
| 11 |
+
| `collected_Abdelrahman/` | Abdelrahman | 15,870 |
|
| 12 |
+
| `collected_Jarek/` | Jarek | 14,829 |
|
| 13 |
+
| `collected_Junhao/` | Junhao | 8,901 |
|
| 14 |
+
| `collected_Kexin/` | Kexin | 32,312 (2 sessions) |
|
| 15 |
+
| `collected_Langyuan/` | Langyuan | 15,749 |
|
| 16 |
+
| `collected_Mohamed/` | Mohamed | 13,218 |
|
| 17 |
+
| `collected_Yingtao/` | Yingtao | 17,591 |
|
| 18 |
+
| `collected_ayten/` | Ayten | 17,621 |
|
| 19 |
+
| `collected_saba/` | Saba | 8,702 |
|
| 20 |
+
| **Total** | **9 participants** | **144,793** |
|
| 21 |
+
|
| 22 |
+
## 2. File Format
|
| 23 |
+
|
| 24 |
+
Each `.npz` file contains:
|
| 25 |
+
|
| 26 |
+
| Key | Shape | Description |
|
| 27 |
+
|-----|-------|-------------|
|
| 28 |
+
| `features` | (N, 17) | 17-dimensional feature vectors (float32) |
|
| 29 |
+
| `labels` | (N,) | Binary labels: 0 = unfocused, 1 = focused |
|
| 30 |
+
| `feature_names` | (17,) | Column names for the 17 features |
|
| 31 |
+
|
| 32 |
+
## 3. Feature List
|
| 33 |
+
|
| 34 |
+
`ear_left`, `ear_right`, `ear_avg`, `h_gaze`, `v_gaze`, `mar`, `yaw`, `pitch`, `roll`, `s_face`, `s_eye`, `gaze_offset`, `head_deviation`, `perclos`, `blink_rate`, `closure_duration`, `yawn_duration`
|
| 35 |
+
|
| 36 |
+
10 of these are selected for training (see `data_preparation/prepare_dataset.py`).
|
| 37 |
+
|
| 38 |
+
## 4. Collection
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
python -m models.collect_features --name yourname
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
1. Webcam opens with live overlay
|
| 45 |
+
2. Press **1** = focused, **0** = unfocused (switch every 10β30 sec)
|
| 46 |
+
3. Press **p** to pause/resume
|
| 47 |
+
4. Press **q** to stop and save
|
{data_preparation β data}/collected_Abdelrahman/abdelrahman_20260306_023035.npz
RENAMED
|
File without changes
|
{data_preparation β data}/collected_Jarek/Jarek_20260225_012931.npz
RENAMED
|
File without changes
|
{data_preparation β data}/collected_Junhao/Junhao_20260303_113554.npz
RENAMED
|
File without changes
|
{data_preparation β data}/collected_Kexin/kexin2_20260305_180229.npz
RENAMED
|
File without changes
|
{data_preparation β data}/collected_Kexin/kexin_20260224_151043.npz
RENAMED
|
File without changes
|
{data_preparation β data}/collected_Langyuan/Langyuan_20260303_153145.npz
RENAMED
|
File without changes
|
{data_preparation β data}/collected_Mohamed/session_20260224_010131.npz
RENAMED
|
File without changes
|
{data_preparation β data}/collected_Yingtao/Yingtao_20260306_023937.npz
RENAMED
|
File without changes
|
{data_preparation/collected_Ayten β data/collected_ayten}/ayten_session_1.npz
RENAMED
|
File without changes
|
{data_preparation/collected_Saba β data/collected_saba}/saba_20260306_230710.npz
RENAMED
|
File without changes
|
data_preparation/MLP/explore_collected_data.ipynb
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data_preparation/MLP/train_mlp.ipynb
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data_preparation/README.md
CHANGED
|
@@ -1,41 +1,75 @@
|
|
| 1 |
-
#
|
| 2 |
|
| 3 |
-
|
| 4 |
|
| 5 |
-
##
|
| 6 |
-
Contains raw session files in `.npz` format.
|
| 7 |
-
Generated using:
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
- 17-dimensional feature vectors
|
| 13 |
-
- Corresponding labels
|
| 14 |
|
| 15 |
-
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
###
|
| 19 |
-
Contains notebooks for:
|
| 20 |
-
- Exploring collected data
|
| 21 |
-
- Training the sklearn MLP model (10 features)
|
| 22 |
|
| 23 |
-
|
| 24 |
-
.
|
| 25 |
|
| 26 |
-
|
| 27 |
|
| 28 |
-
|
| 29 |
-
Eye crop directory structure for CNN training (YOLO).
|
| 30 |
|
| 31 |
-
|
| 32 |
|
| 33 |
-
|
| 34 |
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# data_preparation/
|
| 2 |
|
| 3 |
+
Shared data loading, cleaning, and exploratory analysis.
|
| 4 |
|
| 5 |
+
## 1. Files
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
| File | Description |
|
| 8 |
+
|------|-------------|
|
| 9 |
+
| `prepare_dataset.py` | Central data loading module used by all training scripts and notebooks |
|
| 10 |
+
| `data_exploration.ipynb` | EDA notebook: feature distributions, class balance, correlations |
|
| 11 |
|
| 12 |
+
## 2. prepare_dataset.py
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
Provides a consistent pipeline for loading raw `.npz` data from `data/`:
|
| 15 |
|
| 16 |
+
| Function | Purpose |
|
| 17 |
+
|----------|---------|
|
| 18 |
+
| `load_all_pooled(model_name)` | Load all participants, clean, select features, concatenate |
|
| 19 |
+
| `load_per_person(model_name)` | Load grouped by person (for LOPO cross-validation) |
|
| 20 |
+
| `get_numpy_splits(model_name)` | Load + stratified 70/15/15 split + StandardScaler |
|
| 21 |
+
| `get_dataloaders(model_name)` | Same as above, wrapped in PyTorch DataLoaders |
|
| 22 |
+
| `_split_and_scale(features, labels, ...)` | Reusable split + optional scaling |
|
| 23 |
|
| 24 |
+
### Cleaning rules
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
- `yaw` clipped to [-45, 45], `pitch`/`roll` to [-30, 30]
|
| 27 |
+
- `ear_left`, `ear_right`, `ear_avg` clipped to [0, 0.85]
|
| 28 |
|
| 29 |
+
### Selected features (face_orientation)
|
| 30 |
|
| 31 |
+
`head_deviation`, `s_face`, `s_eye`, `h_gaze`, `pitch`, `ear_left`, `ear_avg`, `ear_right`, `gaze_offset`, `perclos`
|
|
|
|
| 32 |
|
| 33 |
+
## 3. data_exploration.ipynb
|
| 34 |
|
| 35 |
+
Run from this folder or from the project root. Covers:
|
| 36 |
|
| 37 |
+
1. Per-feature statistics (mean, std, min, max)
|
| 38 |
+
2. Class distribution (focused vs unfocused)
|
| 39 |
+
3. Feature histograms and box plots
|
| 40 |
+
4. Correlation matrix
|
| 41 |
|
| 42 |
+
## 4. How to run
|
| 43 |
+
|
| 44 |
+
`prepare_dataset.py` is a **library module**, not a standalone script. You donβt run it directly; you import it from code that needs data.
|
| 45 |
+
|
| 46 |
+
**From repo root:**
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
# Optional: quick test that loading works
|
| 50 |
+
python -c "
|
| 51 |
+
from data_preparation.prepare_dataset import load_all_pooled
|
| 52 |
+
X, y, names = load_all_pooled('face_orientation')
|
| 53 |
+
print(f'Loaded {X.shape[0]} samples, {X.shape[1]} features: {names}')
|
| 54 |
+
"
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
**Used by:**
|
| 58 |
+
|
| 59 |
+
- `python -m models.mlp.train`
|
| 60 |
+
- `python -m models.xgboost.train`
|
| 61 |
+
- `notebooks/mlp.ipynb`, `notebooks/xgboost.ipynb`
|
| 62 |
+
- `data_preparation/data_exploration.ipynb`
|
| 63 |
+
|
| 64 |
+
## 5. Usage (in code)
|
| 65 |
+
|
| 66 |
+
```python
|
| 67 |
+
from data_preparation.prepare_dataset import load_all_pooled, get_numpy_splits
|
| 68 |
+
|
| 69 |
+
# pooled data
|
| 70 |
+
X, y, names = load_all_pooled("face_orientation")
|
| 71 |
+
|
| 72 |
+
# ready-to-train splits
|
| 73 |
+
splits, n_features, n_classes, scaler = get_numpy_splits("face_orientation")
|
| 74 |
+
X_train, y_train = splits["X_train"], splits["y_train"]
|
| 75 |
+
```
|
{models/geometric β data_preparation}/__init__.py
RENAMED
|
File without changes
|
data_preparation/data_exploration.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data_preparation/prepare_dataset.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sklearn.preprocessing import StandardScaler
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import Dataset, DataLoader
|
| 11 |
+
except ImportError: # pragma: no cover
|
| 12 |
+
torch = None
|
| 13 |
+
|
| 14 |
+
class Dataset: # type: ignore
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
class _MissingTorchDataLoader: # type: ignore
|
| 18 |
+
def __init__(self, *args, **kwargs):
|
| 19 |
+
raise ImportError(
|
| 20 |
+
"PyTorch not installed"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
DataLoader = _MissingTorchDataLoader # type: ignore
|
| 24 |
+
|
| 25 |
+
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
|
| 26 |
+
|
| 27 |
+
SELECTED_FEATURES = {
|
| 28 |
+
"face_orientation": [
|
| 29 |
+
'head_deviation', 's_face', 's_eye', 'h_gaze', 'pitch',
|
| 30 |
+
'ear_left', 'ear_avg', 'ear_right', 'gaze_offset', 'perclos'
|
| 31 |
+
],
|
| 32 |
+
"eye_behaviour": [
|
| 33 |
+
'ear_left', 'ear_right', 'ear_avg', 'mar',
|
| 34 |
+
'blink_rate', 'closure_duration', 'perclos', 'yawn_duration'
|
| 35 |
+
]
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class FeatureVectorDataset(Dataset):
|
| 40 |
+
def __init__(self, features: np.ndarray, labels: np.ndarray):
|
| 41 |
+
self.features = torch.tensor(features, dtype=torch.float32)
|
| 42 |
+
self.labels = torch.tensor(labels, dtype=torch.long)
|
| 43 |
+
|
| 44 |
+
def __len__(self):
|
| 45 |
+
return len(self.labels)
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, idx):
|
| 48 |
+
return self.features[idx], self.labels[idx]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ββ Low-level helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
|
| 53 |
+
def _clean_npz(raw, names):
|
| 54 |
+
"""Apply clipping rules in-place. Shared by all loaders."""
|
| 55 |
+
for col, lo, hi in [('yaw', -45, 45), ('pitch', -30, 30), ('roll', -30, 30)]:
|
| 56 |
+
if col in names:
|
| 57 |
+
raw[:, names.index(col)] = np.clip(raw[:, names.index(col)], lo, hi)
|
| 58 |
+
for feat in ['ear_left', 'ear_right', 'ear_avg']:
|
| 59 |
+
if feat in names:
|
| 60 |
+
raw[:, names.index(feat)] = np.clip(raw[:, names.index(feat)], 0, 0.85)
|
| 61 |
+
return raw
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _load_one_npz(npz_path, target_features):
|
| 65 |
+
"""Load a single .npz file, clean and select features. Returns (X, y, selected_feature_names)."""
|
| 66 |
+
data = np.load(npz_path, allow_pickle=True)
|
| 67 |
+
raw = data['features'].astype(np.float32)
|
| 68 |
+
labels = data['labels'].astype(np.int64)
|
| 69 |
+
names = list(data['feature_names'])
|
| 70 |
+
raw = _clean_npz(raw, names)
|
| 71 |
+
selected = [f for f in target_features if f in names]
|
| 72 |
+
idx = [names.index(f) for f in selected]
|
| 73 |
+
return raw[:, idx], labels, selected
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ββ Public data loaders ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
+
|
| 78 |
+
def load_all_pooled(model_name: str = "face_orientation", data_dir: str = None):
|
| 79 |
+
"""Load all collected_*/*.npz, clean, select features, concatenate.
|
| 80 |
+
|
| 81 |
+
Returns (X_all, y_all, all_feature_names).
|
| 82 |
+
"""
|
| 83 |
+
data_dir = data_dir or DATA_DIR
|
| 84 |
+
target_features = SELECTED_FEATURES.get(model_name, SELECTED_FEATURES["face_orientation"])
|
| 85 |
+
pattern = os.path.join(data_dir, "collected_*", "*.npz")
|
| 86 |
+
npz_files = sorted(glob.glob(pattern))
|
| 87 |
+
|
| 88 |
+
if not npz_files:
|
| 89 |
+
print("[DATA] Warning: No .npz files found. Falling back to synthetic.")
|
| 90 |
+
X, y = _generate_synthetic_data(model_name)
|
| 91 |
+
return X, y, target_features
|
| 92 |
+
|
| 93 |
+
all_X, all_y = [], []
|
| 94 |
+
all_names = None
|
| 95 |
+
for npz_path in npz_files:
|
| 96 |
+
X, y, names = _load_one_npz(npz_path, target_features)
|
| 97 |
+
if all_names is None:
|
| 98 |
+
all_names = names
|
| 99 |
+
all_X.append(X)
|
| 100 |
+
all_y.append(y)
|
| 101 |
+
print(f"[DATA] + {os.path.basename(npz_path)}: {X.shape[0]} samples")
|
| 102 |
+
|
| 103 |
+
X_all = np.concatenate(all_X, axis=0)
|
| 104 |
+
y_all = np.concatenate(all_y, axis=0)
|
| 105 |
+
print(f"[DATA] Loaded {len(npz_files)} file(s) for '{model_name}': "
|
| 106 |
+
f"{X_all.shape[0]} total samples, {X_all.shape[1]} features")
|
| 107 |
+
return X_all, y_all, all_names
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def load_per_person(model_name: str = "face_orientation", data_dir: str = None):
|
| 111 |
+
"""Load collected_*/*.npz grouped by person (folder name).
|
| 112 |
+
|
| 113 |
+
Returns dict { person_name: (X, y) } where X/y are per-person numpy arrays.
|
| 114 |
+
Also returns (X_all, y_all) as pooled data.
|
| 115 |
+
"""
|
| 116 |
+
data_dir = data_dir or DATA_DIR
|
| 117 |
+
target_features = SELECTED_FEATURES.get(model_name, SELECTED_FEATURES["face_orientation"])
|
| 118 |
+
pattern = os.path.join(data_dir, "collected_*", "*.npz")
|
| 119 |
+
npz_files = sorted(glob.glob(pattern))
|
| 120 |
+
|
| 121 |
+
if not npz_files:
|
| 122 |
+
raise FileNotFoundError(f"No .npz files matching {pattern}")
|
| 123 |
+
|
| 124 |
+
by_person = {}
|
| 125 |
+
all_X, all_y = [], []
|
| 126 |
+
for npz_path in npz_files:
|
| 127 |
+
folder = os.path.basename(os.path.dirname(npz_path))
|
| 128 |
+
person = folder.replace("collected_", "", 1)
|
| 129 |
+
X, y, _ = _load_one_npz(npz_path, target_features)
|
| 130 |
+
all_X.append(X)
|
| 131 |
+
all_y.append(y)
|
| 132 |
+
if person not in by_person:
|
| 133 |
+
by_person[person] = []
|
| 134 |
+
by_person[person].append((X, y))
|
| 135 |
+
print(f"[DATA] + {person}/{os.path.basename(npz_path)}: {X.shape[0]} samples")
|
| 136 |
+
|
| 137 |
+
for person, chunks in by_person.items():
|
| 138 |
+
by_person[person] = (
|
| 139 |
+
np.concatenate([c[0] for c in chunks], axis=0),
|
| 140 |
+
np.concatenate([c[1] for c in chunks], axis=0),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
X_all = np.concatenate(all_X, axis=0)
|
| 144 |
+
y_all = np.concatenate(all_y, axis=0)
|
| 145 |
+
print(f"[DATA] {len(by_person)} persons, {X_all.shape[0]} total samples, {X_all.shape[1]} features")
|
| 146 |
+
return by_person, X_all, y_all
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def load_raw_npz(npz_path):
|
| 150 |
+
"""Load a single .npz without cleaning or feature selection. For exploration notebooks."""
|
| 151 |
+
data = np.load(npz_path, allow_pickle=True)
|
| 152 |
+
features = data['features'].astype(np.float32)
|
| 153 |
+
labels = data['labels'].astype(np.int64)
|
| 154 |
+
names = list(data['feature_names'])
|
| 155 |
+
return features, labels, names
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ββ Legacy helpers (used by models/mlp/train.py and models/xgboost/train.py) β
|
| 159 |
+
|
| 160 |
+
def _load_real_data(model_name: str):
|
| 161 |
+
X, y, _ = load_all_pooled(model_name)
|
| 162 |
+
return X, y
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _generate_synthetic_data(model_name: str):
|
| 166 |
+
target_features = SELECTED_FEATURES.get(model_name, SELECTED_FEATURES["face_orientation"])
|
| 167 |
+
n = 500
|
| 168 |
+
d = len(target_features)
|
| 169 |
+
c = 2
|
| 170 |
+
rng = np.random.RandomState(42)
|
| 171 |
+
features = rng.randn(n, d).astype(np.float32)
|
| 172 |
+
labels = rng.randint(0, c, size=n).astype(np.int64)
|
| 173 |
+
print(f"[DATA] Using synthetic data for '{model_name}': {n} samples, {d} features, {c} classes")
|
| 174 |
+
return features, labels
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _split_and_scale(features, labels, split_ratios, seed, scale):
|
| 178 |
+
"""Split data into train/val/test (stratified) and optionally scale."""
|
| 179 |
+
test_ratio = split_ratios[2]
|
| 180 |
+
val_ratio = split_ratios[1] / (split_ratios[0] + split_ratios[1])
|
| 181 |
+
|
| 182 |
+
X_train_val, X_test, y_train_val, y_test = train_test_split(
|
| 183 |
+
features, labels, test_size=test_ratio, random_state=seed, stratify=labels,
|
| 184 |
+
)
|
| 185 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 186 |
+
X_train_val, y_train_val, test_size=val_ratio, random_state=seed, stratify=y_train_val,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
scaler = None
|
| 190 |
+
if scale:
|
| 191 |
+
scaler = StandardScaler()
|
| 192 |
+
X_train = scaler.fit_transform(X_train)
|
| 193 |
+
X_val = scaler.transform(X_val)
|
| 194 |
+
X_test = scaler.transform(X_test)
|
| 195 |
+
print("[DATA] Applied StandardScaler (fitted on training split)")
|
| 196 |
+
|
| 197 |
+
splits = {
|
| 198 |
+
"X_train": X_train, "y_train": y_train,
|
| 199 |
+
"X_val": X_val, "y_val": y_val,
|
| 200 |
+
"X_test": X_test, "y_test": y_test,
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
print(f"[DATA] Split (stratified): train={len(y_train)}, val={len(y_val)}, test={len(y_test)}")
|
| 204 |
+
return splits, scaler
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_numpy_splits(model_name: str, split_ratios=(0.7, 0.15, 0.15), seed: int = 42, scale: bool = True):
|
| 208 |
+
"""Return raw numpy arrays for non-PyTorch models (e.g. XGBoost)."""
|
| 209 |
+
features, labels = _load_real_data(model_name)
|
| 210 |
+
num_features = features.shape[1]
|
| 211 |
+
num_classes = int(labels.max()) + 1
|
| 212 |
+
splits, scaler = _split_and_scale(features, labels, split_ratios, seed, scale)
|
| 213 |
+
return splits, num_features, num_classes, scaler
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def get_dataloaders(model_name: str, batch_size: int = 32, split_ratios=(0.7, 0.15, 0.15), seed: int = 42, scale: bool = True):
|
| 217 |
+
"""Return PyTorch DataLoaders for neural-network models."""
|
| 218 |
+
features, labels = _load_real_data(model_name)
|
| 219 |
+
num_features = features.shape[1]
|
| 220 |
+
num_classes = int(labels.max()) + 1
|
| 221 |
+
splits, scaler = _split_and_scale(features, labels, split_ratios, seed, scale)
|
| 222 |
+
|
| 223 |
+
train_ds = FeatureVectorDataset(splits["X_train"], splits["y_train"])
|
| 224 |
+
val_ds = FeatureVectorDataset(splits["X_val"], splits["y_val"])
|
| 225 |
+
test_ds = FeatureVectorDataset(splits["X_test"], splits["y_test"])
|
| 226 |
+
|
| 227 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
|
| 228 |
+
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
|
| 229 |
+
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
|
| 230 |
+
|
| 231 |
+
return train_loader, val_loader, test_loader, num_features, num_classes, scaler
|
| 232 |
+
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
focus-guard:
|
| 3 |
+
build: .
|
| 4 |
+
ports:
|
| 5 |
+
- "7860:7860"
|
eslint.config.js
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import js from '@eslint/js'
|
| 2 |
+
import globals from 'globals'
|
| 3 |
+
import reactHooks from 'eslint-plugin-react-hooks'
|
| 4 |
+
import reactRefresh from 'eslint-plugin-react-refresh'
|
| 5 |
+
import { defineConfig, globalIgnores } from 'eslint/config'
|
| 6 |
+
|
| 7 |
+
export default defineConfig([
|
| 8 |
+
globalIgnores(['dist']),
|
| 9 |
+
{
|
| 10 |
+
files: ['**/*.{js,jsx}'],
|
| 11 |
+
extends: [
|
| 12 |
+
js.configs.recommended,
|
| 13 |
+
reactHooks.configs.flat.recommended,
|
| 14 |
+
reactRefresh.configs.vite,
|
| 15 |
+
],
|
| 16 |
+
languageOptions: {
|
| 17 |
+
ecmaVersion: 2020,
|
| 18 |
+
globals: globals.browser,
|
| 19 |
+
parserOptions: {
|
| 20 |
+
ecmaVersion: 'latest',
|
| 21 |
+
ecmaFeatures: { jsx: true },
|
| 22 |
+
sourceType: 'module',
|
| 23 |
+
},
|
| 24 |
+
},
|
| 25 |
+
rules: {
|
| 26 |
+
'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
|
| 27 |
+
},
|
| 28 |
+
},
|
| 29 |
+
])
|
evaluation/README.md
CHANGED
|
@@ -1,3 +1,46 @@
|
|
| 1 |
-
# evaluation
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# evaluation/
|
| 2 |
|
| 3 |
+
Training logs and performance metrics.
|
| 4 |
+
|
| 5 |
+
## 1. Contents
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
logs/
|
| 9 |
+
βββ face_orientation_training_log.json # MLP (latest run)
|
| 10 |
+
βββ mlp_face_orientation_training_log.json # MLP (alternate)
|
| 11 |
+
βββ xgboost_face_orientation_training_log.json # XGBoost
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
## 2. Log Format
|
| 15 |
+
|
| 16 |
+
Each JSON file records the full training history:
|
| 17 |
+
|
| 18 |
+
**MLP logs:**
|
| 19 |
+
```json
|
| 20 |
+
{
|
| 21 |
+
"config": { "epochs": 30, "lr": 0.001, "batch_size": 32, ... },
|
| 22 |
+
"history": {
|
| 23 |
+
"train_loss": [0.287, 0.260, ...],
|
| 24 |
+
"val_loss": [0.256, 0.245, ...],
|
| 25 |
+
"train_acc": [0.889, 0.901, ...],
|
| 26 |
+
"val_acc": [0.905, 0.909, ...]
|
| 27 |
+
},
|
| 28 |
+
"test": { "accuracy": 0.929, "f1": 0.929, "roc_auc": 0.971 }
|
| 29 |
+
}
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
**XGBoost logs:**
|
| 33 |
+
```json
|
| 34 |
+
{
|
| 35 |
+
"config": { "n_estimators": 600, "max_depth": 8, "learning_rate": 0.149, ... },
|
| 36 |
+
"train_losses": [0.577, ...],
|
| 37 |
+
"val_losses": [0.576, ...],
|
| 38 |
+
"test": { "accuracy": 0.959, "f1": 0.959, "roc_auc": 0.991 }
|
| 39 |
+
}
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## 3. Generated By
|
| 43 |
+
|
| 44 |
+
- `python -m models.mlp.train` β writes MLP log
|
| 45 |
+
- `python -m models.xgboost.train` β writes XGBoost log
|
| 46 |
+
- Notebooks in `notebooks/` also save logs here
|
index.html
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8" />
|
| 6 |
+
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
| 7 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 8 |
+
<title>Focus Guard</title>
|
| 9 |
+
<link href="https://fonts.googleapis.com/css2?family=Nunito:wght@400;700&display=swap" rel="stylesheet">
|
| 10 |
+
</head>
|
| 11 |
+
|
| 12 |
+
<body>
|
| 13 |
+
<div id="root"></div>
|
| 14 |
+
<script type="module" src="/src/main.jsx"></script>
|
| 15 |
+
</body>
|
| 16 |
+
|
| 17 |
+
</html>
|
main.py
ADDED
|
@@ -0,0 +1,964 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
|
| 2 |
+
from fastapi.staticfiles import StaticFiles
|
| 3 |
+
from fastapi.responses import FileResponse
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from typing import Optional, List, Any
|
| 7 |
+
import base64
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import aiosqlite
|
| 11 |
+
import json
|
| 12 |
+
from datetime import datetime, timedelta
|
| 13 |
+
import math
|
| 14 |
+
import os
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Callable
|
| 17 |
+
import asyncio
|
| 18 |
+
import concurrent.futures
|
| 19 |
+
import threading
|
| 20 |
+
|
| 21 |
+
from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
|
| 22 |
+
from av import VideoFrame
|
| 23 |
+
|
| 24 |
+
from mediapipe.tasks.python.vision import FaceLandmarksConnections
|
| 25 |
+
from ui.pipeline import FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline
|
| 26 |
+
from models.face_mesh import FaceMeshDetector
|
| 27 |
+
|
| 28 |
+
# ================ FACE MESH DRAWING (server-side, for WebRTC) ================
|
| 29 |
+
|
| 30 |
+
_FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 31 |
+
_CYAN = (255, 255, 0)
|
| 32 |
+
_GREEN = (0, 255, 0)
|
| 33 |
+
_MAGENTA = (255, 0, 255)
|
| 34 |
+
_ORANGE = (0, 165, 255)
|
| 35 |
+
_RED = (0, 0, 255)
|
| 36 |
+
_WHITE = (255, 255, 255)
|
| 37 |
+
_LIGHT_GREEN = (144, 238, 144)
|
| 38 |
+
|
| 39 |
+
_TESSELATION_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION]
|
| 40 |
+
_CONTOUR_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS]
|
| 41 |
+
_LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
|
| 42 |
+
_RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
|
| 43 |
+
_NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2]
|
| 44 |
+
_LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61]
|
| 45 |
+
_LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78]
|
| 46 |
+
_LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145]
|
| 47 |
+
_RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _lm_px(lm, idx, w, h):
|
| 51 |
+
return (int(lm[idx, 0] * w), int(lm[idx, 1] * h))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _draw_polyline(frame, lm, indices, w, h, color, thickness):
|
| 55 |
+
for i in range(len(indices) - 1):
|
| 56 |
+
cv2.line(frame, _lm_px(lm, indices[i], w, h), _lm_px(lm, indices[i + 1], w, h), color, thickness, cv2.LINE_AA)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _draw_face_mesh(frame, lm, w, h):
|
| 60 |
+
"""Draw tessellation, contours, eyebrows, nose, lips, eyes, irises, gaze lines."""
|
| 61 |
+
# Tessellation (gray triangular grid, semi-transparent)
|
| 62 |
+
overlay = frame.copy()
|
| 63 |
+
for s, e in _TESSELATION_CONNS:
|
| 64 |
+
cv2.line(overlay, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), (200, 200, 200), 1, cv2.LINE_AA)
|
| 65 |
+
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
|
| 66 |
+
# Contours
|
| 67 |
+
for s, e in _CONTOUR_CONNS:
|
| 68 |
+
cv2.line(frame, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), _CYAN, 1, cv2.LINE_AA)
|
| 69 |
+
# Eyebrows
|
| 70 |
+
_draw_polyline(frame, lm, _LEFT_EYEBROW, w, h, _LIGHT_GREEN, 2)
|
| 71 |
+
_draw_polyline(frame, lm, _RIGHT_EYEBROW, w, h, _LIGHT_GREEN, 2)
|
| 72 |
+
# Nose
|
| 73 |
+
_draw_polyline(frame, lm, _NOSE_BRIDGE, w, h, _ORANGE, 1)
|
| 74 |
+
# Lips
|
| 75 |
+
_draw_polyline(frame, lm, _LIPS_OUTER, w, h, _MAGENTA, 1)
|
| 76 |
+
_draw_polyline(frame, lm, _LIPS_INNER, w, h, (200, 0, 200), 1)
|
| 77 |
+
# Eyes
|
| 78 |
+
left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32)
|
| 79 |
+
cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA)
|
| 80 |
+
right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32)
|
| 81 |
+
cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA)
|
| 82 |
+
# EAR key points
|
| 83 |
+
for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
|
| 84 |
+
for idx in indices:
|
| 85 |
+
cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA)
|
| 86 |
+
# Irises + gaze lines
|
| 87 |
+
for iris_idx, eye_inner, eye_outer in [
|
| 88 |
+
(FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
|
| 89 |
+
(FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
|
| 90 |
+
]:
|
| 91 |
+
iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32)
|
| 92 |
+
center = iris_pts[0]
|
| 93 |
+
if len(iris_pts) >= 5:
|
| 94 |
+
radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
|
| 95 |
+
radius = max(int(np.mean(radii)), 2)
|
| 96 |
+
cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA)
|
| 97 |
+
cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA)
|
| 98 |
+
eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w)
|
| 99 |
+
eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h)
|
| 100 |
+
dx, dy = center[0] - eye_cx, center[1] - eye_cy
|
| 101 |
+
cv2.line(frame, tuple(center), (int(center[0] + dx * 3), int(center[1] + dy * 3)), _RED, 1, cv2.LINE_AA)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _draw_hud(frame, result, model_name):
|
| 105 |
+
"""Draw status bar and detail overlay matching live_demo.py."""
|
| 106 |
+
h, w = frame.shape[:2]
|
| 107 |
+
is_focused = result["is_focused"]
|
| 108 |
+
status = "FOCUSED" if is_focused else "NOT FOCUSED"
|
| 109 |
+
color = _GREEN if is_focused else _RED
|
| 110 |
+
|
| 111 |
+
# Top bar
|
| 112 |
+
cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1)
|
| 113 |
+
cv2.putText(frame, status, (10, 28), _FONT, 0.8, color, 2, cv2.LINE_AA)
|
| 114 |
+
cv2.putText(frame, model_name.upper(), (w - 150, 28), _FONT, 0.45, _WHITE, 1, cv2.LINE_AA)
|
| 115 |
+
|
| 116 |
+
# Detail line
|
| 117 |
+
conf = result.get("mlp_prob", result.get("raw_score", 0.0))
|
| 118 |
+
mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
|
| 119 |
+
sf = result.get("s_face", 0)
|
| 120 |
+
se = result.get("s_eye", 0)
|
| 121 |
+
detail = f"conf:{conf:.2f} S_face:{sf:.2f} S_eye:{se:.2f}{mar_s}"
|
| 122 |
+
cv2.putText(frame, detail, (10, 48), _FONT, 0.4, _WHITE, 1, cv2.LINE_AA)
|
| 123 |
+
|
| 124 |
+
# Head pose (top right)
|
| 125 |
+
if result.get("yaw") is not None:
|
| 126 |
+
cv2.putText(frame, f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}",
|
| 127 |
+
(w - 280, 48), _FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA)
|
| 128 |
+
|
| 129 |
+
# Yawn indicator
|
| 130 |
+
if result.get("is_yawning"):
|
| 131 |
+
cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA)
|
| 132 |
+
|
| 133 |
+
# Landmark indices used for face mesh drawing on client (union of all groups).
|
| 134 |
+
# Sending only these instead of all 478 saves ~60% of the landmarks payload.
|
| 135 |
+
_MESH_INDICES = sorted(set(
|
| 136 |
+
[10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] # face oval
|
| 137 |
+
+ [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246] # left eye
|
| 138 |
+
+ [362,382,381,380,374,373,390,249,263,466,388,387,386,385,384,398] # right eye
|
| 139 |
+
+ [468,469,470,471,472, 473,474,475,476,477] # irises
|
| 140 |
+
+ [70,63,105,66,107,55,65,52,53,46] # left eyebrow
|
| 141 |
+
+ [300,293,334,296,336,285,295,282,283,276] # right eyebrow
|
| 142 |
+
+ [6,197,195,5,4,1,19,94,2] # nose bridge
|
| 143 |
+
+ [61,146,91,181,84,17,314,405,321,375,291,409,270,269,267,0,37,39,40,185] # lips outer
|
| 144 |
+
+ [78,95,88,178,87,14,317,402,318,324,308,415,310,311,312,13,82,81,80,191] # lips inner
|
| 145 |
+
+ [33,160,158,133,153,145] # left EAR key points
|
| 146 |
+
+ [362,385,387,263,373,380] # right EAR key points
|
| 147 |
+
))
|
| 148 |
+
# Build a lookup: original_index -> position in sparse array, so client can reconstruct.
|
| 149 |
+
_MESH_INDEX_SET = set(_MESH_INDICES)
|
| 150 |
+
|
| 151 |
+
# Initialize FastAPI app
|
| 152 |
+
app = FastAPI(title="Focus Guard API")
|
| 153 |
+
|
| 154 |
+
# Add CORS middleware
|
| 155 |
+
app.add_middleware(
|
| 156 |
+
CORSMiddleware,
|
| 157 |
+
allow_origins=["*"],
|
| 158 |
+
allow_credentials=True,
|
| 159 |
+
allow_methods=["*"],
|
| 160 |
+
allow_headers=["*"],
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
# Global variables
|
| 164 |
+
db_path = "focus_guard.db"
|
| 165 |
+
pcs = set()
|
| 166 |
+
_cached_model_name = "mlp" # in-memory cache, updated via /api/settings
|
| 167 |
+
|
| 168 |
+
async def _wait_for_ice_gathering(pc: RTCPeerConnection):
|
| 169 |
+
if pc.iceGatheringState == "complete":
|
| 170 |
+
return
|
| 171 |
+
done = asyncio.Event()
|
| 172 |
+
|
| 173 |
+
@pc.on("icegatheringstatechange")
|
| 174 |
+
def _on_state_change():
|
| 175 |
+
if pc.iceGatheringState == "complete":
|
| 176 |
+
done.set()
|
| 177 |
+
|
| 178 |
+
await done.wait()
|
| 179 |
+
|
| 180 |
+
# ================ DATABASE MODELS ================
|
| 181 |
+
|
| 182 |
+
async def init_database():
|
| 183 |
+
"""Initialize SQLite database with required tables"""
|
| 184 |
+
async with aiosqlite.connect(db_path) as db:
|
| 185 |
+
# FocusSessions table
|
| 186 |
+
await db.execute("""
|
| 187 |
+
CREATE TABLE IF NOT EXISTS focus_sessions (
|
| 188 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 189 |
+
start_time TIMESTAMP NOT NULL,
|
| 190 |
+
end_time TIMESTAMP,
|
| 191 |
+
duration_seconds INTEGER DEFAULT 0,
|
| 192 |
+
focus_score REAL DEFAULT 0.0,
|
| 193 |
+
total_frames INTEGER DEFAULT 0,
|
| 194 |
+
focused_frames INTEGER DEFAULT 0,
|
| 195 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 196 |
+
)
|
| 197 |
+
""")
|
| 198 |
+
|
| 199 |
+
# FocusEvents table
|
| 200 |
+
await db.execute("""
|
| 201 |
+
CREATE TABLE IF NOT EXISTS focus_events (
|
| 202 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 203 |
+
session_id INTEGER NOT NULL,
|
| 204 |
+
timestamp TIMESTAMP NOT NULL,
|
| 205 |
+
is_focused BOOLEAN NOT NULL,
|
| 206 |
+
confidence REAL NOT NULL,
|
| 207 |
+
detection_data TEXT,
|
| 208 |
+
FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
|
| 209 |
+
)
|
| 210 |
+
""")
|
| 211 |
+
|
| 212 |
+
# UserSettings table
|
| 213 |
+
await db.execute("""
|
| 214 |
+
CREATE TABLE IF NOT EXISTS user_settings (
|
| 215 |
+
id INTEGER PRIMARY KEY CHECK (id = 1),
|
| 216 |
+
sensitivity INTEGER DEFAULT 6,
|
| 217 |
+
notification_enabled BOOLEAN DEFAULT 1,
|
| 218 |
+
notification_threshold INTEGER DEFAULT 30,
|
| 219 |
+
frame_rate INTEGER DEFAULT 30,
|
| 220 |
+
model_name TEXT DEFAULT 'mlp'
|
| 221 |
+
)
|
| 222 |
+
""")
|
| 223 |
+
|
| 224 |
+
# Insert default settings if not exists
|
| 225 |
+
await db.execute("""
|
| 226 |
+
INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name)
|
| 227 |
+
VALUES (1, 6, 1, 30, 30, 'mlp')
|
| 228 |
+
""")
|
| 229 |
+
|
| 230 |
+
await db.commit()
|
| 231 |
+
|
| 232 |
+
# ================ PYDANTIC MODELS ================
|
| 233 |
+
|
| 234 |
+
class SessionCreate(BaseModel):
|
| 235 |
+
pass
|
| 236 |
+
|
| 237 |
+
class SessionEnd(BaseModel):
|
| 238 |
+
session_id: int
|
| 239 |
+
|
| 240 |
+
class SettingsUpdate(BaseModel):
|
| 241 |
+
sensitivity: Optional[int] = None
|
| 242 |
+
notification_enabled: Optional[bool] = None
|
| 243 |
+
notification_threshold: Optional[int] = None
|
| 244 |
+
frame_rate: Optional[int] = None
|
| 245 |
+
model_name: Optional[str] = None
|
| 246 |
+
|
| 247 |
+
class VideoTransformTrack(VideoStreamTrack):
|
| 248 |
+
def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.track = track
|
| 251 |
+
self.session_id = session_id
|
| 252 |
+
self.get_channel = get_channel
|
| 253 |
+
self.last_inference_time = 0
|
| 254 |
+
self.min_inference_interval = 1 / 60
|
| 255 |
+
self.last_frame = None
|
| 256 |
+
|
| 257 |
+
async def recv(self):
|
| 258 |
+
frame = await self.track.recv()
|
| 259 |
+
img = frame.to_ndarray(format="bgr24")
|
| 260 |
+
if img is None:
|
| 261 |
+
return frame
|
| 262 |
+
|
| 263 |
+
# Normalize size for inference/drawing
|
| 264 |
+
img = cv2.resize(img, (640, 480))
|
| 265 |
+
|
| 266 |
+
now = datetime.now().timestamp()
|
| 267 |
+
do_infer = (now - self.last_inference_time) >= self.min_inference_interval
|
| 268 |
+
|
| 269 |
+
if do_infer:
|
| 270 |
+
self.last_inference_time = now
|
| 271 |
+
|
| 272 |
+
model_name = _cached_model_name
|
| 273 |
+
if model_name not in pipelines or pipelines.get(model_name) is None:
|
| 274 |
+
model_name = 'mlp'
|
| 275 |
+
active_pipeline = pipelines.get(model_name)
|
| 276 |
+
|
| 277 |
+
if active_pipeline is not None:
|
| 278 |
+
loop = asyncio.get_event_loop()
|
| 279 |
+
out = await loop.run_in_executor(
|
| 280 |
+
_inference_executor,
|
| 281 |
+
_process_frame_safe,
|
| 282 |
+
active_pipeline,
|
| 283 |
+
img,
|
| 284 |
+
model_name,
|
| 285 |
+
)
|
| 286 |
+
is_focused = out["is_focused"]
|
| 287 |
+
confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
|
| 288 |
+
metadata = {"s_face": out.get("s_face", 0.0), "s_eye": out.get("s_eye", 0.0), "mar": out.get("mar", 0.0), "model": model_name}
|
| 289 |
+
|
| 290 |
+
# Draw face mesh + HUD on the video frame
|
| 291 |
+
h_f, w_f = img.shape[:2]
|
| 292 |
+
lm = out.get("landmarks")
|
| 293 |
+
if lm is not None:
|
| 294 |
+
_draw_face_mesh(img, lm, w_f, h_f)
|
| 295 |
+
_draw_hud(img, out, model_name)
|
| 296 |
+
else:
|
| 297 |
+
is_focused = False
|
| 298 |
+
confidence = 0.0
|
| 299 |
+
metadata = {"model": model_name}
|
| 300 |
+
cv2.rectangle(img, (0, 0), (img.shape[1], 55), (0, 0, 0), -1)
|
| 301 |
+
cv2.putText(img, "NO MODEL", (10, 28), _FONT, 0.8, _RED, 2, cv2.LINE_AA)
|
| 302 |
+
|
| 303 |
+
if self.session_id:
|
| 304 |
+
await store_focus_event(self.session_id, is_focused, confidence, metadata)
|
| 305 |
+
|
| 306 |
+
channel = self.get_channel()
|
| 307 |
+
if channel and channel.readyState == "open":
|
| 308 |
+
try:
|
| 309 |
+
channel.send(json.dumps({"type": "detection", "focused": is_focused, "confidence": round(confidence, 3), "detections": detections}))
|
| 310 |
+
except Exception:
|
| 311 |
+
pass
|
| 312 |
+
|
| 313 |
+
self.last_frame = img
|
| 314 |
+
elif self.last_frame is not None:
|
| 315 |
+
img = self.last_frame
|
| 316 |
+
|
| 317 |
+
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
|
| 318 |
+
new_frame.pts = frame.pts
|
| 319 |
+
new_frame.time_base = frame.time_base
|
| 320 |
+
return new_frame
|
| 321 |
+
|
| 322 |
+
# ================ DATABASE OPERATIONS ================
|
| 323 |
+
|
| 324 |
+
async def create_session():
|
| 325 |
+
async with aiosqlite.connect(db_path) as db:
|
| 326 |
+
cursor = await db.execute(
|
| 327 |
+
"INSERT INTO focus_sessions (start_time) VALUES (?)",
|
| 328 |
+
(datetime.now().isoformat(),)
|
| 329 |
+
)
|
| 330 |
+
await db.commit()
|
| 331 |
+
return cursor.lastrowid
|
| 332 |
+
|
| 333 |
+
async def end_session(session_id: int):
|
| 334 |
+
async with aiosqlite.connect(db_path) as db:
|
| 335 |
+
cursor = await db.execute(
|
| 336 |
+
"SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
|
| 337 |
+
(session_id,)
|
| 338 |
+
)
|
| 339 |
+
row = await cursor.fetchone()
|
| 340 |
+
|
| 341 |
+
if not row:
|
| 342 |
+
return None
|
| 343 |
+
|
| 344 |
+
start_time_str, total_frames, focused_frames = row
|
| 345 |
+
start_time = datetime.fromisoformat(start_time_str)
|
| 346 |
+
end_time = datetime.now()
|
| 347 |
+
duration = (end_time - start_time).total_seconds()
|
| 348 |
+
focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
|
| 349 |
+
|
| 350 |
+
await db.execute("""
|
| 351 |
+
UPDATE focus_sessions
|
| 352 |
+
SET end_time = ?, duration_seconds = ?, focus_score = ?
|
| 353 |
+
WHERE id = ?
|
| 354 |
+
""", (end_time.isoformat(), int(duration), focus_score, session_id))
|
| 355 |
+
|
| 356 |
+
await db.commit()
|
| 357 |
+
|
| 358 |
+
return {
|
| 359 |
+
'session_id': session_id,
|
| 360 |
+
'start_time': start_time_str,
|
| 361 |
+
'end_time': end_time.isoformat(),
|
| 362 |
+
'duration_seconds': int(duration),
|
| 363 |
+
'focus_score': round(focus_score, 3),
|
| 364 |
+
'total_frames': total_frames,
|
| 365 |
+
'focused_frames': focused_frames
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict):
|
| 369 |
+
async with aiosqlite.connect(db_path) as db:
|
| 370 |
+
await db.execute("""
|
| 371 |
+
INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
|
| 372 |
+
VALUES (?, ?, ?, ?, ?)
|
| 373 |
+
""", (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
|
| 374 |
+
|
| 375 |
+
await db.execute("""
|
| 376 |
+
UPDATE focus_sessions
|
| 377 |
+
SET total_frames = total_frames + 1,
|
| 378 |
+
focused_frames = focused_frames + ?
|
| 379 |
+
WHERE id = ?
|
| 380 |
+
""", (1 if is_focused else 0, session_id))
|
| 381 |
+
await db.commit()
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class _EventBuffer:
|
| 385 |
+
"""Buffer focus events in memory and flush to DB in batches to avoid per-frame DB writes."""
|
| 386 |
+
|
| 387 |
+
def __init__(self, flush_interval: float = 2.0):
|
| 388 |
+
self._buf: list = []
|
| 389 |
+
self._lock = asyncio.Lock()
|
| 390 |
+
self._flush_interval = flush_interval
|
| 391 |
+
self._task: asyncio.Task | None = None
|
| 392 |
+
self._total_frames = 0
|
| 393 |
+
self._focused_frames = 0
|
| 394 |
+
|
| 395 |
+
def start(self):
|
| 396 |
+
if self._task is None:
|
| 397 |
+
self._task = asyncio.create_task(self._flush_loop())
|
| 398 |
+
|
| 399 |
+
async def stop(self):
|
| 400 |
+
if self._task:
|
| 401 |
+
self._task.cancel()
|
| 402 |
+
try:
|
| 403 |
+
await self._task
|
| 404 |
+
except asyncio.CancelledError:
|
| 405 |
+
pass
|
| 406 |
+
self._task = None
|
| 407 |
+
await self._flush()
|
| 408 |
+
|
| 409 |
+
def add(self, session_id: int, is_focused: bool, confidence: float, metadata: dict):
|
| 410 |
+
self._buf.append((session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
|
| 411 |
+
self._total_frames += 1
|
| 412 |
+
if is_focused:
|
| 413 |
+
self._focused_frames += 1
|
| 414 |
+
|
| 415 |
+
async def _flush_loop(self):
|
| 416 |
+
while True:
|
| 417 |
+
await asyncio.sleep(self._flush_interval)
|
| 418 |
+
await self._flush()
|
| 419 |
+
|
| 420 |
+
async def _flush(self):
|
| 421 |
+
async with self._lock:
|
| 422 |
+
if not self._buf:
|
| 423 |
+
return
|
| 424 |
+
batch = self._buf[:]
|
| 425 |
+
total = self._total_frames
|
| 426 |
+
focused = self._focused_frames
|
| 427 |
+
self._buf.clear()
|
| 428 |
+
self._total_frames = 0
|
| 429 |
+
self._focused_frames = 0
|
| 430 |
+
|
| 431 |
+
if not batch:
|
| 432 |
+
return
|
| 433 |
+
|
| 434 |
+
session_id = batch[0][0]
|
| 435 |
+
try:
|
| 436 |
+
async with aiosqlite.connect(db_path) as db:
|
| 437 |
+
await db.executemany("""
|
| 438 |
+
INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
|
| 439 |
+
VALUES (?, ?, ?, ?, ?)
|
| 440 |
+
""", batch)
|
| 441 |
+
await db.execute("""
|
| 442 |
+
UPDATE focus_sessions
|
| 443 |
+
SET total_frames = total_frames + ?,
|
| 444 |
+
focused_frames = focused_frames + ?
|
| 445 |
+
WHERE id = ?
|
| 446 |
+
""", (total, focused, session_id))
|
| 447 |
+
await db.commit()
|
| 448 |
+
except Exception as e:
|
| 449 |
+
print(f"[DB] Flush error: {e}")
|
| 450 |
+
|
| 451 |
+
# ================ STARTUP/SHUTDOWN ================
|
| 452 |
+
|
| 453 |
+
pipelines = {
|
| 454 |
+
"geometric": None,
|
| 455 |
+
"mlp": None,
|
| 456 |
+
"hybrid": None,
|
| 457 |
+
"xgboost": None,
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
# Thread pool for CPU-bound inference so the event loop stays responsive.
|
| 461 |
+
_inference_executor = concurrent.futures.ThreadPoolExecutor(
|
| 462 |
+
max_workers=4,
|
| 463 |
+
thread_name_prefix="inference",
|
| 464 |
+
)
|
| 465 |
+
# One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
|
| 466 |
+
# multiple frames are processed in parallel by the thread pool.
|
| 467 |
+
_pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost")}
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def _process_frame_safe(pipeline, frame, model_name: str):
|
| 471 |
+
"""Run process_frame in executor with per-pipeline lock."""
|
| 472 |
+
with _pipeline_locks[model_name]:
|
| 473 |
+
return pipeline.process_frame(frame)
|
| 474 |
+
|
| 475 |
+
@app.on_event("startup")
|
| 476 |
+
async def startup_event():
|
| 477 |
+
global pipelines, _cached_model_name
|
| 478 |
+
print(" Starting Focus Guard API...")
|
| 479 |
+
await init_database()
|
| 480 |
+
# Load cached model name from DB
|
| 481 |
+
async with aiosqlite.connect(db_path) as db:
|
| 482 |
+
cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
|
| 483 |
+
row = await cursor.fetchone()
|
| 484 |
+
if row:
|
| 485 |
+
_cached_model_name = row[0]
|
| 486 |
+
print("[OK] Database initialized")
|
| 487 |
+
|
| 488 |
+
try:
|
| 489 |
+
pipelines["geometric"] = FaceMeshPipeline()
|
| 490 |
+
print("[OK] FaceMeshPipeline (geometric) loaded")
|
| 491 |
+
except Exception as e:
|
| 492 |
+
print(f"[WARN] FaceMeshPipeline unavailable: {e}")
|
| 493 |
+
|
| 494 |
+
try:
|
| 495 |
+
pipelines["mlp"] = MLPPipeline()
|
| 496 |
+
print("[OK] MLPPipeline loaded")
|
| 497 |
+
except Exception as e:
|
| 498 |
+
print(f"[ERR] Failed to load MLPPipeline: {e}")
|
| 499 |
+
|
| 500 |
+
try:
|
| 501 |
+
pipelines["hybrid"] = HybridFocusPipeline()
|
| 502 |
+
print("[OK] HybridFocusPipeline loaded")
|
| 503 |
+
except Exception as e:
|
| 504 |
+
print(f"[WARN] HybridFocusPipeline unavailable: {e}")
|
| 505 |
+
|
| 506 |
+
try:
|
| 507 |
+
pipelines["xgboost"] = XGBoostPipeline()
|
| 508 |
+
print("[OK] XGBoostPipeline loaded")
|
| 509 |
+
except Exception as e:
|
| 510 |
+
print(f"[ERR] Failed to load XGBoostPipeline: {e}")
|
| 511 |
+
|
| 512 |
+
@app.on_event("shutdown")
|
| 513 |
+
async def shutdown_event():
|
| 514 |
+
_inference_executor.shutdown(wait=False)
|
| 515 |
+
print(" Shutting down Focus Guard API...")
|
| 516 |
+
|
| 517 |
+
# ================ WEBRTC SIGNALING ================
|
| 518 |
+
|
| 519 |
+
@app.post("/api/webrtc/offer")
|
| 520 |
+
async def webrtc_offer(offer: dict):
|
| 521 |
+
try:
|
| 522 |
+
print(f"Received WebRTC offer")
|
| 523 |
+
|
| 524 |
+
pc = RTCPeerConnection()
|
| 525 |
+
pcs.add(pc)
|
| 526 |
+
|
| 527 |
+
session_id = await create_session()
|
| 528 |
+
print(f"Created session: {session_id}")
|
| 529 |
+
|
| 530 |
+
channel_ref = {"channel": None}
|
| 531 |
+
|
| 532 |
+
@pc.on("datachannel")
|
| 533 |
+
def on_datachannel(channel):
|
| 534 |
+
print(f"Data channel opened")
|
| 535 |
+
channel_ref["channel"] = channel
|
| 536 |
+
|
| 537 |
+
@pc.on("track")
|
| 538 |
+
def on_track(track):
|
| 539 |
+
print(f"Received track: {track.kind}")
|
| 540 |
+
if track.kind == "video":
|
| 541 |
+
local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"])
|
| 542 |
+
pc.addTrack(local_track)
|
| 543 |
+
print(f"Video track added")
|
| 544 |
+
|
| 545 |
+
@track.on("ended")
|
| 546 |
+
async def on_ended():
|
| 547 |
+
print(f"Track ended")
|
| 548 |
+
|
| 549 |
+
@pc.on("connectionstatechange")
|
| 550 |
+
async def on_connectionstatechange():
|
| 551 |
+
print(f"Connection state changed: {pc.connectionState}")
|
| 552 |
+
if pc.connectionState in ("failed", "closed", "disconnected"):
|
| 553 |
+
try:
|
| 554 |
+
await end_session(session_id)
|
| 555 |
+
except Exception as e:
|
| 556 |
+
print(f"β Error ending session: {e}")
|
| 557 |
+
pcs.discard(pc)
|
| 558 |
+
await pc.close()
|
| 559 |
+
|
| 560 |
+
await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]))
|
| 561 |
+
print(f"Remote description set")
|
| 562 |
+
|
| 563 |
+
answer = await pc.createAnswer()
|
| 564 |
+
await pc.setLocalDescription(answer)
|
| 565 |
+
print(f"Answer created")
|
| 566 |
+
|
| 567 |
+
await _wait_for_ice_gathering(pc)
|
| 568 |
+
print(f"ICE gathering complete")
|
| 569 |
+
|
| 570 |
+
return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id}
|
| 571 |
+
|
| 572 |
+
except Exception as e:
|
| 573 |
+
print(f"WebRTC offer error: {e}")
|
| 574 |
+
import traceback
|
| 575 |
+
traceback.print_exc()
|
| 576 |
+
raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}")
|
| 577 |
+
|
| 578 |
+
# ================ WEBSOCKET ================
|
| 579 |
+
|
| 580 |
+
@app.websocket("/ws/video")
|
| 581 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 582 |
+
await websocket.accept()
|
| 583 |
+
session_id = None
|
| 584 |
+
frame_count = 0
|
| 585 |
+
running = True
|
| 586 |
+
event_buffer = _EventBuffer(flush_interval=2.0)
|
| 587 |
+
|
| 588 |
+
# Latest frame slot β only the most recent frame is kept, older ones are dropped.
|
| 589 |
+
# Using a dict so nested functions can mutate without nonlocal issues.
|
| 590 |
+
_slot = {"frame": None}
|
| 591 |
+
_frame_ready = asyncio.Event()
|
| 592 |
+
|
| 593 |
+
async def _receive_loop():
|
| 594 |
+
"""Receive messages as fast as possible. Binary = frame, text = control."""
|
| 595 |
+
nonlocal session_id, running
|
| 596 |
+
try:
|
| 597 |
+
while running:
|
| 598 |
+
msg = await websocket.receive()
|
| 599 |
+
msg_type = msg.get("type", "")
|
| 600 |
+
|
| 601 |
+
if msg_type == "websocket.disconnect":
|
| 602 |
+
running = False
|
| 603 |
+
_frame_ready.set()
|
| 604 |
+
return
|
| 605 |
+
|
| 606 |
+
# Binary message β JPEG frame (fast path, no base64)
|
| 607 |
+
raw_bytes = msg.get("bytes")
|
| 608 |
+
if raw_bytes is not None and len(raw_bytes) > 0:
|
| 609 |
+
_slot["frame"] = raw_bytes
|
| 610 |
+
_frame_ready.set()
|
| 611 |
+
continue
|
| 612 |
+
|
| 613 |
+
# Text message β JSON control command (or legacy base64 frame)
|
| 614 |
+
text = msg.get("text")
|
| 615 |
+
if not text:
|
| 616 |
+
continue
|
| 617 |
+
data = json.loads(text)
|
| 618 |
+
|
| 619 |
+
if data["type"] == "frame":
|
| 620 |
+
# Legacy base64 path (fallback)
|
| 621 |
+
_slot["frame"] = base64.b64decode(data["image"])
|
| 622 |
+
_frame_ready.set()
|
| 623 |
+
|
| 624 |
+
elif data["type"] == "start_session":
|
| 625 |
+
session_id = await create_session()
|
| 626 |
+
event_buffer.start()
|
| 627 |
+
for p in pipelines.values():
|
| 628 |
+
if p is not None and hasattr(p, "reset_session"):
|
| 629 |
+
p.reset_session()
|
| 630 |
+
await websocket.send_json({"type": "session_started", "session_id": session_id})
|
| 631 |
+
|
| 632 |
+
elif data["type"] == "end_session":
|
| 633 |
+
if session_id:
|
| 634 |
+
await event_buffer.stop()
|
| 635 |
+
summary = await end_session(session_id)
|
| 636 |
+
if summary:
|
| 637 |
+
await websocket.send_json({"type": "session_ended", "summary": summary})
|
| 638 |
+
session_id = None
|
| 639 |
+
except WebSocketDisconnect:
|
| 640 |
+
running = False
|
| 641 |
+
_frame_ready.set()
|
| 642 |
+
except Exception as e:
|
| 643 |
+
print(f"[WS] receive error: {e}")
|
| 644 |
+
running = False
|
| 645 |
+
_frame_ready.set()
|
| 646 |
+
|
| 647 |
+
async def _process_loop():
|
| 648 |
+
"""Process only the latest frame, dropping stale ones."""
|
| 649 |
+
nonlocal frame_count, running
|
| 650 |
+
loop = asyncio.get_event_loop()
|
| 651 |
+
while running:
|
| 652 |
+
await _frame_ready.wait()
|
| 653 |
+
_frame_ready.clear()
|
| 654 |
+
if not running:
|
| 655 |
+
return
|
| 656 |
+
|
| 657 |
+
# Grab latest frame and clear slot
|
| 658 |
+
raw = _slot["frame"]
|
| 659 |
+
_slot["frame"] = None
|
| 660 |
+
if raw is None:
|
| 661 |
+
continue
|
| 662 |
+
|
| 663 |
+
try:
|
| 664 |
+
nparr = np.frombuffer(raw, np.uint8)
|
| 665 |
+
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 666 |
+
if frame is None:
|
| 667 |
+
continue
|
| 668 |
+
frame = cv2.resize(frame, (640, 480))
|
| 669 |
+
|
| 670 |
+
model_name = _cached_model_name
|
| 671 |
+
if model_name not in pipelines or pipelines.get(model_name) is None:
|
| 672 |
+
model_name = "mlp"
|
| 673 |
+
active_pipeline = pipelines.get(model_name)
|
| 674 |
+
|
| 675 |
+
landmarks_list = None
|
| 676 |
+
if active_pipeline is not None:
|
| 677 |
+
out = await loop.run_in_executor(
|
| 678 |
+
_inference_executor,
|
| 679 |
+
_process_frame_safe,
|
| 680 |
+
active_pipeline,
|
| 681 |
+
frame,
|
| 682 |
+
model_name,
|
| 683 |
+
)
|
| 684 |
+
is_focused = out["is_focused"]
|
| 685 |
+
confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
|
| 686 |
+
|
| 687 |
+
lm = out.get("landmarks")
|
| 688 |
+
if lm is not None:
|
| 689 |
+
# Send all 478 landmarks as flat array for tessellation drawing
|
| 690 |
+
landmarks_list = [
|
| 691 |
+
[round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
|
| 692 |
+
for i in range(lm.shape[0])
|
| 693 |
+
]
|
| 694 |
+
|
| 695 |
+
if session_id:
|
| 696 |
+
event_buffer.add(session_id, is_focused, confidence, {
|
| 697 |
+
"s_face": out.get("s_face", 0.0),
|
| 698 |
+
"s_eye": out.get("s_eye", 0.0),
|
| 699 |
+
"mar": out.get("mar", 0.0),
|
| 700 |
+
"model": model_name,
|
| 701 |
+
})
|
| 702 |
+
else:
|
| 703 |
+
is_focused = False
|
| 704 |
+
confidence = 0.0
|
| 705 |
+
|
| 706 |
+
resp = {
|
| 707 |
+
"type": "detection",
|
| 708 |
+
"focused": is_focused,
|
| 709 |
+
"confidence": round(confidence, 3),
|
| 710 |
+
"model": model_name,
|
| 711 |
+
"fc": frame_count,
|
| 712 |
+
}
|
| 713 |
+
if active_pipeline is not None:
|
| 714 |
+
# Send detailed metrics for HUD
|
| 715 |
+
if out.get("yaw") is not None:
|
| 716 |
+
resp["yaw"] = round(out["yaw"], 1)
|
| 717 |
+
resp["pitch"] = round(out["pitch"], 1)
|
| 718 |
+
resp["roll"] = round(out["roll"], 1)
|
| 719 |
+
if out.get("mar") is not None:
|
| 720 |
+
resp["mar"] = round(out["mar"], 3)
|
| 721 |
+
resp["sf"] = round(out.get("s_face", 0), 3)
|
| 722 |
+
resp["se"] = round(out.get("s_eye", 0), 3)
|
| 723 |
+
if landmarks_list is not None:
|
| 724 |
+
resp["lm"] = landmarks_list
|
| 725 |
+
await websocket.send_json(resp)
|
| 726 |
+
frame_count += 1
|
| 727 |
+
except Exception as e:
|
| 728 |
+
print(f"[WS] process error: {e}")
|
| 729 |
+
|
| 730 |
+
try:
|
| 731 |
+
await asyncio.gather(_receive_loop(), _process_loop())
|
| 732 |
+
except Exception:
|
| 733 |
+
pass
|
| 734 |
+
finally:
|
| 735 |
+
running = False
|
| 736 |
+
if session_id:
|
| 737 |
+
await event_buffer.stop()
|
| 738 |
+
await end_session(session_id)
|
| 739 |
+
|
| 740 |
+
# ================ API ENDPOINTS ================
|
| 741 |
+
|
| 742 |
+
@app.post("/api/sessions/start")
|
| 743 |
+
async def api_start_session():
|
| 744 |
+
session_id = await create_session()
|
| 745 |
+
return {"session_id": session_id}
|
| 746 |
+
|
| 747 |
+
@app.post("/api/sessions/end")
|
| 748 |
+
async def api_end_session(data: SessionEnd):
|
| 749 |
+
summary = await end_session(data.session_id)
|
| 750 |
+
if not summary: raise HTTPException(status_code=404, detail="Session not found")
|
| 751 |
+
return summary
|
| 752 |
+
|
| 753 |
+
@app.get("/api/sessions")
|
| 754 |
+
async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0):
|
| 755 |
+
async with aiosqlite.connect(db_path) as db:
|
| 756 |
+
db.row_factory = aiosqlite.Row
|
| 757 |
+
|
| 758 |
+
# NEW: If importing/exporting all, remove limit if special flag or high limit
|
| 759 |
+
# For simplicity: if limit is -1, return all
|
| 760 |
+
limit_clause = "LIMIT ? OFFSET ?"
|
| 761 |
+
params = []
|
| 762 |
+
|
| 763 |
+
base_query = "SELECT * FROM focus_sessions"
|
| 764 |
+
where_clause = ""
|
| 765 |
+
|
| 766 |
+
if filter == "today":
|
| 767 |
+
date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
| 768 |
+
where_clause = " WHERE start_time >= ?"
|
| 769 |
+
params.append(date_filter.isoformat())
|
| 770 |
+
elif filter == "week":
|
| 771 |
+
date_filter = datetime.now() - timedelta(days=7)
|
| 772 |
+
where_clause = " WHERE start_time >= ?"
|
| 773 |
+
params.append(date_filter.isoformat())
|
| 774 |
+
elif filter == "month":
|
| 775 |
+
date_filter = datetime.now() - timedelta(days=30)
|
| 776 |
+
where_clause = " WHERE start_time >= ?"
|
| 777 |
+
params.append(date_filter.isoformat())
|
| 778 |
+
elif filter == "all":
|
| 779 |
+
# Just ensure we only get completed sessions or all sessions
|
| 780 |
+
where_clause = " WHERE end_time IS NOT NULL"
|
| 781 |
+
|
| 782 |
+
query = f"{base_query}{where_clause} ORDER BY start_time DESC"
|
| 783 |
+
|
| 784 |
+
# Handle Limit for Exports
|
| 785 |
+
if limit == -1:
|
| 786 |
+
# No limit clause for export
|
| 787 |
+
pass
|
| 788 |
+
else:
|
| 789 |
+
query += f" {limit_clause}"
|
| 790 |
+
params.extend([limit, offset])
|
| 791 |
+
|
| 792 |
+
cursor = await db.execute(query, tuple(params))
|
| 793 |
+
rows = await cursor.fetchall()
|
| 794 |
+
return [dict(row) for row in rows]
|
| 795 |
+
|
| 796 |
+
# --- NEW: Import Endpoint ---
|
| 797 |
+
@app.post("/api/import")
|
| 798 |
+
async def import_sessions(sessions: List[dict]):
|
| 799 |
+
count = 0
|
| 800 |
+
try:
|
| 801 |
+
async with aiosqlite.connect(db_path) as db:
|
| 802 |
+
for session in sessions:
|
| 803 |
+
# Use .get() to handle potential missing fields from older versions or edits
|
| 804 |
+
await db.execute("""
|
| 805 |
+
INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at)
|
| 806 |
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
| 807 |
+
""", (
|
| 808 |
+
session.get('start_time'),
|
| 809 |
+
session.get('end_time'),
|
| 810 |
+
session.get('duration_seconds', 0),
|
| 811 |
+
session.get('focus_score', 0.0),
|
| 812 |
+
session.get('total_frames', 0),
|
| 813 |
+
session.get('focused_frames', 0),
|
| 814 |
+
session.get('created_at', session.get('start_time'))
|
| 815 |
+
))
|
| 816 |
+
count += 1
|
| 817 |
+
await db.commit()
|
| 818 |
+
return {"status": "success", "count": count}
|
| 819 |
+
except Exception as e:
|
| 820 |
+
print(f"Import Error: {e}")
|
| 821 |
+
return {"status": "error", "message": str(e)}
|
| 822 |
+
|
| 823 |
+
# --- NEW: Clear History Endpoint ---
|
| 824 |
+
@app.delete("/api/history")
|
| 825 |
+
async def clear_history():
|
| 826 |
+
try:
|
| 827 |
+
async with aiosqlite.connect(db_path) as db:
|
| 828 |
+
# Delete events first (foreign key good practice)
|
| 829 |
+
await db.execute("DELETE FROM focus_events")
|
| 830 |
+
await db.execute("DELETE FROM focus_sessions")
|
| 831 |
+
await db.commit()
|
| 832 |
+
return {"status": "success", "message": "History cleared"}
|
| 833 |
+
except Exception as e:
|
| 834 |
+
return {"status": "error", "message": str(e)}
|
| 835 |
+
|
| 836 |
+
@app.get("/api/sessions/{session_id}")
|
| 837 |
+
async def get_session(session_id: int):
|
| 838 |
+
async with aiosqlite.connect(db_path) as db:
|
| 839 |
+
db.row_factory = aiosqlite.Row
|
| 840 |
+
cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,))
|
| 841 |
+
row = await cursor.fetchone()
|
| 842 |
+
if not row: raise HTTPException(status_code=404, detail="Session not found")
|
| 843 |
+
session = dict(row)
|
| 844 |
+
cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,))
|
| 845 |
+
events = [dict(r) for r in await cursor.fetchall()]
|
| 846 |
+
session['events'] = events
|
| 847 |
+
return session
|
| 848 |
+
|
| 849 |
+
@app.get("/api/settings")
|
| 850 |
+
async def get_settings():
|
| 851 |
+
async with aiosqlite.connect(db_path) as db:
|
| 852 |
+
db.row_factory = aiosqlite.Row
|
| 853 |
+
cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
|
| 854 |
+
row = await cursor.fetchone()
|
| 855 |
+
if row: return dict(row)
|
| 856 |
+
else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
|
| 857 |
+
|
| 858 |
+
@app.put("/api/settings")
|
| 859 |
+
async def update_settings(settings: SettingsUpdate):
|
| 860 |
+
async with aiosqlite.connect(db_path) as db:
|
| 861 |
+
cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
|
| 862 |
+
exists = await cursor.fetchone()
|
| 863 |
+
if not exists:
|
| 864 |
+
await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)")
|
| 865 |
+
await db.commit()
|
| 866 |
+
|
| 867 |
+
updates = []
|
| 868 |
+
params = []
|
| 869 |
+
if settings.sensitivity is not None:
|
| 870 |
+
updates.append("sensitivity = ?")
|
| 871 |
+
params.append(max(1, min(10, settings.sensitivity)))
|
| 872 |
+
if settings.notification_enabled is not None:
|
| 873 |
+
updates.append("notification_enabled = ?")
|
| 874 |
+
params.append(settings.notification_enabled)
|
| 875 |
+
if settings.notification_threshold is not None:
|
| 876 |
+
updates.append("notification_threshold = ?")
|
| 877 |
+
params.append(max(5, min(300, settings.notification_threshold)))
|
| 878 |
+
if settings.frame_rate is not None:
|
| 879 |
+
updates.append("frame_rate = ?")
|
| 880 |
+
params.append(max(5, min(60, settings.frame_rate)))
|
| 881 |
+
if settings.model_name is not None and settings.model_name in pipelines and pipelines[settings.model_name] is not None:
|
| 882 |
+
updates.append("model_name = ?")
|
| 883 |
+
params.append(settings.model_name)
|
| 884 |
+
global _cached_model_name
|
| 885 |
+
_cached_model_name = settings.model_name
|
| 886 |
+
|
| 887 |
+
if updates:
|
| 888 |
+
query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
|
| 889 |
+
await db.execute(query, params)
|
| 890 |
+
await db.commit()
|
| 891 |
+
return {"status": "success", "updated": len(updates) > 0}
|
| 892 |
+
|
| 893 |
+
@app.get("/api/stats/summary")
|
| 894 |
+
async def get_stats_summary():
|
| 895 |
+
async with aiosqlite.connect(db_path) as db:
|
| 896 |
+
cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL")
|
| 897 |
+
total_sessions = (await cursor.fetchone())[0]
|
| 898 |
+
cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL")
|
| 899 |
+
total_focus_time = (await cursor.fetchone())[0] or 0
|
| 900 |
+
cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL")
|
| 901 |
+
avg_focus_score = (await cursor.fetchone())[0] or 0.0
|
| 902 |
+
cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC")
|
| 903 |
+
dates = [row[0] for row in await cursor.fetchall()]
|
| 904 |
+
|
| 905 |
+
streak_days = 0
|
| 906 |
+
if dates:
|
| 907 |
+
current_date = datetime.now().date()
|
| 908 |
+
for i, date_str in enumerate(dates):
|
| 909 |
+
session_date = datetime.fromisoformat(date_str).date()
|
| 910 |
+
expected_date = current_date - timedelta(days=i)
|
| 911 |
+
if session_date == expected_date: streak_days += 1
|
| 912 |
+
else: break
|
| 913 |
+
return {
|
| 914 |
+
'total_sessions': total_sessions,
|
| 915 |
+
'total_focus_time': int(total_focus_time),
|
| 916 |
+
'avg_focus_score': round(avg_focus_score, 3),
|
| 917 |
+
'streak_days': streak_days
|
| 918 |
+
}
|
| 919 |
+
|
| 920 |
+
@app.get("/api/models")
|
| 921 |
+
async def get_available_models():
|
| 922 |
+
"""Return list of loaded model names and which is currently active."""
|
| 923 |
+
available = [name for name, p in pipelines.items() if p is not None]
|
| 924 |
+
async with aiosqlite.connect(db_path) as db:
|
| 925 |
+
cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
|
| 926 |
+
row = await cursor.fetchone()
|
| 927 |
+
current = row[0] if row else "mlp"
|
| 928 |
+
if current not in available and available:
|
| 929 |
+
current = available[0]
|
| 930 |
+
return {"available": available, "current": current}
|
| 931 |
+
|
| 932 |
+
@app.get("/api/mesh-topology")
|
| 933 |
+
async def get_mesh_topology():
|
| 934 |
+
"""Return tessellation edge pairs for client-side face mesh drawing (cached by client)."""
|
| 935 |
+
return {"tessellation": _TESSELATION_CONNS}
|
| 936 |
+
|
| 937 |
+
@app.get("/health")
|
| 938 |
+
async def health_check():
|
| 939 |
+
available = [name for name, p in pipelines.items() if p is not None]
|
| 940 |
+
return {"status": "healthy", "models_loaded": available, "database": os.path.exists(db_path)}
|
| 941 |
+
|
| 942 |
+
# ================ STATIC FILES (SPA SUPPORT) ================
|
| 943 |
+
|
| 944 |
+
# Resolve static dir from this file so it works regardless of cwd
|
| 945 |
+
_STATIC_DIR = Path(__file__).resolve().parent / "static"
|
| 946 |
+
_ASSETS_DIR = _STATIC_DIR / "assets"
|
| 947 |
+
|
| 948 |
+
# 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all
|
| 949 |
+
if _ASSETS_DIR.is_dir():
|
| 950 |
+
app.mount("/assets", StaticFiles(directory=str(_ASSETS_DIR)), name="assets")
|
| 951 |
+
|
| 952 |
+
# 2. Catch-all for SPA: serve index.html for app routes, never for /assets (would break JS MIME type)
|
| 953 |
+
@app.get("/{full_path:path}")
|
| 954 |
+
async def serve_react_app(full_path: str, request: Request):
|
| 955 |
+
if full_path.startswith("api") or full_path.startswith("ws"):
|
| 956 |
+
raise HTTPException(status_code=404, detail="Not Found")
|
| 957 |
+
# Don't serve HTML for asset paths; let them 404 so we don't break module script loading
|
| 958 |
+
if full_path.startswith("assets") or full_path.startswith("assets/"):
|
| 959 |
+
raise HTTPException(status_code=404, detail="Not Found")
|
| 960 |
+
|
| 961 |
+
index_path = _STATIC_DIR / "index.html"
|
| 962 |
+
if index_path.is_file():
|
| 963 |
+
return FileResponse(str(index_path))
|
| 964 |
+
return {"message": "React app not found. Please run 'npm run build' and copy dist to static."}
|
models/README.md
CHANGED
|
@@ -1,10 +1,53 @@
|
|
| 1 |
-
# models
|
| 2 |
|
| 3 |
-
|
| 4 |
-
- **mlp/** β PyTorch MLP on feature vectors (face_orientation / eye_behaviour); checkpoints under `mlp/face_orientation_model/`, `mlp/eye_behaviour_model/`
|
| 5 |
-
- **geometric/face_orientation/** β head pose (solvePnP). **geometric/eye_behaviour/** β EAR, gaze, MAR
|
| 6 |
-
- **pretrained/face_mesh/** β MediaPipe face landmarks (no training)
|
| 7 |
-
- **attention/** β webcam feature collection (17-d), stubs for train/classifier/fusion
|
| 8 |
-
- **prepare_dataset.py** β loads from `data_preparation/processed/` or synthetic; used by `mlp/train.py`
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# models/
|
| 2 |
|
| 3 |
+
Feature extraction modules and model training scripts.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
## 1. Feature Extraction
|
| 6 |
+
|
| 7 |
+
Root-level modules form the real-time inference pipeline:
|
| 8 |
+
|
| 9 |
+
| Module | Input | Output |
|
| 10 |
+
|--------|-------|--------|
|
| 11 |
+
| `face_mesh.py` | BGR frame | 478 MediaPipe landmarks |
|
| 12 |
+
| `head_pose.py` | Landmarks, frame size | yaw, pitch, roll, face/eye score, gaze offset, head deviation |
|
| 13 |
+
| `eye_scorer.py` | Landmarks | EAR (left/right/avg), gaze ratio (h/v), MAR |
|
| 14 |
+
| `eye_crop.py` | Landmarks, frame | Cropped eye region images |
|
| 15 |
+
| `eye_classifier.py` | Eye crops or landmarks | Eye open/closed prediction (geometric fallback) |
|
| 16 |
+
| `collect_features.py` | BGR frame | 17-d feature vector + temporal features (PERCLOS, blink rate, etc.) |
|
| 17 |
+
|
| 18 |
+
## 2. Training Scripts
|
| 19 |
+
|
| 20 |
+
| Folder | Model | Command |
|
| 21 |
+
|--------|-------|---------|
|
| 22 |
+
| `mlp/` | PyTorch MLP (64β32, 2-class) | `python -m models.mlp.train` |
|
| 23 |
+
| `xgboost/` | XGBoost (600 trees, depth 8) | `python -m models.xgboost.train` |
|
| 24 |
+
|
| 25 |
+
### mlp/
|
| 26 |
+
|
| 27 |
+
- `train.py` β training loop with early stopping, ClearML opt-in
|
| 28 |
+
- `sweep.py` β hyperparameter search (Optuna: lr, batch_size)
|
| 29 |
+
- `eval_accuracy.py` β load checkpoint and print test metrics
|
| 30 |
+
- Saves to **`checkpoints/mlp_best.pt`**
|
| 31 |
+
|
| 32 |
+
### xgboost/
|
| 33 |
+
|
| 34 |
+
- `train.py` β training with eval-set logging
|
| 35 |
+
- `sweep.py` / `sweep_local.py` β hyperparameter search (Optuna + ClearML)
|
| 36 |
+
- `eval_accuracy.py` β load checkpoint and print test metrics
|
| 37 |
+
- Saves to **`checkpoints/xgboost_face_orientation_best.json`**
|
| 38 |
+
|
| 39 |
+
## 3. Data Loading
|
| 40 |
+
|
| 41 |
+
All training scripts import from `data_preparation.prepare_dataset`:
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
from data_preparation.prepare_dataset import get_numpy_splits # XGBoost
|
| 45 |
+
from data_preparation.prepare_dataset import get_dataloaders # MLP (PyTorch)
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## 4. Results
|
| 49 |
+
|
| 50 |
+
| Model | Test Accuracy | F1 | ROC-AUC |
|
| 51 |
+
|-------|--------------|-----|---------|
|
| 52 |
+
| XGBoost | 95.87% | 0.959 | 0.991 |
|
| 53 |
+
| MLP | 92.92% | 0.929 | 0.971 |
|
models/{attention/__init__.py β __init__.py}
RENAMED
|
File without changes
|
models/attention/classifier.py
DELETED
|
File without changes
|
models/attention/fusion.py
DELETED
|
File without changes
|
models/attention/train.py
DELETED
|
File without changes
|
models/cnn/notebooks/EyeCNN.ipynb
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"language_info": {
|
| 14 |
+
"name": "python"
|
| 15 |
+
},
|
| 16 |
+
"accelerator": "GPU"
|
| 17 |
+
},
|
| 18 |
+
"cells": [
|
| 19 |
+
{
|
| 20 |
+
"cell_type": "code",
|
| 21 |
+
"source": [
|
| 22 |
+
"import os\n",
|
| 23 |
+
"import torch\n",
|
| 24 |
+
"import torch.nn as nn\n",
|
| 25 |
+
"import torch.optim as optim\n",
|
| 26 |
+
"from torch.utils.data import DataLoader\n",
|
| 27 |
+
"from torchvision import datasets, transforms\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"from google.colab import drive\n",
|
| 30 |
+
"drive.mount('/content/drive')\n",
|
| 31 |
+
"!cp -r /content/drive/MyDrive/Dataset_clean /content/\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"#Verify structure\n",
|
| 34 |
+
"for split in ['train', 'val', 'test']:\n",
|
| 35 |
+
" path = f'/content/Dataset_clean/{split}'\n",
|
| 36 |
+
" classes = os.listdir(path)\n",
|
| 37 |
+
" total = sum(len(os.listdir(os.path.join(path, c))) for c in classes)\n",
|
| 38 |
+
" print(f'{split}: {total} images | classes: {classes}')"
|
| 39 |
+
],
|
| 40 |
+
"metadata": {
|
| 41 |
+
"colab": {
|
| 42 |
+
"base_uri": "https://localhost:8080/"
|
| 43 |
+
},
|
| 44 |
+
"id": "sE1F3em-V5go",
|
| 45 |
+
"outputId": "2c73a9a6-a198-468c-a2cc-253b2de7cc3f"
|
| 46 |
+
},
|
| 47 |
+
"execution_count": null,
|
| 48 |
+
"outputs": [
|
| 49 |
+
{
|
| 50 |
+
"output_type": "stream",
|
| 51 |
+
"name": "stdout",
|
| 52 |
+
"text": [
|
| 53 |
+
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
|
| 54 |
+
]
|
| 55 |
+
}
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": null,
|
| 61 |
+
"metadata": {
|
| 62 |
+
"id": "nG2bh66rQ56G"
|
| 63 |
+
},
|
| 64 |
+
"outputs": [],
|
| 65 |
+
"source": [
|
| 66 |
+
"class EyeCNN(nn.Module):\n",
|
| 67 |
+
" def __init__(self, num_classes=2):\n",
|
| 68 |
+
" super(EyeCNN, self).__init__()\n",
|
| 69 |
+
" self.conv_layers = nn.Sequential(\n",
|
| 70 |
+
" nn.Conv2d(3, 32, 3, 1, 1),\n",
|
| 71 |
+
" nn.BatchNorm2d(32),\n",
|
| 72 |
+
" nn.ReLU(),\n",
|
| 73 |
+
" nn.MaxPool2d(2, 2),\n",
|
| 74 |
+
"\n",
|
| 75 |
+
" nn.Conv2d(32, 64, 3, 1, 1),\n",
|
| 76 |
+
" nn.BatchNorm2d(64),\n",
|
| 77 |
+
" nn.ReLU(),\n",
|
| 78 |
+
" nn.MaxPool2d(2, 2),\n",
|
| 79 |
+
"\n",
|
| 80 |
+
" nn.Conv2d(64, 128, 3, 1, 1),\n",
|
| 81 |
+
" nn.BatchNorm2d(128),\n",
|
| 82 |
+
" nn.ReLU(),\n",
|
| 83 |
+
" nn.MaxPool2d(2, 2),\n",
|
| 84 |
+
"\n",
|
| 85 |
+
" nn.Conv2d(128, 256, 3, 1, 1),\n",
|
| 86 |
+
" nn.BatchNorm2d(256),\n",
|
| 87 |
+
" nn.ReLU(),\n",
|
| 88 |
+
" nn.MaxPool2d(2, 2)\n",
|
| 89 |
+
" )\n",
|
| 90 |
+
"\n",
|
| 91 |
+
" self.fc_layers = nn.Sequential(\n",
|
| 92 |
+
" nn.AdaptiveAvgPool2d((1, 1)),\n",
|
| 93 |
+
" nn.Flatten(),\n",
|
| 94 |
+
" nn.Linear(256, 512),\n",
|
| 95 |
+
" nn.ReLU(),\n",
|
| 96 |
+
" nn.Dropout(0.35),\n",
|
| 97 |
+
" nn.Linear(512, num_classes)\n",
|
| 98 |
+
" )\n",
|
| 99 |
+
"\n",
|
| 100 |
+
" def forward(self, x):\n",
|
| 101 |
+
" x = self.conv_layers(x)\n",
|
| 102 |
+
" x = self.fc_layers(x)\n",
|
| 103 |
+
" return x"
|
| 104 |
+
]
|
| 105 |
+
}
|
| 106 |
+
]
|
| 107 |
+
}
|
models/cnn/notebooks/EyeCNN_Train_Evaluate_new.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/cnn/notebooks/EyeCNN_Training_Evaluate.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/cnn/notebooks/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# GAP Large Project
|
models/{attention/collect_features.py β collect_features.py}
RENAMED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# Usage: python -m models.attention.collect_features [--name alice] [--duration 600]
|
| 2 |
|
| 3 |
import argparse
|
| 4 |
import collections
|
|
@@ -10,13 +9,13 @@ import time
|
|
| 10 |
import cv2
|
| 11 |
import numpy as np
|
| 12 |
|
| 13 |
-
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.
|
| 14 |
if _PROJECT_ROOT not in sys.path:
|
| 15 |
sys.path.insert(0, _PROJECT_ROOT)
|
| 16 |
|
| 17 |
-
from models.
|
| 18 |
-
from models.
|
| 19 |
-
from models.
|
| 20 |
|
| 21 |
FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 22 |
GREEN = (0, 255, 0)
|
|
@@ -38,7 +37,7 @@ assert NUM_FEATURES == 17
|
|
| 38 |
|
| 39 |
class TemporalTracker:
|
| 40 |
EAR_BLINK_THRESH = 0.21
|
| 41 |
-
MAR_YAWN_THRESH = 0.
|
| 42 |
PERCLOS_WINDOW = 60
|
| 43 |
BLINK_WINDOW_SEC = 30.0
|
| 44 |
|
|
@@ -86,25 +85,35 @@ class TemporalTracker:
|
|
| 86 |
return perclos, blink_rate, closure_dur, yawn_dur
|
| 87 |
|
| 88 |
|
| 89 |
-
def extract_features(landmarks, w, h, head_pose, eye_scorer, temporal
|
| 90 |
-
|
|
|
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
ear_avg = (ear_left + ear_right) / 2.0
|
| 95 |
-
h_gaze, v_gaze = compute_gaze_ratio(landmarks)
|
| 96 |
-
mar = compute_mar(landmarks)
|
| 97 |
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
yaw = angles[0] if angles else 0.0
|
| 100 |
pitch = angles[1] if angles else 0.0
|
| 101 |
roll = angles[2] if angles else 0.0
|
| 102 |
|
| 103 |
-
s_face = head_pose.score(landmarks, w, h)
|
| 104 |
-
s_eye = eye_scorer.score(landmarks)
|
| 105 |
|
| 106 |
gaze_offset = math.sqrt((h_gaze - 0.5) ** 2 + (v_gaze - 0.5) ** 2)
|
| 107 |
-
head_deviation = math.sqrt(yaw ** 2 + pitch ** 2)
|
| 108 |
|
| 109 |
perclos, blink_rate, closure_dur, yawn_dur = temporal.update(ear_avg, mar)
|
| 110 |
|
|
@@ -181,7 +190,7 @@ def main():
|
|
| 181 |
parser.add_argument("--duration", type=int, default=600,
|
| 182 |
help="Max recording time (seconds, default 10 min)")
|
| 183 |
parser.add_argument("--output-dir", type=str,
|
| 184 |
-
default=os.path.join(_PROJECT_ROOT, "
|
| 185 |
help="Where to save .npz files")
|
| 186 |
args = parser.parse_args()
|
| 187 |
|
|
@@ -238,13 +247,11 @@ def main():
|
|
| 238 |
landmarks = detector.process(frame)
|
| 239 |
face_ok = landmarks is not None
|
| 240 |
|
| 241 |
-
# record if labeling + face visible
|
| 242 |
if face_ok and label is not None:
|
| 243 |
vec = extract_features(landmarks, w, h, head_pose, eye_scorer, temporal)
|
| 244 |
features_list.append(vec)
|
| 245 |
labels_list.append(label)
|
| 246 |
|
| 247 |
-
# count transitions
|
| 248 |
if prev_label is not None and label != prev_label:
|
| 249 |
transitions += 1
|
| 250 |
prev_label = label
|
|
|
|
|
|
|
| 1 |
|
| 2 |
import argparse
|
| 3 |
import collections
|
|
|
|
| 9 |
import cv2
|
| 10 |
import numpy as np
|
| 11 |
|
| 12 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
if _PROJECT_ROOT not in sys.path:
|
| 14 |
sys.path.insert(0, _PROJECT_ROOT)
|
| 15 |
|
| 16 |
+
from models.face_mesh import FaceMeshDetector
|
| 17 |
+
from models.head_pose import HeadPoseEstimator
|
| 18 |
+
from models.eye_scorer import EyeBehaviourScorer, compute_gaze_ratio, compute_mar
|
| 19 |
|
| 20 |
FONT = cv2.FONT_HERSHEY_SIMPLEX
|
| 21 |
GREEN = (0, 255, 0)
|
|
|
|
| 37 |
|
| 38 |
class TemporalTracker:
|
| 39 |
EAR_BLINK_THRESH = 0.21
|
| 40 |
+
MAR_YAWN_THRESH = 0.55
|
| 41 |
PERCLOS_WINDOW = 60
|
| 42 |
BLINK_WINDOW_SEC = 30.0
|
| 43 |
|
|
|
|
| 85 |
return perclos, blink_rate, closure_dur, yawn_dur
|
| 86 |
|
| 87 |
|
| 88 |
+
def extract_features(landmarks, w, h, head_pose, eye_scorer, temporal,
|
| 89 |
+
*, _pre=None):
|
| 90 |
+
from models.eye_scorer import _LEFT_EYE_EAR, _RIGHT_EYE_EAR, compute_ear
|
| 91 |
|
| 92 |
+
p = _pre or {}
|
| 93 |
+
|
| 94 |
+
ear_left = p.get("ear_left", compute_ear(landmarks, _LEFT_EYE_EAR))
|
| 95 |
+
ear_right = p.get("ear_right", compute_ear(landmarks, _RIGHT_EYE_EAR))
|
| 96 |
ear_avg = (ear_left + ear_right) / 2.0
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
if "h_gaze" in p and "v_gaze" in p:
|
| 99 |
+
h_gaze, v_gaze = p["h_gaze"], p["v_gaze"]
|
| 100 |
+
else:
|
| 101 |
+
h_gaze, v_gaze = compute_gaze_ratio(landmarks)
|
| 102 |
+
|
| 103 |
+
mar = p.get("mar", compute_mar(landmarks))
|
| 104 |
+
|
| 105 |
+
angles = p.get("angles")
|
| 106 |
+
if angles is None:
|
| 107 |
+
angles = head_pose.estimate(landmarks, w, h)
|
| 108 |
yaw = angles[0] if angles else 0.0
|
| 109 |
pitch = angles[1] if angles else 0.0
|
| 110 |
roll = angles[2] if angles else 0.0
|
| 111 |
|
| 112 |
+
s_face = p.get("s_face", head_pose.score(landmarks, w, h))
|
| 113 |
+
s_eye = p.get("s_eye", eye_scorer.score(landmarks))
|
| 114 |
|
| 115 |
gaze_offset = math.sqrt((h_gaze - 0.5) ** 2 + (v_gaze - 0.5) ** 2)
|
| 116 |
+
head_deviation = math.sqrt(yaw ** 2 + pitch ** 2) # cleaned downstream
|
| 117 |
|
| 118 |
perclos, blink_rate, closure_dur, yawn_dur = temporal.update(ear_avg, mar)
|
| 119 |
|
|
|
|
| 190 |
parser.add_argument("--duration", type=int, default=600,
|
| 191 |
help="Max recording time (seconds, default 10 min)")
|
| 192 |
parser.add_argument("--output-dir", type=str,
|
| 193 |
+
default=os.path.join(_PROJECT_ROOT, "data", "collected_data"),
|
| 194 |
help="Where to save .npz files")
|
| 195 |
args = parser.parse_args()
|
| 196 |
|
|
|
|
| 247 |
landmarks = detector.process(frame)
|
| 248 |
face_ok = landmarks is not None
|
| 249 |
|
|
|
|
| 250 |
if face_ok and label is not None:
|
| 251 |
vec = extract_features(landmarks, w, h, head_pose, eye_scorer, temporal)
|
| 252 |
features_list.append(vec)
|
| 253 |
labels_list.append(label)
|
| 254 |
|
|
|
|
| 255 |
if prev_label is not None and label != prev_label:
|
| 256 |
transitions += 1
|
| 257 |
prev_label = label
|
models/eye_classifier.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class EyeClassifier(ABC):
|
| 9 |
+
@property
|
| 10 |
+
@abstractmethod
|
| 11 |
+
def name(self) -> str:
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
@abstractmethod
|
| 15 |
+
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class GeometricOnlyClassifier(EyeClassifier):
|
| 20 |
+
@property
|
| 21 |
+
def name(self) -> str:
|
| 22 |
+
return "geometric"
|
| 23 |
+
|
| 24 |
+
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
|
| 25 |
+
return 1.0
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class YOLOv11Classifier(EyeClassifier):
|
| 29 |
+
def __init__(self, checkpoint_path: str, device: str = "cpu"):
|
| 30 |
+
from ultralytics import YOLO
|
| 31 |
+
|
| 32 |
+
self._model = YOLO(checkpoint_path)
|
| 33 |
+
self._device = device
|
| 34 |
+
|
| 35 |
+
names = self._model.names
|
| 36 |
+
self._attentive_idx = None
|
| 37 |
+
for idx, cls_name in names.items():
|
| 38 |
+
if cls_name in ("open", "attentive"):
|
| 39 |
+
self._attentive_idx = idx
|
| 40 |
+
break
|
| 41 |
+
if self._attentive_idx is None:
|
| 42 |
+
self._attentive_idx = max(names.keys())
|
| 43 |
+
print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}")
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def name(self) -> str:
|
| 47 |
+
return "yolo"
|
| 48 |
+
|
| 49 |
+
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
|
| 50 |
+
if not crops_bgr:
|
| 51 |
+
return 1.0
|
| 52 |
+
results = self._model.predict(crops_bgr, device=self._device, verbose=False)
|
| 53 |
+
scores = [float(r.probs.data[self._attentive_idx]) for r in results]
|
| 54 |
+
return sum(scores) / len(scores) if scores else 1.0
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_eye_classifier(
|
| 58 |
+
path: str | None = None,
|
| 59 |
+
backend: str = "yolo",
|
| 60 |
+
device: str = "cpu",
|
| 61 |
+
) -> EyeClassifier:
|
| 62 |
+
if path is None or backend == "geometric":
|
| 63 |
+
return GeometricOnlyClassifier()
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
return YOLOv11Classifier(path, device=device)
|
| 67 |
+
except ImportError:
|
| 68 |
+
print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics")
|
| 69 |
+
raise
|
models/eye_crop.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from models.face_mesh import FaceMeshDetector
|
| 5 |
+
|
| 6 |
+
LEFT_EYE_CONTOUR = FaceMeshDetector.LEFT_EYE_INDICES
|
| 7 |
+
RIGHT_EYE_CONTOUR = FaceMeshDetector.RIGHT_EYE_INDICES
|
| 8 |
+
|
| 9 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 10 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 11 |
+
|
| 12 |
+
CROP_SIZE = 96
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _bbox_from_landmarks(
|
| 16 |
+
landmarks: np.ndarray,
|
| 17 |
+
indices: list[int],
|
| 18 |
+
frame_w: int,
|
| 19 |
+
frame_h: int,
|
| 20 |
+
expand: float = 0.4,
|
| 21 |
+
) -> tuple[int, int, int, int]:
|
| 22 |
+
pts = landmarks[indices, :2]
|
| 23 |
+
px = pts[:, 0] * frame_w
|
| 24 |
+
py = pts[:, 1] * frame_h
|
| 25 |
+
|
| 26 |
+
x_min, x_max = px.min(), px.max()
|
| 27 |
+
y_min, y_max = py.min(), py.max()
|
| 28 |
+
w = x_max - x_min
|
| 29 |
+
h = y_max - y_min
|
| 30 |
+
cx = (x_min + x_max) / 2
|
| 31 |
+
cy = (y_min + y_max) / 2
|
| 32 |
+
|
| 33 |
+
size = max(w, h) * (1 + expand)
|
| 34 |
+
half = size / 2
|
| 35 |
+
|
| 36 |
+
x1 = int(max(cx - half, 0))
|
| 37 |
+
y1 = int(max(cy - half, 0))
|
| 38 |
+
x2 = int(min(cx + half, frame_w))
|
| 39 |
+
y2 = int(min(cy + half, frame_h))
|
| 40 |
+
|
| 41 |
+
return x1, y1, x2, y2
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def extract_eye_crops(
|
| 45 |
+
frame: np.ndarray,
|
| 46 |
+
landmarks: np.ndarray,
|
| 47 |
+
expand: float = 0.4,
|
| 48 |
+
crop_size: int = CROP_SIZE,
|
| 49 |
+
) -> tuple[np.ndarray, np.ndarray, tuple, tuple]:
|
| 50 |
+
h, w = frame.shape[:2]
|
| 51 |
+
|
| 52 |
+
left_bbox = _bbox_from_landmarks(landmarks, LEFT_EYE_CONTOUR, w, h, expand)
|
| 53 |
+
right_bbox = _bbox_from_landmarks(landmarks, RIGHT_EYE_CONTOUR, w, h, expand)
|
| 54 |
+
|
| 55 |
+
left_crop = frame[left_bbox[1] : left_bbox[3], left_bbox[0] : left_bbox[2]]
|
| 56 |
+
right_crop = frame[right_bbox[1] : right_bbox[3], right_bbox[0] : right_bbox[2]]
|
| 57 |
+
|
| 58 |
+
if left_crop.size == 0:
|
| 59 |
+
left_crop = np.zeros((crop_size, crop_size, 3), dtype=np.uint8)
|
| 60 |
+
else:
|
| 61 |
+
left_crop = cv2.resize(left_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
|
| 62 |
+
|
| 63 |
+
if right_crop.size == 0:
|
| 64 |
+
right_crop = np.zeros((crop_size, crop_size, 3), dtype=np.uint8)
|
| 65 |
+
else:
|
| 66 |
+
right_crop = cv2.resize(right_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
|
| 67 |
+
|
| 68 |
+
return left_crop, right_crop, left_bbox, right_bbox
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def crop_to_tensor(crop_bgr: np.ndarray):
|
| 72 |
+
import torch
|
| 73 |
+
|
| 74 |
+
rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
| 75 |
+
for c in range(3):
|
| 76 |
+
rgb[:, :, c] = (rgb[:, :, c] - IMAGENET_MEAN[c]) / IMAGENET_STD[c]
|
| 77 |
+
return torch.from_numpy(rgb.transpose(2, 0, 1))
|
models/{geometric/eye_behaviour/eye_scorer.py β eye_scorer.py}
RENAMED
|
@@ -95,7 +95,6 @@ def compute_gaze_ratio(landmarks: np.ndarray) -> tuple[float, float]:
|
|
| 95 |
|
| 96 |
|
| 97 |
def compute_mar(landmarks: np.ndarray) -> float:
|
| 98 |
-
# Mouth aspect ratio: high = mouth open (yawning / sleepy)
|
| 99 |
top = landmarks[_MOUTH_TOP, :2]
|
| 100 |
bottom = landmarks[_MOUTH_BOTTOM, :2]
|
| 101 |
left = landmarks[_MOUTH_LEFT, :2]
|
|
@@ -140,7 +139,10 @@ class EyeBehaviourScorer:
|
|
| 140 |
return 0.5 * (1.0 + math.cos(math.pi * t))
|
| 141 |
|
| 142 |
def score(self, landmarks: np.ndarray) -> float:
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
| 144 |
ear_s = self._ear_score(ear)
|
| 145 |
if ear_s < 0.3:
|
| 146 |
return ear_s
|
|
@@ -149,7 +151,9 @@ class EyeBehaviourScorer:
|
|
| 149 |
return ear_s * gaze_s
|
| 150 |
|
| 151 |
def detailed_score(self, landmarks: np.ndarray) -> dict:
|
| 152 |
-
|
|
|
|
|
|
|
| 153 |
ear_s = self._ear_score(ear)
|
| 154 |
h_ratio, v_ratio = compute_gaze_ratio(landmarks)
|
| 155 |
gaze_s = self._gaze_score(h_ratio, v_ratio)
|
|
|
|
| 95 |
|
| 96 |
|
| 97 |
def compute_mar(landmarks: np.ndarray) -> float:
|
|
|
|
| 98 |
top = landmarks[_MOUTH_TOP, :2]
|
| 99 |
bottom = landmarks[_MOUTH_BOTTOM, :2]
|
| 100 |
left = landmarks[_MOUTH_LEFT, :2]
|
|
|
|
| 139 |
return 0.5 * (1.0 + math.cos(math.pi * t))
|
| 140 |
|
| 141 |
def score(self, landmarks: np.ndarray) -> float:
|
| 142 |
+
left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
|
| 143 |
+
right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
|
| 144 |
+
# Use minimum EAR so closing ONE eye is enough to drop the score
|
| 145 |
+
ear = min(left_ear, right_ear)
|
| 146 |
ear_s = self._ear_score(ear)
|
| 147 |
if ear_s < 0.3:
|
| 148 |
return ear_s
|
|
|
|
| 151 |
return ear_s * gaze_s
|
| 152 |
|
| 153 |
def detailed_score(self, landmarks: np.ndarray) -> dict:
|
| 154 |
+
left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
|
| 155 |
+
right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
|
| 156 |
+
ear = min(left_ear, right_ear)
|
| 157 |
ear_s = self._ear_score(ear)
|
| 158 |
h_ratio, v_ratio = compute_gaze_ratio(landmarks)
|
| 159 |
gaze_s = self._gaze_score(h_ratio, v_ratio)
|
models/{pretrained/face_mesh/face_mesh.py β face_mesh.py}
RENAMED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
from urllib.request import urlretrieve
|
| 4 |
|
|
@@ -51,14 +52,16 @@ class FaceMeshDetector:
|
|
| 51 |
running_mode=RunningMode.VIDEO,
|
| 52 |
)
|
| 53 |
self._landmarker = FaceLandmarker.create_from_options(options)
|
| 54 |
-
self.
|
|
|
|
| 55 |
|
| 56 |
def process(self, bgr_frame: np.ndarray) -> np.ndarray | None:
|
| 57 |
# BGR in -> (478,3) norm x,y,z or None
|
| 58 |
rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
|
| 59 |
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
|
| 60 |
-
self.
|
| 61 |
-
|
|
|
|
| 62 |
|
| 63 |
if not result.face_landmarks:
|
| 64 |
return None
|
|
|
|
| 1 |
import os
|
| 2 |
+
import time
|
| 3 |
from pathlib import Path
|
| 4 |
from urllib.request import urlretrieve
|
| 5 |
|
|
|
|
| 52 |
running_mode=RunningMode.VIDEO,
|
| 53 |
)
|
| 54 |
self._landmarker = FaceLandmarker.create_from_options(options)
|
| 55 |
+
self._t0 = time.monotonic()
|
| 56 |
+
self._last_ts = 0
|
| 57 |
|
| 58 |
def process(self, bgr_frame: np.ndarray) -> np.ndarray | None:
|
| 59 |
# BGR in -> (478,3) norm x,y,z or None
|
| 60 |
rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
|
| 61 |
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
|
| 62 |
+
ts = max(int((time.monotonic() - self._t0) * 1000), self._last_ts + 1)
|
| 63 |
+
self._last_ts = ts
|
| 64 |
+
result = self._landmarker.detect_for_video(mp_image, ts)
|
| 65 |
|
| 66 |
if not result.face_landmarks:
|
| 67 |
return None
|
models/geometric/eye_behaviour/__init__.py
DELETED
|
File without changes
|
models/geometric/face_orientation/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
|
|
|
|
|
|
models/{geometric/face_orientation/head_pose.py β head_pose.py}
RENAMED
|
@@ -25,6 +25,8 @@ class HeadPoseEstimator:
|
|
| 25 |
self._camera_matrix = None
|
| 26 |
self._frame_size = None
|
| 27 |
self._dist_coeffs = np.zeros((4, 1), dtype=np.float64)
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def _get_camera_matrix(self, frame_w: int, frame_h: int) -> np.ndarray:
|
| 30 |
if self._camera_matrix is not None and self._frame_size == (frame_w, frame_h):
|
|
@@ -39,6 +41,10 @@ class HeadPoseEstimator:
|
|
| 39 |
return self._camera_matrix
|
| 40 |
|
| 41 |
def _solve(self, landmarks: np.ndarray, frame_w: int, frame_h: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
image_points = np.array(
|
| 43 |
[
|
| 44 |
[landmarks[i, 0] * frame_w, landmarks[i, 1] * frame_h]
|
|
@@ -54,7 +60,10 @@ class HeadPoseEstimator:
|
|
| 54 |
self._dist_coeffs,
|
| 55 |
flags=cv2.SOLVEPNP_ITERATIVE,
|
| 56 |
)
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
def estimate(
|
| 60 |
self, landmarks: np.ndarray, frame_w: int, frame_h: int
|
|
|
|
| 25 |
self._camera_matrix = None
|
| 26 |
self._frame_size = None
|
| 27 |
self._dist_coeffs = np.zeros((4, 1), dtype=np.float64)
|
| 28 |
+
self._cache_key = None
|
| 29 |
+
self._cache_result = None
|
| 30 |
|
| 31 |
def _get_camera_matrix(self, frame_w: int, frame_h: int) -> np.ndarray:
|
| 32 |
if self._camera_matrix is not None and self._frame_size == (frame_w, frame_h):
|
|
|
|
| 41 |
return self._camera_matrix
|
| 42 |
|
| 43 |
def _solve(self, landmarks: np.ndarray, frame_w: int, frame_h: int):
|
| 44 |
+
key = (landmarks.data.tobytes(), frame_w, frame_h)
|
| 45 |
+
if self._cache_key == key:
|
| 46 |
+
return self._cache_result
|
| 47 |
+
|
| 48 |
image_points = np.array(
|
| 49 |
[
|
| 50 |
[landmarks[i, 0] * frame_w, landmarks[i, 1] * frame_h]
|
|
|
|
| 60 |
self._dist_coeffs,
|
| 61 |
flags=cv2.SOLVEPNP_ITERATIVE,
|
| 62 |
)
|
| 63 |
+
result = (success, rvec, tvec, image_points)
|
| 64 |
+
self._cache_key = key
|
| 65 |
+
self._cache_result = result
|
| 66 |
+
return result
|
| 67 |
|
| 68 |
def estimate(
|
| 69 |
self, landmarks: np.ndarray, frame_w: int, frame_h: int
|