pull from hugging face integration test

#1
This view is limited to 50 files because it contains too many changes. See the raw diff here.
Files changed (50) hide show
  1. .gitignore +0 -41
  2. Dockerfile +0 -34
  3. README.md +3 -87
  4. api/history +0 -0
  5. api/import +0 -0
  6. api/sessions +0 -0
  7. app.py +0 -1
  8. checkpoints/hybrid_focus_config.json +0 -10
  9. checkpoints/meta_best.npz +0 -3
  10. checkpoints/mlp_best.pt +0 -3
  11. checkpoints/model_best.joblib +0 -3
  12. checkpoints/scaler_best.joblib +0 -3
  13. checkpoints/xgboost_face_orientation_best.json +0 -0
  14. data/CNN/eye_crops/val/open/.gitkeep +0 -1
  15. data/README.md +0 -47
  16. data/collected_Abdelrahman/abdelrahman_20260306_023035.npz +0 -3
  17. data/collected_Jarek/Jarek_20260225_012931.npz +0 -3
  18. data/collected_Junhao/Junhao_20260303_113554.npz +0 -3
  19. data/collected_Kexin/kexin2_20260305_180229.npz +0 -3
  20. data/collected_Kexin/kexin_20260224_151043.npz +0 -3
  21. data/collected_Langyuan/Langyuan_20260303_153145.npz +0 -3
  22. data/collected_Mohamed/session_20260224_010131.npz +0 -3
  23. data/collected_Yingtao/Yingtao_20260306_023937.npz +0 -3
  24. data/collected_ayten/ayten_session_1.npz +0 -3
  25. data/collected_saba/saba_20260306_230710.npz +0 -3
  26. data_preparation/README.md +0 -75
  27. data_preparation/__init__.py +0 -0
  28. data_preparation/data_exploration.ipynb +0 -0
  29. data_preparation/prepare_dataset.py +0 -232
  30. docker-compose.yml +0 -5
  31. download_l2cs_weights.py +0 -37
  32. eslint.config.js +0 -29
  33. evaluation/README.md +0 -46
  34. index.html +0 -17
  35. main.py +0 -1210
  36. models/L2CS-Net/.gitignore +0 -140
  37. models/L2CS-Net/LICENSE +0 -21
  38. models/L2CS-Net/README.md +0 -148
  39. models/L2CS-Net/demo.py +0 -87
  40. models/L2CS-Net/l2cs/__init__.py +0 -21
  41. models/L2CS-Net/l2cs/datasets.py +0 -157
  42. models/L2CS-Net/l2cs/model.py +0 -73
  43. models/L2CS-Net/l2cs/pipeline.py +0 -133
  44. models/L2CS-Net/l2cs/results.py +0 -11
  45. models/L2CS-Net/l2cs/utils.py +0 -145
  46. models/L2CS-Net/l2cs/vis.py +0 -64
  47. models/L2CS-Net/leave_one_out_eval.py +0 -54
  48. models/L2CS-Net/models/L2CSNet_gaze360.pkl +0 -3
  49. models/L2CS-Net/models/README.md +0 -1
  50. models/L2CS-Net/pyproject.toml +0 -44
