Yingtao-Zheng commited on
Commit
4a5bfab
Β·
1 Parent(s): 82d2ab7

Put all the models together (expect UI)

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitignore +29 -8
  2. Dockerfile +27 -0
  3. README.md +87 -6
  4. app.py +1 -0
  5. checkpoints/hybrid_focus_config.json +10 -0
  6. MLP/models/meta_20260224_024200.npz β†’ checkpoints/meta_best.npz +2 -2
  7. MLP/models/mlp_20260224_024200.joblib β†’ checkpoints/mlp_best.pt +2 -2
  8. best_eye_cnn.pth β†’ checkpoints/model_best.joblib +2 -2
  9. MLP/models/scaler_20260224_024200.joblib β†’ checkpoints/scaler_best.joblib +2 -2
  10. checkpoints/xgboost_face_orientation_best.json +0 -0
  11. {data_preparation β†’ data}/CNN/eye_crops/val/open/.gitkeep +0 -0
  12. data/README.md +47 -0
  13. {data_preparation β†’ data}/collected_Abdelrahman/abdelrahman_20260306_023035.npz +0 -0
  14. {data_preparation β†’ data}/collected_Jarek/Jarek_20260225_012931.npz +0 -0
  15. {data_preparation β†’ data}/collected_Junhao/Junhao_20260303_113554.npz +0 -0
  16. {data_preparation β†’ data}/collected_Kexin/kexin2_20260305_180229.npz +0 -0
  17. {data_preparation β†’ data}/collected_Kexin/kexin_20260224_151043.npz +0 -0
  18. {data_preparation β†’ data}/collected_Langyuan/Langyuan_20260303_153145.npz +0 -0
  19. {data_preparation β†’ data}/collected_Mohamed/session_20260224_010131.npz +0 -0
  20. {data_preparation β†’ data}/collected_Yingtao/Yingtao_20260306_023937.npz +0 -0
  21. {data_preparation/collected_Ayten β†’ data/collected_ayten}/ayten_session_1.npz +0 -0
  22. {data_preparation/collected_Saba β†’ data/collected_saba}/saba_20260306_230710.npz +0 -0
  23. data_preparation/MLP/explore_collected_data.ipynb +0 -0
  24. data_preparation/MLP/train_mlp.ipynb +0 -0
  25. data_preparation/README.md +61 -27
  26. {models/geometric β†’ data_preparation}/__init__.py +0 -0
  27. data_preparation/data_exploration.ipynb +0 -0
  28. data_preparation/prepare_dataset.py +232 -0
  29. docker-compose.yml +5 -0
  30. eslint.config.js +29 -0
  31. evaluation/README.md +45 -2
  32. index.html +17 -0
  33. main.py +964 -0
  34. models/README.md +51 -8
  35. models/{attention/__init__.py β†’ __init__.py} +0 -0
  36. models/attention/classifier.py +0 -0
  37. models/attention/fusion.py +0 -0
  38. models/attention/train.py +0 -0
  39. models/cnn/notebooks/EyeCNN.ipynb +107 -0
  40. models/cnn/notebooks/EyeCNN_Train_Evaluate_new.ipynb +0 -0
  41. models/cnn/notebooks/EyeCNN_Training_Evaluate.ipynb +0 -0
  42. models/cnn/notebooks/README.md +1 -0
  43. models/{attention/collect_features.py β†’ collect_features.py} +26 -19
  44. models/eye_classifier.py +69 -0
  45. models/eye_crop.py +77 -0
  46. models/{geometric/eye_behaviour/eye_scorer.py β†’ eye_scorer.py} +7 -3
  47. models/{pretrained/face_mesh/face_mesh.py β†’ face_mesh.py} +6 -3
  48. models/geometric/eye_behaviour/__init__.py +0 -0
  49. models/geometric/face_orientation/__init__.py +0 -1
  50. models/{geometric/face_orientation/head_pose.py β†’ head_pose.py} +10 -1
.gitignore CHANGED
@@ -1,4 +1,26 @@
1
- __pycache__/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  *.py[cod]
3
  *$py.class
4
  *.so
@@ -9,12 +31,11 @@ env/
9
  .env
10
  *.egg-info/
11
  .eggs/
12
- dist/
13
  build/
14
- .idea/
15
- .vscode/
16
- *.swp
17
- *.swo
18
- docs/
19
- .DS_Store
20
  Thumbs.db
 
 
 
 
 
 
 
1
+ # Logs
2
+ logs
3
+ *.log
4
+ npm-debug.log*
5
+ yarn-debug.log*
6
+ yarn-error.log*
7
+ pnpm-debug.log*
8
+ lerna-debug.log*
9
+
10
+ node_modules/
11
+ dist/
12
+ dist-ssr/
13
+ *.local
14
+
15
+ # Editor directories and files
16
+ .vscode/
17
+ .idea/
18
+ .DS_Store
19
+ *.suo
20
+ *.ntvs*
21
+ *.njsproj
22
+ *.sln
23
+ *.sw?
24
  *.py[cod]
25
  *$py.class
26
  *.so
 
31
  .env
32
  *.egg-info/
33
  .eggs/
 
34
  build/
 
 
 
 
 
 
35
  Thumbs.db
36
+
37
+ # Project specific
38
+ focus_guard.db
39
+ static/
40
+ __pycache__/
41
+ docs/
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ RUN useradd -m -u 1000 user
4
+ ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
5
+
6
+ ENV PYTHONUNBUFFERED=1
7
+
8
+ WORKDIR /app
9
+
10
+ RUN apt-get update && apt-get install -y --no-install-recommends libglib2.0-0 libsm6 libxrender1 libxext6 libxcb1 libgl1 libgomp1 ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev libavdevice-dev libopus-dev libvpx-dev libsrtp2-dev build-essential nodejs npm && rm -rf /var/lib/apt/lists/*
11
+
12
+ COPY requirements.txt ./
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ COPY . .
16
+
17
+ RUN npm install && npm run build && mkdir -p /app/static && cp -R dist/* /app/static/
18
+
19
+ ENV FOCUSGUARD_CACHE_DIR=/app/.cache/focusguard
20
+ RUN python -c "from models.face_mesh import _ensure_model; _ensure_model()"
21
+
22
+ RUN mkdir -p /app/data && chown -R user:user /app
23
+
24
+ USER user
25
+ EXPOSE 7860
26
+
27
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--log-level", "debug"]
README.md CHANGED
@@ -1,10 +1,91 @@
1
  # FocusGuard
2
 
3
- Webcam-based focus detection: face mesh, head pose, eye (geometry or YOLO), plus an MLP trained on collected features.
4
 
5
- - **data_preparation/** β€” collect data, notebooks, processed/collected files
6
- - **models/** β€” face mesh, head pose, eye scorer, YOLO classifier, MLP training, attention feature collection
7
- - **evaluation/** β€” metrics and run logs
8
- - **ui/** β€” live demo (geometry+YOLO or MLP-only)
9
 
10
- Run from here: `pip install -r requirements.txt` then `python ui/live_demo.py` or `python ui/live_demo.py --mlp`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # FocusGuard
2
 
3
+ Real-time webcam-based focus detection system combining geometric feature extraction with machine learning classification. The pipeline extracts 17 facial features (EAR, gaze, head pose, PERCLOS, blink rate, etc.) from MediaPipe landmarks and classifies attentiveness using MLP and XGBoost models. Served via a React + FastAPI web application with live WebSocket video.
4
 
5
+ ## 1. Project Structure
 
 
 
6
 
7
+ ```
8
+ β”œβ”€β”€ data/ Raw collected sessions (collected_<name>/*.npz)
9
+ β”œβ”€β”€ data_preparation/ Data loading, cleaning, and exploration
10
+ β”œβ”€β”€ notebooks/ Training notebooks (MLP, XGBoost) with LOPO evaluation
11
+ β”œβ”€β”€ models/ Feature extraction modules and training scripts
12
+ β”œβ”€β”€ checkpoints/ All saved weights (mlp_best.pt, xgboost_*_best.json, GRU, scalers)
13
+ β”œβ”€β”€ evaluation/ Training logs and metrics (JSON)
14
+ β”œβ”€β”€ ui/ Live OpenCV demo and inference pipeline
15
+ β”œβ”€β”€ src/ React/Vite frontend source
16
+ β”œβ”€β”€ static/ Built frontend (served by FastAPI)
17
+ β”œβ”€β”€ app.py / main.py FastAPI backend (API, WebSocket, DB)
18
+ β”œβ”€β”€ requirements.txt Python dependencies
19
+ └── package.json Frontend dependencies
20
+ ```
21
+
22
+ ## 2. Setup
23
+
24
+ ```bash
25
+ python -m venv venv
26
+ source venv/bin/activate
27
+ pip install -r requirements.txt
28
+ ```
29
+
30
+ Frontend (only needed if modifying the React app):
31
+
32
+ ```bash
33
+ npm install
34
+ npm run build
35
+ cp -r dist/* static/
36
+ ```
37
+
38
+ ## 3. Running
39
+
40
+ **Web application (API + frontend):**
41
+
42
+ ```bash
43
+ uvicorn app:app --host 0.0.0.0 --port 7860
44
+ ```
45
+
46
+ Open http://localhost:7860 in a browser.
47
+
48
+ **Live camera demo (OpenCV):**
49
+
50
+ ```bash
51
+ python ui/live_demo.py
52
+ python ui/live_demo.py --xgb # XGBoost mode
53
+ ```
54
+
55
+ **Training:**
56
+
57
+ ```bash
58
+ python -m models.mlp.train # MLP
59
+ python -m models.xgboost.train # XGBoost
60
+ ```
61
+
62
+ ## 4. Dataset
63
+
64
+ - **9 participants**, each recorded via webcam with real-time labelling (focused / unfocused)
65
+ - **144,793 total samples**, 10 selected features, binary classification
66
+ - Collected using `python -m models.collect_features --name <name>`
67
+ - Stored as `.npz` files in `data/collected_<name>/`
68
+
69
+ ## 5. Models
70
+
71
+ | Model | Test Accuracy | Test F1 | ROC-AUC |
72
+ |-------|--------------|---------|---------|
73
+ | XGBoost (600 trees, depth 8, lr 0.149) | 95.87% | 0.959 | 0.991 |
74
+ | MLP (64β†’32, 30 epochs, lr 1e-3) | 92.92% | 0.929 | 0.971 |
75
+
76
+ Both evaluated on a held-out 15% stratified test split. LOPO (Leave-One-Person-Out) cross-validation available in `notebooks/`.
77
+
78
+ ## 6. Feature Pipeline
79
+
80
+ 1. **Face mesh** β€” MediaPipe 478-landmark detection
81
+ 2. **Head pose** β€” solvePnP β†’ yaw, pitch, roll, face score, gaze offset, head deviation
82
+ 3. **Eye scorer** β€” EAR (left/right/avg), horizontal/vertical gaze ratio, MAR
83
+ 4. **Temporal tracking** β€” PERCLOS, blink rate, closure duration, yawn duration
84
+ 5. **Classification** β€” 10-feature vector β†’ MLP or XGBoost β†’ focused / unfocused
85
+
86
+ ## 7. Tech Stack
87
+
88
+ - **Backend:** Python, FastAPI, WebSocket, aiosqlite
89
+ - **Frontend:** React, Vite, TypeScript
90
+ - **ML:** PyTorch (MLP), XGBoost, scikit-learn
91
+ - **Vision:** MediaPipe, OpenCV
app.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from main import app
checkpoints/hybrid_focus_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "w_mlp": 0.6000000000000001,
3
+ "w_geo": 0.3999999999999999,
4
+ "threshold": 0.35,
5
+ "use_yawn_veto": true,
6
+ "geo_face_weight": 0.4,
7
+ "geo_eye_weight": 0.6,
8
+ "mar_yawn_threshold": 0.55,
9
+ "metric": "f1"
10
+ }
MLP/models/meta_20260224_024200.npz β†’ checkpoints/meta_best.npz RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:769bb62c7bf04aafd808e9b2623e795c2d92bcb933313ebf553d6fce5ebe7143
3
- size 1616
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d78d1df5e25536a2c82c4b8f5fd0c26dd35f44b28fd59761634cbf78c7546f8
3
+ size 4196
MLP/models/mlp_20260224_024200.joblib β†’ checkpoints/mlp_best.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a72933fcf2d0aed998c6303ea4298c04618d937c7f17bf492e76efcf3b4b54d7
3
- size 50484
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2f55129785b6882c304483aa5399f5bf6c9ed6e73dfec7ca6f36cd0436156c8
3
+ size 14497
best_eye_cnn.pth β†’ checkpoints/model_best.joblib RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c3c3d85de013387e8583fe7218daabb83a8a6f46ca5bcacbf6fbf3619b688da8
3
- size 2103809
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:183f2d4419e0eb1e58704e5a7312eb61e331523566d4dc551054a07b3aac7557
3
+ size 5775881
MLP/models/scaler_20260224_024200.joblib β†’ checkpoints/scaler_best.joblib RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3f9ef3721cee28f1472886556e001d0f6ed0abe09011d979a70ca9bf447d453e
3
- size 823
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02ed6b4c0d99e0254c6a740a949da2384db58ec7d3e6df6432b9bfcd3a296c71
3
+ size 783
checkpoints/xgboost_face_orientation_best.json ADDED
The diff for this file is too large to render. See raw diff
 
{data_preparation β†’ data}/CNN/eye_crops/val/open/.gitkeep RENAMED
File without changes
data/README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data/
2
+
3
+ Raw collected session data used for model training and evaluation.
4
+
5
+ ## 1. Contents
6
+
7
+ Each `collected_<name>/` folder contains `.npz` files for one participant:
8
+
9
+ | Folder | Participant | Samples |
10
+ |--------|-------------|---------|
11
+ | `collected_Abdelrahman/` | Abdelrahman | 15,870 |
12
+ | `collected_Jarek/` | Jarek | 14,829 |
13
+ | `collected_Junhao/` | Junhao | 8,901 |
14
+ | `collected_Kexin/` | Kexin | 32,312 (2 sessions) |
15
+ | `collected_Langyuan/` | Langyuan | 15,749 |
16
+ | `collected_Mohamed/` | Mohamed | 13,218 |
17
+ | `collected_Yingtao/` | Yingtao | 17,591 |
18
+ | `collected_ayten/` | Ayten | 17,621 |
19
+ | `collected_saba/` | Saba | 8,702 |
20
+ | **Total** | **9 participants** | **144,793** |
21
+
22
+ ## 2. File Format
23
+
24
+ Each `.npz` file contains:
25
+
26
+ | Key | Shape | Description |
27
+ |-----|-------|-------------|
28
+ | `features` | (N, 17) | 17-dimensional feature vectors (float32) |
29
+ | `labels` | (N,) | Binary labels: 0 = unfocused, 1 = focused |
30
+ | `feature_names` | (17,) | Column names for the 17 features |
31
+
32
+ ## 3. Feature List
33
+
34
+ `ear_left`, `ear_right`, `ear_avg`, `h_gaze`, `v_gaze`, `mar`, `yaw`, `pitch`, `roll`, `s_face`, `s_eye`, `gaze_offset`, `head_deviation`, `perclos`, `blink_rate`, `closure_duration`, `yawn_duration`
35
+
36
+ 10 of these are selected for training (see `data_preparation/prepare_dataset.py`).
37
+
38
+ ## 4. Collection
39
+
40
+ ```bash
41
+ python -m models.collect_features --name yourname
42
+ ```
43
+
44
+ 1. Webcam opens with live overlay
45
+ 2. Press **1** = focused, **0** = unfocused (switch every 10–30 sec)
46
+ 3. Press **p** to pause/resume
47
+ 4. Press **q** to stop and save
{data_preparation β†’ data}/collected_Abdelrahman/abdelrahman_20260306_023035.npz RENAMED
File without changes
{data_preparation β†’ data}/collected_Jarek/Jarek_20260225_012931.npz RENAMED
File without changes
{data_preparation β†’ data}/collected_Junhao/Junhao_20260303_113554.npz RENAMED
File without changes
{data_preparation β†’ data}/collected_Kexin/kexin2_20260305_180229.npz RENAMED
File without changes
{data_preparation β†’ data}/collected_Kexin/kexin_20260224_151043.npz RENAMED
File without changes
{data_preparation β†’ data}/collected_Langyuan/Langyuan_20260303_153145.npz RENAMED
File without changes
{data_preparation β†’ data}/collected_Mohamed/session_20260224_010131.npz RENAMED
File without changes
{data_preparation β†’ data}/collected_Yingtao/Yingtao_20260306_023937.npz RENAMED
File without changes
{data_preparation/collected_Ayten β†’ data/collected_ayten}/ayten_session_1.npz RENAMED
File without changes
{data_preparation/collected_Saba β†’ data/collected_saba}/saba_20260306_230710.npz RENAMED
File without changes
data_preparation/MLP/explore_collected_data.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
data_preparation/MLP/train_mlp.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
data_preparation/README.md CHANGED
@@ -1,41 +1,75 @@
1
- # Data Preparation
2
 
3
- ## Folder Structure
4
 
5
- ### collected/
6
- Contains raw session files in `.npz` format.
7
- Generated using:
8
 
9
- python -m models.attention.collect_features
 
 
 
10
 
11
- Each session includes:
12
- - 17-dimensional feature vectors
13
- - Corresponding labels
14
 
15
- ---
16
 
 
 
 
 
 
 
 
17
 
18
- ### MLP/
19
- Contains notebooks for:
20
- - Exploring collected data
21
- - Training the sklearn MLP model (10 features)
22
 
23
- Trained models are saved to:
24
- ../MLP/models/
25
 
26
- ---
27
 
28
- ### CNN/
29
- Eye crop directory structure for CNN training (YOLO).
30
 
31
- ---
32
 
33
- ## Collecting Data
34
 
35
- **Step-by-step**
 
 
 
36
 
37
- 1. From repo root Install deps: `pip install -r requirements.txt`.
38
- 3. Run: `python -m models.attention.collect_features --name yourname`.
39
- 4. Webcam opens. Look at the camera; press **1** when focused, **0** when unfocused. Switch every 10–30 sec so you get both labels.
40
- 5. Press **p** to pause/resume.
41
- 6. Press **q** when done. One `.npz` is saved to `data_preparation/collected/` (17 features + labels).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data_preparation/
2
 
3
+ Shared data loading, cleaning, and exploratory analysis.
4
 
5
+ ## 1. Files
 
 
6
 
7
+ | File | Description |
8
+ |------|-------------|
9
+ | `prepare_dataset.py` | Central data loading module used by all training scripts and notebooks |
10
+ | `data_exploration.ipynb` | EDA notebook: feature distributions, class balance, correlations |
11
 
12
+ ## 2. prepare_dataset.py
 
 
13
 
14
+ Provides a consistent pipeline for loading raw `.npz` data from `data/`:
15
 
16
+ | Function | Purpose |
17
+ |----------|---------|
18
+ | `load_all_pooled(model_name)` | Load all participants, clean, select features, concatenate |
19
+ | `load_per_person(model_name)` | Load grouped by person (for LOPO cross-validation) |
20
+ | `get_numpy_splits(model_name)` | Load + stratified 70/15/15 split + StandardScaler |
21
+ | `get_dataloaders(model_name)` | Same as above, wrapped in PyTorch DataLoaders |
22
+ | `_split_and_scale(features, labels, ...)` | Reusable split + optional scaling |
23
 
24
+ ### Cleaning rules
 
 
 
25
 
26
+ - `yaw` clipped to [-45, 45], `pitch`/`roll` to [-30, 30]
27
+ - `ear_left`, `ear_right`, `ear_avg` clipped to [0, 0.85]
28
 
29
+ ### Selected features (face_orientation)
30
 
31
+ `head_deviation`, `s_face`, `s_eye`, `h_gaze`, `pitch`, `ear_left`, `ear_avg`, `ear_right`, `gaze_offset`, `perclos`
 
32
 
33
+ ## 3. data_exploration.ipynb
34
 
35
+ Run from this folder or from the project root. Covers:
36
 
37
+ 1. Per-feature statistics (mean, std, min, max)
38
+ 2. Class distribution (focused vs unfocused)
39
+ 3. Feature histograms and box plots
40
+ 4. Correlation matrix
41
 
42
+ ## 4. How to run
43
+
44
+ `prepare_dataset.py` is a **library module**, not a standalone script. You don’t run it directly; you import it from code that needs data.
45
+
46
+ **From repo root:**
47
+
48
+ ```bash
49
+ # Optional: quick test that loading works
50
+ python -c "
51
+ from data_preparation.prepare_dataset import load_all_pooled
52
+ X, y, names = load_all_pooled('face_orientation')
53
+ print(f'Loaded {X.shape[0]} samples, {X.shape[1]} features: {names}')
54
+ "
55
+ ```
56
+
57
+ **Used by:**
58
+
59
+ - `python -m models.mlp.train`
60
+ - `python -m models.xgboost.train`
61
+ - `notebooks/mlp.ipynb`, `notebooks/xgboost.ipynb`
62
+ - `data_preparation/data_exploration.ipynb`
63
+
64
+ ## 5. Usage (in code)
65
+
66
+ ```python
67
+ from data_preparation.prepare_dataset import load_all_pooled, get_numpy_splits
68
+
69
+ # pooled data
70
+ X, y, names = load_all_pooled("face_orientation")
71
+
72
+ # ready-to-train splits
73
+ splits, n_features, n_classes, scaler = get_numpy_splits("face_orientation")
74
+ X_train, y_train = splits["X_train"], splits["y_train"]
75
+ ```
{models/geometric β†’ data_preparation}/__init__.py RENAMED
File without changes
data_preparation/data_exploration.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
data_preparation/prepare_dataset.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+
4
+ import numpy as np
5
+ from sklearn.preprocessing import StandardScaler
6
+ from sklearn.model_selection import train_test_split
7
+
8
+ try:
9
+ import torch
10
+ from torch.utils.data import Dataset, DataLoader
11
+ except ImportError: # pragma: no cover
12
+ torch = None
13
+
14
+ class Dataset: # type: ignore
15
+ pass
16
+
17
+ class _MissingTorchDataLoader: # type: ignore
18
+ def __init__(self, *args, **kwargs):
19
+ raise ImportError(
20
+ "PyTorch not installed"
21
+ )
22
+
23
+ DataLoader = _MissingTorchDataLoader # type: ignore
24
+
25
+ DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
26
+
27
+ SELECTED_FEATURES = {
28
+ "face_orientation": [
29
+ 'head_deviation', 's_face', 's_eye', 'h_gaze', 'pitch',
30
+ 'ear_left', 'ear_avg', 'ear_right', 'gaze_offset', 'perclos'
31
+ ],
32
+ "eye_behaviour": [
33
+ 'ear_left', 'ear_right', 'ear_avg', 'mar',
34
+ 'blink_rate', 'closure_duration', 'perclos', 'yawn_duration'
35
+ ]
36
+ }
37
+
38
+
39
+ class FeatureVectorDataset(Dataset):
40
+ def __init__(self, features: np.ndarray, labels: np.ndarray):
41
+ self.features = torch.tensor(features, dtype=torch.float32)
42
+ self.labels = torch.tensor(labels, dtype=torch.long)
43
+
44
+ def __len__(self):
45
+ return len(self.labels)
46
+
47
+ def __getitem__(self, idx):
48
+ return self.features[idx], self.labels[idx]
49
+
50
+
51
+ # ── Low-level helpers ────────────────────────────────────────────────────
52
+
53
+ def _clean_npz(raw, names):
54
+ """Apply clipping rules in-place. Shared by all loaders."""
55
+ for col, lo, hi in [('yaw', -45, 45), ('pitch', -30, 30), ('roll', -30, 30)]:
56
+ if col in names:
57
+ raw[:, names.index(col)] = np.clip(raw[:, names.index(col)], lo, hi)
58
+ for feat in ['ear_left', 'ear_right', 'ear_avg']:
59
+ if feat in names:
60
+ raw[:, names.index(feat)] = np.clip(raw[:, names.index(feat)], 0, 0.85)
61
+ return raw
62
+
63
+
64
+ def _load_one_npz(npz_path, target_features):
65
+ """Load a single .npz file, clean and select features. Returns (X, y, selected_feature_names)."""
66
+ data = np.load(npz_path, allow_pickle=True)
67
+ raw = data['features'].astype(np.float32)
68
+ labels = data['labels'].astype(np.int64)
69
+ names = list(data['feature_names'])
70
+ raw = _clean_npz(raw, names)
71
+ selected = [f for f in target_features if f in names]
72
+ idx = [names.index(f) for f in selected]
73
+ return raw[:, idx], labels, selected
74
+
75
+
76
+ # ── Public data loaders ──────────────────────────────────────────────────
77
+
78
+ def load_all_pooled(model_name: str = "face_orientation", data_dir: str = None):
79
+ """Load all collected_*/*.npz, clean, select features, concatenate.
80
+
81
+ Returns (X_all, y_all, all_feature_names).
82
+ """
83
+ data_dir = data_dir or DATA_DIR
84
+ target_features = SELECTED_FEATURES.get(model_name, SELECTED_FEATURES["face_orientation"])
85
+ pattern = os.path.join(data_dir, "collected_*", "*.npz")
86
+ npz_files = sorted(glob.glob(pattern))
87
+
88
+ if not npz_files:
89
+ print("[DATA] Warning: No .npz files found. Falling back to synthetic.")
90
+ X, y = _generate_synthetic_data(model_name)
91
+ return X, y, target_features
92
+
93
+ all_X, all_y = [], []
94
+ all_names = None
95
+ for npz_path in npz_files:
96
+ X, y, names = _load_one_npz(npz_path, target_features)
97
+ if all_names is None:
98
+ all_names = names
99
+ all_X.append(X)
100
+ all_y.append(y)
101
+ print(f"[DATA] + {os.path.basename(npz_path)}: {X.shape[0]} samples")
102
+
103
+ X_all = np.concatenate(all_X, axis=0)
104
+ y_all = np.concatenate(all_y, axis=0)
105
+ print(f"[DATA] Loaded {len(npz_files)} file(s) for '{model_name}': "
106
+ f"{X_all.shape[0]} total samples, {X_all.shape[1]} features")
107
+ return X_all, y_all, all_names
108
+
109
+
110
+ def load_per_person(model_name: str = "face_orientation", data_dir: str = None):
111
+ """Load collected_*/*.npz grouped by person (folder name).
112
+
113
+ Returns dict { person_name: (X, y) } where X/y are per-person numpy arrays.
114
+ Also returns (X_all, y_all) as pooled data.
115
+ """
116
+ data_dir = data_dir or DATA_DIR
117
+ target_features = SELECTED_FEATURES.get(model_name, SELECTED_FEATURES["face_orientation"])
118
+ pattern = os.path.join(data_dir, "collected_*", "*.npz")
119
+ npz_files = sorted(glob.glob(pattern))
120
+
121
+ if not npz_files:
122
+ raise FileNotFoundError(f"No .npz files matching {pattern}")
123
+
124
+ by_person = {}
125
+ all_X, all_y = [], []
126
+ for npz_path in npz_files:
127
+ folder = os.path.basename(os.path.dirname(npz_path))
128
+ person = folder.replace("collected_", "", 1)
129
+ X, y, _ = _load_one_npz(npz_path, target_features)
130
+ all_X.append(X)
131
+ all_y.append(y)
132
+ if person not in by_person:
133
+ by_person[person] = []
134
+ by_person[person].append((X, y))
135
+ print(f"[DATA] + {person}/{os.path.basename(npz_path)}: {X.shape[0]} samples")
136
+
137
+ for person, chunks in by_person.items():
138
+ by_person[person] = (
139
+ np.concatenate([c[0] for c in chunks], axis=0),
140
+ np.concatenate([c[1] for c in chunks], axis=0),
141
+ )
142
+
143
+ X_all = np.concatenate(all_X, axis=0)
144
+ y_all = np.concatenate(all_y, axis=0)
145
+ print(f"[DATA] {len(by_person)} persons, {X_all.shape[0]} total samples, {X_all.shape[1]} features")
146
+ return by_person, X_all, y_all
147
+
148
+
149
+ def load_raw_npz(npz_path):
150
+ """Load a single .npz without cleaning or feature selection. For exploration notebooks."""
151
+ data = np.load(npz_path, allow_pickle=True)
152
+ features = data['features'].astype(np.float32)
153
+ labels = data['labels'].astype(np.int64)
154
+ names = list(data['feature_names'])
155
+ return features, labels, names
156
+
157
+
158
+ # ── Legacy helpers (used by models/mlp/train.py and models/xgboost/train.py) ─
159
+
160
+ def _load_real_data(model_name: str):
161
+ X, y, _ = load_all_pooled(model_name)
162
+ return X, y
163
+
164
+
165
+ def _generate_synthetic_data(model_name: str):
166
+ target_features = SELECTED_FEATURES.get(model_name, SELECTED_FEATURES["face_orientation"])
167
+ n = 500
168
+ d = len(target_features)
169
+ c = 2
170
+ rng = np.random.RandomState(42)
171
+ features = rng.randn(n, d).astype(np.float32)
172
+ labels = rng.randint(0, c, size=n).astype(np.int64)
173
+ print(f"[DATA] Using synthetic data for '{model_name}': {n} samples, {d} features, {c} classes")
174
+ return features, labels
175
+
176
+
177
+ def _split_and_scale(features, labels, split_ratios, seed, scale):
178
+ """Split data into train/val/test (stratified) and optionally scale."""
179
+ test_ratio = split_ratios[2]
180
+ val_ratio = split_ratios[1] / (split_ratios[0] + split_ratios[1])
181
+
182
+ X_train_val, X_test, y_train_val, y_test = train_test_split(
183
+ features, labels, test_size=test_ratio, random_state=seed, stratify=labels,
184
+ )
185
+ X_train, X_val, y_train, y_val = train_test_split(
186
+ X_train_val, y_train_val, test_size=val_ratio, random_state=seed, stratify=y_train_val,
187
+ )
188
+
189
+ scaler = None
190
+ if scale:
191
+ scaler = StandardScaler()
192
+ X_train = scaler.fit_transform(X_train)
193
+ X_val = scaler.transform(X_val)
194
+ X_test = scaler.transform(X_test)
195
+ print("[DATA] Applied StandardScaler (fitted on training split)")
196
+
197
+ splits = {
198
+ "X_train": X_train, "y_train": y_train,
199
+ "X_val": X_val, "y_val": y_val,
200
+ "X_test": X_test, "y_test": y_test,
201
+ }
202
+
203
+ print(f"[DATA] Split (stratified): train={len(y_train)}, val={len(y_val)}, test={len(y_test)}")
204
+ return splits, scaler
205
+
206
+
207
+ def get_numpy_splits(model_name: str, split_ratios=(0.7, 0.15, 0.15), seed: int = 42, scale: bool = True):
208
+ """Return raw numpy arrays for non-PyTorch models (e.g. XGBoost)."""
209
+ features, labels = _load_real_data(model_name)
210
+ num_features = features.shape[1]
211
+ num_classes = int(labels.max()) + 1
212
+ splits, scaler = _split_and_scale(features, labels, split_ratios, seed, scale)
213
+ return splits, num_features, num_classes, scaler
214
+
215
+
216
+ def get_dataloaders(model_name: str, batch_size: int = 32, split_ratios=(0.7, 0.15, 0.15), seed: int = 42, scale: bool = True):
217
+ """Return PyTorch DataLoaders for neural-network models."""
218
+ features, labels = _load_real_data(model_name)
219
+ num_features = features.shape[1]
220
+ num_classes = int(labels.max()) + 1
221
+ splits, scaler = _split_and_scale(features, labels, split_ratios, seed, scale)
222
+
223
+ train_ds = FeatureVectorDataset(splits["X_train"], splits["y_train"])
224
+ val_ds = FeatureVectorDataset(splits["X_val"], splits["y_val"])
225
+ test_ds = FeatureVectorDataset(splits["X_test"], splits["y_test"])
226
+
227
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
228
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
229
+ test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
230
+
231
+ return train_loader, val_loader, test_loader, num_features, num_classes, scaler
232
+
docker-compose.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ services:
2
+ focus-guard:
3
+ build: .
4
+ ports:
5
+ - "7860:7860"
eslint.config.js ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import js from '@eslint/js'
2
+ import globals from 'globals'
3
+ import reactHooks from 'eslint-plugin-react-hooks'
4
+ import reactRefresh from 'eslint-plugin-react-refresh'
5
+ import { defineConfig, globalIgnores } from 'eslint/config'
6
+
7
+ export default defineConfig([
8
+ globalIgnores(['dist']),
9
+ {
10
+ files: ['**/*.{js,jsx}'],
11
+ extends: [
12
+ js.configs.recommended,
13
+ reactHooks.configs.flat.recommended,
14
+ reactRefresh.configs.vite,
15
+ ],
16
+ languageOptions: {
17
+ ecmaVersion: 2020,
18
+ globals: globals.browser,
19
+ parserOptions: {
20
+ ecmaVersion: 'latest',
21
+ ecmaFeatures: { jsx: true },
22
+ sourceType: 'module',
23
+ },
24
+ },
25
+ rules: {
26
+ 'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
27
+ },
28
+ },
29
+ ])
evaluation/README.md CHANGED
@@ -1,3 +1,46 @@
1
- # evaluation
2
 