.gitignore DELETED
@@ -1,41 +0,0 @@
1
- # Logs
2
- logs
3
- *.log
4
- npm-debug.log*
5
- yarn-debug.log*
6
- yarn-error.log*
7
- pnpm-debug.log*
8
- lerna-debug.log*
9
-
10
- node_modules/
11
- dist/
12
- dist-ssr/
13
- *.local
14
-
15
- # Editor directories and files
16
- .vscode/
17
- .idea/
18
- .DS_Store
19
- *.suo
20
- *.ntvs*
21
- *.njsproj
22
- *.sln
23
- *.sw?
24
- *.py[cod]
25
- *$py.class
26
- *.so
27
- .Python
28
- venv/
29
- .venv/
30
- env/
31
- .env
32
- *.egg-info/
33
- .eggs/
34
- build/
35
- Thumbs.db
36
-
37
- # Project specific
38
- focus_guard.db
39
- static/
40
- __pycache__/
41
- docs/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Dockerfile DELETED
@@ -1,34 +0,0 @@
1
- FROM python:3.10-slim
2
-
3
- RUN useradd -m -u 1000 user
4
- ENV HOME=/home/user PATH=/home/user/.local/bin:$PATH
5
-
6
- ENV PYTHONUNBUFFERED=1
7
-
8
- WORKDIR /app
9
-
10
- RUN apt-get update && apt-get install -y --no-install-recommends \
11
- libglib2.0-0 libsm6 libxrender1 libxext6 libxcb1 libgl1 libgomp1 \
12
- ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev \
13
- libavdevice-dev libopus-dev libvpx-dev libsrtp2-dev \
14
- build-essential nodejs npm git \
15
- && rm -rf /var/lib/apt/lists/*
16
-
17
- RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
18
-
19
- COPY requirements.txt ./
20
- RUN pip install --no-cache-dir -r requirements.txt
21
-
22
- COPY . .
23
-
24
- RUN npm install && npm run build && mkdir -p /app/static && cp -R dist/* /app/static/
25
-
26
- ENV FOCUSGUARD_CACHE_DIR=/app/.cache/focusguard
27
- RUN python -c "from models.face_mesh import _ensure_model; _ensure_model()"
28
-
29
- RUN mkdir -p /app/data && chown -R user:user /app
30
-
31
- USER user
32
- EXPOSE 7860
33
-
34
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860", "--log-level", "debug"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,94 +1,10 @@
1
  ---
2
- title: FocusGuard
 
3
  colorFrom: indigo
4
  colorTo: purple
5
  sdk: docker
6
  pinned: false
7
  ---
8
 
9
- # FocusGuard - Real-Time Focus Detection
10
-
11
- A web app that monitors whether you're focused on your screen using your webcam. Combines head pose estimation, eye behaviour analysis, and deep learning gaze tracking to detect attention in real time.
12
-
13
- ## How It Works
14
-
15
- 1. **Open the app** and click **Start** - your webcam feed appears with a face mesh overlay.
16
- 2. **Pick a model** from the selector bar (Geometric, XGBoost, L2CS, etc.).
17
- 3. The system analyses each frame and shows **FOCUSED** or **NOT FOCUSED** with a confidence score.
18
- 4. A timeline tracks your focus over time. Session history is saved for review.
19
-
20
- ## Models
21
-
22
- | Model | What it uses | Best for |
23
- |-------|-------------|----------|
24
- | **Geometric** | Head pose angles + eye aspect ratio (EAR) | Fast, no ML needed |
25
- | **XGBoost** | Trained classifier on head/eye features | Balanced accuracy/speed |
26
- | **MLP** | Neural network on same features | Higher accuracy |
27
- | **Hybrid** | Weighted MLP + Geometric ensemble | Best head-pose accuracy |
28
- | **L2CS** | Deep gaze estimation (ResNet50) | Detects eye-only gaze shifts |
29
-
30
- ## L2CS Gaze Tracking
31
-
32
- L2CS-Net predicts where your eyes are looking, not just where your head is pointed. This catches the scenario where your head faces the screen but your eyes wander.
33
-
34
- ### Standalone mode
35
- Select **L2CS** as the model - it handles everything.
36
-
37
- ### Boost mode
38
- Select any other model, then click the **GAZE** toggle. L2CS runs alongside the base model:
39
- - Base model handles head pose and eye openness (35% weight)
40
- - L2CS handles gaze direction (65% weight)
41
- - If L2CS detects gaze is clearly off-screen, it **vetoes** the base model regardless of score
42
-
43
- ### Calibration
44
- After enabling L2CS or Gaze Boost, click **Calibrate** while a session is running:
45
- 1. A fullscreen overlay shows 9 target dots (3x3 grid)
46
- 2. Look at each dot as the progress ring fills
47
- 3. The first dot (centre) sets your baseline gaze offset
48
- 4. After all 9 points, a polynomial model maps your gaze angles to screen coordinates
49
- 5. A cyan tracking dot appears on the video showing where you're looking
50
-
51
- ## Tech Stack
52
-
53
- - **Backend**: FastAPI + WebSocket, Python 3.10
54
- - **Frontend**: React + Vite
55
- - **Face detection**: MediaPipe Face Landmarker (478 landmarks)
56
- - **Gaze estimation**: L2CS-Net (ResNet50, Gaze360 weights)
57
- - **ML models**: XGBoost, PyTorch MLP
58
- - **Deployment**: Docker on Hugging Face Spaces
59
-
60
- ## Running Locally
61
-
62
- ```bash
63
- # install Python deps
64
- pip install -r requirements.txt
65
-
66
- # install frontend deps and build
67
- npm install && npm run build
68
-
69
- # start the server
70
- uvicorn main:app --port 8000
71
- ```
72
-
73
- Open `http://localhost:8000` in your browser.
74
-
75
- ## Project Structure
76
-
77
- ```
78
- main.py # FastAPI app, WebSocket handler, API endpoints
79
- ui/pipeline.py # All focus detection pipelines (Geometric, MLP, XGBoost, Hybrid, L2CS)
80
- models/
81
- face_mesh.py # MediaPipe face landmark detector
82
- head_pose.py # Head pose estimation from landmarks
83
- eye_scorer.py # EAR/eye behaviour scoring
84
- gaze_calibration.py # 9-point polynomial gaze calibration
85
- gaze_eye_fusion.py # Fuses calibrated gaze with eye openness
86
- L2CS-Net/ # In-tree L2CS-Net repo with Gaze360 weights
87
- src/
88
- components/
89
- FocusPageLocal.jsx # Main focus page (camera, controls, model selector)
90
- CalibrationOverlay.jsx # Fullscreen calibration UI
91
- utils/
92
- VideoManagerLocal.js # WebSocket client, frame capture, canvas rendering
93
- Dockerfile # Docker build for HF Spaces
94
- ```
 
1
  ---
2
+ title: IntegrationTest
3
+ emoji: 📚
4
  colorFrom: indigo
5
  colorTo: purple
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/history DELETED
File without changes
api/import DELETED
File without changes
api/sessions DELETED
File without changes
app.py DELETED
@@ -1 +0,0 @@
1
- from main import app
 
 
checkpoints/hybrid_focus_config.json DELETED
@@ -1,10 +0,0 @@
1
- {
2
- "w_mlp": 0.6000000000000001,
3
- "w_geo": 0.3999999999999999,
4
- "threshold": 0.35,
5
- "use_yawn_veto": true,
6
- "geo_face_weight": 0.4,
7
- "geo_eye_weight": 0.6,
8
- "mar_yawn_threshold": 0.55,
9
- "metric": "f1"
10
- }
 
 
 
 
 
 
 
 
 
 
 
checkpoints/meta_best.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5d78d1df5e25536a2c82c4b8f5fd0c26dd35f44b28fd59761634cbf78c7546f8
3
- size 4196
 
 
 
 
checkpoints/mlp_best.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c2f55129785b6882c304483aa5399f5bf6c9ed6e73dfec7ca6f36cd0436156c8
3
- size 14497
 
 
 
 
checkpoints/model_best.joblib DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:183f2d4419e0eb1e58704e5a7312eb61e331523566d4dc551054a07b3aac7557
3
- size 5775881
 
 
 
 
checkpoints/scaler_best.joblib DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:02ed6b4c0d99e0254c6a740a949da2384db58ec7d3e6df6432b9bfcd3a296c71
3
- size 783
 
 
 
 
checkpoints/xgboost_face_orientation_best.json DELETED
The diff for this file is too large to render. See raw diff
 
data/CNN/eye_crops/val/open/.gitkeep DELETED
@@ -1 +0,0 @@
1
-
 
 
data/README.md DELETED
@@ -1,47 +0,0 @@
1
- # data/
2
-
3
- Raw collected session data used for model training and evaluation.
4
-
5
- ## 1. Contents
6
-
7
- Each `collected_<name>/` folder contains `.npz` files for one participant:
8
-
9
- | Folder | Participant | Samples |
10
- |--------|-------------|---------|
11
- | `collected_Abdelrahman/` | Abdelrahman | 15,870 |
12
- | `collected_Jarek/` | Jarek | 14,829 |
13
- | `collected_Junhao/` | Junhao | 8,901 |
14
- | `collected_Kexin/` | Kexin | 32,312 (2 sessions) |
15
- | `collected_Langyuan/` | Langyuan | 15,749 |
16
- | `collected_Mohamed/` | Mohamed | 13,218 |
17
- | `collected_Yingtao/` | Yingtao | 17,591 |
18
- | `collected_ayten/` | Ayten | 17,621 |
19
- | `collected_saba/` | Saba | 8,702 |
20
- | **Total** | **9 participants** | **144,793** |
21
-
22
- ## 2. File Format
23
-
24
- Each `.npz` file contains:
25
-
26
- | Key | Shape | Description |
27
- |-----|-------|-------------|
28
- | `features` | (N, 17) | 17-dimensional feature vectors (float32) |
29
- | `labels` | (N,) | Binary labels: 0 = unfocused, 1 = focused |
30
- | `feature_names` | (17,) | Column names for the 17 features |
31
-
32
- ## 3. Feature List
33
-
34
- `ear_left`, `ear_right`, `ear_avg`, `h_gaze`, `v_gaze`, `mar`, `yaw`, `pitch`, `roll`, `s_face`, `s_eye`, `gaze_offset`, `head_deviation`, `perclos`, `blink_rate`, `closure_duration`, `yawn_duration`
35
-
36
- 10 of these are selected for training (see `data_preparation/prepare_dataset.py`).
37
-
38
- ## 4. Collection
39
-
40
- ```bash
41
- python -m models.collect_features --name yourname
42
- ```
43
-
44
- 1. Webcam opens with live overlay
45
- 2. Press **1** = focused, **0** = unfocused (switch every 10–30 sec)
46
- 3. Press **p** to pause/resume
47
- 4. Press **q** to stop and save
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/collected_Abdelrahman/abdelrahman_20260306_023035.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e2c48532150182c8933d4595e0a0711365645b699647e99976575b7c2adffaf8
3
- size 1207980
 
 
 
 
data/collected_Jarek/Jarek_20260225_012931.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0fa68f4d587eee8d645b23b463a9f1c848b9bacc2adb68603d5fa9cd8cb744c7
3
- size 1128864
 
 
 
 
data/collected_Junhao/Junhao_20260303_113554.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ec321ee79800c04fdc0f999690d07970445aeca61f977bf6537880bbc996b5e5
3
- size 678336
 
 
 
 
data/collected_Kexin/kexin2_20260305_180229.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0e96fe17571fa1fcccc1b4bd0c8838270498883e4db6a608c4d4d4c3a8ac1d0d
3
- size 1129700
 
 
 
 
data/collected_Kexin/kexin_20260224_151043.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8d402ca4e66910a2e174c4f4beec5d7b3db6a04213d29673b227ce6ef04b39c4
3
- size 1329732
 
 
 
 
data/collected_Langyuan/Langyuan_20260303_153145.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5c679cdba334b2f3f0953b7e44f7209056277c826e2b7b5cfcf2b8b750898400
3
- size 1198784
 
 
 
 
data/collected_Mohamed/session_20260224_010131.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0a784f703c13b83911f47ec507d32c25942a07572314b8a77cbf40ca8cdff16f
3
- size 1006428
 
 
 
 
data/collected_Yingtao/Yingtao_20260306_023937.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7a75af17e25dca5f06ea9e7443ea5fee9db638f68a5910e014ee7cb8b7ae80fd
3
- size 1338776
 
 
 
 
data/collected_ayten/ayten_session_1.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fbecdbffa1c1b03b3b0fb5f715dcb4ff885ecc67da4aff78e6952b8847a96014
3
- size 1341056
 
 
 
 
data/collected_saba/saba_20260306_230710.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:db1cab5ddcf9988856c5bdca1183c8eba4647365e675a1d8a200d12f6b5d2097
3
- size 663212
 
 
 
 
data_preparation/README.md DELETED
@@ -1,75 +0,0 @@
1
- # data_preparation/
2
-
3
- Shared data loading, cleaning, and exploratory analysis.
4
-
5
- ## 1. Files
6
-
7
- | File | Description |
8
- |------|-------------|
9
- | `prepare_dataset.py` | Central data loading module used by all training scripts and notebooks |
10
- | `data_exploration.ipynb` | EDA notebook: feature distributions, class balance, correlations |
11
-
12
- ## 2. prepare_dataset.py
13
-
14
- Provides a consistent pipeline for loading raw `.npz` data from `data/`:
15
-
16
- | Function | Purpose |
17
- |----------|---------|
18
- | `load_all_pooled(model_name)` | Load all participants, clean, select features, concatenate |
19
- | `load_per_person(model_name)` | Load grouped by person (for LOPO cross-validation) |
20
- | `get_numpy_splits(model_name)` | Load + stratified 70/15/15 split + StandardScaler |
21
- | `get_dataloaders(model_name)` | Same as above, wrapped in PyTorch DataLoaders |
22
- | `_split_and_scale(features, labels, ...)` | Reusable split + optional scaling |
23
-
24
- ### Cleaning rules
25
-
26
- - `yaw` clipped to [-45, 45], `pitch`/`roll` to [-30, 30]
27
- - `ear_left`, `ear_right`, `ear_avg` clipped to [0, 0.85]
28
-
29
- ### Selected features (face_orientation)
30
-
31
- `head_deviation`, `s_face`, `s_eye`, `h_gaze`, `pitch`, `ear_left`, `ear_avg`, `ear_right`, `gaze_offset`, `perclos`
32
-
33
- ## 3. data_exploration.ipynb
34
-
35
- Run from this folder or from the project root. Covers:
36
-
37
- 1. Per-feature statistics (mean, std, min, max)
38
- 2. Class distribution (focused vs unfocused)
39
- 3. Feature histograms and box plots
40
- 4. Correlation matrix
41
-
42
- ## 4. How to run
43
-
44
- `prepare_dataset.py` is a **library module**, not a standalone script. You don’t run it directly; you import it from code that needs data.
45
-
46
- **From repo root:**
47
-
48
- ```bash
49
- # Optional: quick test that loading works
50
- python -c "
51
- from data_preparation.prepare_dataset import load_all_pooled
52
- X, y, names = load_all_pooled('face_orientation')
53
- print(f'Loaded {X.shape[0]} samples, {X.shape[1]} features: {names}')
54
- "
55
- ```
56
-
57
- **Used by:**
58
-
59
- - `python -m models.mlp.train`
60
- - `python -m models.xgboost.train`
61
- - `notebooks/mlp.ipynb`, `notebooks/xgboost.ipynb`
62
- - `data_preparation/data_exploration.ipynb`
63
-
64
- ## 5. Usage (in code)
65
-
66
- ```python
67
- from data_preparation.prepare_dataset import load_all_pooled, get_numpy_splits
68
-
69
- # pooled data
70
- X, y, names = load_all_pooled("face_orientation")
71
-
72
- # ready-to-train splits
73
- splits, n_features, n_classes, scaler = get_numpy_splits("face_orientation")
74
- X_train, y_train = splits["X_train"], splits["y_train"]
75
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_preparation/__init__.py DELETED
File without changes
data_preparation/data_exploration.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
data_preparation/prepare_dataset.py DELETED
@@ -1,232 +0,0 @@
1
- import os
2
- import glob
3
-
4
- import numpy as np
5
- from sklearn.preprocessing import StandardScaler
6
- from sklearn.model_selection import train_test_split
7
-
8
- try:
9
- import torch
10
- from torch.utils.data import Dataset, DataLoader
11
- except ImportError: # pragma: no cover
12
- torch = None
13
-
14
- class Dataset: # type: ignore
15
- pass
16
-
17
- class _MissingTorchDataLoader: # type: ignore
18
- def __init__(self, *args, **kwargs):
19
- raise ImportError(
20
- "PyTorch not installed"
21
- )
22
-
23
- DataLoader = _MissingTorchDataLoader # type: ignore
24
-
25
- DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
26
-
27
- SELECTED_FEATURES = {
28
- "face_orientation": [
29
- 'head_deviation', 's_face', 's_eye', 'h_gaze', 'pitch',
30
- 'ear_left', 'ear_avg', 'ear_right', 'gaze_offset', 'perclos'
31
- ],
32
- "eye_behaviour": [
33
- 'ear_left', 'ear_right', 'ear_avg', 'mar',
34
- 'blink_rate', 'closure_duration', 'perclos', 'yawn_duration'
35
- ]
36
- }
37
-
38
-
39
- class FeatureVectorDataset(Dataset):
40
- def __init__(self, features: np.ndarray, labels: np.ndarray):
41
- self.features = torch.tensor(features, dtype=torch.float32)
42
- self.labels = torch.tensor(labels, dtype=torch.long)
43
-
44
- def __len__(self):
45
- return len(self.labels)
46
-
47
- def __getitem__(self, idx):
48
- return self.features[idx], self.labels[idx]
49
-
50
-
51
- # ── Low-level helpers ────────────────────────────────────────────────────
52
-
53
- def _clean_npz(raw, names):
54
- """Apply clipping rules in-place. Shared by all loaders."""
55
- for col, lo, hi in [('yaw', -45, 45), ('pitch', -30, 30), ('roll', -30, 30)]:
56
- if col in names:
57
- raw[:, names.index(col)] = np.clip(raw[:, names.index(col)], lo, hi)
58
- for feat in ['ear_left', 'ear_right', 'ear_avg']:
59
- if feat in names:
60
- raw[:, names.index(feat)] = np.clip(raw[:, names.index(feat)], 0, 0.85)
61
- return raw
62
-
63
-
64
- def _load_one_npz(npz_path, target_features):
65
- """Load a single .npz file, clean and select features. Returns (X, y, selected_feature_names)."""
66
- data = np.load(npz_path, allow_pickle=True)
67
- raw = data['features'].astype(np.float32)
68
- labels = data['labels'].astype(np.int64)
69
- names = list(data['feature_names'])
70
- raw = _clean_npz(raw, names)
71
- selected = [f for f in target_features if f in names]
72
- idx = [names.index(f) for f in selected]
73
- return raw[:, idx], labels, selected
74
-
75
-
76
- # ── Public data loaders ──────────────────────────────────────────────────
77
-
78
- def load_all_pooled(model_name: str = "face_orientation", data_dir: str = None):
79
- """Load all collected_*/*.npz, clean, select features, concatenate.
80
-
81
- Returns (X_all, y_all, all_feature_names).
82
- """
83
- data_dir = data_dir or DATA_DIR
84
- target_features = SELECTED_FEATURES.get(model_name, SELECTED_FEATURES["face_orientation"])
85
- pattern = os.path.join(data_dir, "collected_*", "*.npz")
86
- npz_files = sorted(glob.glob(pattern))
87
-
88
- if not npz_files:
89
- print("[DATA] Warning: No .npz files found. Falling back to synthetic.")
90
- X, y = _generate_synthetic_data(model_name)
91
- return X, y, target_features
92
-
93
- all_X, all_y = [], []
94
- all_names = None
95
- for npz_path in npz_files:
96
- X, y, names = _load_one_npz(npz_path, target_features)
97
- if all_names is None:
98
- all_names = names
99
- all_X.append(X)
100
- all_y.append(y)
101
- print(f"[DATA] + {os.path.basename(npz_path)}: {X.shape[0]} samples")
102
-
103
- X_all = np.concatenate(all_X, axis=0)
104
- y_all = np.concatenate(all_y, axis=0)
105
- print(f"[DATA] Loaded {len(npz_files)} file(s) for '{model_name}': "
106
- f"{X_all.shape[0]} total samples, {X_all.shape[1]} features")
107
- return X_all, y_all, all_names
108
-
109
-
110
- def load_per_person(model_name: str = "face_orientation", data_dir: str = None):
111
- """Load collected_*/*.npz grouped by person (folder name).
112
-
113
- Returns dict { person_name: (X, y) } where X/y are per-person numpy arrays.
114
- Also returns (X_all, y_all) as pooled data.
115
- """
116
- data_dir = data_dir or DATA_DIR
117
- target_features = SELECTED_FEATURES.get(model_name, SELECTED_FEATURES["face_orientation"])
118
- pattern = os.path.join(data_dir, "collected_*", "*.npz")
119
- npz_files = sorted(glob.glob(pattern))
120
-
121
- if not npz_files:
122
- raise FileNotFoundError(f"No .npz files matching {pattern}")
123
-
124
- by_person = {}
125
- all_X, all_y = [], []
126
- for npz_path in npz_files:
127
- folder = os.path.basename(os.path.dirname(npz_path))
128
- person = folder.replace("collected_", "", 1)
129
- X, y, _ = _load_one_npz(npz_path, target_features)
130
- all_X.append(X)
131
- all_y.append(y)
132
- if person not in by_person:
133
- by_person[person] = []
134
- by_person[person].append((X, y))
135
- print(f"[DATA] + {person}/{os.path.basename(npz_path)}: {X.shape[0]} samples")
136
-
137
- for person, chunks in by_person.items():
138
- by_person[person] = (
139
- np.concatenate([c[0] for c in chunks], axis=0),
140
- np.concatenate([c[1] for c in chunks], axis=0),
141
- )
142
-
143
- X_all = np.concatenate(all_X, axis=0)
144
- y_all = np.concatenate(all_y, axis=0)
145
- print(f"[DATA] {len(by_person)} persons, {X_all.shape[0]} total samples, {X_all.shape[1]} features")
146
- return by_person, X_all, y_all
147
-
148
-
149
- def load_raw_npz(npz_path):
150
- """Load a single .npz without cleaning or feature selection. For exploration notebooks."""
151
- data = np.load(npz_path, allow_pickle=True)
152
- features = data['features'].astype(np.float32)
153
- labels = data['labels'].astype(np.int64)
154
- names = list(data['feature_names'])
155
- return features, labels, names
156
-
157
-
158
- # ── Legacy helpers (used by models/mlp/train.py and models/xgboost/train.py) ─
159
-
160
- def _load_real_data(model_name: str):
161
- X, y, _ = load_all_pooled(model_name)
162
- return X, y
163
-
164
-
165
- def _generate_synthetic_data(model_name: str):
166
- target_features = SELECTED_FEATURES.get(model_name, SELECTED_FEATURES["face_orientation"])
167
- n = 500
168
- d = len(target_features)
169
- c = 2
170
- rng = np.random.RandomState(42)
171
- features = rng.randn(n, d).astype(np.float32)
172
- labels = rng.randint(0, c, size=n).astype(np.int64)
173
- print(f"[DATA] Using synthetic data for '{model_name}': {n} samples, {d} features, {c} classes")
174
- return features, labels
175
-
176
-
177
- def _split_and_scale(features, labels, split_ratios, seed, scale):
178
- """Split data into train/val/test (stratified) and optionally scale."""
179
- test_ratio = split_ratios[2]
180
- val_ratio = split_ratios[1] / (split_ratios[0] + split_ratios[1])
181
-
182
- X_train_val, X_test, y_train_val, y_test = train_test_split(
183
- features, labels, test_size=test_ratio, random_state=seed, stratify=labels,
184
- )
185
- X_train, X_val, y_train, y_val = train_test_split(
186
- X_train_val, y_train_val, test_size=val_ratio, random_state=seed, stratify=y_train_val,
187
- )
188
-
189
- scaler = None
190
- if scale:
191
- scaler = StandardScaler()
192
- X_train = scaler.fit_transform(X_train)
193
- X_val = scaler.transform(X_val)
194
- X_test = scaler.transform(X_test)
195
- print("[DATA] Applied StandardScaler (fitted on training split)")
196
-
197
- splits = {
198
- "X_train": X_train, "y_train": y_train,
199
- "X_val": X_val, "y_val": y_val,
200
- "X_test": X_test, "y_test": y_test,
201
- }
202
-
203
- print(f"[DATA] Split (stratified): train={len(y_train)}, val={len(y_val)}, test={len(y_test)}")
204
- return splits, scaler
205
-
206
-
207
- def get_numpy_splits(model_name: str, split_ratios=(0.7, 0.15, 0.15), seed: int = 42, scale: bool = True):
208
- """Return raw numpy arrays for non-PyTorch models (e.g. XGBoost)."""
209
- features, labels = _load_real_data(model_name)
210
- num_features = features.shape[1]
211
- num_classes = int(labels.max()) + 1
212
- splits, scaler = _split_and_scale(features, labels, split_ratios, seed, scale)
213
- return splits, num_features, num_classes, scaler
214
-
215
-
216
- def get_dataloaders(model_name: str, batch_size: int = 32, split_ratios=(0.7, 0.15, 0.15), seed: int = 42, scale: bool = True):
217
- """Return PyTorch DataLoaders for neural-network models."""
218
- features, labels = _load_real_data(model_name)
219
- num_features = features.shape[1]
220
- num_classes = int(labels.max()) + 1
221
- splits, scaler = _split_and_scale(features, labels, split_ratios, seed, scale)
222
-
223
- train_ds = FeatureVectorDataset(splits["X_train"], splits["y_train"])
224
- val_ds = FeatureVectorDataset(splits["X_val"], splits["y_val"])
225
- test_ds = FeatureVectorDataset(splits["X_test"], splits["y_test"])
226
-
227
- train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
228
- val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
229
- test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
230
-
231
- return train_loader, val_loader, test_loader, num_features, num_classes, scaler
232
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docker-compose.yml DELETED
@@ -1,5 +0,0 @@
1
- services:
2
- focus-guard:
3
- build: .
4
- ports:
5
- - "7860:7860"
 
 
 
 
 
 
download_l2cs_weights.py DELETED
@@ -1,37 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Downloads L2CS-Net Gaze360 weights into checkpoints/
3
-
4
- import os
5
- import sys
6
-
7
- CHECKPOINTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints")
8
- DEST = os.path.join(CHECKPOINTS_DIR, "L2CSNet_gaze360.pkl")
9
- GDRIVE_ID = "1dL2Jokb19_SBSHAhKHOxJsmYs5-GoyLo"
10
-
11
-
12
- def main():
13
- if os.path.isfile(DEST):
14
- print(f"[OK] Weights already at {DEST}")
15
- return
16
-
17
- try:
18
- import gdown
19
- except ImportError:
20
- print("gdown not installed. Run: pip install gdown")
21
- sys.exit(1)
22
-
23
- os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
24
- print(f"Downloading L2CS-Net weights to {DEST} ...")
25
- gdown.download(f"https://drive.google.com/uc?id={GDRIVE_ID}", DEST, quiet=False)
26
-
27
- if os.path.isfile(DEST):
28
- print(f"[OK] Downloaded ({os.path.getsize(DEST) / 1024 / 1024:.1f} MB)")
29
- else:
30
- print("[ERR] Download failed. Manual download:")
31
- print(" https://drive.google.com/drive/folders/17p6ORr-JQJcw-eYtG2WGNiuS_qVKwdWd")
32
- print(f" Place L2CSNet_gaze360.pkl in {CHECKPOINTS_DIR}/")
33
- sys.exit(1)
34
-
35
-
36
- if __name__ == "__main__":
37
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eslint.config.js DELETED
@@ -1,29 +0,0 @@
1
- import js from '@eslint/js'
2
- import globals from 'globals'
3
- import reactHooks from 'eslint-plugin-react-hooks'
4
- import reactRefresh from 'eslint-plugin-react-refresh'
5
- import { defineConfig, globalIgnores } from 'eslint/config'
6
-
7
- export default defineConfig([
8
- globalIgnores(['dist']),
9
- {
10
- files: ['**/*.{js,jsx}'],
11
- extends: [
12
- js.configs.recommended,
13
- reactHooks.configs.flat.recommended,
14
- reactRefresh.configs.vite,
15
- ],
16
- languageOptions: {
17
- ecmaVersion: 2020,
18
- globals: globals.browser,
19
- parserOptions: {
20
- ecmaVersion: 'latest',
21
- ecmaFeatures: { jsx: true },
22
- sourceType: 'module',
23
- },
24
- },
25
- rules: {
26
- 'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
27
- },
28
- },
29
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
evaluation/README.md DELETED
@@ -1,46 +0,0 @@
1
- # evaluation/
2
-
3
- Training logs and performance metrics.
4
-
5
- ## 1. Contents
6
-
7
- ```
8
- logs/
9
- ├── face_orientation_training_log.json # MLP (latest run)
10
- ├── mlp_face_orientation_training_log.json # MLP (alternate)
11
- └── xgboost_face_orientation_training_log.json # XGBoost
12
- ```
13
-
14
- ## 2. Log Format
15
-
16
- Each JSON file records the full training history:
17
-
18
- **MLP logs:**
19
- ```json
20
- {
21
- "config": { "epochs": 30, "lr": 0.001, "batch_size": 32, ... },
22
- "history": {
23
- "train_loss": [0.287, 0.260, ...],
24
- "val_loss": [0.256, 0.245, ...],
25
- "train_acc": [0.889, 0.901, ...],
26
- "val_acc": [0.905, 0.909, ...]
27
- },
28
- "test": { "accuracy": 0.929, "f1": 0.929, "roc_auc": 0.971 }
29
- }
30
- ```
31
-
32
- **XGBoost logs:**
33
- ```json
34
- {
35
- "config": { "n_estimators": 600, "max_depth": 8, "learning_rate": 0.149, ... },
36
- "train_losses": [0.577, ...],
37
- "val_losses": [0.576, ...],
38
- "test": { "accuracy": 0.959, "f1": 0.959, "roc_auc": 0.991 }
39
- }
40
- ```
41
-
42
- ## 3. Generated By
43
-
44
- - `python -m models.mlp.train` → writes MLP log
45
- - `python -m models.xgboost.train` → writes XGBoost log
46
- - Notebooks in `notebooks/` also save logs here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
index.html DELETED
@@ -1,17 +0,0 @@
1
- <!doctype html>
2
- <html lang="en">
3
-
4
- <head>
5
- <meta charset="UTF-8" />
6
- <link rel="icon" type="image/svg+xml" href="/vite.svg" />
7
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
8
- <title>Focus Guard</title>
9
- <link href="https://fonts.googleapis.com/css2?family=Nunito:wght@400;700&display=swap" rel="stylesheet">
10
- </head>
11
-
12
- <body>
13
- <div id="root"></div>
14
- <script type="module" src="/src/main.jsx"></script>
15
- </body>
16
-
17
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py DELETED
@@ -1,1210 +0,0 @@
1
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
2
- from fastapi.staticfiles import StaticFiles
3
- from fastapi.responses import FileResponse
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from pydantic import BaseModel
6
- from typing import Optional, List, Any
7
- import base64
8
- import cv2
9
- import numpy as np
10
- import aiosqlite
11
- import json
12
- from datetime import datetime, timedelta
13
- import math
14
- import os
15
- from pathlib import Path
16
- from typing import Callable
17
- import asyncio
18
- import concurrent.futures
19
- import threading
20
-
21
- from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
22
- from av import VideoFrame
23
-
24
- from mediapipe.tasks.python.vision import FaceLandmarksConnections
25
- from ui.pipeline import (
26
- FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline,
27
- L2CSPipeline, is_l2cs_weights_available,
28
- )
29
- from models.face_mesh import FaceMeshDetector
30
-
31
- # ================ FACE MESH DRAWING (server-side, for WebRTC) ================
32
-
33
- _FONT = cv2.FONT_HERSHEY_SIMPLEX
34
- _CYAN = (255, 255, 0)
35
- _GREEN = (0, 255, 0)
36
- _MAGENTA = (255, 0, 255)
37
- _ORANGE = (0, 165, 255)
38
- _RED = (0, 0, 255)
39
- _WHITE = (255, 255, 255)
40
- _LIGHT_GREEN = (144, 238, 144)
41
-
42
- _TESSELATION_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION]
43
- _CONTOUR_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS]
44
- _LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46]
45
- _RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276]
46
- _NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2]
47
- _LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61]
48
- _LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78]
49
- _LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145]
50
- _RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380]
51
-
52
-
53
- def _lm_px(lm, idx, w, h):
54
- return (int(lm[idx, 0] * w), int(lm[idx, 1] * h))
55
-
56
-
57
- def _draw_polyline(frame, lm, indices, w, h, color, thickness):
58
- for i in range(len(indices) - 1):
59
- cv2.line(frame, _lm_px(lm, indices[i], w, h), _lm_px(lm, indices[i + 1], w, h), color, thickness, cv2.LINE_AA)
60
-
61
-
62
- def _draw_face_mesh(frame, lm, w, h):
63
- """Draw tessellation, contours, eyebrows, nose, lips, eyes, irises, gaze lines."""
64
- # Tessellation (gray triangular grid, semi-transparent)
65
- overlay = frame.copy()
66
- for s, e in _TESSELATION_CONNS:
67
- cv2.line(overlay, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), (200, 200, 200), 1, cv2.LINE_AA)
68
- cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
69
- # Contours
70
- for s, e in _CONTOUR_CONNS:
71
- cv2.line(frame, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), _CYAN, 1, cv2.LINE_AA)
72
- # Eyebrows
73
- _draw_polyline(frame, lm, _LEFT_EYEBROW, w, h, _LIGHT_GREEN, 2)
74
- _draw_polyline(frame, lm, _RIGHT_EYEBROW, w, h, _LIGHT_GREEN, 2)
75
- # Nose
76
- _draw_polyline(frame, lm, _NOSE_BRIDGE, w, h, _ORANGE, 1)
77
- # Lips
78
- _draw_polyline(frame, lm, _LIPS_OUTER, w, h, _MAGENTA, 1)
79
- _draw_polyline(frame, lm, _LIPS_INNER, w, h, (200, 0, 200), 1)
80
- # Eyes
81
- left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32)
82
- cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA)
83
- right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32)
84
- cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA)
85
- # EAR key points
86
- for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
87
- for idx in indices:
88
- cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA)
89
- # Irises + gaze lines
90
- for iris_idx, eye_inner, eye_outer in [
91
- (FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
92
- (FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
93
- ]:
94
- iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32)
95
- center = iris_pts[0]
96
- if len(iris_pts) >= 5:
97
- radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
98
- radius = max(int(np.mean(radii)), 2)
99
- cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA)
100
- cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA)
101
- eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w)
102
- eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h)
103
- dx, dy = center[0] - eye_cx, center[1] - eye_cy
104
- cv2.line(frame, tuple(center), (int(center[0] + dx * 3), int(center[1] + dy * 3)), _RED, 1, cv2.LINE_AA)
105
-
106
-
107
- def _draw_hud(frame, result, model_name):
108
- """Draw status bar and detail overlay matching live_demo.py."""
109
- h, w = frame.shape[:2]
110
- is_focused = result["is_focused"]
111
- status = "FOCUSED" if is_focused else "NOT FOCUSED"
112
- color = _GREEN if is_focused else _RED
113
-
114
- # Top bar
115
- cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1)
116
- cv2.putText(frame, status, (10, 28), _FONT, 0.8, color, 2, cv2.LINE_AA)
117
- cv2.putText(frame, model_name.upper(), (w - 150, 28), _FONT, 0.45, _WHITE, 1, cv2.LINE_AA)
118
-
119
- # Detail line
120
- conf = result.get("mlp_prob", result.get("raw_score", 0.0))
121
- mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else ""
122
- sf = result.get("s_face", 0)
123
- se = result.get("s_eye", 0)
124
- detail = f"conf:{conf:.2f} S_face:{sf:.2f} S_eye:{se:.2f}{mar_s}"
125
- cv2.putText(frame, detail, (10, 48), _FONT, 0.4, _WHITE, 1, cv2.LINE_AA)
126
-
127
- # Head pose (top right)
128
- if result.get("yaw") is not None:
129
- cv2.putText(frame, f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}",
130
- (w - 280, 48), _FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA)
131
-
132
- # Yawn indicator
133
- if result.get("is_yawning"):
134
- cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA)
135
-
136
- # Landmark indices used for face mesh drawing on client (union of all groups).
137
- # Sending only these instead of all 478 saves ~60% of the landmarks payload.
138
- _MESH_INDICES = sorted(set(
139
- [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] # face oval
140
- + [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246] # left eye
141
- + [362,382,381,380,374,373,390,249,263,466,388,387,386,385,384,398] # right eye
142
- + [468,469,470,471,472, 473,474,475,476,477] # irises
143
- + [70,63,105,66,107,55,65,52,53,46] # left eyebrow
144
- + [300,293,334,296,336,285,295,282,283,276] # right eyebrow
145
- + [6,197,195,5,4,1,19,94,2] # nose bridge
146
- + [61,146,91,181,84,17,314,405,321,375,291,409,270,269,267,0,37,39,40,185] # lips outer
147
- + [78,95,88,178,87,14,317,402,318,324,308,415,310,311,312,13,82,81,80,191] # lips inner
148
- + [33,160,158,133,153,145] # left EAR key points
149
- + [362,385,387,263,373,380] # right EAR key points
150
- ))
151
- # Build a lookup: original_index -> position in sparse array, so client can reconstruct.
152
- _MESH_INDEX_SET = set(_MESH_INDICES)
153
-
154
- # Initialize FastAPI app
155
- app = FastAPI(title="Focus Guard API")
156
-
157
- # Add CORS middleware
158
- app.add_middleware(
159
- CORSMiddleware,
160
- allow_origins=["*"],
161
- allow_credentials=True,
162
- allow_methods=["*"],
163
- allow_headers=["*"],
164
- )
165
-
166
- # Global variables
167
- db_path = "focus_guard.db"
168
- pcs = set()
169
- _cached_model_name = "mlp" # in-memory cache, updated via /api/settings
170
- _l2cs_boost_enabled = False # when True, L2CS runs alongside the base model
171
-
172
- async def _wait_for_ice_gathering(pc: RTCPeerConnection):
173
- if pc.iceGatheringState == "complete":
174
- return
175
- done = asyncio.Event()
176
-
177
- @pc.on("icegatheringstatechange")
178
- def _on_state_change():
179
- if pc.iceGatheringState == "complete":
180
- done.set()
181
-
182
- await done.wait()
183
-
184
- # ================ DATABASE MODELS ================
185
-
186
- async def init_database():
187
- """Initialize SQLite database with required tables"""
188
- async with aiosqlite.connect(db_path) as db:
189
- # FocusSessions table
190
- await db.execute("""
191
- CREATE TABLE IF NOT EXISTS focus_sessions (
192
- id INTEGER PRIMARY KEY AUTOINCREMENT,
193
- start_time TIMESTAMP NOT NULL,
194
- end_time TIMESTAMP,
195
- duration_seconds INTEGER DEFAULT 0,
196
- focus_score REAL DEFAULT 0.0,
197
- total_frames INTEGER DEFAULT 0,
198
- focused_frames INTEGER DEFAULT 0,
199
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
200
- )
201
- """)
202
-
203
- # FocusEvents table
204
- await db.execute("""
205
- CREATE TABLE IF NOT EXISTS focus_events (
206
- id INTEGER PRIMARY KEY AUTOINCREMENT,
207
- session_id INTEGER NOT NULL,
208
- timestamp TIMESTAMP NOT NULL,
209
- is_focused BOOLEAN NOT NULL,
210
- confidence REAL NOT NULL,
211
- detection_data TEXT,
212
- FOREIGN KEY (session_id) REFERENCES focus_sessions (id)
213
- )
214
- """)
215
-
216
- # UserSettings table
217
- await db.execute("""
218
- CREATE TABLE IF NOT EXISTS user_settings (
219
- id INTEGER PRIMARY KEY CHECK (id = 1),
220
- sensitivity INTEGER DEFAULT 6,
221
- notification_enabled BOOLEAN DEFAULT 1,
222
- notification_threshold INTEGER DEFAULT 30,
223
- frame_rate INTEGER DEFAULT 30,
224
- model_name TEXT DEFAULT 'mlp'
225
- )
226
- """)
227
-
228
- # Insert default settings if not exists
229
- await db.execute("""
230
- INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name)
231
- VALUES (1, 6, 1, 30, 30, 'mlp')
232
- """)
233
-
234
- await db.commit()
235
-
236
- # ================ PYDANTIC MODELS ================
237
-
238
- class SessionCreate(BaseModel):
239
- pass
240
-
241
- class SessionEnd(BaseModel):
242
- session_id: int
243
-
244
- class SettingsUpdate(BaseModel):
245
- sensitivity: Optional[int] = None
246
- notification_enabled: Optional[bool] = None
247
- notification_threshold: Optional[int] = None
248
- frame_rate: Optional[int] = None
249
- model_name: Optional[str] = None
250
- l2cs_boost: Optional[bool] = None
251
-
252
- class VideoTransformTrack(VideoStreamTrack):
253
- def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
254
- super().__init__()
255
- self.track = track
256
- self.session_id = session_id
257
- self.get_channel = get_channel
258
- self.last_inference_time = 0
259
- self.min_inference_interval = 1 / 60
260
- self.last_frame = None
261
-
262
- async def recv(self):
263
- frame = await self.track.recv()
264
- img = frame.to_ndarray(format="bgr24")
265
- if img is None:
266
- return frame
267
-
268
- # Normalize size for inference/drawing
269
- img = cv2.resize(img, (640, 480))
270
-
271
- now = datetime.now().timestamp()
272
- do_infer = (now - self.last_inference_time) >= self.min_inference_interval
273
-
274
- if do_infer:
275
- self.last_inference_time = now
276
-
277
- model_name = _cached_model_name
278
- if model_name == "l2cs" and pipelines.get("l2cs") is None:
279
- _ensure_l2cs()
280
- if model_name not in pipelines or pipelines.get(model_name) is None:
281
- model_name = 'mlp'
282
- active_pipeline = pipelines.get(model_name)
283
-
284
- if active_pipeline is not None:
285
- loop = asyncio.get_event_loop()
286
- out = await loop.run_in_executor(
287
- _inference_executor,
288
- _process_frame_safe,
289
- active_pipeline,
290
- img,
291
- model_name,
292
- )
293
- is_focused = out["is_focused"]
294
- confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
295
- metadata = {"s_face": out.get("s_face", 0.0), "s_eye": out.get("s_eye", 0.0), "mar": out.get("mar", 0.0), "model": model_name}
296
-
297
- # Draw face mesh + HUD on the video frame
298
- h_f, w_f = img.shape[:2]
299
- lm = out.get("landmarks")
300
- if lm is not None:
301
- _draw_face_mesh(img, lm, w_f, h_f)
302
- _draw_hud(img, out, model_name)
303
- else:
304
- is_focused = False
305
- confidence = 0.0
306
- metadata = {"model": model_name}
307
- cv2.rectangle(img, (0, 0), (img.shape[1], 55), (0, 0, 0), -1)
308
- cv2.putText(img, "NO MODEL", (10, 28), _FONT, 0.8, _RED, 2, cv2.LINE_AA)
309
-
310
- if self.session_id:
311
- await store_focus_event(self.session_id, is_focused, confidence, metadata)
312
-
313
- channel = self.get_channel()
314
- if channel and channel.readyState == "open":
315
- try:
316
- channel.send(json.dumps({"type": "detection", "focused": is_focused, "confidence": round(confidence, 3), "detections": detections}))
317
- except Exception:
318
- pass
319
-
320
- self.last_frame = img
321
- elif self.last_frame is not None:
322
- img = self.last_frame
323
-
324
- new_frame = VideoFrame.from_ndarray(img, format="bgr24")
325
- new_frame.pts = frame.pts
326
- new_frame.time_base = frame.time_base
327
- return new_frame
328
-
329
- # ================ DATABASE OPERATIONS ================
330
-
331
- async def create_session():
332
- async with aiosqlite.connect(db_path) as db:
333
- cursor = await db.execute(
334
- "INSERT INTO focus_sessions (start_time) VALUES (?)",
335
- (datetime.now().isoformat(),)
336
- )
337
- await db.commit()
338
- return cursor.lastrowid
339
-
340
- async def end_session(session_id: int):
341
- async with aiosqlite.connect(db_path) as db:
342
- cursor = await db.execute(
343
- "SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?",
344
- (session_id,)
345
- )
346
- row = await cursor.fetchone()
347
-
348
- if not row:
349
- return None
350
-
351
- start_time_str, total_frames, focused_frames = row
352
- start_time = datetime.fromisoformat(start_time_str)
353
- end_time = datetime.now()
354
- duration = (end_time - start_time).total_seconds()
355
- focus_score = focused_frames / total_frames if total_frames > 0 else 0.0
356
-
357
- await db.execute("""
358
- UPDATE focus_sessions
359
- SET end_time = ?, duration_seconds = ?, focus_score = ?
360
- WHERE id = ?
361
- """, (end_time.isoformat(), int(duration), focus_score, session_id))
362
-
363
- await db.commit()
364
-
365
- return {
366
- 'session_id': session_id,
367
- 'start_time': start_time_str,
368
- 'end_time': end_time.isoformat(),
369
- 'duration_seconds': int(duration),
370
- 'focus_score': round(focus_score, 3),
371
- 'total_frames': total_frames,
372
- 'focused_frames': focused_frames
373
- }
374
-
375
- async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict):
376
- async with aiosqlite.connect(db_path) as db:
377
- await db.execute("""
378
- INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
379
- VALUES (?, ?, ?, ?, ?)
380
- """, (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
381
-
382
- await db.execute("""
383
- UPDATE focus_sessions
384
- SET total_frames = total_frames + 1,
385
- focused_frames = focused_frames + ?
386
- WHERE id = ?
387
- """, (1 if is_focused else 0, session_id))
388
- await db.commit()
389
-
390
-
391
- class _EventBuffer:
392
- """Buffer focus events in memory and flush to DB in batches to avoid per-frame DB writes."""
393
-
394
- def __init__(self, flush_interval: float = 2.0):
395
- self._buf: list = []
396
- self._lock = asyncio.Lock()
397
- self._flush_interval = flush_interval
398
- self._task: asyncio.Task | None = None
399
- self._total_frames = 0
400
- self._focused_frames = 0
401
-
402
- def start(self):
403
- if self._task is None:
404
- self._task = asyncio.create_task(self._flush_loop())
405
-
406
- async def stop(self):
407
- if self._task:
408
- self._task.cancel()
409
- try:
410
- await self._task
411
- except asyncio.CancelledError:
412
- pass
413
- self._task = None
414
- await self._flush()
415
-
416
- def add(self, session_id: int, is_focused: bool, confidence: float, metadata: dict):
417
- self._buf.append((session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata)))
418
- self._total_frames += 1
419
- if is_focused:
420
- self._focused_frames += 1
421
-
422
- async def _flush_loop(self):
423
- while True:
424
- await asyncio.sleep(self._flush_interval)
425
- await self._flush()
426
-
427
- async def _flush(self):
428
- async with self._lock:
429
- if not self._buf:
430
- return
431
- batch = self._buf[:]
432
- total = self._total_frames
433
- focused = self._focused_frames
434
- self._buf.clear()
435
- self._total_frames = 0
436
- self._focused_frames = 0
437
-
438
- if not batch:
439
- return
440
-
441
- session_id = batch[0][0]
442
- try:
443
- async with aiosqlite.connect(db_path) as db:
444
- await db.executemany("""
445
- INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data)
446
- VALUES (?, ?, ?, ?, ?)
447
- """, batch)
448
- await db.execute("""
449
- UPDATE focus_sessions
450
- SET total_frames = total_frames + ?,
451
- focused_frames = focused_frames + ?
452
- WHERE id = ?
453
- """, (total, focused, session_id))
454
- await db.commit()
455
- except Exception as e:
456
- print(f"[DB] Flush error: {e}")
457
-
458
- # ================ STARTUP/SHUTDOWN ================
459
-
460
- pipelines = {
461
- "geometric": None,
462
- "mlp": None,
463
- "hybrid": None,
464
- "xgboost": None,
465
- "l2cs": None,
466
- }
467
-
468
- # Thread pool for CPU-bound inference so the event loop stays responsive.
469
- _inference_executor = concurrent.futures.ThreadPoolExecutor(
470
- max_workers=4,
471
- thread_name_prefix="inference",
472
- )
473
- # One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
474
- # multiple frames are processed in parallel by the thread pool.
475
- _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost", "l2cs")}
476
-
477
- _l2cs_load_lock = threading.Lock()
478
- _l2cs_error: str | None = None
479
-
480
-
481
- def _ensure_l2cs():
482
- # lazy-load L2CS on first use, double-checked locking
483
- global _l2cs_error
484
- if pipelines["l2cs"] is not None:
485
- return True
486
- with _l2cs_load_lock:
487
- if pipelines["l2cs"] is not None:
488
- return True
489
- if not is_l2cs_weights_available():
490
- _l2cs_error = "Weights not found"
491
- return False
492
- try:
493
- pipelines["l2cs"] = L2CSPipeline()
494
- _l2cs_error = None
495
- print("[OK] L2CSPipeline lazy-loaded")
496
- return True
497
- except Exception as e:
498
- _l2cs_error = str(e)
499
- print(f"[ERR] L2CS lazy-load failed: {e}")
500
- return False
501
-
502
-
503
- def _process_frame_safe(pipeline, frame, model_name):
504
- with _pipeline_locks[model_name]:
505
- return pipeline.process_frame(frame)
506
-
507
-
508
- _BOOST_BASE_W = 0.35
509
- _BOOST_L2CS_W = 0.65
510
- _BOOST_VETO = 0.38 # L2CS below this -> forced not-focused
511
-
512
-
513
- def _process_frame_with_l2cs_boost(base_pipeline, frame, base_model_name):
514
- # run base model
515
- with _pipeline_locks[base_model_name]:
516
- base_out = base_pipeline.process_frame(frame)
517
-
518
- l2cs_pipe = pipelines.get("l2cs")
519
- if l2cs_pipe is None:
520
- base_out["boost_active"] = False
521
- return base_out
522
-
523
- # run L2CS
524
- with _pipeline_locks["l2cs"]:
525
- l2cs_out = l2cs_pipe.process_frame(frame)
526
-
527
- base_score = base_out.get("mlp_prob", base_out.get("raw_score", 0.0))
528
- l2cs_score = l2cs_out.get("raw_score", 0.0)
529
-
530
- # veto: gaze clearly off-screen overrides base model
531
- if l2cs_score < _BOOST_VETO:
532
- fused_score = l2cs_score * 0.8
533
- is_focused = False
534
- else:
535
- fused_score = _BOOST_BASE_W * base_score + _BOOST_L2CS_W * l2cs_score
536
- is_focused = fused_score >= 0.52
537
-
538
- base_out["raw_score"] = fused_score
539
- base_out["is_focused"] = is_focused
540
- base_out["boost_active"] = True
541
- base_out["base_score"] = round(base_score, 3)
542
- base_out["l2cs_score"] = round(l2cs_score, 3)
543
-
544
- if l2cs_out.get("gaze_yaw") is not None:
545
- base_out["gaze_yaw"] = l2cs_out["gaze_yaw"]
546
- base_out["gaze_pitch"] = l2cs_out["gaze_pitch"]
547
-
548
- return base_out
549
-
550
- @app.on_event("startup")
551
- async def startup_event():
552
- global pipelines, _cached_model_name
553
- print(" Starting Focus Guard API...")
554
- await init_database()
555
- # Load cached model name from DB
556
- async with aiosqlite.connect(db_path) as db:
557
- cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
558
- row = await cursor.fetchone()
559
- if row:
560
- _cached_model_name = row[0]
561
- print("[OK] Database initialized")
562
-
563
- try:
564
- pipelines["geometric"] = FaceMeshPipeline()
565
- print("[OK] FaceMeshPipeline (geometric) loaded")
566
- except Exception as e:
567
- print(f"[WARN] FaceMeshPipeline unavailable: {e}")
568
-
569
- try:
570
- pipelines["mlp"] = MLPPipeline()
571
- print("[OK] MLPPipeline loaded")
572
- except Exception as e:
573
- print(f"[ERR] Failed to load MLPPipeline: {e}")
574
-
575
- try:
576
- pipelines["hybrid"] = HybridFocusPipeline()
577
- print("[OK] HybridFocusPipeline loaded")
578
- except Exception as e:
579
- print(f"[WARN] HybridFocusPipeline unavailable: {e}")
580
-
581
- try:
582
- pipelines["xgboost"] = XGBoostPipeline()
583
- print("[OK] XGBoostPipeline loaded")
584
- except Exception as e:
585
- print(f"[ERR] Failed to load XGBoostPipeline: {e}")
586
-
587
- if is_l2cs_weights_available():
588
- print("[OK] L2CS weights found — pipeline will be lazy-loaded on first use")
589
- else:
590
- print("[WARN] L2CS weights not found — l2cs model unavailable")
591
-
592
- @app.on_event("shutdown")
593
- async def shutdown_event():
594
- _inference_executor.shutdown(wait=False)
595
- print(" Shutting down Focus Guard API...")
596
-
597
- # ================ WEBRTC SIGNALING ================
598
-
599
- @app.post("/api/webrtc/offer")
600
- async def webrtc_offer(offer: dict):
601
- try:
602
- print(f"Received WebRTC offer")
603
-
604
- pc = RTCPeerConnection()
605
- pcs.add(pc)
606
-
607
- session_id = await create_session()
608
- print(f"Created session: {session_id}")
609
-
610
- channel_ref = {"channel": None}
611
-
612
- @pc.on("datachannel")
613
- def on_datachannel(channel):
614
- print(f"Data channel opened")
615
- channel_ref["channel"] = channel
616
-
617
- @pc.on("track")
618
- def on_track(track):
619
- print(f"Received track: {track.kind}")
620
- if track.kind == "video":
621
- local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"])
622
- pc.addTrack(local_track)
623
- print(f"Video track added")
624
-
625
- @track.on("ended")
626
- async def on_ended():
627
- print(f"Track ended")
628
-
629
- @pc.on("connectionstatechange")
630
- async def on_connectionstatechange():
631
- print(f"Connection state changed: {pc.connectionState}")
632
- if pc.connectionState in ("failed", "closed", "disconnected"):
633
- try:
634
- await end_session(session_id)
635
- except Exception as e:
636
- print(f"⚠Error ending session: {e}")
637
- pcs.discard(pc)
638
- await pc.close()
639
-
640
- await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"]))
641
- print(f"Remote description set")
642
-
643
- answer = await pc.createAnswer()
644
- await pc.setLocalDescription(answer)
645
- print(f"Answer created")
646
-
647
- await _wait_for_ice_gathering(pc)
648
- print(f"ICE gathering complete")
649
-
650
- return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id}
651
-
652
- except Exception as e:
653
- print(f"WebRTC offer error: {e}")
654
- import traceback
655
- traceback.print_exc()
656
- raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}")
657
-
658
- # ================ WEBSOCKET ================
659
-
660
- @app.websocket("/ws/video")
661
- async def websocket_endpoint(websocket: WebSocket):
662
- from models.gaze_calibration import GazeCalibration
663
- from models.gaze_eye_fusion import GazeEyeFusion
664
-
665
- await websocket.accept()
666
- session_id = None
667
- frame_count = 0
668
- running = True
669
- event_buffer = _EventBuffer(flush_interval=2.0)
670
-
671
- # Calibration state (per-connection)
672
- _cal: dict = {"cal": None, "collecting": False, "fusion": None}
673
-
674
- # Latest frame slot — only the most recent frame is kept, older ones are dropped.
675
- _slot = {"frame": None}
676
- _frame_ready = asyncio.Event()
677
-
678
- async def _receive_loop():
679
- """Receive messages as fast as possible. Binary = frame, text = control."""
680
- nonlocal session_id, running
681
- try:
682
- while running:
683
- msg = await websocket.receive()
684
- msg_type = msg.get("type", "")
685
-
686
- if msg_type == "websocket.disconnect":
687
- running = False
688
- _frame_ready.set()
689
- return
690
-
691
- # Binary message → JPEG frame (fast path, no base64)
692
- raw_bytes = msg.get("bytes")
693
- if raw_bytes is not None and len(raw_bytes) > 0:
694
- _slot["frame"] = raw_bytes
695
- _frame_ready.set()
696
- continue
697
-
698
- # Text message → JSON control command (or legacy base64 frame)
699
- text = msg.get("text")
700
- if not text:
701
- continue
702
- data = json.loads(text)
703
-
704
- if data["type"] == "frame":
705
- _slot["frame"] = base64.b64decode(data["image"])
706
- _frame_ready.set()
707
-
708
- elif data["type"] == "start_session":
709
- session_id = await create_session()
710
- event_buffer.start()
711
- for p in pipelines.values():
712
- if p is not None and hasattr(p, "reset_session"):
713
- p.reset_session()
714
- await websocket.send_json({"type": "session_started", "session_id": session_id})
715
-
716
- elif data["type"] == "end_session":
717
- if session_id:
718
- await event_buffer.stop()
719
- summary = await end_session(session_id)
720
- if summary:
721
- await websocket.send_json({"type": "session_ended", "summary": summary})
722
- session_id = None
723
-
724
- # ---- Calibration commands ----
725
- elif data["type"] == "calibration_start":
726
- loop = asyncio.get_event_loop()
727
- await loop.run_in_executor(_inference_executor, _ensure_l2cs)
728
- _cal["cal"] = GazeCalibration()
729
- _cal["collecting"] = True
730
- _cal["fusion"] = None
731
- cal = _cal["cal"]
732
- await websocket.send_json({
733
- "type": "calibration_started",
734
- "num_points": cal.num_points,
735
- "target": list(cal.current_target),
736
- "index": cal.current_index,
737
- })
738
-
739
- elif data["type"] == "calibration_next":
740
- cal = _cal.get("cal")
741
- if cal is not None:
742
- more = cal.advance()
743
- if more:
744
- await websocket.send_json({
745
- "type": "calibration_point",
746
- "target": list(cal.current_target),
747
- "index": cal.current_index,
748
- })
749
- else:
750
- _cal["collecting"] = False
751
- ok = cal.fit()
752
- if ok:
753
- _cal["fusion"] = GazeEyeFusion(cal)
754
- await websocket.send_json({"type": "calibration_done", "success": True})
755
- else:
756
- await websocket.send_json({"type": "calibration_done", "success": False, "error": "Not enough samples"})
757
-
758
- elif data["type"] == "calibration_cancel":
759
- _cal["cal"] = None
760
- _cal["collecting"] = False
761
- _cal["fusion"] = None
762
- await websocket.send_json({"type": "calibration_cancelled"})
763
-
764
- except WebSocketDisconnect:
765
- running = False
766
- _frame_ready.set()
767
- except Exception as e:
768
- print(f"[WS] receive error: {e}")
769
- running = False
770
- _frame_ready.set()
771
-
772
- async def _process_loop():
773
- """Process only the latest frame, dropping stale ones."""
774
- nonlocal frame_count, running
775
- loop = asyncio.get_event_loop()
776
- while running:
777
- await _frame_ready.wait()
778
- _frame_ready.clear()
779
- if not running:
780
- return
781
-
782
- raw = _slot["frame"]
783
- _slot["frame"] = None
784
- if raw is None:
785
- continue
786
-
787
- try:
788
- nparr = np.frombuffer(raw, np.uint8)
789
- frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
790
- if frame is None:
791
- continue
792
- frame = cv2.resize(frame, (640, 480))
793
-
794
- # During calibration collection, always use L2CS
795
- collecting = _cal.get("collecting", False)
796
- if collecting:
797
- if pipelines.get("l2cs") is None:
798
- await loop.run_in_executor(_inference_executor, _ensure_l2cs)
799
- use_model = "l2cs" if pipelines.get("l2cs") is not None else _cached_model_name
800
- else:
801
- use_model = _cached_model_name
802
-
803
- model_name = use_model
804
- if model_name == "l2cs" and pipelines.get("l2cs") is None:
805
- await loop.run_in_executor(_inference_executor, _ensure_l2cs)
806
- if model_name not in pipelines or pipelines.get(model_name) is None:
807
- model_name = "mlp"
808
- active_pipeline = pipelines.get(model_name)
809
-
810
- # L2CS boost: run L2CS alongside base model
811
- use_boost = (
812
- _l2cs_boost_enabled
813
- and model_name != "l2cs"
814
- and pipelines.get("l2cs") is not None
815
- and not collecting
816
- )
817
-
818
- landmarks_list = None
819
- out = None
820
- if active_pipeline is not None:
821
- if use_boost:
822
- out = await loop.run_in_executor(
823
- _inference_executor,
824
- _process_frame_with_l2cs_boost,
825
- active_pipeline,
826
- frame,
827
- model_name,
828
- )
829
- else:
830
- out = await loop.run_in_executor(
831
- _inference_executor,
832
- _process_frame_safe,
833
- active_pipeline,
834
- frame,
835
- model_name,
836
- )
837
- is_focused = out["is_focused"]
838
- confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
839
-
840
- lm = out.get("landmarks")
841
- if lm is not None:
842
- landmarks_list = [
843
- [round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
844
- for i in range(lm.shape[0])
845
- ]
846
-
847
- # Calibration sample collection (L2CS gaze angles)
848
- if collecting and _cal.get("cal") is not None:
849
- pipe_yaw = out.get("gaze_yaw")
850
- pipe_pitch = out.get("gaze_pitch")
851
- if pipe_yaw is not None and pipe_pitch is not None:
852
- _cal["cal"].collect_sample(pipe_yaw, pipe_pitch)
853
-
854
- # Gaze fusion (when L2CS active + calibration fitted)
855
- fusion = _cal.get("fusion")
856
- if (
857
- fusion is not None
858
- and model_name == "l2cs"
859
- and out.get("gaze_yaw") is not None
860
- ):
861
- fuse = fusion.update(
862
- out["gaze_yaw"], out["gaze_pitch"], lm
863
- )
864
- is_focused = fuse["focused"]
865
- confidence = fuse["focus_score"]
866
-
867
- if session_id:
868
- metadata = {
869
- "s_face": out.get("s_face", 0.0),
870
- "s_eye": out.get("s_eye", 0.0),
871
- "mar": out.get("mar", 0.0),
872
- "model": model_name,
873
- }
874
- event_buffer.add(session_id, is_focused, confidence, metadata)
875
- else:
876
- is_focused = False
877
- confidence = 0.0
878
-
879
- resp = {
880
- "type": "detection",
881
- "focused": is_focused,
882
- "confidence": round(confidence, 3),
883
- "model": model_name,
884
- "fc": frame_count,
885
- }
886
- if out is not None:
887
- if out.get("yaw") is not None:
888
- resp["yaw"] = round(out["yaw"], 1)
889
- resp["pitch"] = round(out["pitch"], 1)
890
- resp["roll"] = round(out["roll"], 1)
891
- if out.get("mar") is not None:
892
- resp["mar"] = round(out["mar"], 3)
893
- resp["sf"] = round(out.get("s_face", 0), 3)
894
- resp["se"] = round(out.get("s_eye", 0), 3)
895
-
896
- # Gaze fusion fields (L2CS standalone or boost mode)
897
- fusion = _cal.get("fusion")
898
- has_gaze = out.get("gaze_yaw") is not None
899
- if fusion is not None and has_gaze and (model_name == "l2cs" or use_boost):
900
- fuse = fusion.update(out["gaze_yaw"], out["gaze_pitch"], out.get("landmarks"))
901
- resp["gaze_x"] = fuse["gaze_x"]
902
- resp["gaze_y"] = fuse["gaze_y"]
903
- resp["on_screen"] = fuse["on_screen"]
904
- if model_name == "l2cs":
905
- resp["focused"] = fuse["focused"]
906
- resp["confidence"] = round(fuse["focus_score"], 3)
907
-
908
- if out.get("boost_active"):
909
- resp["boost"] = True
910
- resp["base_score"] = out.get("base_score", 0)
911
- resp["l2cs_score"] = out.get("l2cs_score", 0)
912
-
913
- if landmarks_list is not None:
914
- resp["lm"] = landmarks_list
915
- await websocket.send_json(resp)
916
- frame_count += 1
917
- except Exception as e:
918
- print(f"[WS] process error: {e}")
919
-
920
- try:
921
- await asyncio.gather(_receive_loop(), _process_loop())
922
- except Exception:
923
- pass
924
- finally:
925
- running = False
926
- if session_id:
927
- await event_buffer.stop()
928
- await end_session(session_id)
929
-
930
- # ================ API ENDPOINTS ================
931
-
932
- @app.post("/api/sessions/start")
933
- async def api_start_session():
934
- session_id = await create_session()
935
- return {"session_id": session_id}
936
-
937
- @app.post("/api/sessions/end")
938
- async def api_end_session(data: SessionEnd):
939
- summary = await end_session(data.session_id)
940
- if not summary: raise HTTPException(status_code=404, detail="Session not found")
941
- return summary
942
-
943
- @app.get("/api/sessions")
944
- async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0):
945
- async with aiosqlite.connect(db_path) as db:
946
- db.row_factory = aiosqlite.Row
947
-
948
- # NEW: If importing/exporting all, remove limit if special flag or high limit
949
- # For simplicity: if limit is -1, return all
950
- limit_clause = "LIMIT ? OFFSET ?"
951
- params = []
952
-
953
- base_query = "SELECT * FROM focus_sessions"
954
- where_clause = ""
955
-
956
- if filter == "today":
957
- date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
958
- where_clause = " WHERE start_time >= ?"
959
- params.append(date_filter.isoformat())
960
- elif filter == "week":
961
- date_filter = datetime.now() - timedelta(days=7)
962
- where_clause = " WHERE start_time >= ?"
963
- params.append(date_filter.isoformat())
964
- elif filter == "month":
965
- date_filter = datetime.now() - timedelta(days=30)
966
- where_clause = " WHERE start_time >= ?"
967
- params.append(date_filter.isoformat())
968
- elif filter == "all":
969
- # Just ensure we only get completed sessions or all sessions
970
- where_clause = " WHERE end_time IS NOT NULL"
971
-
972
- query = f"{base_query}{where_clause} ORDER BY start_time DESC"
973
-
974
- # Handle Limit for Exports
975
- if limit == -1:
976
- # No limit clause for export
977
- pass
978
- else:
979
- query += f" {limit_clause}"
980
- params.extend([limit, offset])
981
-
982
- cursor = await db.execute(query, tuple(params))
983
- rows = await cursor.fetchall()
984
- return [dict(row) for row in rows]
985
-
986
- # --- NEW: Import Endpoint ---
987
- @app.post("/api/import")
988
- async def import_sessions(sessions: List[dict]):
989
- count = 0
990
- try:
991
- async with aiosqlite.connect(db_path) as db:
992
- for session in sessions:
993
- # Use .get() to handle potential missing fields from older versions or edits
994
- await db.execute("""
995
- INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at)
996
- VALUES (?, ?, ?, ?, ?, ?, ?)
997
- """, (
998
- session.get('start_time'),
999
- session.get('end_time'),
1000
- session.get('duration_seconds', 0),
1001
- session.get('focus_score', 0.0),
1002
- session.get('total_frames', 0),
1003
- session.get('focused_frames', 0),
1004
- session.get('created_at', session.get('start_time'))
1005
- ))
1006
- count += 1
1007
- await db.commit()
1008
- return {"status": "success", "count": count}
1009
- except Exception as e:
1010
- print(f"Import Error: {e}")
1011
- return {"status": "error", "message": str(e)}
1012
-
1013
- # --- NEW: Clear History Endpoint ---
1014
- @app.delete("/api/history")
1015
- async def clear_history():
1016
- try:
1017
- async with aiosqlite.connect(db_path) as db:
1018
- # Delete events first (foreign key good practice)
1019
- await db.execute("DELETE FROM focus_events")
1020
- await db.execute("DELETE FROM focus_sessions")
1021
- await db.commit()
1022
- return {"status": "success", "message": "History cleared"}
1023
- except Exception as e:
1024
- return {"status": "error", "message": str(e)}
1025
-
1026
- @app.get("/api/sessions/{session_id}")
1027
- async def get_session(session_id: int):
1028
- async with aiosqlite.connect(db_path) as db:
1029
- db.row_factory = aiosqlite.Row
1030
- cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,))
1031
- row = await cursor.fetchone()
1032
- if not row: raise HTTPException(status_code=404, detail="Session not found")
1033
- session = dict(row)
1034
- cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,))
1035
- events = [dict(r) for r in await cursor.fetchall()]
1036
- session['events'] = events
1037
- return session
1038
-
1039
- @app.get("/api/settings")
1040
- async def get_settings():
1041
- async with aiosqlite.connect(db_path) as db:
1042
- db.row_factory = aiosqlite.Row
1043
- cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
1044
- row = await cursor.fetchone()
1045
- result = dict(row) if row else {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
1046
- result['l2cs_boost'] = _l2cs_boost_enabled
1047
- return result
1048
-
1049
- @app.put("/api/settings")
1050
- async def update_settings(settings: SettingsUpdate):
1051
- async with aiosqlite.connect(db_path) as db:
1052
- cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
1053
- exists = await cursor.fetchone()
1054
- if not exists:
1055
- await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)")
1056
- await db.commit()
1057
-
1058
- updates = []
1059
- params = []
1060
- if settings.sensitivity is not None:
1061
- updates.append("sensitivity = ?")
1062
- params.append(max(1, min(10, settings.sensitivity)))
1063
- if settings.notification_enabled is not None:
1064
- updates.append("notification_enabled = ?")
1065
- params.append(settings.notification_enabled)
1066
- if settings.notification_threshold is not None:
1067
- updates.append("notification_threshold = ?")
1068
- params.append(max(5, min(300, settings.notification_threshold)))
1069
- if settings.frame_rate is not None:
1070
- updates.append("frame_rate = ?")
1071
- params.append(max(5, min(60, settings.frame_rate)))
1072
- if settings.model_name is not None and settings.model_name in pipelines:
1073
- if settings.model_name == "l2cs":
1074
- loop = asyncio.get_event_loop()
1075
- loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs)
1076
- if not loaded:
1077
- raise HTTPException(status_code=400, detail=f"L2CS model unavailable: {_l2cs_error}")
1078
- elif pipelines[settings.model_name] is None:
1079
- raise HTTPException(status_code=400, detail=f"Model '{settings.model_name}' not loaded")
1080
- updates.append("model_name = ?")
1081
- params.append(settings.model_name)
1082
- global _cached_model_name
1083
- _cached_model_name = settings.model_name
1084
-
1085
- if settings.l2cs_boost is not None:
1086
- global _l2cs_boost_enabled
1087
- if settings.l2cs_boost:
1088
- loop = asyncio.get_event_loop()
1089
- loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs)
1090
- if not loaded:
1091
- raise HTTPException(status_code=400, detail=f"L2CS boost unavailable: {_l2cs_error}")
1092
- _l2cs_boost_enabled = settings.l2cs_boost
1093
-
1094
- if updates:
1095
- query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
1096
- await db.execute(query, params)
1097
- await db.commit()
1098
- return {"status": "success", "updated": len(updates) > 0}
1099
-
1100
- @app.get("/api/stats/summary")
1101
- async def get_stats_summary():
1102
- async with aiosqlite.connect(db_path) as db:
1103
- cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL")
1104
- total_sessions = (await cursor.fetchone())[0]
1105
- cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL")
1106
- total_focus_time = (await cursor.fetchone())[0] or 0
1107
- cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL")
1108
- avg_focus_score = (await cursor.fetchone())[0] or 0.0
1109
- cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC")
1110
- dates = [row[0] for row in await cursor.fetchall()]
1111
-
1112
- streak_days = 0
1113
- if dates:
1114
- current_date = datetime.now().date()
1115
- for i, date_str in enumerate(dates):
1116
- session_date = datetime.fromisoformat(date_str).date()
1117
- expected_date = current_date - timedelta(days=i)
1118
- if session_date == expected_date: streak_days += 1
1119
- else: break
1120
- return {
1121
- 'total_sessions': total_sessions,
1122
- 'total_focus_time': int(total_focus_time),
1123
- 'avg_focus_score': round(avg_focus_score, 3),
1124
- 'streak_days': streak_days
1125
- }
1126
-
1127
- @app.get("/api/models")
1128
- async def get_available_models():
1129
- """Return model names, statuses, and which is currently active."""
1130
- statuses = {}
1131
- errors = {}
1132
- available = []
1133
- for name, p in pipelines.items():
1134
- if name == "l2cs":
1135
- if p is not None:
1136
- statuses[name] = "ready"
1137
- available.append(name)
1138
- elif is_l2cs_weights_available():
1139
- statuses[name] = "lazy"
1140
- available.append(name)
1141
- elif _l2cs_error:
1142
- statuses[name] = "error"
1143
- errors[name] = _l2cs_error
1144
- else:
1145
- statuses[name] = "unavailable"
1146
- elif p is not None:
1147
- statuses[name] = "ready"
1148
- available.append(name)
1149
- else:
1150
- statuses[name] = "unavailable"
1151
- async with aiosqlite.connect(db_path) as db:
1152
- cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
1153
- row = await cursor.fetchone()
1154
- current = row[0] if row else "mlp"
1155
- if current not in available and available:
1156
- current = available[0]
1157
- l2cs_boost_available = (
1158
- statuses.get("l2cs") in ("ready", "lazy") and current != "l2cs"
1159
- )
1160
- return {
1161
- "available": available,
1162
- "current": current,
1163
- "statuses": statuses,
1164
- "errors": errors,
1165
- "l2cs_boost": _l2cs_boost_enabled,
1166
- "l2cs_boost_available": l2cs_boost_available,
1167
- }
1168
-
1169
- @app.get("/api/l2cs/status")
1170
- async def l2cs_status():
1171
- """L2CS-specific status: weights available, loaded, and calibration info."""
1172
- loaded = pipelines.get("l2cs") is not None
1173
- return {
1174
- "weights_available": is_l2cs_weights_available(),
1175
- "loaded": loaded,
1176
- "error": _l2cs_error,
1177
- }
1178
-
1179
- @app.get("/api/mesh-topology")
1180
- async def get_mesh_topology():
1181
- """Return tessellation edge pairs for client-side face mesh drawing (cached by client)."""
1182
- return {"tessellation": _TESSELATION_CONNS}
1183
-
1184
- @app.get("/health")
1185
- async def health_check():
1186
- available = [name for name, p in pipelines.items() if p is not None]
1187
- return {"status": "healthy", "models_loaded": available, "database": os.path.exists(db_path)}
1188
-
1189
- # ================ STATIC FILES (SPA SUPPORT) ================
1190
-
1191
- FRONTEND_DIR = "dist" if os.path.exists("dist/index.html") else "static"
1192
-
1193
- assets_path = os.path.join(FRONTEND_DIR, "assets")
1194
- if os.path.exists(assets_path):
1195
- app.mount("/assets", StaticFiles(directory=assets_path), name="assets")
1196
-
1197
- @app.get("/{full_path:path}")
1198
- async def serve_react_app(full_path: str, request: Request):
1199
- if full_path.startswith("api") or full_path.startswith("ws"):
1200
- raise HTTPException(status_code=404, detail="Not Found")
1201
-
1202
- file_path = os.path.join(FRONTEND_DIR, full_path)
1203
- if os.path.isfile(file_path):
1204
- return FileResponse(file_path)
1205
-
1206
- index_path = os.path.join(FRONTEND_DIR, "index.html")
1207
- if os.path.exists(index_path):
1208
- return FileResponse(index_path)
1209
- else:
1210
- return {"message": "React app not found. Please run npm run build."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/L2CS-Net/.gitignore DELETED
@@ -1,140 +0,0 @@
1
- # Ignore the test data - sensitive
2
- datasets/
3
- evaluation/
4
- output/
5
-
6
- # Ignore debugging configurations
7
- /.vscode
8
-
9
- # Byte-compiled / optimized / DLL files
10
- __pycache__/
11
- *.py[cod]
12
- *$py.class
13
-
14
- # C extensions
15
- *.so
16
-
17
- # Distribution / packaging
18
- .Python
19
- build/
20
- develop-eggs/
21
- dist/
22
- downloads/
23
- eggs/
24
- .eggs/
25
- lib/
26
- lib64/
27
- parts/
28
- sdist/
29
- var/
30
- wheels/
31
- pip-wheel-metadata/
32
- share/python-wheels/
33
- *.egg-info/
34
- .installed.cfg
35
- *.egg
36
- MANIFEST
37
-
38
- # PyInstaller
39
- # Usually these files are written by a python script from a template
40
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
- *.manifest
42
- *.spec
43
-
44
- # Installer logs
45
- pip-log.txt
46
- pip-delete-this-directory.txt
47
-
48
- # Unit test / coverage reports
49
- htmlcov/
50
- .tox/
51
- .nox/
52
- .coverage
53
- .coverage.*
54
- .cache
55
- nosetests.xml
56
- coverage.xml
57
- *.cover
58
- *.py,cover
59
- .hypothesis/
60
- .pytest_cache/
61
-
62
- # Translations
63
- *.mo
64
- *.pot
65
-
66
- # Django stuff:
67
- *.log
68
- local_settings.py
69
- db.sqlite3
70
- db.sqlite3-journal
71
-
72
- # Flask stuff:
73
- instance/
74
- .webassets-cache
75
-
76
- # Scrapy stuff:
77
- .scrapy
78
-
79
- # Sphinx documentation
80
- docs/_build/
81
-
82
- # PyBuilder
83
- target/
84
-
85
- # Jupyter Notebook
86
- .ipynb_checkpoints
87
-
88
- # IPython
89
- profile_default/
90
- ipython_config.py
91
-
92
- # pyenv
93
- .python-version
94
-
95
- # pipenv
96
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
- # install all needed dependencies.
100
- #Pipfile.lock
101
-
102
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow
103
- __pypackages__/
104
-
105
- # Celery stuff
106
- celerybeat-schedule
107
- celerybeat.pid
108
-
109
- # SageMath parsed files
110
- *.sage.py
111
-
112
- # Environments
113
- .env
114
- .venv
115
- env/
116
- venv/
117
- ENV/
118
- env.bak/
119
- venv.bak/
120
-
121
- # Spyder project settings
122
- .spyderproject
123
- .spyproject
124
-
125
- # Rope project settings
126
- .ropeproject
127
-
128
- # mkdocs documentation
129
- /site
130
-
131
- # mypy
132
- .mypy_cache/
133
- .dmypy.json
134
- dmypy.json
135
-
136
- # Pyre type checker
137
- .pyre/
138
-
139
- # Ignore other files
140
- my.secrets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/L2CS-Net/LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2022 Ahmed Abdelrahman
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/L2CS-Net/README.md DELETED
@@ -1,148 +0,0 @@
1
-
2
-
3
-
4
- <p align="center">
5
- <img src="https://github.com/Ahmednull/Storage/blob/main/gaze.gif" alt="animated" />
6
- </p>
7
-
8
-
9
- ___
10
-
11
- # L2CS-Net
12
-
13
- The official PyTorch implementation of L2CS-Net for gaze estimation and tracking.
14
-
15
- ## Installation
16
- <img src="https://img.shields.io/badge/python%20-%2314354C.svg?&style=for-the-badge&logo=python&logoColor=white"/> <img src="https://img.shields.io/badge/PyTorch%20-%23EE4C2C.svg?&style=for-the-badge&logo=PyTorch&logoColor=white" />
17
-
18
- Install package with the following:
19
-
20
- ```
21
- pip install git+https://github.com/Ahmednull/L2CS-Net.git@main
22
- ```
23
-
24
- Or, you can git clone the repo and install with the following:
25
-
26
- ```
27
- pip install [-e] .
28
- ```
29
-
30
- Now you should be able to import the package with the following command:
31
-
32
- ```
33
- $ python
34
- >>> import l2cs
35
- ```
36
-
37
- ## Usage
38
-
39
- Detect face and predict gaze from webcam
40
-
41
- ```python
42
- from l2cs import Pipeline, render
43
- import cv2
44
-
45
- gaze_pipeline = Pipeline(
46
- weights=CWD / 'models' / 'L2CSNet_gaze360.pkl',
47
- arch='ResNet50',
48
- device=torch.device('cpu') # or 'gpu'
49
- )
50
-
51
- cap = cv2.VideoCapture(cam)
52
- _, frame = cap.read()
53
-
54
- # Process frame and visualize
55
- results = gaze_pipeline.step(frame)
56
- frame = render(frame, results)
57
- ```
58
-
59
- ## Demo
60
- * Download the pre-trained models from [here](https://drive.google.com/drive/folders/17p6ORr-JQJcw-eYtG2WGNiuS_qVKwdWd?usp=sharing) and Store it to *models/*.
61
- * Run:
62
- ```
63
- python demo.py \
64
- --snapshot models/L2CSNet_gaze360.pkl \
65
- --gpu 0 \
66
- --cam 0 \
67
- ```
68
- This means the demo will run using *L2CSNet_gaze360.pkl* pretrained model
69
-
70
- ## Community Contributions
71
-
72
- - [Gaze Detection and Eye Tracking: A How-To Guide](https://blog.roboflow.com/gaze-direction-position/): Use L2CS-Net through a HTTP interface with the open source Roboflow Inference project.
73
-
74
- ## MPIIGaze
75
- We provide the code for train and test MPIIGaze dataset with leave-one-person-out evaluation.
76
-
77
- ### Prepare datasets
78
- * Download **MPIIFaceGaze dataset** from [here](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/its-written-all-over-your-face-full-face-appearance-based-gaze-estimation).
79
- * Apply data preprocessing from [here](http://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/).
80
- * Store the dataset to *datasets/MPIIFaceGaze*.
81
-
82
- ### Train
83
- ```
84
- python train.py \
85
- --dataset mpiigaze \
86
- --snapshot output/snapshots \
87
- --gpu 0 \
88
- --num_epochs 50 \
89
- --batch_size 16 \
90
- --lr 0.00001 \
91
- --alpha 1 \
92
-
93
- ```
94
- This means the code will perform leave-one-person-out training automatically and store the models to *output/snapshots*.
95
-
96
- ### Test
97
- ```
98
- python test.py \
99
- --dataset mpiigaze \
100
- --snapshot output/snapshots/snapshot_folder \
101
- --evalpath evaluation/L2CS-mpiigaze \
102
- --gpu 0 \
103
- ```
104
- This means the code will perform leave-one-person-out testing automatically and store the results to *evaluation/L2CS-mpiigaze*.
105
-
106
- To get the average leave-one-person-out accuracy use:
107
- ```
108
- python leave_one_out_eval.py \
109
- --evalpath evaluation/L2CS-mpiigaze \
110
- --respath evaluation/L2CS-mpiigaze \
111
- ```
112
- This means the code will take the evaluation path and outputs the leave-one-out gaze accuracy to the *evaluation/L2CS-mpiigaze*.
113
-
114
- ## Gaze360
115
- We provide the code for train and test Gaze360 dataset with train-val-test evaluation.
116
-
117
- ### Prepare datasets
118
- * Download **Gaze360 dataset** from [here](http://gaze360.csail.mit.edu/download.php).
119
-
120
- * Apply data preprocessing from [here](http://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/).
121
-
122
- * Store the dataset to *datasets/Gaze360*.
123
-
124
-
125
- ### Train
126
- ```
127
- python train.py \
128
- --dataset gaze360 \
129
- --snapshot output/snapshots \
130
- --gpu 0 \
131
- --num_epochs 50 \
132
- --batch_size 16 \
133
- --lr 0.00001 \
134
- --alpha 1 \
135
-
136
- ```
137
- This means the code will perform training and store the models to *output/snapshots*.
138
-
139
- ### Test
140
- ```
141
- python test.py \
142
- --dataset gaze360 \
143
- --snapshot output/snapshots/snapshot_folder \
144
- --evalpath evaluation/L2CS-gaze360 \
145
- --gpu 0 \
146
- ```
147
- This means the code will perform testing on snapshot_folder and store the results to *evaluation/L2CS-gaze360*.
148
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/L2CS-Net/demo.py DELETED
@@ -1,87 +0,0 @@
1
- import argparse
2
- import pathlib
3
- import numpy as np
4
- import cv2
5
- import time
6
-
7
- import torch
8
- import torch.nn as nn
9
- from torch.autograd import Variable
10
- from torchvision import transforms
11
- import torch.backends.cudnn as cudnn
12
- import torchvision
13
-
14
- from PIL import Image
15
- from PIL import Image, ImageOps
16
-
17
- from face_detection import RetinaFace
18
-
19
- from l2cs import select_device, draw_gaze, getArch, Pipeline, render
20
-
21
- CWD = pathlib.Path.cwd()
22
-
23
- def parse_args():
24
- """Parse input arguments."""
25
- parser = argparse.ArgumentParser(
26
- description='Gaze evalution using model pretrained with L2CS-Net on Gaze360.')
27
- parser.add_argument(
28
- '--device',dest='device', help='Device to run model: cpu or gpu:0',
29
- default="cpu", type=str)
30
- parser.add_argument(
31
- '--snapshot',dest='snapshot', help='Path of model snapshot.',
32
- default='output/snapshots/L2CS-gaze360-_loader-180-4/_epoch_55.pkl', type=str)
33
- parser.add_argument(
34
- '--cam',dest='cam_id', help='Camera device id to use [0]',
35
- default=0, type=int)
36
- parser.add_argument(
37
- '--arch',dest='arch',help='Network architecture, can be: ResNet18, ResNet34, ResNet50, ResNet101, ResNet152',
38
- default='ResNet50', type=str)
39
-
40
- args = parser.parse_args()
41
- return args
42
-
43
- if __name__ == '__main__':
44
- args = parse_args()
45
-
46
- cudnn.enabled = True
47
- arch=args.arch
48
- cam = args.cam_id
49
- # snapshot_path = args.snapshot
50
-
51
- gaze_pipeline = Pipeline(
52
- weights=CWD / 'models' / 'L2CSNet_gaze360.pkl',
53
- arch='ResNet50',
54
- device = select_device(args.device, batch_size=1)
55
- )
56
-
57
- cap = cv2.VideoCapture(cam)
58
-
59
- # Check if the webcam is opened correctly
60
- if not cap.isOpened():
61
- raise IOError("Cannot open webcam")
62
-
63
- with torch.no_grad():
64
- while True:
65
-
66
- # Get frame
67
- success, frame = cap.read()
68
- start_fps = time.time()
69
-
70
- if not success:
71
- print("Failed to obtain frame")
72
- time.sleep(0.1)
73
-
74
- # Process frame
75
- results = gaze_pipeline.step(frame)
76
-
77
- # Visualize output
78
- frame = render(frame, results)
79
-
80
- myFPS = 1.0 / (time.time() - start_fps)
81
- cv2.putText(frame, 'FPS: {:.1f}'.format(myFPS), (10, 20),cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (0, 255, 0), 1, cv2.LINE_AA)
82
-
83
- cv2.imshow("Demo",frame)
84
- if cv2.waitKey(1) & 0xFF == ord('q'):
85
- break
86
- success,frame = cap.read()
87
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/L2CS-Net/l2cs/__init__.py DELETED
@@ -1,21 +0,0 @@
1
- from .utils import select_device, natural_keys, gazeto3d, angular, getArch
2
- from .vis import draw_gaze, render
3
- from .model import L2CS
4
- from .pipeline import Pipeline
5
- from .datasets import Gaze360, Mpiigaze
6
-
7
- __all__ = [
8
- # Classes
9
- 'L2CS',
10
- 'Pipeline',
11
- 'Gaze360',
12
- 'Mpiigaze',
13
- # Utils
14
- 'render',
15
- 'select_device',
16
- 'draw_gaze',
17
- 'natural_keys',
18
- 'gazeto3d',
19
- 'angular',
20
- 'getArch'
21
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/L2CS-Net/l2cs/datasets.py DELETED
@@ -1,157 +0,0 @@
1
- import os
2
- import numpy as np
3
- import cv2
4
-
5
-
6
- import torch
7
- from torch.utils.data.dataset import Dataset
8
- from torchvision import transforms
9
- from PIL import Image, ImageFilter
10
-
11
-
12
- class Gaze360(Dataset):
13
- def __init__(self, path, root, transform, angle, binwidth, train=True):
14
- self.transform = transform
15
- self.root = root
16
- self.orig_list_len = 0
17
- self.angle = angle
18
- if train==False:
19
- angle=90
20
- self.binwidth=binwidth
21
- self.lines = []
22
- if isinstance(path, list):
23
- for i in path:
24
- with open(i) as f:
25
- print("here")
26
- line = f.readlines()
27
- line.pop(0)
28
- self.lines.extend(line)
29
- else:
30
- with open(path) as f:
31
- lines = f.readlines()
32
- lines.pop(0)
33
- self.orig_list_len = len(lines)
34
- for line in lines:
35
- gaze2d = line.strip().split(" ")[5]
36
- label = np.array(gaze2d.split(",")).astype("float")
37
- if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle:
38
- self.lines.append(line)
39
-
40
-
41
- print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines), angle))
42
-
43
- def __len__(self):
44
- return len(self.lines)
45
-
46
- def __getitem__(self, idx):
47
- line = self.lines[idx]
48
- line = line.strip().split(" ")
49
-
50
- face = line[0]
51
- lefteye = line[1]
52
- righteye = line[2]
53
- name = line[3]
54
- gaze2d = line[5]
55
- label = np.array(gaze2d.split(",")).astype("float")
56
- label = torch.from_numpy(label).type(torch.FloatTensor)
57
-
58
- pitch = label[0]* 180 / np.pi
59
- yaw = label[1]* 180 / np.pi
60
-
61
- img = Image.open(os.path.join(self.root, face))
62
-
63
- # fimg = cv2.imread(os.path.join(self.root, face))
64
- # fimg = cv2.resize(fimg, (448, 448))/255.0
65
- # fimg = fimg.transpose(2, 0, 1)
66
- # img=torch.from_numpy(fimg).type(torch.FloatTensor)
67
-
68
- if self.transform:
69
- img = self.transform(img)
70
-
71
- # Bin values
72
- bins = np.array(range(-1*self.angle, self.angle, self.binwidth))
73
- binned_pose = np.digitize([pitch, yaw], bins) - 1
74
-
75
- labels = binned_pose
76
- cont_labels = torch.FloatTensor([pitch, yaw])
77
-
78
-
79
- return img, labels, cont_labels, name
80
-
81
- class Mpiigaze(Dataset):
82
- def __init__(self, pathorg, root, transform, train, angle,fold=0):
83
- self.transform = transform
84
- self.root = root
85
- self.orig_list_len = 0
86
- self.lines = []
87
- path=pathorg.copy()
88
- if train==True:
89
- path.pop(fold)
90
- else:
91
- path=path[fold]
92
- if isinstance(path, list):
93
- for i in path:
94
- with open(i) as f:
95
- lines = f.readlines()
96
- lines.pop(0)
97
- self.orig_list_len += len(lines)
98
- for line in lines:
99
- gaze2d = line.strip().split(" ")[7]
100
- label = np.array(gaze2d.split(",")).astype("float")
101
- if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle:
102
- self.lines.append(line)
103
- else:
104
- with open(path) as f:
105
- lines = f.readlines()
106
- lines.pop(0)
107
- self.orig_list_len += len(lines)
108
- for line in lines:
109
- gaze2d = line.strip().split(" ")[7]
110
- label = np.array(gaze2d.split(",")).astype("float")
111
- if abs((label[0]*180/np.pi)) <= 42 and abs((label[1]*180/np.pi)) <= 42:
112
- self.lines.append(line)
113
-
114
- print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines),angle))
115
-
116
- def __len__(self):
117
- return len(self.lines)
118
-
119
- def __getitem__(self, idx):
120
- line = self.lines[idx]
121
- line = line.strip().split(" ")
122
-
123
- name = line[3]
124
- gaze2d = line[7]
125
- head2d = line[8]
126
- lefteye = line[1]
127
- righteye = line[2]
128
- face = line[0]
129
-
130
- label = np.array(gaze2d.split(",")).astype("float")
131
- label = torch.from_numpy(label).type(torch.FloatTensor)
132
-
133
-
134
- pitch = label[0]* 180 / np.pi
135
- yaw = label[1]* 180 / np.pi
136
-
137
- img = Image.open(os.path.join(self.root, face))
138
-
139
- # fimg = cv2.imread(os.path.join(self.root, face))
140
- # fimg = cv2.resize(fimg, (448, 448))/255.0
141
- # fimg = fimg.transpose(2, 0, 1)
142
- # img=torch.from_numpy(fimg).type(torch.FloatTensor)
143
-
144
- if self.transform:
145
- img = self.transform(img)
146
-
147
- # Bin values
148
- bins = np.array(range(-42, 42,3))
149
- binned_pose = np.digitize([pitch, yaw], bins) - 1
150
-
151
- labels = binned_pose
152
- cont_labels = torch.FloatTensor([pitch, yaw])
153
-
154
-
155
- return img, labels, cont_labels, name
156
-
157
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/L2CS-Net/l2cs/model.py DELETED
@@ -1,73 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.autograd import Variable
4
- import math
5
- import torch.nn.functional as F
6
-
7
-
8
- class L2CS(nn.Module):
9
- def __init__(self, block, layers, num_bins):
10
- self.inplanes = 64
11
- super(L2CS, self).__init__()
12
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
13
- self.bn1 = nn.BatchNorm2d(64)
14
- self.relu = nn.ReLU(inplace=True)
15
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
16
- self.layer1 = self._make_layer(block, 64, layers[0])
17
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
18
- self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
19
- self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
20
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
21
-
22
- self.fc_yaw_gaze = nn.Linear(512 * block.expansion, num_bins)
23
- self.fc_pitch_gaze = nn.Linear(512 * block.expansion, num_bins)
24
-
25
- # Vestigial layer from previous experiments
26
- self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)
27
-
28
- for m in self.modules():
29
- if isinstance(m, nn.Conv2d):
30
- n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
31
- m.weight.data.normal_(0, math.sqrt(2. / n))
32
- elif isinstance(m, nn.BatchNorm2d):
33
- m.weight.data.fill_(1)
34
- m.bias.data.zero_()
35
-
36
- def _make_layer(self, block, planes, blocks, stride=1):
37
- downsample = None
38
- if stride != 1 or self.inplanes != planes * block.expansion:
39
- downsample = nn.Sequential(
40
- nn.Conv2d(self.inplanes, planes * block.expansion,
41
- kernel_size=1, stride=stride, bias=False),
42
- nn.BatchNorm2d(planes * block.expansion),
43
- )
44
-
45
- layers = []
46
- layers.append(block(self.inplanes, planes, stride, downsample))
47
- self.inplanes = planes * block.expansion
48
- for i in range(1, blocks):
49
- layers.append(block(self.inplanes, planes))
50
-
51
- return nn.Sequential(*layers)
52
-
53
- def forward(self, x):
54
- x = self.conv1(x)
55
- x = self.bn1(x)
56
- x = self.relu(x)
57
- x = self.maxpool(x)
58
-
59
- x = self.layer1(x)
60
- x = self.layer2(x)
61
- x = self.layer3(x)
62
- x = self.layer4(x)
63
- x = self.avgpool(x)
64
- x = x.view(x.size(0), -1)
65
-
66
-
67
- # gaze
68
- pre_yaw_gaze = self.fc_yaw_gaze(x)
69
- pre_pitch_gaze = self.fc_pitch_gaze(x)
70
- return pre_yaw_gaze, pre_pitch_gaze
71
-
72
-
73
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/L2CS-Net/l2cs/pipeline.py DELETED
@@ -1,133 +0,0 @@
1
- import pathlib
2
- from typing import Union
3
-
4
- import cv2
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- from dataclasses import dataclass
9
- from face_detection import RetinaFace
10
-
11
- from .utils import prep_input_numpy, getArch
12
- from .results import GazeResultContainer
13
-
14
-
15
- class Pipeline:
16
-
17
- def __init__(
18
- self,
19
- weights: pathlib.Path,
20
- arch: str,
21
- device: str = 'cpu',
22
- include_detector:bool = True,
23
- confidence_threshold:float = 0.5
24
- ):
25
-
26
- # Save input parameters
27
- self.weights = weights
28
- self.include_detector = include_detector
29
- self.device = device
30
- self.confidence_threshold = confidence_threshold
31
-
32
- # Create L2CS model
33
- self.model = getArch(arch, 90)
34
- self.model.load_state_dict(torch.load(self.weights, map_location=device))
35
- self.model.to(self.device)
36
- self.model.eval()
37
-
38
- # Create RetinaFace if requested
39
- if self.include_detector:
40
-
41
- if device.type == 'cpu':
42
- self.detector = RetinaFace()
43
- else:
44
- self.detector = RetinaFace(gpu_id=device.index)
45
-
46
- self.softmax = nn.Softmax(dim=1)
47
- self.idx_tensor = [idx for idx in range(90)]
48
- self.idx_tensor = torch.FloatTensor(self.idx_tensor).to(self.device)
49
-
50
- def step(self, frame: np.ndarray) -> GazeResultContainer:
51
-
52
- # Creating containers
53
- face_imgs = []
54
- bboxes = []
55
- landmarks = []
56
- scores = []
57
-
58
- if self.include_detector:
59
- faces = self.detector(frame)
60
-
61
- if faces is not None:
62
- for box, landmark, score in faces:
63
-
64
- # Apply threshold
65
- if score < self.confidence_threshold:
66
- continue
67
-
68
- # Extract safe min and max of x,y
69
- x_min=int(box[0])
70
- if x_min < 0:
71
- x_min = 0
72
- y_min=int(box[1])
73
- if y_min < 0:
74
- y_min = 0
75
- x_max=int(box[2])
76
- y_max=int(box[3])
77
-
78
- # Crop image
79
- img = frame[y_min:y_max, x_min:x_max]
80
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
81
- img = cv2.resize(img, (224, 224))
82
- face_imgs.append(img)
83
-
84
- # Save data
85
- bboxes.append(box)
86
- landmarks.append(landmark)
87
- scores.append(score)
88
-
89
- # Predict gaze
90
- pitch, yaw = self.predict_gaze(np.stack(face_imgs))
91
-
92
- else:
93
-
94
- pitch = np.empty((0,1))
95
- yaw = np.empty((0,1))
96
-
97
- else:
98
- pitch, yaw = self.predict_gaze(frame)
99
-
100
- # Save data
101
- results = GazeResultContainer(
102
- pitch=pitch,
103
- yaw=yaw,
104
- bboxes=np.stack(bboxes),
105
- landmarks=np.stack(landmarks),
106
- scores=np.stack(scores)
107
- )
108
-
109
- return results
110
-
111
- def predict_gaze(self, frame: Union[np.ndarray, torch.Tensor]):
112
-
113
- # Prepare input
114
- if isinstance(frame, np.ndarray):
115
- img = prep_input_numpy(frame, self.device)
116
- elif isinstance(frame, torch.Tensor):
117
- img = frame
118
- else:
119
- raise RuntimeError("Invalid dtype for input")
120
-
121
- # Predict
122
- gaze_pitch, gaze_yaw = self.model(img)
123
- pitch_predicted = self.softmax(gaze_pitch)
124
- yaw_predicted = self.softmax(gaze_yaw)
125
-
126
- # Get continuous predictions in degrees.
127
- pitch_predicted = torch.sum(pitch_predicted.data * self.idx_tensor, dim=1) * 4 - 180
128
- yaw_predicted = torch.sum(yaw_predicted.data * self.idx_tensor, dim=1) * 4 - 180
129
-
130
- pitch_predicted= pitch_predicted.cpu().detach().numpy()* np.pi/180.0
131
- yaw_predicted= yaw_predicted.cpu().detach().numpy()* np.pi/180.0
132
-
133
- return pitch_predicted, yaw_predicted
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/L2CS-Net/l2cs/results.py DELETED
@@ -1,11 +0,0 @@
1
- from dataclasses import dataclass
2
- import numpy as np
3
-
4
- @dataclass
5
- class GazeResultContainer:
6
-
7
- pitch: np.ndarray
8
- yaw: np.ndarray
9
- bboxes: np.ndarray
10
- landmarks: np.ndarray
11
- scores: np.ndarray
 
 
 
 
 
 
 
 
 
 
 
 
models/L2CS-Net/l2cs/utils.py DELETED
@@ -1,145 +0,0 @@
1
- import sys
2
- import os
3
- import math
4
- from math import cos, sin
5
- from pathlib import Path
6
- import subprocess
7
- import re
8
-
9
- import numpy as np
10
- import torch
11
- import torch.nn as nn
12
- import scipy.io as sio
13
- import cv2
14
- import torchvision
15
- from torchvision import transforms
16
-
17
- from .model import L2CS
18
-
19
- transformations = transforms.Compose([
20
- transforms.ToPILImage(),
21
- transforms.Resize(448),
22
- transforms.ToTensor(),
23
- transforms.Normalize(
24
- mean=[0.485, 0.456, 0.406],
25
- std=[0.229, 0.224, 0.225]
26
- )
27
- ])
28
-
29
- def atoi(text):
30
- return int(text) if text.isdigit() else text
31
-
32
- def natural_keys(text):
33
- '''
34
- alist.sort(key=natural_keys) sorts in human order
35
- http://nedbatchelder.com/blog/200712/human_sorting.html
36
- (See Toothy's implementation in the comments)
37
- '''
38
- return [ atoi(c) for c in re.split(r'(\d+)', text) ]
39
-
40
- def prep_input_numpy(img:np.ndarray, device:str):
41
- """Preparing a Numpy Array as input to L2CS-Net."""
42
-
43
- if len(img.shape) == 4:
44
- imgs = []
45
- for im in img:
46
- imgs.append(transformations(im))
47
- img = torch.stack(imgs)
48
- else:
49
- img = transformations(img)
50
-
51
- img = img.to(device)
52
-
53
- if len(img.shape) == 3:
54
- img = img.unsqueeze(0)
55
-
56
- return img
57
-
58
- def gazeto3d(gaze):
59
- gaze_gt = np.zeros([3])
60
- gaze_gt[0] = -np.cos(gaze[1]) * np.sin(gaze[0])
61
- gaze_gt[1] = -np.sin(gaze[1])
62
- gaze_gt[2] = -np.cos(gaze[1]) * np.cos(gaze[0])
63
- return gaze_gt
64
-
65
- def angular(gaze, label):
66
- total = np.sum(gaze * label)
67
- return np.arccos(min(total/(np.linalg.norm(gaze)* np.linalg.norm(label)), 0.9999999))*180/np.pi
68
-
69
- def select_device(device='', batch_size=None):
70
- # device = 'cpu' or '0' or '0,1,2,3'
71
- s = f'YOLOv3 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
72
- cpu = device.lower() == 'cpu'
73
- if cpu:
74
- os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
75
- elif device: # non-cpu device requested
76
- os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
77
- # assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
78
-
79
- cuda = not cpu and torch.cuda.is_available()
80
- if cuda:
81
- devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7
82
- n = len(devices) # device count
83
- if n > 1 and batch_size: # check batch_size is divisible by device_count
84
- assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
85
- space = ' ' * len(s)
86
- for i, d in enumerate(devices):
87
- p = torch.cuda.get_device_properties(i)
88
- s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
89
- else:
90
- s += 'CPU\n'
91
-
92
- return torch.device('cuda:0' if cuda else 'cpu')
93
-
94
- def spherical2cartesial(x):
95
-
96
- output = torch.zeros(x.size(0),3)
97
- output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0])
98
- output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0])
99
- output[:,1] = torch.sin(x[:,1])
100
-
101
- return output
102
-
103
- def compute_angular_error(input,target):
104
-
105
- input = spherical2cartesial(input)
106
- target = spherical2cartesial(target)
107
-
108
- input = input.view(-1,3,1)
109
- target = target.view(-1,1,3)
110
- output_dot = torch.bmm(target,input)
111
- output_dot = output_dot.view(-1)
112
- output_dot = torch.acos(output_dot)
113
- output_dot = output_dot.data
114
- output_dot = 180*torch.mean(output_dot)/math.pi
115
- return output_dot
116
-
117
- def softmax_temperature(tensor, temperature):
118
- result = torch.exp(tensor / temperature)
119
- result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result))
120
- return result
121
-
122
- def git_describe(path=Path(__file__).parent): # path must be a directory
123
- # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
124
- s = f'git -C {path} describe --tags --long --always'
125
- try:
126
- return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
127
- except subprocess.CalledProcessError as e:
128
- return '' # not a git repository
129
-
130
- def getArch(arch,bins):
131
- # Base network structure
132
- if arch == 'ResNet18':
133
- model = L2CS( torchvision.models.resnet.BasicBlock,[2, 2, 2, 2], bins)
134
- elif arch == 'ResNet34':
135
- model = L2CS( torchvision.models.resnet.BasicBlock,[3, 4, 6, 3], bins)
136
- elif arch == 'ResNet101':
137
- model = L2CS( torchvision.models.resnet.Bottleneck,[3, 4, 23, 3], bins)
138
- elif arch == 'ResNet152':
139
- model = L2CS( torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
140
- else:
141
- if arch != 'ResNet50':
142
- print('Invalid value for architecture is passed! '
143
- 'The default value of ResNet50 will be used instead!')
144
- model = L2CS( torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
145
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/L2CS-Net/l2cs/vis.py DELETED
@@ -1,64 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- from .results import GazeResultContainer
4
-
5
- def draw_gaze(a,b,c,d,image_in, pitchyaw, thickness=2, color=(255, 255, 0),sclae=2.0):
6
- """Draw gaze angle on given image with a given eye positions."""
7
- image_out = image_in
8
- (h, w) = image_in.shape[:2]
9
- length = c
10
- pos = (int(a+c / 2.0), int(b+d / 2.0))
11
- if len(image_out.shape) == 2 or image_out.shape[2] == 1:
12
- image_out = cv2.cvtColor(image_out, cv2.COLOR_GRAY2BGR)
13
- dx = -length * np.sin(pitchyaw[0]) * np.cos(pitchyaw[1])
14
- dy = -length * np.sin(pitchyaw[1])
15
- cv2.arrowedLine(image_out, tuple(np.round(pos).astype(np.int32)),
16
- tuple(np.round([pos[0] + dx, pos[1] + dy]).astype(int)), color,
17
- thickness, cv2.LINE_AA, tipLength=0.18)
18
- return image_out
19
-
20
- def draw_bbox(frame: np.ndarray, bbox: np.ndarray):
21
-
22
- x_min=int(bbox[0])
23
- if x_min < 0:
24
- x_min = 0
25
- y_min=int(bbox[1])
26
- if y_min < 0:
27
- y_min = 0
28
- x_max=int(bbox[2])
29
- y_max=int(bbox[3])
30
-
31
- cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0,255,0), 1)
32
-
33
- return frame
34
-
35
- def render(frame: np.ndarray, results: GazeResultContainer):
36
-
37
- # Draw bounding boxes
38
- for bbox in results.bboxes:
39
- frame = draw_bbox(frame, bbox)
40
-
41
- # Draw Gaze
42
- for i in range(results.pitch.shape[0]):
43
-
44
- bbox = results.bboxes[i]
45
- pitch = results.pitch[i]
46
- yaw = results.yaw[i]
47
-
48
- # Extract safe min and max of x,y
49
- x_min=int(bbox[0])
50
- if x_min < 0:
51
- x_min = 0
52
- y_min=int(bbox[1])
53
- if y_min < 0:
54
- y_min = 0
55
- x_max=int(bbox[2])
56
- y_max=int(bbox[3])
57
-
58
- # Compute sizes
59
- bbox_width = x_max - x_min
60
- bbox_height = y_max - y_min
61
-
62
- draw_gaze(x_min,y_min,bbox_width, bbox_height,frame,(pitch,yaw),color=(0,0,255))
63
-
64
- return frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/L2CS-Net/leave_one_out_eval.py DELETED
@@ -1,54 +0,0 @@
1
- import os
2
- import argparse
3
-
4
-
5
-
6
- def parse_args():
7
- """Parse input arguments."""
8
- parser = argparse.ArgumentParser(
9
- description='gaze estimation using binned loss function.')
10
- parser.add_argument(
11
- '--evalpath', dest='evalpath', help='path for evaluating gaze test.',
12
- default="evaluation\L2CS-gaze360-_standard-10", type=str)
13
- parser.add_argument(
14
- '--respath', dest='respath', help='path for saving result.',
15
- default="evaluation\L2CS-gaze360-_standard-10", type=str)
16
-
17
- if __name__ == '__main__':
18
-
19
- args = parse_args()
20
- evalpath =args.evalpath
21
- respath=args.respath
22
- if not os.path.exist(respath):
23
- os.makedirs(respath)
24
- with open(os.path.join(respath,"avg.log"), 'w') as outfile:
25
- outfile.write("Average equal\n")
26
-
27
- min=10.0
28
- dirlist = os.listdir(evalpath)
29
- dirlist.sort()
30
- l=0.0
31
- for j in range(50):
32
- j=20
33
- avg=0.0
34
- h=j+3
35
- for i in dirlist:
36
- with open(evalpath+"/"+i+"/mpiigaze_binned.log") as myfile:
37
-
38
- x=list(myfile)[h]
39
- str1 = ""
40
-
41
- # traverse in the string
42
- for ele in x:
43
- str1 += ele
44
- split_string = str1.split("MAE:",1)[1]
45
- avg+=float(split_string)
46
-
47
- avg=avg/15.0
48
- if avg<min:
49
- min=avg
50
- l=j+1
51
- outfile.write("epoch"+str(j+1)+"= "+str(avg)+"\n")
52
-
53
- outfile.write("min angular error equal= "+str(min)+"at epoch= "+str(l)+"\n")
54
- print(min)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/L2CS-Net/models/L2CSNet_gaze360.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8a7f3480d868dd48261e1d59f915b0ef0bb33ea12ea00938fb2168f212080665
3
- size 95849977
 
 
 
 
models/L2CS-Net/models/README.md DELETED
@@ -1 +0,0 @@
1
- # Path to pre-trained models
 
 
models/L2CS-Net/pyproject.toml DELETED
@@ -1,44 +0,0 @@
1
- [project]
2
- name = "l2cs"
3
- version = "0.0.1"
4
- description = "The official PyTorch implementation of L2CS-Net for gaze estimation and tracking"
5
- authors = [
6
- {name = "Ahmed Abderlrahman"},
7
- {name = "Thorsten Hempel"}
8
- ]
9
- license = {file = "LICENSE.txt"}
10
- readme = "README.md"
11
- requires-python = ">3.6"
12
-
13
- keywords = ["gaze", "estimation", "eye-tracking", "deep-learning", "pytorch"]
14
-
15
- classifiers = [
16
- "Programming Language :: Python :: 3"
17
- ]
18
-
19
- dependencies = [
20
- 'matplotlib>=3.3.4',
21
- 'numpy>=1.19.5',
22
- 'opencv-python>=4.5.5',
23
- 'pandas>=1.1.5',
24
- 'Pillow>=8.4.0',
25
- 'scipy>=1.5.4',
26
- 'torch>=1.10.1',
27
- 'torchvision>=0.11.2',
28
- 'face_detection@git+https://github.com/elliottzheng/face-detection'
29
- ]
30
-
31
- [project.urls]
32
- homepath = "https://github.com/Ahmednull/L2CS-Net"
33
- repository = "https://github.com/Ahmednull/L2CS-Net"
34
-
35
- [build-system]
36
- requires = ["setuptools", "wheel"]
37
- build-backend = "setuptools.build_meta"
38
-
39
- # https://setuptools.pypa.io/en/stable/userguide/datafiles.html
40
- [tool.setuptools]
41
- include-package-data = true
42
-
43
- [tool.setuptools.packages.find]
44
- where = ["."]