3
- Place metrics scripts, run configs, and results here. Logs dir is used by `models.mlp.train` for training logs.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # evaluation/
2
 
3
+ Training logs and performance metrics.
4
+
5
+ ## 1. Contents
6
+
7
+ ```
8
+ logs/
9
+ β”œβ”€β”€ face_orientation_training_log.json # MLP (latest run)
10
+ β”œβ”€β”€ mlp_face_orientation_training_log.json # MLP (alternate)
11
+ └── xgboost_face_orientation_training_log.json # XGBoost
12
+ ```
13
+
14
+ ## 2. Log Format
15
+
16
+ Each JSON file records the full training history:
17
+
18
+ **MLP logs:**
19
+ ```json
20
+ {
21
+ "config": { "epochs": 30, "lr": 0.001, "batch_size": 32, ... },
22
+ "history": {
23
+ "train_loss": [0.287, 0.260, ...],
24
+ "val_loss": [0.256, 0.245, ...],
25
+ "train_acc": [0.889, 0.901, ...],
26
+ "val_acc": [0.905, 0.909, ...]
27
+ },
28
+ "test": { "accuracy": 0.929, "f1": 0.929, "roc_auc": 0.971 }
29
+ }
30
+ ```
31
+
32
+ **XGBoost logs:**
33
+ ```json
34
+ {
35
+ "config": { "n_estimators": 600, "max_depth": 8, "learning_rate": 0.149, ... },
36
+ "train_losses": [0.577, ...],
37
+ "val_losses": [0.576, ...],
38
+ "test": { "accuracy": 0.959, "f1": 0.959, "roc_auc": 0.991 }
39
+ }
40
+ ```
41
+
42
+ ## 3. Generated By
43
+
44
+ - `python -m models.mlp.train` β†’ writes MLP log
45
+ - `python -m models.xgboost.train` β†’ writes XGBoost log
46
+ - Notebooks in `notebooks/` also save logs here
index.html ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8" />
6
+ <link rel="icon" type="image/svg+xml" href="/vite.svg" />
7
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
8
+ <title>Focus Guard</title>
9
+ <link href="https://fonts.googleapis.com/css2?family=Nunito:wght@400;700&display=swap" rel="stylesheet">
10
+ </head>
11
+
12
+ <body>
13
+ <div id="root"></div>
14
+ <script type="module" src="/src/main.jsx"></script>
15
+ </body>
16
+
17
+ </html>
main.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
2
+ from fastapi.staticfiles import StaticFiles
3
+ from fastapi.responses import FileResponse
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ from typing import Optional, List, Any
7
+ import base64
8
+ import cv2
9
+ import numpy as np
10
+ import aiosqlite
11
+ import json
12
+ from datetime import datetime, timedelta
13
+ import math
14
+ import os
15
+ from pathlib import Path
16
+ from typing import Callable
17
+ import asyncio
18
+ import concurrent.futures
19
+ import threading
20
+
21
+ from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
22
+ from av import VideoFrame
23
+
24
+ from mediapipe.tasks.python.vision import FaceLandmarksConnections
25
+ from ui.pipeline import FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline
26
+ from models.face_mesh import FaceMeshDetector
27
+
28
+ # ================ FACE MESH DRAWING (server-side, for WebRTC) ================
29
+
30
+ _FONT = cv2.FONT_HERSHEY_SIMPLEX
31
+ _CYAN = (255, 255, 0)
32
+ _GREEN = (0, 255, 0)
33
+ _MAGENTA = (255, 0, 255)
34
+ _ORANGE = (0, 165, 255)
35
+ _RED = (0, 0, 255)
36
+ _WHITE = (255, 255, 255)
37
+ _LIGHT_GREEN = (144, 238, 144)
38
+
39
+ _TESSELATION_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION]
40
+ _CONTOUR_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS]
41
+ _LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
42
+ _RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
43
+ _NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2]
44
+ _LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61]
45
+ _LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78]
46
+ _LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145]
47
+ _RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380]
48
+
49
+
50
+ def _lm_px(lm, idx, w, h):
51
+ return (int(lm[idx, 0] * w), int(lm[idx, 1] * h))
52
+
53
+
54
+ def _draw_polyline(frame, lm, indices, w, h, color, thickness):
55
+ for i in range(len(indices) - 1):
56
+ cv2.line(frame, _lm_px(lm, indices[i], w, h), _lm_px(lm, indices[i + 1], w, h), color, thickness, cv2.LINE_AA)
57
+
58
+
59
+ def _draw_face_mesh(frame, lm, w, h):
60
+ """Draw tessellation, contours, eyebrows, nose, lips, eyes, irises, gaze lines."""
61
+ # Tessellation (gray triangular grid, semi-transparent)
62
+ overlay = frame.copy()
63
+ for s, e in _TESSELATION_CONNS:
64
+ cv2.line(overlay, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), (200, 200, 200), 1, cv2.LINE_AA)
65
+ cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
66
+ # Contours
67
+ for s, e in _CONTOUR_CONNS:
68
+ cv2.line(frame, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), _CYAN, 1, cv2.LINE_AA)
69
+ # Eyebrows
70
+ _draw_polyline(frame, lm, _LEFT_EYEBROW, w, h, _LIGHT_GREEN, 2)
71
+ _draw_polyline(frame, lm, _RIGHT_EYEBROW, w, h, _LIGHT_GREEN, 2)
72
+ # Nose
73
+ _draw_polyline(frame, lm, _NOSE_BRIDGE, w, h, _ORANGE, 1)
74
+ # Lips
75
+ _draw_polyline(frame, lm, _LIPS_OUTER, w, h, _MAGENTA, 1)
76
+ _draw_polyline(frame, lm, _LIPS_INNER, w, h, (200, 0, 200), 1)
77
+ # Eyes
78
+ left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32)
79
+ cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA)
80
+ right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32)
81
+ cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA)
82
+ # EAR key points
83
+ for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
84
+ for idx in indices:
85
+ cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA)
86
+ # Irises + gaze lines
87
+ for iris_idx, eye_inner, eye_outer in [
88
+ (FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
89
+ (FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
90
+ ]:
91
+ iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32)
92
+ center = iris_pts[0]
93
+ if len(iris_pts) >= 5:
94
+ radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
95
+ radius = max(int(np.mean(radii)), 2)
96
+ cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA)
97
+ cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA)
98
+ eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w)
99
+ eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h)
100
+ dx, dy = center[0] - eye_cx, center[1] - eye_cy
101
+ cv2.line(frame, tuple(center), (int(center[0] + dx * 3), int(center[1] + dy * 3)), _RED, 1, cv2.LINE_AA)
102
+
103
+
104
+ def _draw_hud(frame, result, model_name):
105
+ """Draw status bar and detail overlay matching live_demo.py."""
106
+ h, w = frame.shape[:2]
107
+ is_focused = result["is_focused"]
108
+ status = "FOCUSED" if is_focused else "NOT FOCUSED"
109
+ color = _GREEN if is_focused else _RED
110
+
111
+ # Top bar
112
+ cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1)
113
+ cv2.putText(frame, status, (10, 28), _FONT, 0.8, color, 2, cv2.LINE_AA)
114
+ cv2.putText(frame, model_name.upper(), (w - 150, 28), _FONT, 0.45, _WHITE, 1, cv2.LINE_AA)
115
+
116
+ # Detail line
117
+ conf = result.get("mlp_prob", result.get("raw_score", 0.0))
118
+ mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
119
+ sf = result.get("s_face", 0)
120
+ se = result.get("s_eye", 0)
121
+ detail = f"conf:{conf:.2f} S_face:{sf:.2f} S_eye:{se:.2f}{mar_s}"
122
+ cv2.putText(frame, detail, (10, 48), _FONT, 0.4, _WHITE, 1, cv2.LINE_AA)
123
+
124
+ # Head pose (top right)
125
+ if result.get("yaw") is not None:
126
+ cv2.putText(frame, f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}",
127
+ (w - 280, 48), _FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA)
128
+
129
+ # Yawn indicator
130
+ if result.get("is_yawning"):
131
+ cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA)
132
+
133
+ # Landmark indices used for face mesh drawing on client (union of all groups).
134
+ # Sending only these instead of all 478 saves ~60% of the landmarks payload.
135
+ _MESH_INDICES = sorted(set(
136
+ [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] # face oval
137
+ + [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246] # left eye
138
+ + [362,382,381,380,374,373,390,249,263,466,388,387,386,385,384,398] # right eye
139
+ + [468,469,470,471,472, 473,474,475,476,477] # irises
140
+ + [70,63,105,66,107,55,65,52,53,46] # left eyebrow
141
+ + [300,293,334,296,336,285,295,282,283,276] # right eyebrow
142
+ + [6,197,195,5,4,1,19,94,2] # nose bridge
143
+ + [61,146,91,181,84,17,314,405,321,375,291,409,270,269,267,0,37,39,40,185] # lips outer
144
+ + [78,95,88,178,87,14,317,402,318,324,308,415,310,311,312,13,82,81,80,191] # lips inner
145
+ + [33,160,158,133,153,145] # left EAR key points
146
+ + [362,385,387,263,373,380] # right EAR key points
147
+ ))
148
+ # Build a lookup: original_index -> position in sparse array, so client can reconstruct.
149
+ _MESH_INDEX_SET = set(_MESH_INDICES)
150
+
151
+ # Initialize FastAPI app
152
+ app = FastAPI(title="Focus Guard API")
153
+
154
+ # Add CORS middleware
155
+ app.add_middleware(
156
+ CORSMiddleware,
157
+ allow_origins=["*"],
158
+ allow_credentials=True,
159
+ allow_methods=["*"],
160
+ allow_headers=["*"],
161
+ )
162
+
163
+ # Global variables
164
+ db_path = "focus_guard.db"
165
+ pcs = set()
166
+ _cached_model_name = "mlp" # in-memory cache, updated via /api/settings
167
+
168
+ async def _wait_for_ice_gathering(pc: RTCPeerConnection):
169
+ if pc.iceGatheringState == "complete":
170
+ return
171
+ done = asyncio.Event()
172
+
173
+ @pc.on("icegatheringstatechange")
174
+ def _on_state_change():
175
+ if pc.iceGatheringState == "complete":
176
+ done.set()
177
+
178
+ await done.wait()
179
+
180
+ # ================ DATABASE MODELS ================
181
+
182
+ async def init_database():
183
+ """Initialize SQLite database with required tables"""
184
+ async with aiosqlite.connect(db_path) as db:
185
+ # FocusSessions table
186
+ await db.execute("""
187
+ CREATE TABLE IF NOT EXISTS focus_sessions (
188
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
189
+ start_time TIMESTAMP NOT NULL,
190
+ end_time TIMESTAMP,
191
+ duration_seconds INTEGER DEFAULT 0,
192
+ focus_score REAL DEFAULT 0.0,
193
+ total_frames INTEGER DEFAULT 0,
194
+ focused_frames INTEGER DEFAULT 0,
195
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
196
+ )
197
+ """)
198
+
199
+ # FocusEvents table
200
+ await db.execute("""
201
+ CREATE TABLE IF NOT EXISTS focus_events (
202
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
203
+ session_id INTEGER NOT NULL,
204
+ timestamp TIMESTAMP NOT NULL,
205
+ is_focused BOOLEAN NOT NULL,
206
+ confidence REAL NOT NULL,
207
+ detection_data TEXT,
208
+ FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
209
+ )
210
+ """)
211
+
212
+ # UserSettings table
213
+ await db.execute("""
214
+ CREATE TABLE IF NOT EXISTS user_settings (
215
+ id INTEGER PRIMARY KEY CHECK (id = 1),
216
+ sensitivity INTEGER DEFAULT 6,
217
+ notification_enabled BOOLEAN DEFAULT 1,
218
+ notification_threshold INTEGER DEFAULT 30,
219
+ frame_rate INTEGER DEFAULT 30,
220
+ model_name TEXT DEFAULT 'mlp'
221
+ )
222
+ """)
223
+
224
+ # Insert default settings if not exists
225
+ await db.execute("""
226
+ INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name)
227
+ VALUES (1, 6, 1, 30, 30, 'mlp')
228
+ """)
229
+
230
+ await db.commit()
231
+
232
+ # ================ PYDANTIC MODELS ================
233
+
234
+ class SessionCreate(BaseModel):
235
+ pass
236
+
237
+ class SessionEnd(BaseModel):
238
+ session_id: int
239
+
240
+ class SettingsUpdate(BaseModel):
241
+ sensitivity: Optional[int] = None
242
+ notification_enabled: Optional[bool] = None
243
+ notification_threshold: Optional[int] = None
244
+ frame_rate: Optional[int] = None
245
+ model_name: Optional[str] = None
246
+
247
+ class VideoTransformTrack(VideoStreamTrack):
248
+ def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
249
+ super().__init__()
250
+ self.track = track
251
+ self.session_id = session_id
252
+ self.get_channel = get_channel
253
+ self.last_inference_time = 0
254
+ self.min_inference_interval = 1 / 60
255
+ self.last_frame = None
256
+
257
+ async def recv(self):
258
+ frame = await self.track.recv()
259
+ img = frame.to_ndarray(format="bgr24")
260
+ if img is None:
261
+ return frame
262
+
263
+ # Normalize size for inference/drawing
264
+ img = cv2.resize(img, (640, 480))
265
+
266
+ now = datetime.now().timestamp()
267
+ do_infer = (now - self.last_inference_time) >= self.min_inference_interval
268
+
269
+ if do_infer:
270
+ self.last_inference_time = now
271
+
272
+ model_name = _cached_model_name
273
+ if model_name not in pipelines or pipelines.get(model_name) is None:
274
+ model_name = 'mlp'
275
+ active_pipeline = pipelines.get(model_name)
276
+
277
+ if active_pipeline is not None:
278
+ loop = asyncio.get_event_loop()
279
+ out = await loop.run_in_executor(
280
+ _inference_executor,
281
+ _process_frame_safe,
282
+ active_pipeline,
283
+ img,
284
+ model_name,
285
+ )
286
+ is_focused = out["is_focused"]
287
+ confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
288
+ metadata = {"s_face": out.get("s_face", 0.0), "s_eye": out.get("s_eye", 0.0), "mar": out.get("mar", 0.0), "model": model_name}
289
+
290
+ # Draw face mesh + HUD on the video frame
291
+ h_f, w_f = img.shape[:2]
292
+ lm = out.get("landmarks")
293
+ if lm is not None:
294
+ _draw_face_mesh(img, lm, w_f, h_f)
295
+ _draw_hud(img, out, model_name)
296
+ else:
297
+ is_focused = False
298
+ confidence = 0.0
299
+ metadata = {"model": model_name}
300
+ cv2.rectangle(img, (0, 0), (img.shape[1], 55), (0, 0, 0), -1)
301
+ cv2.putText(img, "NO MODEL", (10, 28), _FONT, 0.8, _RED, 2, cv2.LINE_AA)
302
+
303
+ if self.session_id:
304
+ await store_focus_event(self.session_id, is_focused, confidence, metadata)
305
+
306
+ channel = self.get_channel()
307
+ if channel and channel.readyState == "open":
308
+ try:
309
+ channel.send(json.dumps({"type": "detection", "focused": is_focused, "confidence": round(confidence, 3), "detections": detections}))
310
+ except Exception:
311
+ pass
312
+
313
+ self.last_frame = img
314
+ elif self.last_frame is not None:
315
+ img = self.last_frame
316
+
317
+ new_frame = VideoFrame.from_ndarray(img, format="bgr24")
318
+ new_frame.pts = frame.pts
319
+ new_frame.time_base = frame.time_base
320
+ return new_frame
321
+
322
+ # ================ DATABASE OPERATIONS ================
323
+
324
+ async def create_session():
325
+ async with aiosqlite.connect(db_path) as db:
326
+ cursor = await db.execute(
327
+ "INSERT INTO focus_sessions (start_time) VALUES (?)",
328
+ (datetime.now().isoformat(),)
329
+ )
330
+ await db.commit()
331
+ return cursor.lastrowid
332
+
333
+ async def end_session(session_id: int):
334
+ async with aiosqlite.connect(db_path) as db:
335
+ cursor = await db.execute(
336
+ "SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
337
+ (session_id,)
338
+ )
339
+ row = await cursor.fetchone()
340
+
341
+ if not row:
342
+ return None
343
+
344
+ start_time_str, total_frames, focused_frames = row
345
+ start_time = datetime.fromisoformat(start_time_str)
346
+ end_time = datetime.now()
347
+ duration = (end_time - start_time).total_seconds()
348
+ focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
349
+
350
+ await db.execute("""
351
+ UPDATE focus_sessions
352
+ SET end_time = ?, duration_seconds = ?, focus_score = ?
353
+ WHERE id = ?
354
+ """, (end_time.isoformat(), int(duration), focus_score, session_id))
355
+
356
+ await db.commit()
357
+
358
+ return {
359
+ 'session_id': session_id,
360
+ 'start_time': start_time_str,
361
+ 'end_time': end_time.isoformat(),
362
+ 'duration_seconds': int(duration),
363
+ 'focus_score': round(focus_score, 3),
364
+ 'total_frames': total_frames,
365
+ 'focused_frames': focused_frames
366
+ }
367
+
368
+ async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict):
369
+ async with aiosqlite.connect(db_path) as db:
370
+ await db.execute("""
371
+ INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
372
+ VALUES (?, ?, ?, ?, ?)
373
+ """, (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
374
+
375
+ await db.execute("""
376
+ UPDATE focus_sessions
377
+ SET total_frames = total_frames + 1,
378
+ focused_frames = focused_frames + ?
379
+ WHERE id = ?
380
+ """, (1 if is_focused else 0, session_id))
381
+ await db.commit()
382
+
383
+
384
+ class _EventBuffer:
385
+ """Buffer focus events in memory and flush to DB in batches to avoid per-frame DB writes."""
386
+
387
+ def __init__(self, flush_interval: float = 2.0):
388
+ self._buf: list = []
389
+ self._lock = asyncio.Lock()
390
+ self._flush_interval = flush_interval
391
+ self._task: asyncio.Task | None = None
392
+ self._total_frames = 0
393
+ self._focused_frames = 0
394
+
395
+ def start(self):
396
+ if self._task is None:
397
+ self._task = asyncio.create_task(self._flush_loop())
398
+
399
+ async def stop(self):
400
+ if self._task:
401
+ self._task.cancel()
402
+ try:
403
+ await self._task
404
+ except asyncio.CancelledError:
405
+ pass
406
+ self._task = None
407
+ await self._flush()
408
+
409
+ def add(self, session_id: int, is_focused: bool, confidence: float, metadata: dict):
410
+ self._buf.append((session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
411
+ self._total_frames += 1
412
+ if is_focused:
413
+ self._focused_frames += 1
414
+
415
+ async def _flush_loop(self):
416
+ while True:
417
+ await asyncio.sleep(self._flush_interval)
418
+ await self._flush()
419
+
420
+ async def _flush(self):
421
+ async with self._lock:
422
+ if not self._buf:
423
+ return
424
+ batch = self._buf[:]
425
+ total = self._total_frames
426
+ focused = self._focused_frames
427
+ self._buf.clear()
428
+ self._total_frames = 0
429
+ self._focused_frames = 0
430
+
431
+ if not batch:
432
+ return
433
+
434
+ session_id = batch[0][0]
435
+ try:
436
+ async with aiosqlite.connect(db_path) as db:
437
+ await db.executemany("""
438
+ INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
439
+ VALUES (?, ?, ?, ?, ?)
440
+ """, batch)
441
+ await db.execute("""
442
+ UPDATE focus_sessions
443
+ SET total_frames = total_frames + ?,
444
+ focused_frames = focused_frames + ?
445
+ WHERE id = ?
446
+ """, (total, focused, session_id))
447
+ await db.commit()
448
+ except Exception as e:
449
+ print(f"[DB] Flush error: {e}")
450
+
451
+ # ================ STARTUP/SHUTDOWN ================
452
+
453
+ pipelines = {
454
+ "geometric": None,
455
+ "mlp": None,
456
+ "hybrid": None,
457
+ "xgboost": None,
458
+ }
459
+
460
+ # Thread pool for CPU-bound inference so the event loop stays responsive.
461
+ _inference_executor = concurrent.futures.ThreadPoolExecutor(
462
+ max_workers=4,
463
+ thread_name_prefix="inference",
464
+ )
465
+ # One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
466
+ # multiple frames are processed in parallel by the thread pool.
467
+ _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost")}
468
+
469
+
470
+ def _process_frame_safe(pipeline, frame, model_name: str):
471
+ """Run process_frame in executor with per-pipeline lock."""
472
+ with _pipeline_locks[model_name]:
473
+ return pipeline.process_frame(frame)
474
+
475
+ @app.on_event("startup")
476
+ async def startup_event():
477
+ global pipelines, _cached_model_name
478
+ print(" Starting Focus Guard API...")
479
+ await init_database()
480
+ # Load cached model name from DB
481
+ async with aiosqlite.connect(db_path) as db:
482
+ cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
483
+ row = await cursor.fetchone()
484
+ if row:
485
+ _cached_model_name = row[0]
486
+ print("[OK] Database initialized")
487
+
488
+ try:
489
+ pipelines["geometric"] = FaceMeshPipeline()
490
+ print("[OK] FaceMeshPipeline (geometric) loaded")
491
+ except Exception as e:
492
+ print(f"[WARN] FaceMeshPipeline unavailable: {e}")
493
+
494
+ try:
495
+ pipelines["mlp"] = MLPPipeline()
496
+ print("[OK] MLPPipeline loaded")
497
+ except Exception as e:
498
+ print(f"[ERR] Failed to load MLPPipeline: {e}")
499
+
500
+ try:
501
+ pipelines["hybrid"] = HybridFocusPipeline()
502
+ print("[OK] HybridFocusPipeline loaded")
503
+ except Exception as e:
504
+ print(f"[WARN] HybridFocusPipeline unavailable: {e}")
505
+
506
+ try:
507
+ pipelines["xgboost"] = XGBoostPipeline()
508
+ print("[OK] XGBoostPipeline loaded")
509
+ except Exception as e:
510
+ print(f"[ERR] Failed to load XGBoostPipeline: {e}")
511
+
512
+ @app.on_event("shutdown")
513
+ async def shutdown_event():
514
+ _inference_executor.shutdown(wait=False)
515
+ print(" Shutting down Focus Guard API...")
516
+
517
+ # ================ WEBRTC SIGNALING ================
518
+
519
+ @app.post("/api/webrtc/offer")
520
+ async def webrtc_offer(offer: dict):
521
+ try:
522
+ print(f"Received WebRTC offer")
523
+
524
+ pc = RTCPeerConnection()
525
+ pcs.add(pc)
526
+
527
+ session_id = await create_session()
528
+ print(f"Created session: {session_id}")
529
+
530
+ channel_ref = {"channel": None}
531
+
532
+ @pc.on("datachannel")
533
+ def on_datachannel(channel):
534
+ print(f"Data channel opened")
535
+ channel_ref["channel"] = channel
536
+
537
+ @pc.on("track")
538
+ def on_track(track):
539
+ print(f"Received track: {track.kind}")
540
+ if track.kind == "video":
541
+ local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"])
542
+ pc.addTrack(local_track)
543
+ print(f"Video track added")
544
+
545
+ @track.on("ended")
546
+ async def on_ended():
547
+ print(f"Track ended")
548
+
549
+ @pc.on("connectionstatechange")
550
+ async def on_connectionstatechange():
551
+ print(f"Connection state changed: {pc.connectionState}")
552
+ if pc.connectionState in ("failed", "closed", "disconnected"):
553
+ try:
554
+ await end_session(session_id)
555
+ except Exception as e:
556
+ print(f"⚠Error ending session: {e}")
557
+ pcs.discard(pc)
558
+ await pc.close()
559
+
560
+ await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]))
561
+ print(f"Remote description set")
562
+
563
+ answer = await pc.createAnswer()
564
+ await pc.setLocalDescription(answer)
565
+ print(f"Answer created")
566
+
567
+ await _wait_for_ice_gathering(pc)
568
+ print(f"ICE gathering complete")
569
+
570
+ return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id}
571
+
572
+ except Exception as e:
573
+ print(f"WebRTC offer error: {e}")
574
+ import traceback
575
+ traceback.print_exc()
576
+ raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}")
577
+
578
+ # ================ WEBSOCKET ================
579
+
580
+ @app.websocket("/ws/video")
581
+ async def websocket_endpoint(websocket: WebSocket):
582
+ await websocket.accept()
583
+ session_id = None
584
+ frame_count = 0
585
+ running = True
586
+ event_buffer = _EventBuffer(flush_interval=2.0)
587
+
588
+ # Latest frame slot β€” only the most recent frame is kept, older ones are dropped.
589
+ # Using a dict so nested functions can mutate without nonlocal issues.
590
+ _slot = {"frame": None}
591
+ _frame_ready = asyncio.Event()
592
+
593
+ async def _receive_loop():
594
+ """Receive messages as fast as possible. Binary = frame, text = control."""
595
+ nonlocal session_id, running
596
+ try:
597
+ while running:
598
+ msg = await websocket.receive()
599
+ msg_type = msg.get("type", "")
600
+
601
+ if msg_type == "websocket.disconnect":
602
+ running = False
603
+ _frame_ready.set()
604
+ return
605
+
606
+ # Binary message β†’ JPEG frame (fast path, no base64)
607
+ raw_bytes = msg.get("bytes")
608
+ if raw_bytes is not None and len(raw_bytes) > 0:
609
+ _slot["frame"] = raw_bytes
610
+ _frame_ready.set()
611
+ continue
612
+
613
+ # Text message β†’ JSON control command (or legacy base64 frame)
614
+ text = msg.get("text")
615
+ if not text:
616
+ continue
617
+ data = json.loads(text)
618
+
619
+ if data["type"] == "frame":
620
+ # Legacy base64 path (fallback)
621
+ _slot["frame"] = base64.b64decode(data["image"])
622
+ _frame_ready.set()
623
+
624
+ elif data["type"] == "start_session":
625
+ session_id = await create_session()
626
+ event_buffer.start()
627
+ for p in pipelines.values():
628
+ if p is not None and hasattr(p, "reset_session"):
629
+ p.reset_session()
630
+ await websocket.send_json({"type": "session_started", "session_id": session_id})
631
+
632
+ elif data["type"] == "end_session":
633
+ if session_id:
634
+ await event_buffer.stop()
635
+ summary = await end_session(session_id)
636
+ if summary:
637
+ await websocket.send_json({"type": "session_ended", "summary": summary})
638
+ session_id = None
639
+ except WebSocketDisconnect:
640
+ running = False
641
+ _frame_ready.set()
642
+ except Exception as e:
643
+ print(f"[WS] receive error: {e}")
644
+ running = False
645
+ _frame_ready.set()
646
+
647
+ async def _process_loop():
648
+ """Process only the latest frame, dropping stale ones."""
649
+ nonlocal frame_count, running
650
+ loop = asyncio.get_event_loop()
651
+ while running:
652
+ await _frame_ready.wait()
653
+ _frame_ready.clear()
654
+ if not running:
655
+ return
656
+
657
+ # Grab latest frame and clear slot
658
+ raw = _slot["frame"]
659
+ _slot["frame"] = None
660
+ if raw is None:
661
+ continue
662
+
663
+ try:
664
+ nparr = np.frombuffer(raw, np.uint8)
665
+ frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
666
+ if frame is None:
667
+ continue
668
+ frame = cv2.resize(frame, (640, 480))
669
+
670
+ model_name = _cached_model_name
671
+ if model_name not in pipelines or pipelines.get(model_name) is None:
672
+ model_name = "mlp"
673
+ active_pipeline = pipelines.get(model_name)
674
+
675
+ landmarks_list = None
676
+ if active_pipeline is not None:
677
+ out = await loop.run_in_executor(
678
+ _inference_executor,
679
+ _process_frame_safe,
680
+ active_pipeline,
681
+ frame,
682
+ model_name,
683
+ )
684
+ is_focused = out["is_focused"]
685
+ confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
686
+
687
+ lm = out.get("landmarks")
688
+ if lm is not None:
689
+ # Send all 478 landmarks as flat array for tessellation drawing
690
+ landmarks_list = [
691
+ [round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
692
+ for i in range(lm.shape[0])
693
+ ]
694
+
695
+ if session_id:
696
+ event_buffer.add(session_id, is_focused, confidence, {
697
+ "s_face": out.get("s_face", 0.0),
698
+ "s_eye": out.get("s_eye", 0.0),
699
+ "mar": out.get("mar", 0.0),
700
+ "model": model_name,
701
+ })
702
+ else:
703
+ is_focused = False
704
+ confidence = 0.0
705
+
706
+ resp = {
707
+ "type": "detection",
708
+ "focused": is_focused,
709
+ "confidence": round(confidence, 3),
710
+ "model": model_name,
711
+ "fc": frame_count,
712
+ }
713
+ if active_pipeline is not None:
714
+ # Send detailed metrics for HUD
715
+ if out.get("yaw") is not None:
716
+ resp["yaw"] = round(out["yaw"], 1)
717
+ resp["pitch"] = round(out["pitch"], 1)
718
+ resp["roll"] = round(out["roll"], 1)
719
+ if out.get("mar") is not None:
720
+ resp["mar"] = round(out["mar"], 3)
721
+ resp["sf"] = round(out.get("s_face", 0), 3)
722
+ resp["se"] = round(out.get("s_eye", 0), 3)
723
+ if landmarks_list is not None:
724
+ resp["lm"] = landmarks_list
725
+ await websocket.send_json(resp)
726
+ frame_count += 1
727
+ except Exception as e:
728
+ print(f"[WS] process error: {e}")
729
+
730
+ try:
731
+ await asyncio.gather(_receive_loop(), _process_loop())
732
+ except Exception:
733
+ pass
734
+ finally:
735
+ running = False
736
+ if session_id:
737
+ await event_buffer.stop()
738
+ await end_session(session_id)
739
+
740
+ # ================ API ENDPOINTS ================
741
+
742
+ @app.post("/api/sessions/start")
743
+ async def api_start_session():
744
+ session_id = await create_session()
745
+ return {"session_id": session_id}
746
+
747
+ @app.post("/api/sessions/end")
748
+ async def api_end_session(data: SessionEnd):
749
+ summary = await end_session(data.session_id)
750
+ if not summary: raise HTTPException(status_code=404, detail="Session not found")
751
+ return summary
752
+
753
+ @app.get("/api/sessions")
754
+ async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0):
755
+ async with aiosqlite.connect(db_path) as db:
756
+ db.row_factory = aiosqlite.Row
757
+
758
+ # NEW: If importing/exporting all, remove limit if special flag or high limit
759
+ # For simplicity: if limit is -1, return all
760
+ limit_clause = "LIMIT ? OFFSET ?"
761
+ params = []
762
+
763
+ base_query = "SELECT * FROM focus_sessions"
764
+ where_clause = ""
765
+
766
+ if filter == "today":
767
+ date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
768
+ where_clause = " WHERE start_time >= ?"
769
+ params.append(date_filter.isoformat())
770
+ elif filter == "week":
771
+ date_filter = datetime.now() - timedelta(days=7)
772
+ where_clause = " WHERE start_time >= ?"
773
+ params.append(date_filter.isoformat())
774
+ elif filter == "month":
775
+ date_filter = datetime.now() - timedelta(days=30)
776
+ where_clause = " WHERE start_time >= ?"
777
+ params.append(date_filter.isoformat())
778
+ elif filter == "all":
779
+ # Just ensure we only get completed sessions or all sessions
780
+ where_clause = " WHERE end_time IS NOT NULL"
781
+
782
+ query = f"{base_query}{where_clause} ORDER BY start_time DESC"
783
+
784
+ # Handle Limit for Exports
785
+ if limit == -1:
786
+ # No limit clause for export
787
+ pass
788
+ else:
789
+ query += f" {limit_clause}"
790
+ params.extend([limit, offset])
791
+
792
+ cursor = await db.execute(query, tuple(params))
793
+ rows = await cursor.fetchall()
794
+ return [dict(row) for row in rows]
795
+
796
+ # --- NEW: Import Endpoint ---
797
+ @app.post("/api/import")
798
+ async def import_sessions(sessions: List[dict]):
799
+ count = 0
800
+ try:
801
+ async with aiosqlite.connect(db_path) as db:
802
+ for session in sessions:
803
+ # Use .get() to handle potential missing fields from older versions or edits
804
+ await db.execute("""
805
+ INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at)
806
+ VALUES (?, ?, ?, ?, ?, ?, ?)
807
+ """, (
808
+ session.get('start_time'),
809
+ session.get('end_time'),
810
+ session.get('duration_seconds', 0),
811
+ session.get('focus_score', 0.0),
812
+ session.get('total_frames', 0),
813
+ session.get('focused_frames', 0),
814
+ session.get('created_at', session.get('start_time'))
815
+ ))
816
+ count += 1
817
+ await db.commit()
818
+ return {"status": "success", "count": count}
819
+ except Exception as e:
820
+ print(f"Import Error: {e}")
821
+ return {"status": "error", "message": str(e)}
822
+
823
+ # --- NEW: Clear History Endpoint ---
824
+ @app.delete("/api/history")
825
+ async def clear_history():
826
+ try:
827
+ async with aiosqlite.connect(db_path) as db:
828
+ # Delete events first (foreign key good practice)
829
+ await db.execute("DELETE FROM focus_events")
830
+ await db.execute("DELETE FROM focus_sessions")
831
+ await db.commit()
832
+ return {"status": "success", "message": "History cleared"}
833
+ except Exception as e:
834
+ return {"status": "error", "message": str(e)}
835
+
836
+ @app.get("/api/sessions/{session_id}")
837
+ async def get_session(session_id: int):
838
+ async with aiosqlite.connect(db_path) as db:
839
+ db.row_factory = aiosqlite.Row
840
+ cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,))
841
+ row = await cursor.fetchone()
842
+ if not row: raise HTTPException(status_code=404, detail="Session not found")
843
+ session = dict(row)
844
+ cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,))
845
+ events = [dict(r) for r in await cursor.fetchall()]
846
+ session['events'] = events
847
+ return session
848
+
849
+ @app.get("/api/settings")
850
+ async def get_settings():
851
+ async with aiosqlite.connect(db_path) as db:
852
+ db.row_factory = aiosqlite.Row
853
+ cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
854
+ row = await cursor.fetchone()
855
+ if row: return dict(row)
856
+ else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
857
+
858
+ @app.put("/api/settings")
859
+ async def update_settings(settings: SettingsUpdate):
860
+ async with aiosqlite.connect(db_path) as db:
861
+ cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
862
+ exists = await cursor.fetchone()
863
+ if not exists:
864
+ await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)")
865
+ await db.commit()
866
+
867
+ updates = []
868
+ params = []
869
+ if settings.sensitivity is not None:
870
+ updates.append("sensitivity = ?")
871
+ params.append(max(1, min(10, settings.sensitivity)))
872
+ if settings.notification_enabled is not None:
873
+ updates.append("notification_enabled = ?")
874
+ params.append(settings.notification_enabled)
875
+ if settings.notification_threshold is not None:
876
+ updates.append("notification_threshold = ?")
877
+ params.append(max(5, min(300, settings.notification_threshold)))
878
+ if settings.frame_rate is not None:
879
+ updates.append("frame_rate = ?")
880
+ params.append(max(5, min(60, settings.frame_rate)))
881
+ if settings.model_name is not None and settings.model_name in pipelines and pipelines[settings.model_name] is not None:
882
+ updates.append("model_name = ?")
883
+ params.append(settings.model_name)
884
+ global _cached_model_name
885
+ _cached_model_name = settings.model_name
886
+
887
+ if updates:
888
+ query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
889
+ await db.execute(query, params)
890
+ await db.commit()
891
+ return {"status": "success", "updated": len(updates) > 0}
892
+
893
+ @app.get("/api/stats/summary")
894
+ async def get_stats_summary():
895
+ async with aiosqlite.connect(db_path) as db:
896
+ cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL")
897
+ total_sessions = (await cursor.fetchone())[0]
898
+ cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL")
899
+ total_focus_time = (await cursor.fetchone())[0] or 0
900
+ cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL")
901
+ avg_focus_score = (await cursor.fetchone())[0] or 0.0
902
+ cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC")
903
+ dates = [row[0] for row in await cursor.fetchall()]
904
+
905
+ streak_days = 0
906
+ if dates:
907
+ current_date = datetime.now().date()
908
+ for i, date_str in enumerate(dates):
909
+ session_date = datetime.fromisoformat(date_str).date()
910
+ expected_date = current_date - timedelta(days=i)
911
+ if session_date == expected_date: streak_days += 1
912
+ else: break
913
+ return {
914
+ 'total_sessions': total_sessions,
915
+ 'total_focus_time': int(total_focus_time),
916
+ 'avg_focus_score': round(avg_focus_score, 3),
917
+ 'streak_days': streak_days
918
+ }
919
+
920
+ @app.get("/api/models")
921
+ async def get_available_models():
922
+ """Return list of loaded model names and which is currently active."""
923
+ available = [name for name, p in pipelines.items() if p is not None]
924
+ async with aiosqlite.connect(db_path) as db:
925
+ cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
926
+ row = await cursor.fetchone()
927
+ current = row[0] if row else "mlp"
928
+ if current not in available and available:
929
+ current = available[0]
930
+ return {"available": available, "current": current}
931
+
932
+ @app.get("/api/mesh-topology")
933
+ async def get_mesh_topology():
934
+ """Return tessellation edge pairs for client-side face mesh drawing (cached by client)."""
935
+ return {"tessellation": _TESSELATION_CONNS}
936
+
937
+ @app.get("/health")
938
+ async def health_check():
939
+ available = [name for name, p in pipelines.items() if p is not None]
940
+ return {"status": "healthy", "models_loaded": available, "database": os.path.exists(db_path)}
941
+
942
+ # ================ STATIC FILES (SPA SUPPORT) ================
943
+
944
+ # Resolve static dir from this file so it works regardless of cwd
945
+ _STATIC_DIR = Path(__file__).resolve().parent / "static"
946
+ _ASSETS_DIR = _STATIC_DIR / "assets"
947
+
948
+ # 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all
949
+ if _ASSETS_DIR.is_dir():
950
+ app.mount("/assets", StaticFiles(directory=str(_ASSETS_DIR)), name="assets")
951
+
952
+ # 2. Catch-all for SPA: serve index.html for app routes, never for /assets (would break JS MIME type)
953
+ @app.get("/{full_path:path}")
954
+ async def serve_react_app(full_path: str, request: Request):
955
+ if full_path.startswith("api") or full_path.startswith("ws"):
956
+ raise HTTPException(status_code=404, detail="Not Found")
957
+ # Don't serve HTML for asset paths; let them 404 so we don't break module script loading
958
+ if full_path.startswith("assets") or full_path.startswith("assets/"):
959
+ raise HTTPException(status_code=404, detail="Not Found")
960
+
961
+ index_path = _STATIC_DIR / "index.html"
962
+ if index_path.is_file():
963
+ return FileResponse(str(index_path))
964
+ return {"message": "React app not found. Please run 'npm run build' and copy dist to static."}
models/README.md CHANGED
@@ -1,10 +1,53 @@
1
- # models
2
 
3
- - **cnn/eye_attention/** β€” YOLO open/closed eye classifier, crop helper, train stub
4
- - **mlp/** β€” PyTorch MLP on feature vectors (face_orientation / eye_behaviour); checkpoints under `mlp/face_orientation_model/`, `mlp/eye_behaviour_model/`
5
- - **geometric/face_orientation/** β€” head pose (solvePnP). **geometric/eye_behaviour/** β€” EAR, gaze, MAR
6
- - **pretrained/face_mesh/** β€” MediaPipe face landmarks (no training)
7
- - **attention/** β€” webcam feature collection (17-d), stubs for train/classifier/fusion
8
- - **prepare_dataset.py** β€” loads from `data_preparation/processed/` or synthetic; used by `mlp/train.py`
9
 
10
- Run legacy MLP training: `python -m models.mlp.train`. The sklearn MLP used in the live demo is trained in `data_preparation/MLP/train_mlp.ipynb` and saved under `../MLP/models/`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/
2
 
3
+ Feature extraction modules and model training scripts.
 
 
 
 
 
4
 
5
+ ## 1. Feature Extraction
6
+
7
+ Root-level modules form the real-time inference pipeline:
8
+
9
+ | Module | Input | Output |
10
+ |--------|-------|--------|
11
+ | `face_mesh.py` | BGR frame | 478 MediaPipe landmarks |
12
+ | `head_pose.py` | Landmarks, frame size | yaw, pitch, roll, face/eye score, gaze offset, head deviation |
13
+ | `eye_scorer.py` | Landmarks | EAR (left/right/avg), gaze ratio (h/v), MAR |
14
+ | `eye_crop.py` | Landmarks, frame | Cropped eye region images |
15
+ | `eye_classifier.py` | Eye crops or landmarks | Eye open/closed prediction (geometric fallback) |
16
+ | `collect_features.py` | BGR frame | 17-d feature vector + temporal features (PERCLOS, blink rate, etc.) |
17
+
18
+ ## 2. Training Scripts
19
+
20
+ | Folder | Model | Command |
21
+ |--------|-------|---------|
22
+ | `mlp/` | PyTorch MLP (64β†’32, 2-class) | `python -m models.mlp.train` |
23
+ | `xgboost/` | XGBoost (600 trees, depth 8) | `python -m models.xgboost.train` |
24
+
25
+ ### mlp/
26
+
27
+ - `train.py` β€” training loop with early stopping, ClearML opt-in
28
+ - `sweep.py` β€” hyperparameter search (Optuna: lr, batch_size)
29
+ - `eval_accuracy.py` β€” load checkpoint and print test metrics
30
+ - Saves to **`checkpoints/mlp_best.pt`**
31
+
32
+ ### xgboost/
33
+
34
+ - `train.py` β€” training with eval-set logging
35
+ - `sweep.py` / `sweep_local.py` β€” hyperparameter search (Optuna + ClearML)
36
+ - `eval_accuracy.py` β€” load checkpoint and print test metrics
37
+ - Saves to **`checkpoints/xgboost_face_orientation_best.json`**
38
+
39
+ ## 3. Data Loading
40
+
41
+ All training scripts import from `data_preparation.prepare_dataset`:
42
+
43
+ ```python
44
+ from data_preparation.prepare_dataset import get_numpy_splits # XGBoost
45
+ from data_preparation.prepare_dataset import get_dataloaders # MLP (PyTorch)
46
+ ```
47
+
48
+ ## 4. Results
49
+
50
+ | Model | Test Accuracy | F1 | ROC-AUC |
51
+ |-------|--------------|-----|---------|
52
+ | XGBoost | 95.87% | 0.959 | 0.991 |
53
+ | MLP | 92.92% | 0.929 | 0.971 |
models/{attention/__init__.py β†’ __init__.py} RENAMED
File without changes
models/attention/classifier.py DELETED
File without changes
models/attention/fusion.py DELETED
File without changes
models/attention/train.py DELETED
File without changes
models/cnn/notebooks/EyeCNN.ipynb ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "source": [
22
+ "import os\n",
23
+ "import torch\n",
24
+ "import torch.nn as nn\n",
25
+ "import torch.optim as optim\n",
26
+ "from torch.utils.data import DataLoader\n",
27
+ "from torchvision import datasets, transforms\n",
28
+ "\n",
29
+ "from google.colab import drive\n",
30
+ "drive.mount('/content/drive')\n",
31
+ "!cp -r /content/drive/MyDrive/Dataset_clean /content/\n",
32
+ "\n",
33
+ "#Verify structure\n",
34
+ "for split in ['train', 'val', 'test']:\n",
35
+ " path = f'/content/Dataset_clean/{split}'\n",
36
+ " classes = os.listdir(path)\n",
37
+ " total = sum(len(os.listdir(os.path.join(path, c))) for c in classes)\n",
38
+ " print(f'{split}: {total} images | classes: {classes}')"
39
+ ],
40
+ "metadata": {
41
+ "colab": {
42
+ "base_uri": "https://localhost:8080/"
43
+ },
44
+ "id": "sE1F3em-V5go",
45
+ "outputId": "2c73a9a6-a198-468c-a2cc-253b2de7cc3f"
46
+ },
47
+ "execution_count": null,
48
+ "outputs": [
49
+ {
50
+ "output_type": "stream",
51
+ "name": "stdout",
52
+ "text": [
53
+ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
54
+ ]
55
+ }
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {
62
+ "id": "nG2bh66rQ56G"
63
+ },
64
+ "outputs": [],
65
+ "source": [
66
+ "class EyeCNN(nn.Module):\n",
67
+ " def __init__(self, num_classes=2):\n",
68
+ " super(EyeCNN, self).__init__()\n",
69
+ " self.conv_layers = nn.Sequential(\n",
70
+ " nn.Conv2d(3, 32, 3, 1, 1),\n",
71
+ " nn.BatchNorm2d(32),\n",
72
+ " nn.ReLU(),\n",
73
+ " nn.MaxPool2d(2, 2),\n",
74
+ "\n",
75
+ " nn.Conv2d(32, 64, 3, 1, 1),\n",
76
+ " nn.BatchNorm2d(64),\n",
77
+ " nn.ReLU(),\n",
78
+ " nn.MaxPool2d(2, 2),\n",
79
+ "\n",
80
+ " nn.Conv2d(64, 128, 3, 1, 1),\n",
81
+ " nn.BatchNorm2d(128),\n",
82
+ " nn.ReLU(),\n",
83
+ " nn.MaxPool2d(2, 2),\n",
84
+ "\n",
85
+ " nn.Conv2d(128, 256, 3, 1, 1),\n",
86
+ " nn.BatchNorm2d(256),\n",
87
+ " nn.ReLU(),\n",
88
+ " nn.MaxPool2d(2, 2)\n",
89
+ " )\n",
90
+ "\n",
91
+ " self.fc_layers = nn.Sequential(\n",
92
+ " nn.AdaptiveAvgPool2d((1, 1)),\n",
93
+ " nn.Flatten(),\n",
94
+ " nn.Linear(256, 512),\n",
95
+ " nn.ReLU(),\n",
96
+ " nn.Dropout(0.35),\n",
97
+ " nn.Linear(512, num_classes)\n",
98
+ " )\n",
99
+ "\n",
100
+ " def forward(self, x):\n",
101
+ " x = self.conv_layers(x)\n",
102
+ " x = self.fc_layers(x)\n",
103
+ " return x"
104
+ ]
105
+ }
106
+ ]
107
+ }
models/cnn/notebooks/EyeCNN_Train_Evaluate_new.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/cnn/notebooks/EyeCNN_Training_Evaluate.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/cnn/notebooks/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # GAP Large Project
models/{attention/collect_features.py β†’ collect_features.py} RENAMED
@@ -1,4 +1,3 @@
1
- # Usage: python -m models.attention.collect_features [--name alice] [--duration 600]
2
 
3
  import argparse
4
  import collections
@@ -10,13 +9,13 @@ import time
10
  import cv2
11
  import numpy as np
12
 
13
- _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
  if _PROJECT_ROOT not in sys.path:
15
  sys.path.insert(0, _PROJECT_ROOT)
16
 
17
- from models.pretrained.face_mesh.face_mesh import FaceMeshDetector
18
- from models.geometric.face_orientation.head_pose import HeadPoseEstimator
19
- from models.geometric.eye_behaviour.eye_scorer import EyeBehaviourScorer, compute_gaze_ratio, compute_mar
20
 
21
  FONT = cv2.FONT_HERSHEY_SIMPLEX
22
  GREEN = (0, 255, 0)
@@ -38,7 +37,7 @@ assert NUM_FEATURES == 17
38
 
39
  class TemporalTracker:
40
  EAR_BLINK_THRESH = 0.21
41
- MAR_YAWN_THRESH = 0.04
42
  PERCLOS_WINDOW = 60
43
  BLINK_WINDOW_SEC = 30.0
44
 
@@ -86,25 +85,35 @@ class TemporalTracker:
86
  return perclos, blink_rate, closure_dur, yawn_dur
87
 
88
 
89
- def extract_features(landmarks, w, h, head_pose, eye_scorer, temporal):
90
- from models.geometric.eye_behaviour.eye_scorer import _LEFT_EYE_EAR, _RIGHT_EYE_EAR, compute_ear
 
91
 
92
- ear_left = compute_ear(landmarks, _LEFT_EYE_EAR)
93
- ear_right = compute_ear(landmarks, _RIGHT_EYE_EAR)
 
 
94
  ear_avg = (ear_left + ear_right) / 2.0
95
- h_gaze, v_gaze = compute_gaze_ratio(landmarks)
96
- mar = compute_mar(landmarks)
97
 
98
- angles = head_pose.estimate(landmarks, w, h)
 
 
 
 
 
 
 
 
 
99
  yaw = angles[0] if angles else 0.0
100
  pitch = angles[1] if angles else 0.0
101
  roll = angles[2] if angles else 0.0
102
 
103
- s_face = head_pose.score(landmarks, w, h)
104
- s_eye = eye_scorer.score(landmarks)
105
 
106
  gaze_offset = math.sqrt((h_gaze - 0.5) ** 2 + (v_gaze - 0.5) ** 2)
107
- head_deviation = math.sqrt(yaw ** 2 + pitch ** 2)
108
 
109
  perclos, blink_rate, closure_dur, yawn_dur = temporal.update(ear_avg, mar)
110
 
@@ -181,7 +190,7 @@ def main():
181
  parser.add_argument("--duration", type=int, default=600,
182
  help="Max recording time (seconds, default 10 min)")
183
  parser.add_argument("--output-dir", type=str,
184
- default=os.path.join(_PROJECT_ROOT, "data_preparation", "collected"),
185
  help="Where to save .npz files")
186
  args = parser.parse_args()
187
 
@@ -238,13 +247,11 @@ def main():
238
  landmarks = detector.process(frame)
239
  face_ok = landmarks is not None
240
 
241
- # record if labeling + face visible
242
  if face_ok and label is not None:
243
  vec = extract_features(landmarks, w, h, head_pose, eye_scorer, temporal)
244
  features_list.append(vec)
245
  labels_list.append(label)
246
 
247
- # count transitions
248
  if prev_label is not None and label != prev_label:
249
  transitions += 1
250
  prev_label = label
 
 
1
 
2
  import argparse
3
  import collections
 
9
  import cv2
10
  import numpy as np
11
 
12
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
  if _PROJECT_ROOT not in sys.path:
14
  sys.path.insert(0, _PROJECT_ROOT)
15
 
16
+ from models.face_mesh import FaceMeshDetector
17
+ from models.head_pose import HeadPoseEstimator
18
+ from models.eye_scorer import EyeBehaviourScorer, compute_gaze_ratio, compute_mar
19
 
20
  FONT = cv2.FONT_HERSHEY_SIMPLEX
21
  GREEN = (0, 255, 0)
 
37
 
38
  class TemporalTracker:
39
  EAR_BLINK_THRESH = 0.21
40
+ MAR_YAWN_THRESH = 0.55
41
  PERCLOS_WINDOW = 60
42
  BLINK_WINDOW_SEC = 30.0
43
 
 
85
  return perclos, blink_rate, closure_dur, yawn_dur
86
 
87
 
88
+ def extract_features(landmarks, w, h, head_pose, eye_scorer, temporal,
89
+ *, _pre=None):
90
+ from models.eye_scorer import _LEFT_EYE_EAR, _RIGHT_EYE_EAR, compute_ear
91
 
92
+ p = _pre or {}
93
+
94
+ ear_left = p.get("ear_left", compute_ear(landmarks, _LEFT_EYE_EAR))
95
+ ear_right = p.get("ear_right", compute_ear(landmarks, _RIGHT_EYE_EAR))
96
  ear_avg = (ear_left + ear_right) / 2.0
 
 
97
 
98
+ if "h_gaze" in p and "v_gaze" in p:
99
+ h_gaze, v_gaze = p["h_gaze"], p["v_gaze"]
100
+ else:
101
+ h_gaze, v_gaze = compute_gaze_ratio(landmarks)
102
+
103
+ mar = p.get("mar", compute_mar(landmarks))
104
+
105
+ angles = p.get("angles")
106
+ if angles is None:
107
+ angles = head_pose.estimate(landmarks, w, h)
108
  yaw = angles[0] if angles else 0.0
109
  pitch = angles[1] if angles else 0.0
110
  roll = angles[2] if angles else 0.0
111
 
112
+ s_face = p.get("s_face", head_pose.score(landmarks, w, h))
113
+ s_eye = p.get("s_eye", eye_scorer.score(landmarks))
114
 
115
  gaze_offset = math.sqrt((h_gaze - 0.5) ** 2 + (v_gaze - 0.5) ** 2)
116
+ head_deviation = math.sqrt(yaw ** 2 + pitch ** 2) # cleaned downstream
117
 
118
  perclos, blink_rate, closure_dur, yawn_dur = temporal.update(ear_avg, mar)
119
 
 
190
  parser.add_argument("--duration", type=int, default=600,
191
  help="Max recording time (seconds, default 10 min)")
192
  parser.add_argument("--output-dir", type=str,
193
+ default=os.path.join(_PROJECT_ROOT, "data", "collected_data"),
194
  help="Where to save .npz files")
195
  args = parser.parse_args()
196
 
 
247
  landmarks = detector.process(frame)
248
  face_ok = landmarks is not None
249
 
 
250
  if face_ok and label is not None:
251
  vec = extract_features(landmarks, w, h, head_pose, eye_scorer, temporal)
252
  features_list.append(vec)
253
  labels_list.append(label)
254
 
 
255
  if prev_label is not None and label != prev_label:
256
  transitions += 1
257
  prev_label = label
models/eye_classifier.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ import numpy as np
6
+
7
+
8
+ class EyeClassifier(ABC):
9
+ @property
10
+ @abstractmethod
11
+ def name(self) -> str:
12
+ pass
13
+
14
+ @abstractmethod
15
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
16
+ pass
17
+
18
+
19
+ class GeometricOnlyClassifier(EyeClassifier):
20
+ @property
21
+ def name(self) -> str:
22
+ return "geometric"
23
+
24
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
25
+ return 1.0
26
+
27
+
28
+ class YOLOv11Classifier(EyeClassifier):
29
+ def __init__(self, checkpoint_path: str, device: str = "cpu"):
30
+ from ultralytics import YOLO
31
+
32
+ self._model = YOLO(checkpoint_path)
33
+ self._device = device
34
+
35
+ names = self._model.names
36
+ self._attentive_idx = None
37
+ for idx, cls_name in names.items():
38
+ if cls_name in ("open", "attentive"):
39
+ self._attentive_idx = idx
40
+ break
41
+ if self._attentive_idx is None:
42
+ self._attentive_idx = max(names.keys())
43
+ print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}")
44
+
45
+ @property
46
+ def name(self) -> str:
47
+ return "yolo"
48
+
49
+ def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
50
+ if not crops_bgr:
51
+ return 1.0
52
+ results = self._model.predict(crops_bgr, device=self._device, verbose=False)
53
+ scores = [float(r.probs.data[self._attentive_idx]) for r in results]
54
+ return sum(scores) / len(scores) if scores else 1.0
55
+
56
+
57
+ def load_eye_classifier(
58
+ path: str | None = None,
59
+ backend: str = "yolo",
60
+ device: str = "cpu",
61
+ ) -> EyeClassifier:
62
+ if path is None or backend == "geometric":
63
+ return GeometricOnlyClassifier()
64
+
65
+ try:
66
+ return YOLOv11Classifier(path, device=device)
67
+ except ImportError:
68
+ print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics")
69
+ raise
models/eye_crop.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from models.face_mesh import FaceMeshDetector
5
+
6
+ LEFT_EYE_CONTOUR = FaceMeshDetector.LEFT_EYE_INDICES
7
+ RIGHT_EYE_CONTOUR = FaceMeshDetector.RIGHT_EYE_INDICES
8
+
9
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
10
+ IMAGENET_STD = (0.229, 0.224, 0.225)
11
+
12
+ CROP_SIZE = 96
13
+
14
+
15
+ def _bbox_from_landmarks(
16
+ landmarks: np.ndarray,
17
+ indices: list[int],
18
+ frame_w: int,
19
+ frame_h: int,
20
+ expand: float = 0.4,
21
+ ) -> tuple[int, int, int, int]:
22
+ pts = landmarks[indices, :2]
23
+ px = pts[:, 0] * frame_w
24
+ py = pts[:, 1] * frame_h
25
+
26
+ x_min, x_max = px.min(), px.max()
27
+ y_min, y_max = py.min(), py.max()
28
+ w = x_max - x_min
29
+ h = y_max - y_min
30
+ cx = (x_min + x_max) / 2
31
+ cy = (y_min + y_max) / 2
32
+
33
+ size = max(w, h) * (1 + expand)
34
+ half = size / 2
35
+
36
+ x1 = int(max(cx - half, 0))
37
+ y1 = int(max(cy - half, 0))
38
+ x2 = int(min(cx + half, frame_w))
39
+ y2 = int(min(cy + half, frame_h))
40
+
41
+ return x1, y1, x2, y2
42
+
43
+
44
+ def extract_eye_crops(
45
+ frame: np.ndarray,
46
+ landmarks: np.ndarray,
47
+ expand: float = 0.4,
48
+ crop_size: int = CROP_SIZE,
49
+ ) -> tuple[np.ndarray, np.ndarray, tuple, tuple]:
50
+ h, w = frame.shape[:2]
51
+
52
+ left_bbox = _bbox_from_landmarks(landmarks, LEFT_EYE_CONTOUR, w, h, expand)
53
+ right_bbox = _bbox_from_landmarks(landmarks, RIGHT_EYE_CONTOUR, w, h, expand)
54
+
55
+ left_crop = frame[left_bbox[1] : left_bbox[3], left_bbox[0] : left_bbox[2]]
56
+ right_crop = frame[right_bbox[1] : right_bbox[3], right_bbox[0] : right_bbox[2]]
57
+
58
+ if left_crop.size == 0:
59
+ left_crop = np.zeros((crop_size, crop_size, 3), dtype=np.uint8)
60
+ else:
61
+ left_crop = cv2.resize(left_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
62
+
63
+ if right_crop.size == 0:
64
+ right_crop = np.zeros((crop_size, crop_size, 3), dtype=np.uint8)
65
+ else:
66
+ right_crop = cv2.resize(right_crop, (crop_size, crop_size), interpolation=cv2.INTER_AREA)
67
+
68
+ return left_crop, right_crop, left_bbox, right_bbox
69
+
70
+
71
+ def crop_to_tensor(crop_bgr: np.ndarray):
72
+ import torch
73
+
74
+ rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
75
+ for c in range(3):
76
+ rgb[:, :, c] = (rgb[:, :, c] - IMAGENET_MEAN[c]) / IMAGENET_STD[c]
77
+ return torch.from_numpy(rgb.transpose(2, 0, 1))
models/{geometric/eye_behaviour/eye_scorer.py β†’ eye_scorer.py} RENAMED
@@ -95,7 +95,6 @@ def compute_gaze_ratio(landmarks: np.ndarray) -> tuple[float, float]:
95
 
96
 
97
  def compute_mar(landmarks: np.ndarray) -> float:
98
- # Mouth aspect ratio: high = mouth open (yawning / sleepy)
99
  top = landmarks[_MOUTH_TOP, :2]
100
  bottom = landmarks[_MOUTH_BOTTOM, :2]
101
  left = landmarks[_MOUTH_LEFT, :2]
@@ -140,7 +139,10 @@ class EyeBehaviourScorer:
140
  return 0.5 * (1.0 + math.cos(math.pi * t))
141
 
142
  def score(self, landmarks: np.ndarray) -> float:
143
- ear = compute_avg_ear(landmarks)
 
 
 
144
  ear_s = self._ear_score(ear)
145
  if ear_s < 0.3:
146
  return ear_s
@@ -149,7 +151,9 @@ class EyeBehaviourScorer:
149
  return ear_s * gaze_s
150
 
151
  def detailed_score(self, landmarks: np.ndarray) -> dict:
152
- ear = compute_avg_ear(landmarks)
 
 
153
  ear_s = self._ear_score(ear)
154
  h_ratio, v_ratio = compute_gaze_ratio(landmarks)
155
  gaze_s = self._gaze_score(h_ratio, v_ratio)
 
95
 
96
 
97
  def compute_mar(landmarks: np.ndarray) -> float:
 
98
  top = landmarks[_MOUTH_TOP, :2]
99
  bottom = landmarks[_MOUTH_BOTTOM, :2]
100
  left = landmarks[_MOUTH_LEFT, :2]
 
139
  return 0.5 * (1.0 + math.cos(math.pi * t))
140
 
141
  def score(self, landmarks: np.ndarray) -> float:
142
+ left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
143
+ right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
144
+ # Use minimum EAR so closing ONE eye is enough to drop the score
145
+ ear = min(left_ear, right_ear)
146
  ear_s = self._ear_score(ear)
147
  if ear_s < 0.3:
148
  return ear_s
 
151
  return ear_s * gaze_s
152
 
153
  def detailed_score(self, landmarks: np.ndarray) -> dict:
154
+ left_ear = compute_ear(landmarks, _LEFT_EYE_EAR)
155
+ right_ear = compute_ear(landmarks, _RIGHT_EYE_EAR)
156
+ ear = min(left_ear, right_ear)
157
  ear_s = self._ear_score(ear)
158
  h_ratio, v_ratio = compute_gaze_ratio(landmarks)
159
  gaze_s = self._gaze_score(h_ratio, v_ratio)
models/{pretrained/face_mesh/face_mesh.py β†’ face_mesh.py} RENAMED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from pathlib import Path
3
  from urllib.request import urlretrieve
4
 
@@ -51,14 +52,16 @@ class FaceMeshDetector:
51
  running_mode=RunningMode.VIDEO,
52
  )
53
  self._landmarker = FaceLandmarker.create_from_options(options)
54
- self._frame_ts = 0 # ms, for video API
 
55
 
56
  def process(self, bgr_frame: np.ndarray) -> np.ndarray | None:
57
  # BGR in -> (478,3) norm x,y,z or None
58
  rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
59
  mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
60
- self._frame_ts += 33 # ~30fps
61
- result = self._landmarker.detect_for_video(mp_image, self._frame_ts)
 
62
 
63
  if not result.face_landmarks:
64
  return None
 
1
  import os
2
+ import time
3
  from pathlib import Path
4
  from urllib.request import urlretrieve
5
 
 
52
  running_mode=RunningMode.VIDEO,
53
  )
54
  self._landmarker = FaceLandmarker.create_from_options(options)
55
+ self._t0 = time.monotonic()
56
+ self._last_ts = 0
57
 
58
  def process(self, bgr_frame: np.ndarray) -> np.ndarray | None:
59
  # BGR in -> (478,3) norm x,y,z or None
60
  rgb = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB)
61
  mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
62
+ ts = max(int((time.monotonic() - self._t0) * 1000), self._last_ts + 1)
63
+ self._last_ts = ts
64
+ result = self._landmarker.detect_for_video(mp_image, ts)
65
 
66
  if not result.face_landmarks:
67
  return None
models/geometric/eye_behaviour/__init__.py DELETED
File without changes
models/geometric/face_orientation/__init__.py DELETED
@@ -1 +0,0 @@
1
-
 
 
models/{geometric/face_orientation/head_pose.py β†’ head_pose.py} RENAMED
@@ -25,6 +25,8 @@ class HeadPoseEstimator:
25
  self._camera_matrix = None
26
  self._frame_size = None
27
  self._dist_coeffs = np.zeros((4, 1), dtype=np.float64)
 
 
28
 
29
  def _get_camera_matrix(self, frame_w: int, frame_h: int) -> np.ndarray:
30
  if self._camera_matrix is not None and self._frame_size == (frame_w, frame_h):
@@ -39,6 +41,10 @@ class HeadPoseEstimator:
39
  return self._camera_matrix
40
 
41
  def _solve(self, landmarks: np.ndarray, frame_w: int, frame_h: int):
 
 
 
 
42
  image_points = np.array(
43
  [
44
  [landmarks[i, 0] * frame_w, landmarks[i, 1] * frame_h]
@@ -54,7 +60,10 @@ class HeadPoseEstimator:
54
  self._dist_coeffs,
55
  flags=cv2.SOLVEPNP_ITERATIVE,
56
  )
57
- return success, rvec, tvec, image_points
 
 
 
58
 
59
  def estimate(
60
  self, landmarks: np.ndarray, frame_w: int, frame_h: int
 
25
  self._camera_matrix = None
26
  self._frame_size = None
27
  self._dist_coeffs = np.zeros((4, 1), dtype=np.float64)
28
+ self._cache_key = None
29
+ self._cache_result = None
30
 
31
  def _get_camera_matrix(self, frame_w: int, frame_h: int) -> np.ndarray:
32
  if self._camera_matrix is not None and self._frame_size == (frame_w, frame_h):
 
41
  return self._camera_matrix
42
 
43
  def _solve(self, landmarks: np.ndarray, frame_w: int, frame_h: int):
44
+ key = (landmarks.data.tobytes(), frame_w, frame_h)
45
+ if self._cache_key == key:
46
+ return self._cache_result
47
+
48
  image_points = np.array(
49
  [
50
  [landmarks[i, 0] * frame_w, landmarks[i, 1] * frame_h]
 
60
  self._dist_coeffs,
61
  flags=cv2.SOLVEPNP_ITERATIVE,
62
  )
63
+ result = (success, rvec, tvec, image_points)
64
+ self._cache_key = key
65
+ self._cache_result = result
66
+ return result
67
 
68
  def estimate(
69
  self, landmarks: np.ndarray, frame_w: int, frame_h: int