Spaces:
Running
Running
Abdelrahman Almatrooshi commited on
Commit ·
bb2a2db
1
Parent(s): 6d9eb2d
Integrate L2CS-Net gaze estimation
Browse files- Add L2CS-Net in-tree (models/L2CS-Net/) with Gaze360 weights via Git LFS
- L2CSPipeline: ResNet50 gaze + MediaPipe head pose, roll de-rotation, cosine scoring
- 9-point polynomial gaze calibration with bias correction and IQR outlier filtering
- Gaze-eye fusion: calibrated screen coords + EAR for focus detection
- L2CS Boost mode: runs gaze alongside any base model (35/65 weight, veto at 0.38)
- Calibration UI: fullscreen overlay, auto-advance, progress ring
- Frontend: GAZE toggle, Calibrate button, gaze pointer dot on canvas
- Bumped capture resolution to 640x480 @ JPEG 0.75
- Dockerfile: added git, CPU-only torch for HF Space deployment
- Dockerfile +8 -1
- README.md +46 -7
- checkpoints/L2CSNet_gaze360.pkl +3 -0
- download_l2cs_weights.py +37 -0
- main.py +338 -37
- models/L2CS-Net/.gitignore +140 -0
- models/L2CS-Net/LICENSE +21 -0
- models/L2CS-Net/README.md +148 -0
- models/L2CS-Net/demo.py +87 -0
- models/L2CS-Net/l2cs/__init__.py +21 -0
- models/L2CS-Net/l2cs/datasets.py +157 -0
- models/L2CS-Net/l2cs/model.py +73 -0
- models/L2CS-Net/l2cs/pipeline.py +133 -0
- models/L2CS-Net/l2cs/results.py +11 -0
- models/L2CS-Net/l2cs/utils.py +145 -0
- models/L2CS-Net/l2cs/vis.py +64 -0
- models/L2CS-Net/leave_one_out_eval.py +54 -0
- models/L2CS-Net/models/L2CSNet_gaze360.pkl +3 -0
- models/L2CS-Net/models/README.md +1 -0
- models/L2CS-Net/pyproject.toml +44 -0
- models/L2CS-Net/test.py +284 -0
- models/L2CS-Net/train.py +384 -0
- models/gaze_calibration.py +146 -0
- models/gaze_eye_fusion.py +66 -0
- requirements.txt +2 -0
- src/components/CalibrationOverlay.jsx +146 -0
- src/components/FocusPageLocal.jsx +140 -2
- src/utils/VideoManagerLocal.js +97 -3
- ui/pipeline.py +149 -5
Dockerfile
CHANGED
|
@@ -7,7 +7,14 @@ ENV PYTHONUNBUFFERED=1
|
|
| 7 |
|
| 8 |
WORKDIR /app
|
| 9 |
|
| 10 |
-
RUN apt-get update && apt-get install -y --no-install-recommends
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
COPY requirements.txt ./
|
| 13 |
RUN pip install --no-cache-dir -r requirements.txt
|
|
|
|
| 7 |
|
| 8 |
WORKDIR /app
|
| 9 |
|
| 10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 11 |
+
libglib2.0-0 libsm6 libxrender1 libxext6 libxcb1 libgl1 libgomp1 \
|
| 12 |
+
ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev \
|
| 13 |
+
libavdevice-dev libopus-dev libvpx-dev libsrtp2-dev \
|
| 14 |
+
build-essential nodejs npm git \
|
| 15 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 16 |
+
|
| 17 |
+
RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
| 18 |
|
| 19 |
COPY requirements.txt ./
|
| 20 |
RUN pip install --no-cache-dir -r requirements.txt
|
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# FocusGuard
|
| 2 |
|
| 3 |
-
Webcam-based focus detection: MediaPipe face mesh
|
| 4 |
|
| 5 |
## Project layout
|
| 6 |
|
|
@@ -9,10 +9,18 @@ Webcam-based focus detection: MediaPipe face mesh → 17 features (EAR, gaze, he
|
|
| 9 |
├── data_preparation/ loaders, split, scale
|
| 10 |
├── notebooks/ MLP/XGB training + LOPO
|
| 11 |
├── models/ face_mesh, head_pose, eye_scorer, train scripts
|
|
|
|
|
|
|
|
|
|
| 12 |
├── checkpoints/ mlp_best.pt, xgboost_*_best.json, scalers
|
| 13 |
├── evaluation/ logs, plots, justify_thresholds
|
| 14 |
├── ui/ pipeline.py, live_demo.py
|
| 15 |
├── src/ React frontend
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
├── static/ built frontend (after npm run build)
|
| 17 |
├── main.py, app.py FastAPI backend
|
| 18 |
├── requirements.txt
|
|
@@ -70,19 +78,50 @@ python -m models.xgboost.train
|
|
| 70 |
|
| 71 |
9 participants, 144,793 samples, 10 features, binary labels. Collect with `python -m models.collect_features --name <name>`. Data lives in `data/collected_<name>/`.
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
## Model numbers (15% test split)
|
| 74 |
|
| 75 |
| Model | Accuracy | F1 | ROC-AUC |
|
| 76 |
|-------|----------|-----|---------|
|
| 77 |
| XGBoost (600 trees, depth 8) | 95.87% | 0.959 | 0.991 |
|
| 78 |
-
| MLP (64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
## Pipeline
|
| 81 |
|
| 82 |
1. Face mesh (MediaPipe 478 pts)
|
| 83 |
-
2. Head pose
|
| 84 |
-
3. Eye scorer
|
| 85 |
-
4. Temporal
|
| 86 |
-
5. 10-d vector
|
| 87 |
|
| 88 |
-
**Stack:** FastAPI, aiosqlite, React/Vite, PyTorch, XGBoost, MediaPipe, OpenCV.
|
|
|
|
| 1 |
# FocusGuard
|
| 2 |
|
| 3 |
+
Webcam-based focus detection: MediaPipe face mesh -> 17 features (EAR, gaze, head pose, PERCLOS, etc.) -> MLP or XGBoost for focused/unfocused. React + FastAPI app with WebSocket video.
|
| 4 |
|
| 5 |
## Project layout
|
| 6 |
|
|
|
|
| 9 |
├── data_preparation/ loaders, split, scale
|
| 10 |
├── notebooks/ MLP/XGB training + LOPO
|
| 11 |
├── models/ face_mesh, head_pose, eye_scorer, train scripts
|
| 12 |
+
│ ├── gaze_calibration.py 9-point polynomial gaze calibration
|
| 13 |
+
│ ├── gaze_eye_fusion.py Fuses calibrated gaze with eye openness
|
| 14 |
+
│ └── L2CS-Net/ In-tree L2CS-Net repo with Gaze360 weights
|
| 15 |
├── checkpoints/ mlp_best.pt, xgboost_*_best.json, scalers
|
| 16 |
├── evaluation/ logs, plots, justify_thresholds
|
| 17 |
├── ui/ pipeline.py, live_demo.py
|
| 18 |
├── src/ React frontend
|
| 19 |
+
│ ├── components/
|
| 20 |
+
│ │ ├── FocusPageLocal.jsx Main focus page (camera, controls, model selector)
|
| 21 |
+
│ │ └── CalibrationOverlay.jsx Fullscreen calibration UI
|
| 22 |
+
│ └── utils/
|
| 23 |
+
│ └── VideoManagerLocal.js WebSocket client, frame capture, canvas rendering
|
| 24 |
├── static/ built frontend (after npm run build)
|
| 25 |
├── main.py, app.py FastAPI backend
|
| 26 |
├── requirements.txt
|
|
|
|
| 78 |
|
| 79 |
9 participants, 144,793 samples, 10 features, binary labels. Collect with `python -m models.collect_features --name <name>`. Data lives in `data/collected_<name>/`.
|
| 80 |
|
| 81 |
+
## Models
|
| 82 |
+
|
| 83 |
+
| Model | What it uses | Best for |
|
| 84 |
+
|-------|-------------|----------|
|
| 85 |
+
| **Geometric** | Head pose angles + eye aspect ratio (EAR) | Fast, no ML needed |
|
| 86 |
+
| **XGBoost** | Trained classifier on head/eye features (600 trees, depth 8) | Balanced accuracy/speed |
|
| 87 |
+
| **MLP** | Neural network on same features (64->32) | Higher accuracy |
|
| 88 |
+
| **Hybrid** | Weighted MLP + Geometric ensemble | Best head-pose accuracy |
|
| 89 |
+
| **L2CS** | Deep gaze estimation (ResNet50, Gaze360 weights) | Detects eye-only gaze shifts |
|
| 90 |
+
|
| 91 |
## Model numbers (15% test split)
|
| 92 |
|
| 93 |
| Model | Accuracy | F1 | ROC-AUC |
|
| 94 |
|-------|----------|-----|---------|
|
| 95 |
| XGBoost (600 trees, depth 8) | 95.87% | 0.959 | 0.991 |
|
| 96 |
+
| MLP (64->32) | 92.92% | 0.929 | 0.971 |
|
| 97 |
+
|
| 98 |
+
## L2CS Gaze Tracking
|
| 99 |
+
|
| 100 |
+
L2CS-Net predicts where your eyes are looking, not just where your head is pointed. This catches the scenario where your head faces the screen but your eyes wander.
|
| 101 |
+
|
| 102 |
+
### Standalone mode
|
| 103 |
+
Select **L2CS** as the model - it handles everything.
|
| 104 |
+
|
| 105 |
+
### Boost mode
|
| 106 |
+
Select any other model, then click the **GAZE** toggle. L2CS runs alongside the base model:
|
| 107 |
+
- Base model handles head pose and eye openness (35% weight)
|
| 108 |
+
- L2CS handles gaze direction (65% weight)
|
| 109 |
+
- If L2CS detects gaze is clearly off-screen, it **vetoes** the base model regardless of score
|
| 110 |
+
|
| 111 |
+
### Calibration
|
| 112 |
+
After enabling L2CS or Gaze Boost, click **Calibrate** while a session is running:
|
| 113 |
+
1. A fullscreen overlay shows 9 target dots (3x3 grid)
|
| 114 |
+
2. Look at each dot as the progress ring fills
|
| 115 |
+
3. The first dot (centre) sets your baseline gaze offset
|
| 116 |
+
4. After all 9 points, a polynomial model maps your gaze angles to screen coordinates
|
| 117 |
+
5. A cyan tracking dot appears on the video showing where you're looking
|
| 118 |
|
| 119 |
## Pipeline
|
| 120 |
|
| 121 |
1. Face mesh (MediaPipe 478 pts)
|
| 122 |
+
2. Head pose -> yaw, pitch, roll, scores, gaze offset
|
| 123 |
+
3. Eye scorer -> EAR, gaze ratio, MAR
|
| 124 |
+
4. Temporal -> PERCLOS, blink rate, yawn
|
| 125 |
+
5. 10-d vector -> MLP or XGBoost -> focused / unfocused
|
| 126 |
|
| 127 |
+
**Stack:** FastAPI, aiosqlite, React/Vite, PyTorch, XGBoost, MediaPipe, OpenCV, L2CS-Net.
|
checkpoints/L2CSNet_gaze360.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a7f3480d868dd48261e1d59f915b0ef0bb33ea12ea00938fb2168f212080665
|
| 3 |
+
size 95849977
|
download_l2cs_weights.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Downloads L2CS-Net Gaze360 weights into checkpoints/
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
CHECKPOINTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints")
|
| 8 |
+
DEST = os.path.join(CHECKPOINTS_DIR, "L2CSNet_gaze360.pkl")
|
| 9 |
+
GDRIVE_ID = "1dL2Jokb19_SBSHAhKHOxJsmYs5-GoyLo"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
if os.path.isfile(DEST):
|
| 14 |
+
print(f"[OK] Weights already at {DEST}")
|
| 15 |
+
return
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import gdown
|
| 19 |
+
except ImportError:
|
| 20 |
+
print("gdown not installed. Run: pip install gdown")
|
| 21 |
+
sys.exit(1)
|
| 22 |
+
|
| 23 |
+
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
|
| 24 |
+
print(f"Downloading L2CS-Net weights to {DEST} ...")
|
| 25 |
+
gdown.download(f"https://drive.google.com/uc?id={GDRIVE_ID}", DEST, quiet=False)
|
| 26 |
+
|
| 27 |
+
if os.path.isfile(DEST):
|
| 28 |
+
print(f"[OK] Downloaded ({os.path.getsize(DEST) / 1024 / 1024:.1f} MB)")
|
| 29 |
+
else:
|
| 30 |
+
print("[ERR] Download failed. Manual download:")
|
| 31 |
+
print(" https://drive.google.com/drive/folders/17p6ORr-JQJcw-eYtG2WGNiuS_qVKwdWd")
|
| 32 |
+
print(f" Place L2CSNet_gaze360.pkl in {CHECKPOINTS_DIR}/")
|
| 33 |
+
sys.exit(1)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
main()
|
main.py
CHANGED
|
@@ -25,7 +25,10 @@ from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
|
|
| 25 |
from av import VideoFrame
|
| 26 |
|
| 27 |
from mediapipe.tasks.python.vision import FaceLandmarksConnections
|
| 28 |
-
from ui.pipeline import
|
|
|
|
|
|
|
|
|
|
| 29 |
from models.face_mesh import FaceMeshDetector
|
| 30 |
|
| 31 |
# ================ FACE MESH DRAWING (server-side, for WebRTC) ================
|
|
@@ -212,17 +215,7 @@ app.add_middleware(
|
|
| 212 |
db_path = "focus_guard.db"
|
| 213 |
pcs = set()
|
| 214 |
_cached_model_name = "mlp"
|
| 215 |
-
|
| 216 |
-
"geometric": None,
|
| 217 |
-
"mlp": None,
|
| 218 |
-
"hybrid": None,
|
| 219 |
-
"xgboost": None,
|
| 220 |
-
}
|
| 221 |
-
_inference_executor = concurrent.futures.ThreadPoolExecutor(
|
| 222 |
-
max_workers=4,
|
| 223 |
-
thread_name_prefix="inference",
|
| 224 |
-
)
|
| 225 |
-
_pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost")}
|
| 226 |
|
| 227 |
async def _wait_for_ice_gathering(pc: RTCPeerConnection):
|
| 228 |
if pc.iceGatheringState == "complete":
|
|
@@ -302,6 +295,7 @@ class SettingsUpdate(BaseModel):
|
|
| 302 |
notification_threshold: Optional[int] = None
|
| 303 |
frame_rate: Optional[int] = None
|
| 304 |
model_name: Optional[str] = None
|
|
|
|
| 305 |
|
| 306 |
class VideoTransformTrack(VideoStreamTrack):
|
| 307 |
def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
|
|
@@ -329,6 +323,8 @@ class VideoTransformTrack(VideoStreamTrack):
|
|
| 329 |
self.last_inference_time = now
|
| 330 |
|
| 331 |
model_name = _cached_model_name
|
|
|
|
|
|
|
| 332 |
if model_name not in pipelines or pipelines.get(model_name) is None:
|
| 333 |
model_name = 'mlp'
|
| 334 |
active_pipeline = pipelines.get(model_name)
|
|
@@ -513,10 +509,56 @@ class _EventBuffer:
|
|
| 513 |
except Exception as e:
|
| 514 |
print(f"[DB] Flush error: {e}")
|
| 515 |
|
| 516 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
with _pipeline_locks[model_name]:
|
| 518 |
return pipeline.process_frame(frame)
|
| 519 |
|
|
|
|
| 520 |
def _first_available_pipeline_name(preferred: str | None = None) -> str | None:
|
| 521 |
if preferred and preferred in pipelines and pipelines.get(preferred) is not None:
|
| 522 |
return preferred
|
|
@@ -525,6 +567,96 @@ def _first_available_pipeline_name(preferred: str | None = None) -> str | None:
|
|
| 525 |
return name
|
| 526 |
return None
|
| 527 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
# ================ WEBRTC SIGNALING ================
|
| 529 |
|
| 530 |
@app.post("/api/webrtc/offer")
|
|
@@ -590,14 +722,19 @@ async def webrtc_offer(offer: dict):
|
|
| 590 |
|
| 591 |
@app.websocket("/ws/video")
|
| 592 |
async def websocket_endpoint(websocket: WebSocket):
|
|
|
|
|
|
|
|
|
|
| 593 |
await websocket.accept()
|
| 594 |
session_id = None
|
| 595 |
frame_count = 0
|
| 596 |
running = True
|
| 597 |
event_buffer = _EventBuffer(flush_interval=2.0)
|
| 598 |
|
| 599 |
-
#
|
| 600 |
-
|
|
|
|
|
|
|
| 601 |
_slot = {"frame": None}
|
| 602 |
_frame_ready = asyncio.Event()
|
| 603 |
|
|
@@ -628,7 +765,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 628 |
data = json.loads(text)
|
| 629 |
|
| 630 |
if data["type"] == "frame":
|
| 631 |
-
# Legacy base64 path (fallback)
|
| 632 |
_slot["frame"] = base64.b64decode(data["image"])
|
| 633 |
_frame_ready.set()
|
| 634 |
|
|
@@ -647,6 +783,47 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 647 |
if summary:
|
| 648 |
await websocket.send_json({"type": "session_ended", "summary": summary})
|
| 649 |
session_id = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
except WebSocketDisconnect:
|
| 651 |
running = False
|
| 652 |
_frame_ready.set()
|
|
@@ -665,7 +842,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 665 |
if not running:
|
| 666 |
return
|
| 667 |
|
| 668 |
-
# Grab latest frame and clear slot
|
| 669 |
raw = _slot["frame"]
|
| 670 |
_slot["frame"] = None
|
| 671 |
if raw is None:
|
|
@@ -678,36 +854,87 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 678 |
continue
|
| 679 |
frame = cv2.resize(frame, (640, 480))
|
| 680 |
|
| 681 |
-
|
| 682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
|
| 684 |
landmarks_list = None
|
|
|
|
| 685 |
if active_pipeline is not None:
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 693 |
is_focused = out["is_focused"]
|
| 694 |
confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
|
| 695 |
|
| 696 |
lm = out.get("landmarks")
|
| 697 |
if lm is not None:
|
| 698 |
-
# Send all 478 landmarks as flat array for tessellation drawing
|
| 699 |
landmarks_list = [
|
| 700 |
[round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
|
| 701 |
for i in range(lm.shape[0])
|
| 702 |
]
|
| 703 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 704 |
if session_id:
|
| 705 |
-
|
| 706 |
"s_face": out.get("s_face", 0.0),
|
| 707 |
"s_eye": out.get("s_eye", 0.0),
|
| 708 |
"mar": out.get("mar", 0.0),
|
| 709 |
"model": model_name,
|
| 710 |
-
}
|
|
|
|
| 711 |
else:
|
| 712 |
is_focused = False
|
| 713 |
confidence = 0.0
|
|
@@ -721,8 +948,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 721 |
"fc": frame_count,
|
| 722 |
"frame_count": frame_count,
|
| 723 |
}
|
| 724 |
-
if
|
| 725 |
-
# Send detailed metrics for HUD
|
| 726 |
if out.get("yaw") is not None:
|
| 727 |
resp["yaw"] = round(out["yaw"], 1)
|
| 728 |
resp["pitch"] = round(out["pitch"], 1)
|
|
@@ -731,6 +957,24 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 731 |
resp["mar"] = round(out["mar"], 3)
|
| 732 |
resp["sf"] = round(out.get("s_face", 0), 3)
|
| 733 |
resp["se"] = round(out.get("s_eye", 0), 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
if landmarks_list is not None:
|
| 735 |
resp["lm"] = landmarks_list
|
| 736 |
await websocket.send_json(resp)
|
|
@@ -863,8 +1107,9 @@ async def get_settings():
|
|
| 863 |
db.row_factory = aiosqlite.Row
|
| 864 |
cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
|
| 865 |
row = await cursor.fetchone()
|
| 866 |
-
|
| 867 |
-
|
|
|
|
| 868 |
|
| 869 |
@app.put("/api/settings")
|
| 870 |
async def update_settings(settings: SettingsUpdate):
|
|
@@ -889,12 +1134,28 @@ async def update_settings(settings: SettingsUpdate):
|
|
| 889 |
if settings.frame_rate is not None:
|
| 890 |
updates.append("frame_rate = ?")
|
| 891 |
params.append(max(5, min(60, settings.frame_rate)))
|
| 892 |
-
if settings.model_name is not None and settings.model_name in pipelines
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
updates.append("model_name = ?")
|
| 894 |
params.append(settings.model_name)
|
| 895 |
global _cached_model_name
|
| 896 |
_cached_model_name = settings.model_name
|
| 897 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 898 |
if updates:
|
| 899 |
query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
|
| 900 |
await db.execute(query, params)
|
|
@@ -946,15 +1207,55 @@ async def get_stats_summary():
|
|
| 946 |
|
| 947 |
@app.get("/api/models")
|
| 948 |
async def get_available_models():
|
| 949 |
-
"""Return
|
| 950 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 951 |
async with aiosqlite.connect(db_path) as db:
|
| 952 |
cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
|
| 953 |
row = await cursor.fetchone()
|
| 954 |
current = row[0] if row else "mlp"
|
| 955 |
if current not in available and available:
|
| 956 |
current = available[0]
|
| 957 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 958 |
|
| 959 |
@app.get("/api/mesh-topology")
|
| 960 |
async def get_mesh_topology():
|
|
|
|
| 25 |
from av import VideoFrame
|
| 26 |
|
| 27 |
from mediapipe.tasks.python.vision import FaceLandmarksConnections
|
| 28 |
+
from ui.pipeline import (
|
| 29 |
+
FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline,
|
| 30 |
+
L2CSPipeline, is_l2cs_weights_available,
|
| 31 |
+
)
|
| 32 |
from models.face_mesh import FaceMeshDetector
|
| 33 |
|
| 34 |
# ================ FACE MESH DRAWING (server-side, for WebRTC) ================
|
|
|
|
| 215 |
db_path = "focus_guard.db"
|
| 216 |
pcs = set()
|
| 217 |
_cached_model_name = "mlp"
|
| 218 |
+
_l2cs_boost_enabled = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
async def _wait_for_ice_gathering(pc: RTCPeerConnection):
|
| 221 |
if pc.iceGatheringState == "complete":
|
|
|
|
| 295 |
notification_threshold: Optional[int] = None
|
| 296 |
frame_rate: Optional[int] = None
|
| 297 |
model_name: Optional[str] = None
|
| 298 |
+
l2cs_boost: Optional[bool] = None
|
| 299 |
|
| 300 |
class VideoTransformTrack(VideoStreamTrack):
|
| 301 |
def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
|
|
|
|
| 323 |
self.last_inference_time = now
|
| 324 |
|
| 325 |
model_name = _cached_model_name
|
| 326 |
+
if model_name == "l2cs" and pipelines.get("l2cs") is None:
|
| 327 |
+
_ensure_l2cs()
|
| 328 |
if model_name not in pipelines or pipelines.get(model_name) is None:
|
| 329 |
model_name = 'mlp'
|
| 330 |
active_pipeline = pipelines.get(model_name)
|
|
|
|
| 509 |
except Exception as e:
|
| 510 |
print(f"[DB] Flush error: {e}")
|
| 511 |
|
| 512 |
+
# ================ STARTUP/SHUTDOWN ================
|
| 513 |
+
|
| 514 |
+
pipelines = {
|
| 515 |
+
"geometric": None,
|
| 516 |
+
"mlp": None,
|
| 517 |
+
"hybrid": None,
|
| 518 |
+
"xgboost": None,
|
| 519 |
+
"l2cs": None,
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
# Thread pool for CPU-bound inference so the event loop stays responsive.
|
| 523 |
+
_inference_executor = concurrent.futures.ThreadPoolExecutor(
|
| 524 |
+
max_workers=4,
|
| 525 |
+
thread_name_prefix="inference",
|
| 526 |
+
)
|
| 527 |
+
# One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
|
| 528 |
+
# multiple frames are processed in parallel by the thread pool.
|
| 529 |
+
_pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost", "l2cs")}
|
| 530 |
+
|
| 531 |
+
_l2cs_load_lock = threading.Lock()
|
| 532 |
+
_l2cs_error: str | None = None
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def _ensure_l2cs():
|
| 536 |
+
# lazy-load L2CS on first use, double-checked locking
|
| 537 |
+
global _l2cs_error
|
| 538 |
+
if pipelines["l2cs"] is not None:
|
| 539 |
+
return True
|
| 540 |
+
with _l2cs_load_lock:
|
| 541 |
+
if pipelines["l2cs"] is not None:
|
| 542 |
+
return True
|
| 543 |
+
if not is_l2cs_weights_available():
|
| 544 |
+
_l2cs_error = "Weights not found"
|
| 545 |
+
return False
|
| 546 |
+
try:
|
| 547 |
+
pipelines["l2cs"] = L2CSPipeline()
|
| 548 |
+
_l2cs_error = None
|
| 549 |
+
print("[OK] L2CSPipeline lazy-loaded")
|
| 550 |
+
return True
|
| 551 |
+
except Exception as e:
|
| 552 |
+
_l2cs_error = str(e)
|
| 553 |
+
print(f"[ERR] L2CS lazy-load failed: {e}")
|
| 554 |
+
return False
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def _process_frame_safe(pipeline, frame, model_name):
|
| 558 |
with _pipeline_locks[model_name]:
|
| 559 |
return pipeline.process_frame(frame)
|
| 560 |
|
| 561 |
+
|
| 562 |
def _first_available_pipeline_name(preferred: str | None = None) -> str | None:
|
| 563 |
if preferred and preferred in pipelines and pipelines.get(preferred) is not None:
|
| 564 |
return preferred
|
|
|
|
| 567 |
return name
|
| 568 |
return None
|
| 569 |
|
| 570 |
+
|
| 571 |
+
_BOOST_BASE_W = 0.35
|
| 572 |
+
_BOOST_L2CS_W = 0.65
|
| 573 |
+
_BOOST_VETO = 0.38 # L2CS below this -> forced not-focused
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def _process_frame_with_l2cs_boost(base_pipeline, frame, base_model_name):
|
| 577 |
+
# run base model
|
| 578 |
+
with _pipeline_locks[base_model_name]:
|
| 579 |
+
base_out = base_pipeline.process_frame(frame)
|
| 580 |
+
|
| 581 |
+
l2cs_pipe = pipelines.get("l2cs")
|
| 582 |
+
if l2cs_pipe is None:
|
| 583 |
+
base_out["boost_active"] = False
|
| 584 |
+
return base_out
|
| 585 |
+
|
| 586 |
+
# run L2CS
|
| 587 |
+
with _pipeline_locks["l2cs"]:
|
| 588 |
+
l2cs_out = l2cs_pipe.process_frame(frame)
|
| 589 |
+
|
| 590 |
+
base_score = base_out.get("mlp_prob", base_out.get("raw_score", 0.0))
|
| 591 |
+
l2cs_score = l2cs_out.get("raw_score", 0.0)
|
| 592 |
+
|
| 593 |
+
# veto: gaze clearly off-screen overrides base model
|
| 594 |
+
if l2cs_score < _BOOST_VETO:
|
| 595 |
+
fused_score = l2cs_score * 0.8
|
| 596 |
+
is_focused = False
|
| 597 |
+
else:
|
| 598 |
+
fused_score = _BOOST_BASE_W * base_score + _BOOST_L2CS_W * l2cs_score
|
| 599 |
+
is_focused = fused_score >= 0.52
|
| 600 |
+
|
| 601 |
+
base_out["raw_score"] = fused_score
|
| 602 |
+
base_out["is_focused"] = is_focused
|
| 603 |
+
base_out["boost_active"] = True
|
| 604 |
+
base_out["base_score"] = round(base_score, 3)
|
| 605 |
+
base_out["l2cs_score"] = round(l2cs_score, 3)
|
| 606 |
+
|
| 607 |
+
if l2cs_out.get("gaze_yaw") is not None:
|
| 608 |
+
base_out["gaze_yaw"] = l2cs_out["gaze_yaw"]
|
| 609 |
+
base_out["gaze_pitch"] = l2cs_out["gaze_pitch"]
|
| 610 |
+
|
| 611 |
+
return base_out
|
| 612 |
+
|
| 613 |
+
@app.on_event("startup")
|
| 614 |
+
async def startup_event():
|
| 615 |
+
global pipelines, _cached_model_name
|
| 616 |
+
print(" Starting Focus Guard API...")
|
| 617 |
+
await init_database()
|
| 618 |
+
# Load cached model name from DB
|
| 619 |
+
async with aiosqlite.connect(db_path) as db:
|
| 620 |
+
cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
|
| 621 |
+
row = await cursor.fetchone()
|
| 622 |
+
if row:
|
| 623 |
+
_cached_model_name = row[0]
|
| 624 |
+
print("[OK] Database initialized")
|
| 625 |
+
|
| 626 |
+
try:
|
| 627 |
+
pipelines["geometric"] = FaceMeshPipeline()
|
| 628 |
+
print("[OK] FaceMeshPipeline (geometric) loaded")
|
| 629 |
+
except Exception as e:
|
| 630 |
+
print(f"[WARN] FaceMeshPipeline unavailable: {e}")
|
| 631 |
+
|
| 632 |
+
try:
|
| 633 |
+
pipelines["mlp"] = MLPPipeline()
|
| 634 |
+
print("[OK] MLPPipeline loaded")
|
| 635 |
+
except Exception as e:
|
| 636 |
+
print(f"[ERR] Failed to load MLPPipeline: {e}")
|
| 637 |
+
|
| 638 |
+
try:
|
| 639 |
+
pipelines["hybrid"] = HybridFocusPipeline()
|
| 640 |
+
print("[OK] HybridFocusPipeline loaded")
|
| 641 |
+
except Exception as e:
|
| 642 |
+
print(f"[WARN] HybridFocusPipeline unavailable: {e}")
|
| 643 |
+
|
| 644 |
+
try:
|
| 645 |
+
pipelines["xgboost"] = XGBoostPipeline()
|
| 646 |
+
print("[OK] XGBoostPipeline loaded")
|
| 647 |
+
except Exception as e:
|
| 648 |
+
print(f"[ERR] Failed to load XGBoostPipeline: {e}")
|
| 649 |
+
|
| 650 |
+
if is_l2cs_weights_available():
|
| 651 |
+
print("[OK] L2CS weights found — pipeline will be lazy-loaded on first use")
|
| 652 |
+
else:
|
| 653 |
+
print("[WARN] L2CS weights not found — l2cs model unavailable")
|
| 654 |
+
|
| 655 |
+
@app.on_event("shutdown")
|
| 656 |
+
async def shutdown_event():
|
| 657 |
+
_inference_executor.shutdown(wait=False)
|
| 658 |
+
print(" Shutting down Focus Guard API...")
|
| 659 |
+
|
| 660 |
# ================ WEBRTC SIGNALING ================
|
| 661 |
|
| 662 |
@app.post("/api/webrtc/offer")
|
|
|
|
| 722 |
|
| 723 |
@app.websocket("/ws/video")
|
| 724 |
async def websocket_endpoint(websocket: WebSocket):
|
| 725 |
+
from models.gaze_calibration import GazeCalibration
|
| 726 |
+
from models.gaze_eye_fusion import GazeEyeFusion
|
| 727 |
+
|
| 728 |
await websocket.accept()
|
| 729 |
session_id = None
|
| 730 |
frame_count = 0
|
| 731 |
running = True
|
| 732 |
event_buffer = _EventBuffer(flush_interval=2.0)
|
| 733 |
|
| 734 |
+
# Calibration state (per-connection)
|
| 735 |
+
_cal: dict = {"cal": None, "collecting": False, "fusion": None}
|
| 736 |
+
|
| 737 |
+
# Latest frame slot — only the most recent frame is kept, older ones are dropped.
|
| 738 |
_slot = {"frame": None}
|
| 739 |
_frame_ready = asyncio.Event()
|
| 740 |
|
|
|
|
| 765 |
data = json.loads(text)
|
| 766 |
|
| 767 |
if data["type"] == "frame":
|
|
|
|
| 768 |
_slot["frame"] = base64.b64decode(data["image"])
|
| 769 |
_frame_ready.set()
|
| 770 |
|
|
|
|
| 783 |
if summary:
|
| 784 |
await websocket.send_json({"type": "session_ended", "summary": summary})
|
| 785 |
session_id = None
|
| 786 |
+
|
| 787 |
+
# ---- Calibration commands ----
|
| 788 |
+
elif data["type"] == "calibration_start":
|
| 789 |
+
loop = asyncio.get_event_loop()
|
| 790 |
+
await loop.run_in_executor(_inference_executor, _ensure_l2cs)
|
| 791 |
+
_cal["cal"] = GazeCalibration()
|
| 792 |
+
_cal["collecting"] = True
|
| 793 |
+
_cal["fusion"] = None
|
| 794 |
+
cal = _cal["cal"]
|
| 795 |
+
await websocket.send_json({
|
| 796 |
+
"type": "calibration_started",
|
| 797 |
+
"num_points": cal.num_points,
|
| 798 |
+
"target": list(cal.current_target),
|
| 799 |
+
"index": cal.current_index,
|
| 800 |
+
})
|
| 801 |
+
|
| 802 |
+
elif data["type"] == "calibration_next":
|
| 803 |
+
cal = _cal.get("cal")
|
| 804 |
+
if cal is not None:
|
| 805 |
+
more = cal.advance()
|
| 806 |
+
if more:
|
| 807 |
+
await websocket.send_json({
|
| 808 |
+
"type": "calibration_point",
|
| 809 |
+
"target": list(cal.current_target),
|
| 810 |
+
"index": cal.current_index,
|
| 811 |
+
})
|
| 812 |
+
else:
|
| 813 |
+
_cal["collecting"] = False
|
| 814 |
+
ok = cal.fit()
|
| 815 |
+
if ok:
|
| 816 |
+
_cal["fusion"] = GazeEyeFusion(cal)
|
| 817 |
+
await websocket.send_json({"type": "calibration_done", "success": True})
|
| 818 |
+
else:
|
| 819 |
+
await websocket.send_json({"type": "calibration_done", "success": False, "error": "Not enough samples"})
|
| 820 |
+
|
| 821 |
+
elif data["type"] == "calibration_cancel":
|
| 822 |
+
_cal["cal"] = None
|
| 823 |
+
_cal["collecting"] = False
|
| 824 |
+
_cal["fusion"] = None
|
| 825 |
+
await websocket.send_json({"type": "calibration_cancelled"})
|
| 826 |
+
|
| 827 |
except WebSocketDisconnect:
|
| 828 |
running = False
|
| 829 |
_frame_ready.set()
|
|
|
|
| 842 |
if not running:
|
| 843 |
return
|
| 844 |
|
|
|
|
| 845 |
raw = _slot["frame"]
|
| 846 |
_slot["frame"] = None
|
| 847 |
if raw is None:
|
|
|
|
| 854 |
continue
|
| 855 |
frame = cv2.resize(frame, (640, 480))
|
| 856 |
|
| 857 |
+
# During calibration collection, always use L2CS
|
| 858 |
+
collecting = _cal.get("collecting", False)
|
| 859 |
+
if collecting:
|
| 860 |
+
if pipelines.get("l2cs") is None:
|
| 861 |
+
await loop.run_in_executor(_inference_executor, _ensure_l2cs)
|
| 862 |
+
use_model = "l2cs" if pipelines.get("l2cs") is not None else _cached_model_name
|
| 863 |
+
else:
|
| 864 |
+
use_model = _cached_model_name
|
| 865 |
+
|
| 866 |
+
model_name = use_model
|
| 867 |
+
if model_name == "l2cs" and pipelines.get("l2cs") is None:
|
| 868 |
+
await loop.run_in_executor(_inference_executor, _ensure_l2cs)
|
| 869 |
+
if model_name not in pipelines or pipelines.get(model_name) is None:
|
| 870 |
+
model_name = "mlp"
|
| 871 |
+
active_pipeline = pipelines.get(model_name)
|
| 872 |
+
|
| 873 |
+
# L2CS boost: run L2CS alongside base model
|
| 874 |
+
use_boost = (
|
| 875 |
+
_l2cs_boost_enabled
|
| 876 |
+
and model_name != "l2cs"
|
| 877 |
+
and pipelines.get("l2cs") is not None
|
| 878 |
+
and not collecting
|
| 879 |
+
)
|
| 880 |
|
| 881 |
landmarks_list = None
|
| 882 |
+
out = None
|
| 883 |
if active_pipeline is not None:
|
| 884 |
+
if use_boost:
|
| 885 |
+
out = await loop.run_in_executor(
|
| 886 |
+
_inference_executor,
|
| 887 |
+
_process_frame_with_l2cs_boost,
|
| 888 |
+
active_pipeline,
|
| 889 |
+
frame,
|
| 890 |
+
model_name,
|
| 891 |
+
)
|
| 892 |
+
else:
|
| 893 |
+
out = await loop.run_in_executor(
|
| 894 |
+
_inference_executor,
|
| 895 |
+
_process_frame_safe,
|
| 896 |
+
active_pipeline,
|
| 897 |
+
frame,
|
| 898 |
+
model_name,
|
| 899 |
+
)
|
| 900 |
is_focused = out["is_focused"]
|
| 901 |
confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
|
| 902 |
|
| 903 |
lm = out.get("landmarks")
|
| 904 |
if lm is not None:
|
|
|
|
| 905 |
landmarks_list = [
|
| 906 |
[round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
|
| 907 |
for i in range(lm.shape[0])
|
| 908 |
]
|
| 909 |
|
| 910 |
+
# Calibration sample collection (L2CS gaze angles)
|
| 911 |
+
if collecting and _cal.get("cal") is not None:
|
| 912 |
+
pipe_yaw = out.get("gaze_yaw")
|
| 913 |
+
pipe_pitch = out.get("gaze_pitch")
|
| 914 |
+
if pipe_yaw is not None and pipe_pitch is not None:
|
| 915 |
+
_cal["cal"].collect_sample(pipe_yaw, pipe_pitch)
|
| 916 |
+
|
| 917 |
+
# Gaze fusion (when L2CS active + calibration fitted)
|
| 918 |
+
fusion = _cal.get("fusion")
|
| 919 |
+
if (
|
| 920 |
+
fusion is not None
|
| 921 |
+
and model_name == "l2cs"
|
| 922 |
+
and out.get("gaze_yaw") is not None
|
| 923 |
+
):
|
| 924 |
+
fuse = fusion.update(
|
| 925 |
+
out["gaze_yaw"], out["gaze_pitch"], lm
|
| 926 |
+
)
|
| 927 |
+
is_focused = fuse["focused"]
|
| 928 |
+
confidence = fuse["focus_score"]
|
| 929 |
+
|
| 930 |
if session_id:
|
| 931 |
+
metadata = {
|
| 932 |
"s_face": out.get("s_face", 0.0),
|
| 933 |
"s_eye": out.get("s_eye", 0.0),
|
| 934 |
"mar": out.get("mar", 0.0),
|
| 935 |
"model": model_name,
|
| 936 |
+
}
|
| 937 |
+
event_buffer.add(session_id, is_focused, confidence, metadata)
|
| 938 |
else:
|
| 939 |
is_focused = False
|
| 940 |
confidence = 0.0
|
|
|
|
| 948 |
"fc": frame_count,
|
| 949 |
"frame_count": frame_count,
|
| 950 |
}
|
| 951 |
+
if out is not None:
|
|
|
|
| 952 |
if out.get("yaw") is not None:
|
| 953 |
resp["yaw"] = round(out["yaw"], 1)
|
| 954 |
resp["pitch"] = round(out["pitch"], 1)
|
|
|
|
| 957 |
resp["mar"] = round(out["mar"], 3)
|
| 958 |
resp["sf"] = round(out.get("s_face", 0), 3)
|
| 959 |
resp["se"] = round(out.get("s_eye", 0), 3)
|
| 960 |
+
|
| 961 |
+
# Gaze fusion fields (L2CS standalone or boost mode)
|
| 962 |
+
fusion = _cal.get("fusion")
|
| 963 |
+
has_gaze = out.get("gaze_yaw") is not None
|
| 964 |
+
if fusion is not None and has_gaze and (model_name == "l2cs" or use_boost):
|
| 965 |
+
fuse = fusion.update(out["gaze_yaw"], out["gaze_pitch"], out.get("landmarks"))
|
| 966 |
+
resp["gaze_x"] = fuse["gaze_x"]
|
| 967 |
+
resp["gaze_y"] = fuse["gaze_y"]
|
| 968 |
+
resp["on_screen"] = fuse["on_screen"]
|
| 969 |
+
if model_name == "l2cs":
|
| 970 |
+
resp["focused"] = fuse["focused"]
|
| 971 |
+
resp["confidence"] = round(fuse["focus_score"], 3)
|
| 972 |
+
|
| 973 |
+
if out.get("boost_active"):
|
| 974 |
+
resp["boost"] = True
|
| 975 |
+
resp["base_score"] = out.get("base_score", 0)
|
| 976 |
+
resp["l2cs_score"] = out.get("l2cs_score", 0)
|
| 977 |
+
|
| 978 |
if landmarks_list is not None:
|
| 979 |
resp["lm"] = landmarks_list
|
| 980 |
await websocket.send_json(resp)
|
|
|
|
| 1107 |
db.row_factory = aiosqlite.Row
|
| 1108 |
cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
|
| 1109 |
row = await cursor.fetchone()
|
| 1110 |
+
result = dict(row) if row else {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
|
| 1111 |
+
result['l2cs_boost'] = _l2cs_boost_enabled
|
| 1112 |
+
return result
|
| 1113 |
|
| 1114 |
@app.put("/api/settings")
|
| 1115 |
async def update_settings(settings: SettingsUpdate):
|
|
|
|
| 1134 |
if settings.frame_rate is not None:
|
| 1135 |
updates.append("frame_rate = ?")
|
| 1136 |
params.append(max(5, min(60, settings.frame_rate)))
|
| 1137 |
+
if settings.model_name is not None and settings.model_name in pipelines:
|
| 1138 |
+
if settings.model_name == "l2cs":
|
| 1139 |
+
loop = asyncio.get_event_loop()
|
| 1140 |
+
loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs)
|
| 1141 |
+
if not loaded:
|
| 1142 |
+
raise HTTPException(status_code=400, detail=f"L2CS model unavailable: {_l2cs_error}")
|
| 1143 |
+
elif pipelines[settings.model_name] is None:
|
| 1144 |
+
raise HTTPException(status_code=400, detail=f"Model '{settings.model_name}' not loaded")
|
| 1145 |
updates.append("model_name = ?")
|
| 1146 |
params.append(settings.model_name)
|
| 1147 |
global _cached_model_name
|
| 1148 |
_cached_model_name = settings.model_name
|
| 1149 |
|
| 1150 |
+
if settings.l2cs_boost is not None:
|
| 1151 |
+
global _l2cs_boost_enabled
|
| 1152 |
+
if settings.l2cs_boost:
|
| 1153 |
+
loop = asyncio.get_event_loop()
|
| 1154 |
+
loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs)
|
| 1155 |
+
if not loaded:
|
| 1156 |
+
raise HTTPException(status_code=400, detail=f"L2CS boost unavailable: {_l2cs_error}")
|
| 1157 |
+
_l2cs_boost_enabled = settings.l2cs_boost
|
| 1158 |
+
|
| 1159 |
if updates:
|
| 1160 |
query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
|
| 1161 |
await db.execute(query, params)
|
|
|
|
| 1207 |
|
| 1208 |
@app.get("/api/models")
|
| 1209 |
async def get_available_models():
|
| 1210 |
+
"""Return model names, statuses, and which is currently active."""
|
| 1211 |
+
statuses = {}
|
| 1212 |
+
errors = {}
|
| 1213 |
+
available = []
|
| 1214 |
+
for name, p in pipelines.items():
|
| 1215 |
+
if name == "l2cs":
|
| 1216 |
+
if p is not None:
|
| 1217 |
+
statuses[name] = "ready"
|
| 1218 |
+
available.append(name)
|
| 1219 |
+
elif is_l2cs_weights_available():
|
| 1220 |
+
statuses[name] = "lazy"
|
| 1221 |
+
available.append(name)
|
| 1222 |
+
elif _l2cs_error:
|
| 1223 |
+
statuses[name] = "error"
|
| 1224 |
+
errors[name] = _l2cs_error
|
| 1225 |
+
else:
|
| 1226 |
+
statuses[name] = "unavailable"
|
| 1227 |
+
elif p is not None:
|
| 1228 |
+
statuses[name] = "ready"
|
| 1229 |
+
available.append(name)
|
| 1230 |
+
else:
|
| 1231 |
+
statuses[name] = "unavailable"
|
| 1232 |
async with aiosqlite.connect(db_path) as db:
|
| 1233 |
cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
|
| 1234 |
row = await cursor.fetchone()
|
| 1235 |
current = row[0] if row else "mlp"
|
| 1236 |
if current not in available and available:
|
| 1237 |
current = available[0]
|
| 1238 |
+
l2cs_boost_available = (
|
| 1239 |
+
statuses.get("l2cs") in ("ready", "lazy") and current != "l2cs"
|
| 1240 |
+
)
|
| 1241 |
+
return {
|
| 1242 |
+
"available": available,
|
| 1243 |
+
"current": current,
|
| 1244 |
+
"statuses": statuses,
|
| 1245 |
+
"errors": errors,
|
| 1246 |
+
"l2cs_boost": _l2cs_boost_enabled,
|
| 1247 |
+
"l2cs_boost_available": l2cs_boost_available,
|
| 1248 |
+
}
|
| 1249 |
+
|
| 1250 |
+
@app.get("/api/l2cs/status")
|
| 1251 |
+
async def l2cs_status():
|
| 1252 |
+
"""L2CS-specific status: weights available, loaded, and calibration info."""
|
| 1253 |
+
loaded = pipelines.get("l2cs") is not None
|
| 1254 |
+
return {
|
| 1255 |
+
"weights_available": is_l2cs_weights_available(),
|
| 1256 |
+
"loaded": loaded,
|
| 1257 |
+
"error": _l2cs_error,
|
| 1258 |
+
}
|
| 1259 |
|
| 1260 |
@app.get("/api/mesh-topology")
|
| 1261 |
async def get_mesh_topology():
|
models/L2CS-Net/.gitignore
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore the test data - sensitive
|
| 2 |
+
datasets/
|
| 3 |
+
evaluation/
|
| 4 |
+
output/
|
| 5 |
+
|
| 6 |
+
# Ignore debugging configurations
|
| 7 |
+
/.vscode
|
| 8 |
+
|
| 9 |
+
# Byte-compiled / optimized / DLL files
|
| 10 |
+
__pycache__/
|
| 11 |
+
*.py[cod]
|
| 12 |
+
*$py.class
|
| 13 |
+
|
| 14 |
+
# C extensions
|
| 15 |
+
*.so
|
| 16 |
+
|
| 17 |
+
# Distribution / packaging
|
| 18 |
+
.Python
|
| 19 |
+
build/
|
| 20 |
+
develop-eggs/
|
| 21 |
+
dist/
|
| 22 |
+
downloads/
|
| 23 |
+
eggs/
|
| 24 |
+
.eggs/
|
| 25 |
+
lib/
|
| 26 |
+
lib64/
|
| 27 |
+
parts/
|
| 28 |
+
sdist/
|
| 29 |
+
var/
|
| 30 |
+
wheels/
|
| 31 |
+
pip-wheel-metadata/
|
| 32 |
+
share/python-wheels/
|
| 33 |
+
*.egg-info/
|
| 34 |
+
.installed.cfg
|
| 35 |
+
*.egg
|
| 36 |
+
MANIFEST
|
| 37 |
+
|
| 38 |
+
# PyInstaller
|
| 39 |
+
# Usually these files are written by a python script from a template
|
| 40 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 41 |
+
*.manifest
|
| 42 |
+
*.spec
|
| 43 |
+
|
| 44 |
+
# Installer logs
|
| 45 |
+
pip-log.txt
|
| 46 |
+
pip-delete-this-directory.txt
|
| 47 |
+
|
| 48 |
+
# Unit test / coverage reports
|
| 49 |
+
htmlcov/
|
| 50 |
+
.tox/
|
| 51 |
+
.nox/
|
| 52 |
+
.coverage
|
| 53 |
+
.coverage.*
|
| 54 |
+
.cache
|
| 55 |
+
nosetests.xml
|
| 56 |
+
coverage.xml
|
| 57 |
+
*.cover
|
| 58 |
+
*.py,cover
|
| 59 |
+
.hypothesis/
|
| 60 |
+
.pytest_cache/
|
| 61 |
+
|
| 62 |
+
# Translations
|
| 63 |
+
*.mo
|
| 64 |
+
*.pot
|
| 65 |
+
|
| 66 |
+
# Django stuff:
|
| 67 |
+
*.log
|
| 68 |
+
local_settings.py
|
| 69 |
+
db.sqlite3
|
| 70 |
+
db.sqlite3-journal
|
| 71 |
+
|
| 72 |
+
# Flask stuff:
|
| 73 |
+
instance/
|
| 74 |
+
.webassets-cache
|
| 75 |
+
|
| 76 |
+
# Scrapy stuff:
|
| 77 |
+
.scrapy
|
| 78 |
+
|
| 79 |
+
# Sphinx documentation
|
| 80 |
+
docs/_build/
|
| 81 |
+
|
| 82 |
+
# PyBuilder
|
| 83 |
+
target/
|
| 84 |
+
|
| 85 |
+
# Jupyter Notebook
|
| 86 |
+
.ipynb_checkpoints
|
| 87 |
+
|
| 88 |
+
# IPython
|
| 89 |
+
profile_default/
|
| 90 |
+
ipython_config.py
|
| 91 |
+
|
| 92 |
+
# pyenv
|
| 93 |
+
.python-version
|
| 94 |
+
|
| 95 |
+
# pipenv
|
| 96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 99 |
+
# install all needed dependencies.
|
| 100 |
+
#Pipfile.lock
|
| 101 |
+
|
| 102 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 103 |
+
__pypackages__/
|
| 104 |
+
|
| 105 |
+
# Celery stuff
|
| 106 |
+
celerybeat-schedule
|
| 107 |
+
celerybeat.pid
|
| 108 |
+
|
| 109 |
+
# SageMath parsed files
|
| 110 |
+
*.sage.py
|
| 111 |
+
|
| 112 |
+
# Environments
|
| 113 |
+
.env
|
| 114 |
+
.venv
|
| 115 |
+
env/
|
| 116 |
+
venv/
|
| 117 |
+
ENV/
|
| 118 |
+
env.bak/
|
| 119 |
+
venv.bak/
|
| 120 |
+
|
| 121 |
+
# Spyder project settings
|
| 122 |
+
.spyderproject
|
| 123 |
+
.spyproject
|
| 124 |
+
|
| 125 |
+
# Rope project settings
|
| 126 |
+
.ropeproject
|
| 127 |
+
|
| 128 |
+
# mkdocs documentation
|
| 129 |
+
/site
|
| 130 |
+
|
| 131 |
+
# mypy
|
| 132 |
+
.mypy_cache/
|
| 133 |
+
.dmypy.json
|
| 134 |
+
dmypy.json
|
| 135 |
+
|
| 136 |
+
# Pyre type checker
|
| 137 |
+
.pyre/
|
| 138 |
+
|
| 139 |
+
# Ignore other files
|
| 140 |
+
my.secrets
|
models/L2CS-Net/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2022 Ahmed Abdelrahman
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
models/L2CS-Net/README.md
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
<p align="center">
|
| 5 |
+
<img src="https://github.com/Ahmednull/Storage/blob/main/gaze.gif" alt="animated" />
|
| 6 |
+
</p>
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
___
|
| 10 |
+
|
| 11 |
+
# L2CS-Net
|
| 12 |
+
|
| 13 |
+
The official PyTorch implementation of L2CS-Net for gaze estimation and tracking.
|
| 14 |
+
|
| 15 |
+
## Installation
|
| 16 |
+
<img src="https://img.shields.io/badge/python%20-%2314354C.svg?&style=for-the-badge&logo=python&logoColor=white"/> <img src="https://img.shields.io/badge/PyTorch%20-%23EE4C2C.svg?&style=for-the-badge&logo=PyTorch&logoColor=white" />
|
| 17 |
+
|
| 18 |
+
Install package with the following:
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
pip install git+https://github.com/Ahmednull/L2CS-Net.git@main
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
Or, you can git clone the repo and install with the following:
|
| 25 |
+
|
| 26 |
+
```
|
| 27 |
+
pip install [-e] .
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
Now you should be able to import the package with the following command:
|
| 31 |
+
|
| 32 |
+
```
|
| 33 |
+
$ python
|
| 34 |
+
>>> import l2cs
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
## Usage
|
| 38 |
+
|
| 39 |
+
Detect face and predict gaze from webcam
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
from l2cs import Pipeline, render
|
| 43 |
+
import cv2
|
| 44 |
+
|
| 45 |
+
gaze_pipeline = Pipeline(
|
| 46 |
+
weights=CWD / 'models' / 'L2CSNet_gaze360.pkl',
|
| 47 |
+
arch='ResNet50',
|
| 48 |
+
device=torch.device('cpu') # or 'gpu'
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
cap = cv2.VideoCapture(cam)
|
| 52 |
+
_, frame = cap.read()
|
| 53 |
+
|
| 54 |
+
# Process frame and visualize
|
| 55 |
+
results = gaze_pipeline.step(frame)
|
| 56 |
+
frame = render(frame, results)
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Demo
|
| 60 |
+
* Download the pre-trained models from [here](https://drive.google.com/drive/folders/17p6ORr-JQJcw-eYtG2WGNiuS_qVKwdWd?usp=sharing) and Store it to *models/*.
|
| 61 |
+
* Run:
|
| 62 |
+
```
|
| 63 |
+
python demo.py \
|
| 64 |
+
--snapshot models/L2CSNet_gaze360.pkl \
|
| 65 |
+
--gpu 0 \
|
| 66 |
+
--cam 0 \
|
| 67 |
+
```
|
| 68 |
+
This means the demo will run using *L2CSNet_gaze360.pkl* pretrained model
|
| 69 |
+
|
| 70 |
+
## Community Contributions
|
| 71 |
+
|
| 72 |
+
- [Gaze Detection and Eye Tracking: A How-To Guide](https://blog.roboflow.com/gaze-direction-position/): Use L2CS-Net through a HTTP interface with the open source Roboflow Inference project.
|
| 73 |
+
|
| 74 |
+
## MPIIGaze
|
| 75 |
+
We provide the code for train and test MPIIGaze dataset with leave-one-person-out evaluation.
|
| 76 |
+
|
| 77 |
+
### Prepare datasets
|
| 78 |
+
* Download **MPIIFaceGaze dataset** from [here](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/its-written-all-over-your-face-full-face-appearance-based-gaze-estimation).
|
| 79 |
+
* Apply data preprocessing from [here](http://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/).
|
| 80 |
+
* Store the dataset to *datasets/MPIIFaceGaze*.
|
| 81 |
+
|
| 82 |
+
### Train
|
| 83 |
+
```
|
| 84 |
+
python train.py \
|
| 85 |
+
--dataset mpiigaze \
|
| 86 |
+
--snapshot output/snapshots \
|
| 87 |
+
--gpu 0 \
|
| 88 |
+
--num_epochs 50 \
|
| 89 |
+
--batch_size 16 \
|
| 90 |
+
--lr 0.00001 \
|
| 91 |
+
--alpha 1 \
|
| 92 |
+
|
| 93 |
+
```
|
| 94 |
+
This means the code will perform leave-one-person-out training automatically and store the models to *output/snapshots*.
|
| 95 |
+
|
| 96 |
+
### Test
|
| 97 |
+
```
|
| 98 |
+
python test.py \
|
| 99 |
+
--dataset mpiigaze \
|
| 100 |
+
--snapshot output/snapshots/snapshot_folder \
|
| 101 |
+
--evalpath evaluation/L2CS-mpiigaze \
|
| 102 |
+
--gpu 0 \
|
| 103 |
+
```
|
| 104 |
+
This means the code will perform leave-one-person-out testing automatically and store the results to *evaluation/L2CS-mpiigaze*.
|
| 105 |
+
|
| 106 |
+
To get the average leave-one-person-out accuracy use:
|
| 107 |
+
```
|
| 108 |
+
python leave_one_out_eval.py \
|
| 109 |
+
--evalpath evaluation/L2CS-mpiigaze \
|
| 110 |
+
--respath evaluation/L2CS-mpiigaze \
|
| 111 |
+
```
|
| 112 |
+
This means the code will take the evaluation path and outputs the leave-one-out gaze accuracy to the *evaluation/L2CS-mpiigaze*.
|
| 113 |
+
|
| 114 |
+
## Gaze360
|
| 115 |
+
We provide the code for train and test Gaze360 dataset with train-val-test evaluation.
|
| 116 |
+
|
| 117 |
+
### Prepare datasets
|
| 118 |
+
* Download **Gaze360 dataset** from [here](http://gaze360.csail.mit.edu/download.php).
|
| 119 |
+
|
| 120 |
+
* Apply data preprocessing from [here](http://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/).
|
| 121 |
+
|
| 122 |
+
* Store the dataset to *datasets/Gaze360*.
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
### Train
|
| 126 |
+
```
|
| 127 |
+
python train.py \
|
| 128 |
+
--dataset gaze360 \
|
| 129 |
+
--snapshot output/snapshots \
|
| 130 |
+
--gpu 0 \
|
| 131 |
+
--num_epochs 50 \
|
| 132 |
+
--batch_size 16 \
|
| 133 |
+
--lr 0.00001 \
|
| 134 |
+
--alpha 1 \
|
| 135 |
+
|
| 136 |
+
```
|
| 137 |
+
This means the code will perform training and store the models to *output/snapshots*.
|
| 138 |
+
|
| 139 |
+
### Test
|
| 140 |
+
```
|
| 141 |
+
python test.py \
|
| 142 |
+
--dataset gaze360 \
|
| 143 |
+
--snapshot output/snapshots/snapshot_folder \
|
| 144 |
+
--evalpath evaluation/L2CS-gaze360 \
|
| 145 |
+
--gpu 0 \
|
| 146 |
+
```
|
| 147 |
+
This means the code will perform testing on snapshot_folder and store the results to *evaluation/L2CS-gaze360*.
|
| 148 |
+
|
models/L2CS-Net/demo.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import pathlib
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.autograd import Variable
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
import torch.backends.cudnn as cudnn
|
| 12 |
+
import torchvision
|
| 13 |
+
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from PIL import Image, ImageOps
|
| 16 |
+
|
| 17 |
+
from face_detection import RetinaFace
|
| 18 |
+
|
| 19 |
+
from l2cs import select_device, draw_gaze, getArch, Pipeline, render
|
| 20 |
+
|
| 21 |
+
CWD = pathlib.Path.cwd()
|
| 22 |
+
|
| 23 |
+
def parse_args():
|
| 24 |
+
"""Parse input arguments."""
|
| 25 |
+
parser = argparse.ArgumentParser(
|
| 26 |
+
description='Gaze evalution using model pretrained with L2CS-Net on Gaze360.')
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
'--device',dest='device', help='Device to run model: cpu or gpu:0',
|
| 29 |
+
default="cpu", type=str)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
'--snapshot',dest='snapshot', help='Path of model snapshot.',
|
| 32 |
+
default='output/snapshots/L2CS-gaze360-_loader-180-4/_epoch_55.pkl', type=str)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
'--cam',dest='cam_id', help='Camera device id to use [0]',
|
| 35 |
+
default=0, type=int)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
'--arch',dest='arch',help='Network architecture, can be: ResNet18, ResNet34, ResNet50, ResNet101, ResNet152',
|
| 38 |
+
default='ResNet50', type=str)
|
| 39 |
+
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
return args
|
| 42 |
+
|
| 43 |
+
if __name__ == '__main__':
|
| 44 |
+
args = parse_args()
|
| 45 |
+
|
| 46 |
+
cudnn.enabled = True
|
| 47 |
+
arch=args.arch
|
| 48 |
+
cam = args.cam_id
|
| 49 |
+
# snapshot_path = args.snapshot
|
| 50 |
+
|
| 51 |
+
gaze_pipeline = Pipeline(
|
| 52 |
+
weights=CWD / 'models' / 'L2CSNet_gaze360.pkl',
|
| 53 |
+
arch='ResNet50',
|
| 54 |
+
device = select_device(args.device, batch_size=1)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
cap = cv2.VideoCapture(cam)
|
| 58 |
+
|
| 59 |
+
# Check if the webcam is opened correctly
|
| 60 |
+
if not cap.isOpened():
|
| 61 |
+
raise IOError("Cannot open webcam")
|
| 62 |
+
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
while True:
|
| 65 |
+
|
| 66 |
+
# Get frame
|
| 67 |
+
success, frame = cap.read()
|
| 68 |
+
start_fps = time.time()
|
| 69 |
+
|
| 70 |
+
if not success:
|
| 71 |
+
print("Failed to obtain frame")
|
| 72 |
+
time.sleep(0.1)
|
| 73 |
+
|
| 74 |
+
# Process frame
|
| 75 |
+
results = gaze_pipeline.step(frame)
|
| 76 |
+
|
| 77 |
+
# Visualize output
|
| 78 |
+
frame = render(frame, results)
|
| 79 |
+
|
| 80 |
+
myFPS = 1.0 / (time.time() - start_fps)
|
| 81 |
+
cv2.putText(frame, 'FPS: {:.1f}'.format(myFPS), (10, 20),cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (0, 255, 0), 1, cv2.LINE_AA)
|
| 82 |
+
|
| 83 |
+
cv2.imshow("Demo",frame)
|
| 84 |
+
if cv2.waitKey(1) & 0xFF == ord('q'):
|
| 85 |
+
break
|
| 86 |
+
success,frame = cap.read()
|
| 87 |
+
|
models/L2CS-Net/l2cs/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .utils import select_device, natural_keys, gazeto3d, angular, getArch
|
| 2 |
+
from .vis import draw_gaze, render
|
| 3 |
+
from .model import L2CS
|
| 4 |
+
from .pipeline import Pipeline
|
| 5 |
+
from .datasets import Gaze360, Mpiigaze
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
# Classes
|
| 9 |
+
'L2CS',
|
| 10 |
+
'Pipeline',
|
| 11 |
+
'Gaze360',
|
| 12 |
+
'Mpiigaze',
|
| 13 |
+
# Utils
|
| 14 |
+
'render',
|
| 15 |
+
'select_device',
|
| 16 |
+
'draw_gaze',
|
| 17 |
+
'natural_keys',
|
| 18 |
+
'gazeto3d',
|
| 19 |
+
'angular',
|
| 20 |
+
'getArch'
|
| 21 |
+
]
|
models/L2CS-Net/l2cs/datasets.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data.dataset import Dataset
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
from PIL import Image, ImageFilter
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Gaze360(Dataset):
|
| 13 |
+
def __init__(self, path, root, transform, angle, binwidth, train=True):
|
| 14 |
+
self.transform = transform
|
| 15 |
+
self.root = root
|
| 16 |
+
self.orig_list_len = 0
|
| 17 |
+
self.angle = angle
|
| 18 |
+
if train==False:
|
| 19 |
+
angle=90
|
| 20 |
+
self.binwidth=binwidth
|
| 21 |
+
self.lines = []
|
| 22 |
+
if isinstance(path, list):
|
| 23 |
+
for i in path:
|
| 24 |
+
with open(i) as f:
|
| 25 |
+
print("here")
|
| 26 |
+
line = f.readlines()
|
| 27 |
+
line.pop(0)
|
| 28 |
+
self.lines.extend(line)
|
| 29 |
+
else:
|
| 30 |
+
with open(path) as f:
|
| 31 |
+
lines = f.readlines()
|
| 32 |
+
lines.pop(0)
|
| 33 |
+
self.orig_list_len = len(lines)
|
| 34 |
+
for line in lines:
|
| 35 |
+
gaze2d = line.strip().split(" ")[5]
|
| 36 |
+
label = np.array(gaze2d.split(",")).astype("float")
|
| 37 |
+
if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle:
|
| 38 |
+
self.lines.append(line)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines), angle))
|
| 42 |
+
|
| 43 |
+
def __len__(self):
|
| 44 |
+
return len(self.lines)
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, idx):
|
| 47 |
+
line = self.lines[idx]
|
| 48 |
+
line = line.strip().split(" ")
|
| 49 |
+
|
| 50 |
+
face = line[0]
|
| 51 |
+
lefteye = line[1]
|
| 52 |
+
righteye = line[2]
|
| 53 |
+
name = line[3]
|
| 54 |
+
gaze2d = line[5]
|
| 55 |
+
label = np.array(gaze2d.split(",")).astype("float")
|
| 56 |
+
label = torch.from_numpy(label).type(torch.FloatTensor)
|
| 57 |
+
|
| 58 |
+
pitch = label[0]* 180 / np.pi
|
| 59 |
+
yaw = label[1]* 180 / np.pi
|
| 60 |
+
|
| 61 |
+
img = Image.open(os.path.join(self.root, face))
|
| 62 |
+
|
| 63 |
+
# fimg = cv2.imread(os.path.join(self.root, face))
|
| 64 |
+
# fimg = cv2.resize(fimg, (448, 448))/255.0
|
| 65 |
+
# fimg = fimg.transpose(2, 0, 1)
|
| 66 |
+
# img=torch.from_numpy(fimg).type(torch.FloatTensor)
|
| 67 |
+
|
| 68 |
+
if self.transform:
|
| 69 |
+
img = self.transform(img)
|
| 70 |
+
|
| 71 |
+
# Bin values
|
| 72 |
+
bins = np.array(range(-1*self.angle, self.angle, self.binwidth))
|
| 73 |
+
binned_pose = np.digitize([pitch, yaw], bins) - 1
|
| 74 |
+
|
| 75 |
+
labels = binned_pose
|
| 76 |
+
cont_labels = torch.FloatTensor([pitch, yaw])
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
return img, labels, cont_labels, name
|
| 80 |
+
|
| 81 |
+
class Mpiigaze(Dataset):
|
| 82 |
+
def __init__(self, pathorg, root, transform, train, angle,fold=0):
|
| 83 |
+
self.transform = transform
|
| 84 |
+
self.root = root
|
| 85 |
+
self.orig_list_len = 0
|
| 86 |
+
self.lines = []
|
| 87 |
+
path=pathorg.copy()
|
| 88 |
+
if train==True:
|
| 89 |
+
path.pop(fold)
|
| 90 |
+
else:
|
| 91 |
+
path=path[fold]
|
| 92 |
+
if isinstance(path, list):
|
| 93 |
+
for i in path:
|
| 94 |
+
with open(i) as f:
|
| 95 |
+
lines = f.readlines()
|
| 96 |
+
lines.pop(0)
|
| 97 |
+
self.orig_list_len += len(lines)
|
| 98 |
+
for line in lines:
|
| 99 |
+
gaze2d = line.strip().split(" ")[7]
|
| 100 |
+
label = np.array(gaze2d.split(",")).astype("float")
|
| 101 |
+
if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle:
|
| 102 |
+
self.lines.append(line)
|
| 103 |
+
else:
|
| 104 |
+
with open(path) as f:
|
| 105 |
+
lines = f.readlines()
|
| 106 |
+
lines.pop(0)
|
| 107 |
+
self.orig_list_len += len(lines)
|
| 108 |
+
for line in lines:
|
| 109 |
+
gaze2d = line.strip().split(" ")[7]
|
| 110 |
+
label = np.array(gaze2d.split(",")).astype("float")
|
| 111 |
+
if abs((label[0]*180/np.pi)) <= 42 and abs((label[1]*180/np.pi)) <= 42:
|
| 112 |
+
self.lines.append(line)
|
| 113 |
+
|
| 114 |
+
print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines),angle))
|
| 115 |
+
|
| 116 |
+
def __len__(self):
|
| 117 |
+
return len(self.lines)
|
| 118 |
+
|
| 119 |
+
def __getitem__(self, idx):
|
| 120 |
+
line = self.lines[idx]
|
| 121 |
+
line = line.strip().split(" ")
|
| 122 |
+
|
| 123 |
+
name = line[3]
|
| 124 |
+
gaze2d = line[7]
|
| 125 |
+
head2d = line[8]
|
| 126 |
+
lefteye = line[1]
|
| 127 |
+
righteye = line[2]
|
| 128 |
+
face = line[0]
|
| 129 |
+
|
| 130 |
+
label = np.array(gaze2d.split(",")).astype("float")
|
| 131 |
+
label = torch.from_numpy(label).type(torch.FloatTensor)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
pitch = label[0]* 180 / np.pi
|
| 135 |
+
yaw = label[1]* 180 / np.pi
|
| 136 |
+
|
| 137 |
+
img = Image.open(os.path.join(self.root, face))
|
| 138 |
+
|
| 139 |
+
# fimg = cv2.imread(os.path.join(self.root, face))
|
| 140 |
+
# fimg = cv2.resize(fimg, (448, 448))/255.0
|
| 141 |
+
# fimg = fimg.transpose(2, 0, 1)
|
| 142 |
+
# img=torch.from_numpy(fimg).type(torch.FloatTensor)
|
| 143 |
+
|
| 144 |
+
if self.transform:
|
| 145 |
+
img = self.transform(img)
|
| 146 |
+
|
| 147 |
+
# Bin values
|
| 148 |
+
bins = np.array(range(-42, 42,3))
|
| 149 |
+
binned_pose = np.digitize([pitch, yaw], bins) - 1
|
| 150 |
+
|
| 151 |
+
labels = binned_pose
|
| 152 |
+
cont_labels = torch.FloatTensor([pitch, yaw])
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
return img, labels, cont_labels, name
|
| 156 |
+
|
| 157 |
+
|
models/L2CS-Net/l2cs/model.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.autograd import Variable
|
| 4 |
+
import math
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class L2CS(nn.Module):
|
| 9 |
+
def __init__(self, block, layers, num_bins):
|
| 10 |
+
self.inplanes = 64
|
| 11 |
+
super(L2CS, self).__init__()
|
| 12 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
|
| 13 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 14 |
+
self.relu = nn.ReLU(inplace=True)
|
| 15 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 16 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 17 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 18 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 19 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 20 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 21 |
+
|
| 22 |
+
self.fc_yaw_gaze = nn.Linear(512 * block.expansion, num_bins)
|
| 23 |
+
self.fc_pitch_gaze = nn.Linear(512 * block.expansion, num_bins)
|
| 24 |
+
|
| 25 |
+
# Vestigial layer from previous experiments
|
| 26 |
+
self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)
|
| 27 |
+
|
| 28 |
+
for m in self.modules():
|
| 29 |
+
if isinstance(m, nn.Conv2d):
|
| 30 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 31 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 32 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 33 |
+
m.weight.data.fill_(1)
|
| 34 |
+
m.bias.data.zero_()
|
| 35 |
+
|
| 36 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 37 |
+
downsample = None
|
| 38 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 39 |
+
downsample = nn.Sequential(
|
| 40 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 41 |
+
kernel_size=1, stride=stride, bias=False),
|
| 42 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
layers = []
|
| 46 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 47 |
+
self.inplanes = planes * block.expansion
|
| 48 |
+
for i in range(1, blocks):
|
| 49 |
+
layers.append(block(self.inplanes, planes))
|
| 50 |
+
|
| 51 |
+
return nn.Sequential(*layers)
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
x = self.conv1(x)
|
| 55 |
+
x = self.bn1(x)
|
| 56 |
+
x = self.relu(x)
|
| 57 |
+
x = self.maxpool(x)
|
| 58 |
+
|
| 59 |
+
x = self.layer1(x)
|
| 60 |
+
x = self.layer2(x)
|
| 61 |
+
x = self.layer3(x)
|
| 62 |
+
x = self.layer4(x)
|
| 63 |
+
x = self.avgpool(x)
|
| 64 |
+
x = x.view(x.size(0), -1)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# gaze
|
| 68 |
+
pre_yaw_gaze = self.fc_yaw_gaze(x)
|
| 69 |
+
pre_pitch_gaze = self.fc_pitch_gaze(x)
|
| 70 |
+
return pre_yaw_gaze, pre_pitch_gaze
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
models/L2CS-Net/l2cs/pipeline.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pathlib
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from face_detection import RetinaFace
|
| 10 |
+
|
| 11 |
+
from .utils import prep_input_numpy, getArch
|
| 12 |
+
from .results import GazeResultContainer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Pipeline:
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
weights: pathlib.Path,
|
| 20 |
+
arch: str,
|
| 21 |
+
device: str = 'cpu',
|
| 22 |
+
include_detector:bool = True,
|
| 23 |
+
confidence_threshold:float = 0.5
|
| 24 |
+
):
|
| 25 |
+
|
| 26 |
+
# Save input parameters
|
| 27 |
+
self.weights = weights
|
| 28 |
+
self.include_detector = include_detector
|
| 29 |
+
self.device = device
|
| 30 |
+
self.confidence_threshold = confidence_threshold
|
| 31 |
+
|
| 32 |
+
# Create L2CS model
|
| 33 |
+
self.model = getArch(arch, 90)
|
| 34 |
+
self.model.load_state_dict(torch.load(self.weights, map_location=device))
|
| 35 |
+
self.model.to(self.device)
|
| 36 |
+
self.model.eval()
|
| 37 |
+
|
| 38 |
+
# Create RetinaFace if requested
|
| 39 |
+
if self.include_detector:
|
| 40 |
+
|
| 41 |
+
if device.type == 'cpu':
|
| 42 |
+
self.detector = RetinaFace()
|
| 43 |
+
else:
|
| 44 |
+
self.detector = RetinaFace(gpu_id=device.index)
|
| 45 |
+
|
| 46 |
+
self.softmax = nn.Softmax(dim=1)
|
| 47 |
+
self.idx_tensor = [idx for idx in range(90)]
|
| 48 |
+
self.idx_tensor = torch.FloatTensor(self.idx_tensor).to(self.device)
|
| 49 |
+
|
| 50 |
+
def step(self, frame: np.ndarray) -> GazeResultContainer:
|
| 51 |
+
|
| 52 |
+
# Creating containers
|
| 53 |
+
face_imgs = []
|
| 54 |
+
bboxes = []
|
| 55 |
+
landmarks = []
|
| 56 |
+
scores = []
|
| 57 |
+
|
| 58 |
+
if self.include_detector:
|
| 59 |
+
faces = self.detector(frame)
|
| 60 |
+
|
| 61 |
+
if faces is not None:
|
| 62 |
+
for box, landmark, score in faces:
|
| 63 |
+
|
| 64 |
+
# Apply threshold
|
| 65 |
+
if score < self.confidence_threshold:
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
# Extract safe min and max of x,y
|
| 69 |
+
x_min=int(box[0])
|
| 70 |
+
if x_min < 0:
|
| 71 |
+
x_min = 0
|
| 72 |
+
y_min=int(box[1])
|
| 73 |
+
if y_min < 0:
|
| 74 |
+
y_min = 0
|
| 75 |
+
x_max=int(box[2])
|
| 76 |
+
y_max=int(box[3])
|
| 77 |
+
|
| 78 |
+
# Crop image
|
| 79 |
+
img = frame[y_min:y_max, x_min:x_max]
|
| 80 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 81 |
+
img = cv2.resize(img, (224, 224))
|
| 82 |
+
face_imgs.append(img)
|
| 83 |
+
|
| 84 |
+
# Save data
|
| 85 |
+
bboxes.append(box)
|
| 86 |
+
landmarks.append(landmark)
|
| 87 |
+
scores.append(score)
|
| 88 |
+
|
| 89 |
+
# Predict gaze
|
| 90 |
+
pitch, yaw = self.predict_gaze(np.stack(face_imgs))
|
| 91 |
+
|
| 92 |
+
else:
|
| 93 |
+
|
| 94 |
+
pitch = np.empty((0,1))
|
| 95 |
+
yaw = np.empty((0,1))
|
| 96 |
+
|
| 97 |
+
else:
|
| 98 |
+
pitch, yaw = self.predict_gaze(frame)
|
| 99 |
+
|
| 100 |
+
# Save data
|
| 101 |
+
results = GazeResultContainer(
|
| 102 |
+
pitch=pitch,
|
| 103 |
+
yaw=yaw,
|
| 104 |
+
bboxes=np.stack(bboxes),
|
| 105 |
+
landmarks=np.stack(landmarks),
|
| 106 |
+
scores=np.stack(scores)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return results
|
| 110 |
+
|
| 111 |
+
def predict_gaze(self, frame: Union[np.ndarray, torch.Tensor]):
|
| 112 |
+
|
| 113 |
+
# Prepare input
|
| 114 |
+
if isinstance(frame, np.ndarray):
|
| 115 |
+
img = prep_input_numpy(frame, self.device)
|
| 116 |
+
elif isinstance(frame, torch.Tensor):
|
| 117 |
+
img = frame
|
| 118 |
+
else:
|
| 119 |
+
raise RuntimeError("Invalid dtype for input")
|
| 120 |
+
|
| 121 |
+
# Predict
|
| 122 |
+
gaze_pitch, gaze_yaw = self.model(img)
|
| 123 |
+
pitch_predicted = self.softmax(gaze_pitch)
|
| 124 |
+
yaw_predicted = self.softmax(gaze_yaw)
|
| 125 |
+
|
| 126 |
+
# Get continuous predictions in degrees.
|
| 127 |
+
pitch_predicted = torch.sum(pitch_predicted.data * self.idx_tensor, dim=1) * 4 - 180
|
| 128 |
+
yaw_predicted = torch.sum(yaw_predicted.data * self.idx_tensor, dim=1) * 4 - 180
|
| 129 |
+
|
| 130 |
+
pitch_predicted= pitch_predicted.cpu().detach().numpy()* np.pi/180.0
|
| 131 |
+
yaw_predicted= yaw_predicted.cpu().detach().numpy()* np.pi/180.0
|
| 132 |
+
|
| 133 |
+
return pitch_predicted, yaw_predicted
|
models/L2CS-Net/l2cs/results.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class GazeResultContainer:
|
| 6 |
+
|
| 7 |
+
pitch: np.ndarray
|
| 8 |
+
yaw: np.ndarray
|
| 9 |
+
bboxes: np.ndarray
|
| 10 |
+
landmarks: np.ndarray
|
| 11 |
+
scores: np.ndarray
|
models/L2CS-Net/l2cs/utils.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
from math import cos, sin
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import subprocess
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import scipy.io as sio
|
| 13 |
+
import cv2
|
| 14 |
+
import torchvision
|
| 15 |
+
from torchvision import transforms
|
| 16 |
+
|
| 17 |
+
from .model import L2CS
|
| 18 |
+
|
| 19 |
+
transformations = transforms.Compose([
|
| 20 |
+
transforms.ToPILImage(),
|
| 21 |
+
transforms.Resize(448),
|
| 22 |
+
transforms.ToTensor(),
|
| 23 |
+
transforms.Normalize(
|
| 24 |
+
mean=[0.485, 0.456, 0.406],
|
| 25 |
+
std=[0.229, 0.224, 0.225]
|
| 26 |
+
)
|
| 27 |
+
])
|
| 28 |
+
|
| 29 |
+
def atoi(text):
|
| 30 |
+
return int(text) if text.isdigit() else text
|
| 31 |
+
|
| 32 |
+
def natural_keys(text):
|
| 33 |
+
'''
|
| 34 |
+
alist.sort(key=natural_keys) sorts in human order
|
| 35 |
+
http://nedbatchelder.com/blog/200712/human_sorting.html
|
| 36 |
+
(See Toothy's implementation in the comments)
|
| 37 |
+
'''
|
| 38 |
+
return [ atoi(c) for c in re.split(r'(\d+)', text) ]
|
| 39 |
+
|
| 40 |
+
def prep_input_numpy(img:np.ndarray, device:str):
|
| 41 |
+
"""Preparing a Numpy Array as input to L2CS-Net."""
|
| 42 |
+
|
| 43 |
+
if len(img.shape) == 4:
|
| 44 |
+
imgs = []
|
| 45 |
+
for im in img:
|
| 46 |
+
imgs.append(transformations(im))
|
| 47 |
+
img = torch.stack(imgs)
|
| 48 |
+
else:
|
| 49 |
+
img = transformations(img)
|
| 50 |
+
|
| 51 |
+
img = img.to(device)
|
| 52 |
+
|
| 53 |
+
if len(img.shape) == 3:
|
| 54 |
+
img = img.unsqueeze(0)
|
| 55 |
+
|
| 56 |
+
return img
|
| 57 |
+
|
| 58 |
+
def gazeto3d(gaze):
|
| 59 |
+
gaze_gt = np.zeros([3])
|
| 60 |
+
gaze_gt[0] = -np.cos(gaze[1]) * np.sin(gaze[0])
|
| 61 |
+
gaze_gt[1] = -np.sin(gaze[1])
|
| 62 |
+
gaze_gt[2] = -np.cos(gaze[1]) * np.cos(gaze[0])
|
| 63 |
+
return gaze_gt
|
| 64 |
+
|
| 65 |
+
def angular(gaze, label):
|
| 66 |
+
total = np.sum(gaze * label)
|
| 67 |
+
return np.arccos(min(total/(np.linalg.norm(gaze)* np.linalg.norm(label)), 0.9999999))*180/np.pi
|
| 68 |
+
|
| 69 |
+
def select_device(device='', batch_size=None):
|
| 70 |
+
# device = 'cpu' or '0' or '0,1,2,3'
|
| 71 |
+
s = f'YOLOv3 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
|
| 72 |
+
cpu = device.lower() == 'cpu'
|
| 73 |
+
if cpu:
|
| 74 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
| 75 |
+
elif device: # non-cpu device requested
|
| 76 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
|
| 77 |
+
# assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
|
| 78 |
+
|
| 79 |
+
cuda = not cpu and torch.cuda.is_available()
|
| 80 |
+
if cuda:
|
| 81 |
+
devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
| 82 |
+
n = len(devices) # device count
|
| 83 |
+
if n > 1 and batch_size: # check batch_size is divisible by device_count
|
| 84 |
+
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
| 85 |
+
space = ' ' * len(s)
|
| 86 |
+
for i, d in enumerate(devices):
|
| 87 |
+
p = torch.cuda.get_device_properties(i)
|
| 88 |
+
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
|
| 89 |
+
else:
|
| 90 |
+
s += 'CPU\n'
|
| 91 |
+
|
| 92 |
+
return torch.device('cuda:0' if cuda else 'cpu')
|
| 93 |
+
|
| 94 |
+
def spherical2cartesial(x):
|
| 95 |
+
|
| 96 |
+
output = torch.zeros(x.size(0),3)
|
| 97 |
+
output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0])
|
| 98 |
+
output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0])
|
| 99 |
+
output[:,1] = torch.sin(x[:,1])
|
| 100 |
+
|
| 101 |
+
return output
|
| 102 |
+
|
| 103 |
+
def compute_angular_error(input,target):
|
| 104 |
+
|
| 105 |
+
input = spherical2cartesial(input)
|
| 106 |
+
target = spherical2cartesial(target)
|
| 107 |
+
|
| 108 |
+
input = input.view(-1,3,1)
|
| 109 |
+
target = target.view(-1,1,3)
|
| 110 |
+
output_dot = torch.bmm(target,input)
|
| 111 |
+
output_dot = output_dot.view(-1)
|
| 112 |
+
output_dot = torch.acos(output_dot)
|
| 113 |
+
output_dot = output_dot.data
|
| 114 |
+
output_dot = 180*torch.mean(output_dot)/math.pi
|
| 115 |
+
return output_dot
|
| 116 |
+
|
| 117 |
+
def softmax_temperature(tensor, temperature):
|
| 118 |
+
result = torch.exp(tensor / temperature)
|
| 119 |
+
result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result))
|
| 120 |
+
return result
|
| 121 |
+
|
| 122 |
+
def git_describe(path=Path(__file__).parent): # path must be a directory
|
| 123 |
+
# return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
| 124 |
+
s = f'git -C {path} describe --tags --long --always'
|
| 125 |
+
try:
|
| 126 |
+
return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
|
| 127 |
+
except subprocess.CalledProcessError as e:
|
| 128 |
+
return '' # not a git repository
|
| 129 |
+
|
| 130 |
+
def getArch(arch,bins):
|
| 131 |
+
# Base network structure
|
| 132 |
+
if arch == 'ResNet18':
|
| 133 |
+
model = L2CS( torchvision.models.resnet.BasicBlock,[2, 2, 2, 2], bins)
|
| 134 |
+
elif arch == 'ResNet34':
|
| 135 |
+
model = L2CS( torchvision.models.resnet.BasicBlock,[3, 4, 6, 3], bins)
|
| 136 |
+
elif arch == 'ResNet101':
|
| 137 |
+
model = L2CS( torchvision.models.resnet.Bottleneck,[3, 4, 23, 3], bins)
|
| 138 |
+
elif arch == 'ResNet152':
|
| 139 |
+
model = L2CS( torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
|
| 140 |
+
else:
|
| 141 |
+
if arch != 'ResNet50':
|
| 142 |
+
print('Invalid value for architecture is passed! '
|
| 143 |
+
'The default value of ResNet50 will be used instead!')
|
| 144 |
+
model = L2CS( torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
|
| 145 |
+
return model
|
models/L2CS-Net/l2cs/vis.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from .results import GazeResultContainer
|
| 4 |
+
|
| 5 |
+
def draw_gaze(a,b,c,d,image_in, pitchyaw, thickness=2, color=(255, 255, 0),sclae=2.0):
|
| 6 |
+
"""Draw gaze angle on given image with a given eye positions."""
|
| 7 |
+
image_out = image_in
|
| 8 |
+
(h, w) = image_in.shape[:2]
|
| 9 |
+
length = c
|
| 10 |
+
pos = (int(a+c / 2.0), int(b+d / 2.0))
|
| 11 |
+
if len(image_out.shape) == 2 or image_out.shape[2] == 1:
|
| 12 |
+
image_out = cv2.cvtColor(image_out, cv2.COLOR_GRAY2BGR)
|
| 13 |
+
dx = -length * np.sin(pitchyaw[0]) * np.cos(pitchyaw[1])
|
| 14 |
+
dy = -length * np.sin(pitchyaw[1])
|
| 15 |
+
cv2.arrowedLine(image_out, tuple(np.round(pos).astype(np.int32)),
|
| 16 |
+
tuple(np.round([pos[0] + dx, pos[1] + dy]).astype(int)), color,
|
| 17 |
+
thickness, cv2.LINE_AA, tipLength=0.18)
|
| 18 |
+
return image_out
|
| 19 |
+
|
| 20 |
+
def draw_bbox(frame: np.ndarray, bbox: np.ndarray):
|
| 21 |
+
|
| 22 |
+
x_min=int(bbox[0])
|
| 23 |
+
if x_min < 0:
|
| 24 |
+
x_min = 0
|
| 25 |
+
y_min=int(bbox[1])
|
| 26 |
+
if y_min < 0:
|
| 27 |
+
y_min = 0
|
| 28 |
+
x_max=int(bbox[2])
|
| 29 |
+
y_max=int(bbox[3])
|
| 30 |
+
|
| 31 |
+
cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0,255,0), 1)
|
| 32 |
+
|
| 33 |
+
return frame
|
| 34 |
+
|
| 35 |
+
def render(frame: np.ndarray, results: GazeResultContainer):
|
| 36 |
+
|
| 37 |
+
# Draw bounding boxes
|
| 38 |
+
for bbox in results.bboxes:
|
| 39 |
+
frame = draw_bbox(frame, bbox)
|
| 40 |
+
|
| 41 |
+
# Draw Gaze
|
| 42 |
+
for i in range(results.pitch.shape[0]):
|
| 43 |
+
|
| 44 |
+
bbox = results.bboxes[i]
|
| 45 |
+
pitch = results.pitch[i]
|
| 46 |
+
yaw = results.yaw[i]
|
| 47 |
+
|
| 48 |
+
# Extract safe min and max of x,y
|
| 49 |
+
x_min=int(bbox[0])
|
| 50 |
+
if x_min < 0:
|
| 51 |
+
x_min = 0
|
| 52 |
+
y_min=int(bbox[1])
|
| 53 |
+
if y_min < 0:
|
| 54 |
+
y_min = 0
|
| 55 |
+
x_max=int(bbox[2])
|
| 56 |
+
y_max=int(bbox[3])
|
| 57 |
+
|
| 58 |
+
# Compute sizes
|
| 59 |
+
bbox_width = x_max - x_min
|
| 60 |
+
bbox_height = y_max - y_min
|
| 61 |
+
|
| 62 |
+
draw_gaze(x_min,y_min,bbox_width, bbox_height,frame,(pitch,yaw),color=(0,0,255))
|
| 63 |
+
|
| 64 |
+
return frame
|
models/L2CS-Net/leave_one_out_eval.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def parse_args():
|
| 7 |
+
"""Parse input arguments."""
|
| 8 |
+
parser = argparse.ArgumentParser(
|
| 9 |
+
description='gaze estimation using binned loss function.')
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
'--evalpath', dest='evalpath', help='path for evaluating gaze test.',
|
| 12 |
+
default="evaluation\L2CS-gaze360-_standard-10", type=str)
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
'--respath', dest='respath', help='path for saving result.',
|
| 15 |
+
default="evaluation\L2CS-gaze360-_standard-10", type=str)
|
| 16 |
+
|
| 17 |
+
if __name__ == '__main__':
|
| 18 |
+
|
| 19 |
+
args = parse_args()
|
| 20 |
+
evalpath =args.evalpath
|
| 21 |
+
respath=args.respath
|
| 22 |
+
if not os.path.exist(respath):
|
| 23 |
+
os.makedirs(respath)
|
| 24 |
+
with open(os.path.join(respath,"avg.log"), 'w') as outfile:
|
| 25 |
+
outfile.write("Average equal\n")
|
| 26 |
+
|
| 27 |
+
min=10.0
|
| 28 |
+
dirlist = os.listdir(evalpath)
|
| 29 |
+
dirlist.sort()
|
| 30 |
+
l=0.0
|
| 31 |
+
for j in range(50):
|
| 32 |
+
j=20
|
| 33 |
+
avg=0.0
|
| 34 |
+
h=j+3
|
| 35 |
+
for i in dirlist:
|
| 36 |
+
with open(evalpath+"/"+i+"/mpiigaze_binned.log") as myfile:
|
| 37 |
+
|
| 38 |
+
x=list(myfile)[h]
|
| 39 |
+
str1 = ""
|
| 40 |
+
|
| 41 |
+
# traverse in the string
|
| 42 |
+
for ele in x:
|
| 43 |
+
str1 += ele
|
| 44 |
+
split_string = str1.split("MAE:",1)[1]
|
| 45 |
+
avg+=float(split_string)
|
| 46 |
+
|
| 47 |
+
avg=avg/15.0
|
| 48 |
+
if avg<min:
|
| 49 |
+
min=avg
|
| 50 |
+
l=j+1
|
| 51 |
+
outfile.write("epoch"+str(j+1)+"= "+str(avg)+"\n")
|
| 52 |
+
|
| 53 |
+
outfile.write("min angular error equal= "+str(min)+"at epoch= "+str(l)+"\n")
|
| 54 |
+
print(min)
|
models/L2CS-Net/models/L2CSNet_gaze360.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a7f3480d868dd48261e1d59f915b0ef0bb33ea12ea00938fb2168f212080665
|
| 3 |
+
size 95849977
|
models/L2CS-Net/models/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Path to pre-trained models
|
models/L2CS-Net/pyproject.toml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "l2cs"
|
| 3 |
+
version = "0.0.1"
|
| 4 |
+
description = "The official PyTorch implementation of L2CS-Net for gaze estimation and tracking"
|
| 5 |
+
authors = [
|
| 6 |
+
{name = "Ahmed Abderlrahman"},
|
| 7 |
+
{name = "Thorsten Hempel"}
|
| 8 |
+
]
|
| 9 |
+
license = {file = "LICENSE.txt"}
|
| 10 |
+
readme = "README.md"
|
| 11 |
+
requires-python = ">3.6"
|
| 12 |
+
|
| 13 |
+
keywords = ["gaze", "estimation", "eye-tracking", "deep-learning", "pytorch"]
|
| 14 |
+
|
| 15 |
+
classifiers = [
|
| 16 |
+
"Programming Language :: Python :: 3"
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
dependencies = [
|
| 20 |
+
'matplotlib>=3.3.4',
|
| 21 |
+
'numpy>=1.19.5',
|
| 22 |
+
'opencv-python>=4.5.5',
|
| 23 |
+
'pandas>=1.1.5',
|
| 24 |
+
'Pillow>=8.4.0',
|
| 25 |
+
'scipy>=1.5.4',
|
| 26 |
+
'torch>=1.10.1',
|
| 27 |
+
'torchvision>=0.11.2',
|
| 28 |
+
'face_detection@git+https://github.com/elliottzheng/face-detection'
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.urls]
|
| 32 |
+
homepath = "https://github.com/Ahmednull/L2CS-Net"
|
| 33 |
+
repository = "https://github.com/Ahmednull/L2CS-Net"
|
| 34 |
+
|
| 35 |
+
[build-system]
|
| 36 |
+
requires = ["setuptools", "wheel"]
|
| 37 |
+
build-backend = "setuptools.build_meta"
|
| 38 |
+
|
| 39 |
+
# https://setuptools.pypa.io/en/stable/userguide/datafiles.html
|
| 40 |
+
[tool.setuptools]
|
| 41 |
+
include-package-data = true
|
| 42 |
+
|
| 43 |
+
[tool.setuptools.packages.find]
|
| 44 |
+
where = ["."]
|
models/L2CS-Net/test.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, argparse
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.autograd import Variable
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
import torch.backends.cudnn as cudnn
|
| 10 |
+
import torchvision
|
| 11 |
+
|
| 12 |
+
from l2cs import select_device, natural_keys, gazeto3d, angular, getArch, L2CS, Gaze360, Mpiigaze
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_args():
|
| 16 |
+
"""Parse input arguments."""
|
| 17 |
+
parser = argparse.ArgumentParser(
|
| 18 |
+
description='Gaze estimation using L2CSNet .')
|
| 19 |
+
# Gaze360
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
'--gaze360image_dir', dest='gaze360image_dir', help='Directory path for gaze images.',
|
| 22 |
+
default='datasets/Gaze360/Image', type=str)
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
'--gaze360label_dir', dest='gaze360label_dir', help='Directory path for gaze labels.',
|
| 25 |
+
default='datasets/Gaze360/Label/test.label', type=str)
|
| 26 |
+
# mpiigaze
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
'--gazeMpiimage_dir', dest='gazeMpiimage_dir', help='Directory path for gaze images.',
|
| 29 |
+
default='datasets/MPIIFaceGaze/Image', type=str)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
'--gazeMpiilabel_dir', dest='gazeMpiilabel_dir', help='Directory path for gaze labels.',
|
| 32 |
+
default='datasets/MPIIFaceGaze/Label', type=str)
|
| 33 |
+
# Important args -------------------------------------------------------------------------------------------------------
|
| 34 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
'--dataset', dest='dataset', help='gaze360, mpiigaze',
|
| 37 |
+
default= "gaze360", type=str)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
'--snapshot', dest='snapshot', help='Path to the folder contains models.',
|
| 40 |
+
default='output/snapshots/L2CS-gaze360-_loader-180-4-lr', type=str)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
'--evalpath', dest='evalpath', help='path for the output evaluating gaze test.',
|
| 43 |
+
default="evaluation/L2CS-gaze360-_loader-180-4-lr", type=str)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
'--gpu',dest='gpu_id', help='GPU device id to use [0]',
|
| 46 |
+
default="0", type=str)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
'--batch_size', dest='batch_size', help='Batch size.',
|
| 49 |
+
default=100, type=int)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
'--arch', dest='arch', help='Network architecture, can be: ResNet18, ResNet34, [ResNet50], ''ResNet101, ResNet152, Squeezenet_1_0, Squeezenet_1_1, MobileNetV2',
|
| 52 |
+
default='ResNet50', type=str)
|
| 53 |
+
# ---------------------------------------------------------------------------------------------------------------------
|
| 54 |
+
# Important args ------------------------------------------------------------------------------------------------------
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
return args
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def getArch(arch,bins):
|
| 60 |
+
# Base network structure
|
| 61 |
+
if arch == 'ResNet18':
|
| 62 |
+
model = L2CS( torchvision.models.resnet.BasicBlock,[2, 2, 2, 2], bins)
|
| 63 |
+
elif arch == 'ResNet34':
|
| 64 |
+
model = L2CS( torchvision.models.resnet.BasicBlock,[3, 4, 6, 3], bins)
|
| 65 |
+
elif arch == 'ResNet101':
|
| 66 |
+
model = L2CS( torchvision.models.resnet.Bottleneck,[3, 4, 23, 3], bins)
|
| 67 |
+
elif arch == 'ResNet152':
|
| 68 |
+
model = L2CS( torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
|
| 69 |
+
else:
|
| 70 |
+
if arch != 'ResNet50':
|
| 71 |
+
print('Invalid value for architecture is passed! '
|
| 72 |
+
'The default value of ResNet50 will be used instead!')
|
| 73 |
+
model = L2CS( torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
|
| 74 |
+
return model
|
| 75 |
+
|
| 76 |
+
if __name__ == '__main__':
|
| 77 |
+
args = parse_args()
|
| 78 |
+
cudnn.enabled = True
|
| 79 |
+
gpu = select_device(args.gpu_id, batch_size=args.batch_size)
|
| 80 |
+
batch_size=args.batch_size
|
| 81 |
+
arch=args.arch
|
| 82 |
+
data_set=args.dataset
|
| 83 |
+
evalpath =args.evalpath
|
| 84 |
+
snapshot_path = args.snapshot
|
| 85 |
+
bins=args.bins
|
| 86 |
+
angle=args.angle
|
| 87 |
+
bin_width=args.bin_width
|
| 88 |
+
|
| 89 |
+
transformations = transforms.Compose([
|
| 90 |
+
transforms.Resize(448),
|
| 91 |
+
transforms.ToTensor(),
|
| 92 |
+
transforms.Normalize(
|
| 93 |
+
mean=[0.485, 0.456, 0.406],
|
| 94 |
+
std=[0.229, 0.224, 0.225]
|
| 95 |
+
)
|
| 96 |
+
])
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if data_set=="gaze360":
|
| 101 |
+
|
| 102 |
+
gaze_dataset=Gaze360(args.gaze360label_dir,args.gaze360image_dir, transformations, 180, 4, train=False)
|
| 103 |
+
test_loader = torch.utils.data.DataLoader(
|
| 104 |
+
dataset=gaze_dataset,
|
| 105 |
+
batch_size=batch_size,
|
| 106 |
+
shuffle=False,
|
| 107 |
+
num_workers=4,
|
| 108 |
+
pin_memory=True)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if not os.path.exists(evalpath):
|
| 113 |
+
os.makedirs(evalpath)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# list all epochs for testing
|
| 117 |
+
folder = os.listdir(snapshot_path)
|
| 118 |
+
folder.sort(key=natural_keys)
|
| 119 |
+
softmax = nn.Softmax(dim=1)
|
| 120 |
+
with open(os.path.join(evalpath,data_set+".log"), 'w') as outfile:
|
| 121 |
+
configuration = f"\ntest configuration = gpu_id={gpu}, batch_size={batch_size}, model_arch={arch}\nStart testing dataset={data_set}----------------------------------------\n"
|
| 122 |
+
print(configuration)
|
| 123 |
+
outfile.write(configuration)
|
| 124 |
+
epoch_list=[]
|
| 125 |
+
avg_yaw=[]
|
| 126 |
+
avg_pitch=[]
|
| 127 |
+
avg_MAE=[]
|
| 128 |
+
for epochs in folder:
|
| 129 |
+
# Base network structure
|
| 130 |
+
model=getArch(arch, 90)
|
| 131 |
+
saved_state_dict = torch.load(os.path.join(snapshot_path, epochs))
|
| 132 |
+
model.load_state_dict(saved_state_dict)
|
| 133 |
+
model.cuda(gpu)
|
| 134 |
+
model.eval()
|
| 135 |
+
total = 0
|
| 136 |
+
idx_tensor = [idx for idx in range(90)]
|
| 137 |
+
idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
|
| 138 |
+
avg_error = .0
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
for j, (images, labels, cont_labels, name) in enumerate(test_loader):
|
| 143 |
+
images = Variable(images).cuda(gpu)
|
| 144 |
+
total += cont_labels.size(0)
|
| 145 |
+
|
| 146 |
+
label_pitch = cont_labels[:,0].float()*np.pi/180
|
| 147 |
+
label_yaw = cont_labels[:,1].float()*np.pi/180
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
gaze_pitch, gaze_yaw = model(images)
|
| 151 |
+
|
| 152 |
+
# Binned predictions
|
| 153 |
+
_, pitch_bpred = torch.max(gaze_pitch.data, 1)
|
| 154 |
+
_, yaw_bpred = torch.max(gaze_yaw.data, 1)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# Continuous predictions
|
| 158 |
+
pitch_predicted = softmax(gaze_pitch)
|
| 159 |
+
yaw_predicted = softmax(gaze_yaw)
|
| 160 |
+
|
| 161 |
+
# mapping from binned (0 to 28) to angels (-180 to 180)
|
| 162 |
+
pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu() * 4 - 180
|
| 163 |
+
yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu() * 4 - 180
|
| 164 |
+
|
| 165 |
+
pitch_predicted = pitch_predicted*np.pi/180
|
| 166 |
+
yaw_predicted = yaw_predicted*np.pi/180
|
| 167 |
+
|
| 168 |
+
for p,y,pl,yl in zip(pitch_predicted,yaw_predicted,label_pitch,label_yaw):
|
| 169 |
+
avg_error += angular(gazeto3d([p,y]), gazeto3d([pl,yl]))
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
x = ''.join(filter(lambda i: i.isdigit(), epochs))
|
| 174 |
+
epoch_list.append(x)
|
| 175 |
+
avg_MAE.append(avg_error/total)
|
| 176 |
+
loger = f"[{epochs}---{args.dataset}] Total Num:{total},MAE:{avg_error/total}\n"
|
| 177 |
+
outfile.write(loger)
|
| 178 |
+
print(loger)
|
| 179 |
+
|
| 180 |
+
fig = plt.figure(figsize=(14, 8))
|
| 181 |
+
plt.xlabel('epoch')
|
| 182 |
+
plt.ylabel('avg')
|
| 183 |
+
plt.title('Gaze angular error')
|
| 184 |
+
plt.legend()
|
| 185 |
+
plt.plot(epoch_list, avg_MAE, color='k', label='mae')
|
| 186 |
+
fig.savefig(os.path.join(evalpath,data_set+".png"), format='png')
|
| 187 |
+
plt.show()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
elif data_set=="mpiigaze":
|
| 192 |
+
model_used=getArch(arch, bins)
|
| 193 |
+
|
| 194 |
+
for fold in range(15):
|
| 195 |
+
folder = os.listdir(args.gazeMpiilabel_dir)
|
| 196 |
+
folder.sort()
|
| 197 |
+
testlabelpathombined = [os.path.join(args.gazeMpiilabel_dir, j) for j in folder]
|
| 198 |
+
gaze_dataset=Mpiigaze(testlabelpathombined,args.gazeMpiimage_dir, transformations, False, angle, fold)
|
| 199 |
+
|
| 200 |
+
test_loader = torch.utils.data.DataLoader(
|
| 201 |
+
dataset=gaze_dataset,
|
| 202 |
+
batch_size=batch_size,
|
| 203 |
+
shuffle=True,
|
| 204 |
+
num_workers=4,
|
| 205 |
+
pin_memory=True)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if not os.path.exists(os.path.join(evalpath, f"fold"+str(fold))):
|
| 209 |
+
os.makedirs(os.path.join(evalpath, f"fold"+str(fold)))
|
| 210 |
+
|
| 211 |
+
# list all epochs for testing
|
| 212 |
+
folder = os.listdir(os.path.join(snapshot_path,"fold"+str(fold)))
|
| 213 |
+
folder.sort(key=natural_keys)
|
| 214 |
+
|
| 215 |
+
softmax = nn.Softmax(dim=1)
|
| 216 |
+
with open(os.path.join(evalpath, os.path.join("fold"+str(fold), data_set+".log")), 'w') as outfile:
|
| 217 |
+
configuration = f"\ntest configuration equal gpu_id={gpu}, batch_size={batch_size}, model_arch={arch}\nStart testing dataset={data_set}, fold={fold}---------------------------------------\n"
|
| 218 |
+
print(configuration)
|
| 219 |
+
outfile.write(configuration)
|
| 220 |
+
epoch_list=[]
|
| 221 |
+
avg_MAE=[]
|
| 222 |
+
for epochs in folder:
|
| 223 |
+
model=model_used
|
| 224 |
+
saved_state_dict = torch.load(os.path.join(snapshot_path+"/fold"+str(fold),epochs))
|
| 225 |
+
model= nn.DataParallel(model,device_ids=[0])
|
| 226 |
+
model.load_state_dict(saved_state_dict)
|
| 227 |
+
model.cuda(gpu)
|
| 228 |
+
model.eval()
|
| 229 |
+
total = 0
|
| 230 |
+
idx_tensor = [idx for idx in range(28)]
|
| 231 |
+
idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
|
| 232 |
+
avg_error = .0
|
| 233 |
+
with torch.no_grad():
|
| 234 |
+
for j, (images, labels, cont_labels, name) in enumerate(test_loader):
|
| 235 |
+
images = Variable(images).cuda(gpu)
|
| 236 |
+
total += cont_labels.size(0)
|
| 237 |
+
|
| 238 |
+
label_pitch = cont_labels[:,0].float()*np.pi/180
|
| 239 |
+
label_yaw = cont_labels[:,1].float()*np.pi/180
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
gaze_pitch, gaze_yaw = model(images)
|
| 243 |
+
|
| 244 |
+
# Binned predictions
|
| 245 |
+
_, pitch_bpred = torch.max(gaze_pitch.data, 1)
|
| 246 |
+
_, yaw_bpred = torch.max(gaze_yaw.data, 1)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# Continuous predictions
|
| 250 |
+
pitch_predicted = softmax(gaze_pitch)
|
| 251 |
+
yaw_predicted = softmax(gaze_yaw)
|
| 252 |
+
|
| 253 |
+
# mapping from binned (0 to 28) to angels (-42 to 42)
|
| 254 |
+
pitch_predicted = \
|
| 255 |
+
torch.sum(pitch_predicted * idx_tensor, 1).cpu() * 3 - 42
|
| 256 |
+
yaw_predicted = \
|
| 257 |
+
torch.sum(yaw_predicted * idx_tensor, 1).cpu() * 3 - 42
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
pitch_predicted = pitch_predicted*np.pi/180
|
| 261 |
+
yaw_predicted = yaw_predicted*np.pi/180
|
| 262 |
+
|
| 263 |
+
for p,y,pl,yl in zip(pitch_predicted, yaw_predicted, label_pitch, label_yaw):
|
| 264 |
+
avg_error += angular(gazeto3d([p,y]), gazeto3d([pl,yl]))
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
x = ''.join(filter(lambda i: i.isdigit(), epochs))
|
| 268 |
+
epoch_list.append(x)
|
| 269 |
+
avg_MAE.append(avg_error/ total)
|
| 270 |
+
loger = f"[{epochs}---{args.dataset}] Total Num:{total},MAE:{avg_error/total} \n"
|
| 271 |
+
outfile.write(loger)
|
| 272 |
+
print(loger)
|
| 273 |
+
|
| 274 |
+
fig = plt.figure(figsize=(14, 8))
|
| 275 |
+
plt.xlabel('epoch')
|
| 276 |
+
plt.ylabel('avg')
|
| 277 |
+
plt.title('Gaze angular error')
|
| 278 |
+
plt.legend()
|
| 279 |
+
plt.plot(epoch_list, avg_MAE, color='k', label='mae')
|
| 280 |
+
fig.savefig(os.path.join(evalpath, os.path.join("fold"+str(fold), data_set+".png")), format='png')
|
| 281 |
+
# plt.show()
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
|
models/L2CS-Net/train.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import torch.utils.model_zoo as model_zoo
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.autograd import Variable
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
import torch.backends.cudnn as cudnn
|
| 12 |
+
import torchvision
|
| 13 |
+
|
| 14 |
+
from l2cs import L2CS, select_device, Gaze360, Mpiigaze
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def parse_args():
|
| 18 |
+
"""Parse input arguments."""
|
| 19 |
+
parser = argparse.ArgumentParser(description='Gaze estimation using L2CSNet.')
|
| 20 |
+
# Gaze360
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
'--gaze360image_dir', dest='gaze360image_dir', help='Directory path for gaze images.',
|
| 23 |
+
default='datasets/Gaze360/Image', type=str)
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
'--gaze360label_dir', dest='gaze360label_dir', help='Directory path for gaze labels.',
|
| 26 |
+
default='datasets/Gaze360/Label/train.label', type=str)
|
| 27 |
+
# mpiigaze
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
'--gazeMpiimage_dir', dest='gazeMpiimage_dir', help='Directory path for gaze images.',
|
| 30 |
+
default='datasets/MPIIFaceGaze/Image', type=str)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
'--gazeMpiilabel_dir', dest='gazeMpiilabel_dir', help='Directory path for gaze labels.',
|
| 33 |
+
default='datasets/MPIIFaceGaze/Label', type=str)
|
| 34 |
+
|
| 35 |
+
# Important args -------------------------------------------------------------------------------------------------------
|
| 36 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
'--dataset', dest='dataset', help='mpiigaze, rtgene, gaze360, ethgaze',
|
| 39 |
+
default= "gaze360", type=str)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
'--output', dest='output', help='Path of output models.',
|
| 42 |
+
default='output/snapshots/', type=str)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
'--snapshot', dest='snapshot', help='Path of model snapshot.',
|
| 45 |
+
default='', type=str)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
'--gpu', dest='gpu_id', help='GPU device id to use [0] or multiple 0,1,2,3',
|
| 48 |
+
default='0', type=str)
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
'--num_epochs', dest='num_epochs', help='Maximum number of training epochs.',
|
| 51 |
+
default=60, type=int)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
'--batch_size', dest='batch_size', help='Batch size.',
|
| 54 |
+
default=1, type=int)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
'--arch', dest='arch', help='Network architecture, can be: ResNet18, ResNet34, [ResNet50], ''ResNet101, ResNet152, Squeezenet_1_0, Squeezenet_1_1, MobileNetV2',
|
| 57 |
+
default='ResNet50', type=str)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
'--alpha', dest='alpha', help='Regression loss coefficient.',
|
| 60 |
+
default=1, type=float)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
'--lr', dest='lr', help='Base learning rate.',
|
| 63 |
+
default=0.00001, type=float)
|
| 64 |
+
# ---------------------------------------------------------------------------------------------------------------------
|
| 65 |
+
# Important args ------------------------------------------------------------------------------------------------------
|
| 66 |
+
args = parser.parse_args()
|
| 67 |
+
return args
|
| 68 |
+
|
| 69 |
+
def get_ignored_params(model):
|
| 70 |
+
# Generator function that yields ignored params.
|
| 71 |
+
b = [model.conv1, model.bn1, model.fc_finetune]
|
| 72 |
+
for i in range(len(b)):
|
| 73 |
+
for module_name, module in b[i].named_modules():
|
| 74 |
+
if 'bn' in module_name:
|
| 75 |
+
module.eval()
|
| 76 |
+
for name, param in module.named_parameters():
|
| 77 |
+
yield param
|
| 78 |
+
|
| 79 |
+
def get_non_ignored_params(model):
|
| 80 |
+
# Generator function that yields params that will be optimized.
|
| 81 |
+
b = [model.layer1, model.layer2, model.layer3, model.layer4]
|
| 82 |
+
for i in range(len(b)):
|
| 83 |
+
for module_name, module in b[i].named_modules():
|
| 84 |
+
if 'bn' in module_name:
|
| 85 |
+
module.eval()
|
| 86 |
+
for name, param in module.named_parameters():
|
| 87 |
+
yield param
|
| 88 |
+
|
| 89 |
+
def get_fc_params(model):
|
| 90 |
+
# Generator function that yields fc layer params.
|
| 91 |
+
b = [model.fc_yaw_gaze, model.fc_pitch_gaze]
|
| 92 |
+
for i in range(len(b)):
|
| 93 |
+
for module_name, module in b[i].named_modules():
|
| 94 |
+
for name, param in module.named_parameters():
|
| 95 |
+
yield param
|
| 96 |
+
|
| 97 |
+
def load_filtered_state_dict(model, snapshot):
|
| 98 |
+
# By user apaszke from discuss.pytorch.org
|
| 99 |
+
model_dict = model.state_dict()
|
| 100 |
+
snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
|
| 101 |
+
model_dict.update(snapshot)
|
| 102 |
+
model.load_state_dict(model_dict)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def getArch_weights(arch, bins):
|
| 106 |
+
if arch == 'ResNet18':
|
| 107 |
+
model = L2CS(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], bins)
|
| 108 |
+
pre_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
| 109 |
+
elif arch == 'ResNet34':
|
| 110 |
+
model = L2CS(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], bins)
|
| 111 |
+
pre_url = 'https://download.pytorch.org/models/resnet34-333f7ec4.pth'
|
| 112 |
+
elif arch == 'ResNet101':
|
| 113 |
+
model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], bins)
|
| 114 |
+
pre_url = 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'
|
| 115 |
+
elif arch == 'ResNet152':
|
| 116 |
+
model = L2CS(torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
|
| 117 |
+
pre_url = 'https://download.pytorch.org/models/resnet152-b121ed2d.pth'
|
| 118 |
+
else:
|
| 119 |
+
if arch != 'ResNet50':
|
| 120 |
+
print('Invalid value for architecture is passed! '
|
| 121 |
+
'The default value of ResNet50 will be used instead!')
|
| 122 |
+
model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
|
| 123 |
+
pre_url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
|
| 124 |
+
|
| 125 |
+
return model, pre_url
|
| 126 |
+
|
| 127 |
+
if __name__ == '__main__':
|
| 128 |
+
args = parse_args()
|
| 129 |
+
cudnn.enabled = True
|
| 130 |
+
num_epochs = args.num_epochs
|
| 131 |
+
batch_size = args.batch_size
|
| 132 |
+
gpu = select_device(args.gpu_id, batch_size=args.batch_size)
|
| 133 |
+
data_set=args.dataset
|
| 134 |
+
alpha = args.alpha
|
| 135 |
+
output=args.output
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
transformations = transforms.Compose([
|
| 139 |
+
transforms.Resize(448),
|
| 140 |
+
transforms.ToTensor(),
|
| 141 |
+
transforms.Normalize(
|
| 142 |
+
mean=[0.485, 0.456, 0.406],
|
| 143 |
+
std=[0.229, 0.224, 0.225]
|
| 144 |
+
)
|
| 145 |
+
])
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
if data_set=="gaze360":
|
| 150 |
+
model, pre_url = getArch_weights(args.arch, 90)
|
| 151 |
+
if args.snapshot == '':
|
| 152 |
+
load_filtered_state_dict(model, model_zoo.load_url(pre_url))
|
| 153 |
+
else:
|
| 154 |
+
saved_state_dict = torch.load(args.snapshot)
|
| 155 |
+
model.load_state_dict(saved_state_dict)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
model.cuda(gpu)
|
| 159 |
+
dataset=Gaze360(args.gaze360label_dir, args.gaze360image_dir, transformations, 180, 4)
|
| 160 |
+
print('Loading data.')
|
| 161 |
+
train_loader_gaze = DataLoader(
|
| 162 |
+
dataset=dataset,
|
| 163 |
+
batch_size=int(batch_size),
|
| 164 |
+
shuffle=True,
|
| 165 |
+
num_workers=0,
|
| 166 |
+
pin_memory=True)
|
| 167 |
+
torch.backends.cudnn.benchmark = True
|
| 168 |
+
|
| 169 |
+
summary_name = '{}_{}'.format('L2CS-gaze360-', int(time.time()))
|
| 170 |
+
output=os.path.join(output, summary_name)
|
| 171 |
+
if not os.path.exists(output):
|
| 172 |
+
os.makedirs(output)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
criterion = nn.CrossEntropyLoss().cuda(gpu)
|
| 176 |
+
reg_criterion = nn.MSELoss().cuda(gpu)
|
| 177 |
+
softmax = nn.Softmax(dim=1).cuda(gpu)
|
| 178 |
+
idx_tensor = [idx for idx in range(90)]
|
| 179 |
+
idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# Optimizer gaze
|
| 183 |
+
optimizer_gaze = torch.optim.Adam([
|
| 184 |
+
{'params': get_ignored_params(model), 'lr': 0},
|
| 185 |
+
{'params': get_non_ignored_params(model), 'lr': args.lr},
|
| 186 |
+
{'params': get_fc_params(model), 'lr': args.lr}
|
| 187 |
+
], args.lr)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
configuration = f"\ntrain configuration, gpu_id={args.gpu_id}, batch_size={batch_size}, model_arch={args.arch}\nStart testing dataset={data_set}, loader={len(train_loader_gaze)}------------------------- \n"
|
| 191 |
+
print(configuration)
|
| 192 |
+
for epoch in range(num_epochs):
|
| 193 |
+
sum_loss_pitch_gaze = sum_loss_yaw_gaze = iter_gaze = 0
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
for i, (images_gaze, labels_gaze, cont_labels_gaze,name) in enumerate(train_loader_gaze):
|
| 197 |
+
images_gaze = Variable(images_gaze).cuda(gpu)
|
| 198 |
+
|
| 199 |
+
# Binned labels
|
| 200 |
+
label_pitch_gaze = Variable(labels_gaze[:, 0]).cuda(gpu)
|
| 201 |
+
label_yaw_gaze = Variable(labels_gaze[:, 1]).cuda(gpu)
|
| 202 |
+
|
| 203 |
+
# Continuous labels
|
| 204 |
+
label_pitch_cont_gaze = Variable(cont_labels_gaze[:, 0]).cuda(gpu)
|
| 205 |
+
label_yaw_cont_gaze = Variable(cont_labels_gaze[:, 1]).cuda(gpu)
|
| 206 |
+
|
| 207 |
+
pitch, yaw = model(images_gaze)
|
| 208 |
+
|
| 209 |
+
# Cross entropy loss
|
| 210 |
+
loss_pitch_gaze = criterion(pitch, label_pitch_gaze)
|
| 211 |
+
loss_yaw_gaze = criterion(yaw, label_yaw_gaze)
|
| 212 |
+
|
| 213 |
+
# MSE loss
|
| 214 |
+
pitch_predicted = softmax(pitch)
|
| 215 |
+
yaw_predicted = softmax(yaw)
|
| 216 |
+
|
| 217 |
+
pitch_predicted = \
|
| 218 |
+
torch.sum(pitch_predicted * idx_tensor, 1) * 4 - 180
|
| 219 |
+
yaw_predicted = \
|
| 220 |
+
torch.sum(yaw_predicted * idx_tensor, 1) * 4 - 180
|
| 221 |
+
|
| 222 |
+
loss_reg_pitch = reg_criterion(
|
| 223 |
+
pitch_predicted, label_pitch_cont_gaze)
|
| 224 |
+
loss_reg_yaw = reg_criterion(
|
| 225 |
+
yaw_predicted, label_yaw_cont_gaze)
|
| 226 |
+
|
| 227 |
+
# Total loss
|
| 228 |
+
loss_pitch_gaze += alpha * loss_reg_pitch
|
| 229 |
+
loss_yaw_gaze += alpha * loss_reg_yaw
|
| 230 |
+
|
| 231 |
+
sum_loss_pitch_gaze += loss_pitch_gaze
|
| 232 |
+
sum_loss_yaw_gaze += loss_yaw_gaze
|
| 233 |
+
|
| 234 |
+
loss_seq = [loss_pitch_gaze, loss_yaw_gaze]
|
| 235 |
+
grad_seq = [torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))]
|
| 236 |
+
optimizer_gaze.zero_grad(set_to_none=True)
|
| 237 |
+
torch.autograd.backward(loss_seq, grad_seq)
|
| 238 |
+
optimizer_gaze.step()
|
| 239 |
+
# scheduler.step()
|
| 240 |
+
|
| 241 |
+
iter_gaze += 1
|
| 242 |
+
|
| 243 |
+
if (i+1) % 100 == 0:
|
| 244 |
+
print('Epoch [%d/%d], Iter [%d/%d] Losses: '
|
| 245 |
+
'Gaze Yaw %.4f,Gaze Pitch %.4f' % (
|
| 246 |
+
epoch+1,
|
| 247 |
+
num_epochs,
|
| 248 |
+
i+1,
|
| 249 |
+
len(dataset)//batch_size,
|
| 250 |
+
sum_loss_pitch_gaze/iter_gaze,
|
| 251 |
+
sum_loss_yaw_gaze/iter_gaze
|
| 252 |
+
)
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
if epoch % 1 == 0 and epoch < num_epochs:
|
| 257 |
+
print('Taking snapshot...',
|
| 258 |
+
torch.save(model.state_dict(),
|
| 259 |
+
output +'/'+
|
| 260 |
+
'_epoch_' + str(epoch+1) + '.pkl')
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
elif data_set=="mpiigaze":
|
| 266 |
+
folder = os.listdir(args.gazeMpiilabel_dir)
|
| 267 |
+
folder.sort()
|
| 268 |
+
testlabelpathombined = [os.path.join(args.gazeMpiilabel_dir, j) for j in folder]
|
| 269 |
+
for fold in range(15):
|
| 270 |
+
model, pre_url = getArch_weights(args.arch, 28)
|
| 271 |
+
load_filtered_state_dict(model, model_zoo.load_url(pre_url))
|
| 272 |
+
model = nn.DataParallel(model)
|
| 273 |
+
model.to(gpu)
|
| 274 |
+
print('Loading data.')
|
| 275 |
+
dataset=Mpiigaze(testlabelpathombined,args.gazeMpiimage_dir, transformations, True, fold)
|
| 276 |
+
train_loader_gaze = DataLoader(
|
| 277 |
+
dataset=dataset,
|
| 278 |
+
batch_size=int(batch_size),
|
| 279 |
+
shuffle=True,
|
| 280 |
+
num_workers=4,
|
| 281 |
+
pin_memory=True)
|
| 282 |
+
torch.backends.cudnn.benchmark = True
|
| 283 |
+
|
| 284 |
+
summary_name = '{}_{}'.format('L2CS-mpiigaze', int(time.time()))
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
if not os.path.exists(os.path.join(output+'/{}'.format(summary_name),'fold' + str(fold))):
|
| 288 |
+
os.makedirs(os.path.join(output+'/{}'.format(summary_name),'fold' + str(fold)))
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
criterion = nn.CrossEntropyLoss().cuda(gpu)
|
| 292 |
+
reg_criterion = nn.MSELoss().cuda(gpu)
|
| 293 |
+
softmax = nn.Softmax(dim=1).cuda(gpu)
|
| 294 |
+
idx_tensor = [idx for idx in range(28)]
|
| 295 |
+
idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
|
| 296 |
+
|
| 297 |
+
# Optimizer gaze
|
| 298 |
+
optimizer_gaze = torch.optim.Adam([
|
| 299 |
+
{'params': get_ignored_params(model, args.arch), 'lr': 0},
|
| 300 |
+
{'params': get_non_ignored_params(model, args.arch), 'lr': args.lr},
|
| 301 |
+
{'params': get_fc_params(model, args.arch), 'lr': args.lr}
|
| 302 |
+
], args.lr)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
configuration = f"\ntrain configuration, gpu_id={args.gpu_id}, batch_size={batch_size}, model_arch={args.arch}\n Start training dataset={data_set}, loader={len(train_loader_gaze)}, fold={fold}--------------\n"
|
| 307 |
+
print(configuration)
|
| 308 |
+
for epoch in range(num_epochs):
|
| 309 |
+
sum_loss_pitch_gaze = sum_loss_yaw_gaze = iter_gaze = 0
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
for i, (images_gaze, labels_gaze, cont_labels_gaze,name) in enumerate(train_loader_gaze):
|
| 313 |
+
images_gaze = Variable(images_gaze).cuda(gpu)
|
| 314 |
+
|
| 315 |
+
# Binned labels
|
| 316 |
+
label_pitch_gaze = Variable(labels_gaze[:, 0]).cuda(gpu)
|
| 317 |
+
label_yaw_gaze = Variable(labels_gaze[:, 1]).cuda(gpu)
|
| 318 |
+
|
| 319 |
+
# Continuous labels
|
| 320 |
+
label_pitch_cont_gaze = Variable(cont_labels_gaze[:, 0]).cuda(gpu)
|
| 321 |
+
label_yaw_cont_gaze = Variable(cont_labels_gaze[:, 1]).cuda(gpu)
|
| 322 |
+
|
| 323 |
+
pitch, yaw = model(images_gaze)
|
| 324 |
+
|
| 325 |
+
# Cross entropy loss
|
| 326 |
+
loss_pitch_gaze = criterion(pitch, label_pitch_gaze)
|
| 327 |
+
loss_yaw_gaze = criterion(yaw, label_yaw_gaze)
|
| 328 |
+
|
| 329 |
+
# MSE loss
|
| 330 |
+
pitch_predicted = softmax(pitch)
|
| 331 |
+
yaw_predicted = softmax(yaw)
|
| 332 |
+
|
| 333 |
+
pitch_predicted = \
|
| 334 |
+
torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 42
|
| 335 |
+
yaw_predicted = \
|
| 336 |
+
torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 42
|
| 337 |
+
|
| 338 |
+
loss_reg_pitch = reg_criterion(
|
| 339 |
+
pitch_predicted, label_pitch_cont_gaze)
|
| 340 |
+
loss_reg_yaw = reg_criterion(
|
| 341 |
+
yaw_predicted, label_yaw_cont_gaze)
|
| 342 |
+
|
| 343 |
+
# Total loss
|
| 344 |
+
loss_pitch_gaze += alpha * loss_reg_pitch
|
| 345 |
+
loss_yaw_gaze += alpha * loss_reg_yaw
|
| 346 |
+
|
| 347 |
+
sum_loss_pitch_gaze += loss_pitch_gaze
|
| 348 |
+
sum_loss_yaw_gaze += loss_yaw_gaze
|
| 349 |
+
|
| 350 |
+
loss_seq = [loss_pitch_gaze, loss_yaw_gaze]
|
| 351 |
+
grad_seq = \
|
| 352 |
+
[torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))]
|
| 353 |
+
|
| 354 |
+
optimizer_gaze.zero_grad(set_to_none=True)
|
| 355 |
+
torch.autograd.backward(loss_seq, grad_seq)
|
| 356 |
+
optimizer_gaze.step()
|
| 357 |
+
|
| 358 |
+
iter_gaze += 1
|
| 359 |
+
|
| 360 |
+
if (i+1) % 100 == 0:
|
| 361 |
+
print('Epoch [%d/%d], Iter [%d/%d] Losses: '
|
| 362 |
+
'Gaze Yaw %.4f,Gaze Pitch %.4f' % (
|
| 363 |
+
epoch+1,
|
| 364 |
+
num_epochs,
|
| 365 |
+
i+1,
|
| 366 |
+
len(dataset)//batch_size,
|
| 367 |
+
sum_loss_pitch_gaze/iter_gaze,
|
| 368 |
+
sum_loss_yaw_gaze/iter_gaze
|
| 369 |
+
)
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
# Save models at numbered epochs.
|
| 375 |
+
if epoch % 1 == 0 and epoch < num_epochs:
|
| 376 |
+
print('Taking snapshot...',
|
| 377 |
+
torch.save(model.state_dict(),
|
| 378 |
+
output+'/fold' + str(fold) +'/'+
|
| 379 |
+
'_epoch_' + str(epoch+1) + '.pkl')
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
|
models/gaze_calibration.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 9-point gaze calibration for L2CS-Net
|
| 2 |
+
# Maps raw gaze angles -> normalised screen coords via polynomial least-squares.
|
| 3 |
+
# Centre point is the bias reference (subtracted from all readings).
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
|
| 8 |
+
# 3x3 grid, centre first (bias ref), then row by row
|
| 9 |
+
DEFAULT_TARGETS = [
|
| 10 |
+
(0.5, 0.5),
|
| 11 |
+
(0.15, 0.15), (0.50, 0.15), (0.85, 0.15),
|
| 12 |
+
(0.15, 0.50), (0.85, 0.50),
|
| 13 |
+
(0.15, 0.85), (0.50, 0.85), (0.85, 0.85),
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class _PointSamples:
|
| 19 |
+
target_x: float
|
| 20 |
+
target_y: float
|
| 21 |
+
yaws: list = field(default_factory=list)
|
| 22 |
+
pitches: list = field(default_factory=list)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _iqr_filter(values):
|
| 26 |
+
if len(values) < 4:
|
| 27 |
+
return values
|
| 28 |
+
arr = np.array(values)
|
| 29 |
+
q1, q3 = np.percentile(arr, [25, 75])
|
| 30 |
+
iqr = q3 - q1
|
| 31 |
+
lo, hi = q1 - 1.5 * iqr, q3 + 1.5 * iqr
|
| 32 |
+
return arr[(arr >= lo) & (arr <= hi)].tolist()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class GazeCalibration:
|
| 36 |
+
|
| 37 |
+
def __init__(self, targets=None):
|
| 38 |
+
self._targets = targets or list(DEFAULT_TARGETS)
|
| 39 |
+
self._points = [_PointSamples(tx, ty) for tx, ty in self._targets]
|
| 40 |
+
self._current_idx = 0
|
| 41 |
+
self._fitted = False
|
| 42 |
+
self._W = None # (6, 2) polynomial weights
|
| 43 |
+
self._yaw_bias = 0.0
|
| 44 |
+
self._pitch_bias = 0.0
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def num_points(self):
|
| 48 |
+
return len(self._targets)
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def current_index(self):
|
| 52 |
+
return self._current_idx
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def current_target(self):
|
| 56 |
+
if self._current_idx < len(self._targets):
|
| 57 |
+
return self._targets[self._current_idx]
|
| 58 |
+
return self._targets[-1]
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def is_complete(self):
|
| 62 |
+
return self._current_idx >= len(self._targets)
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def is_fitted(self):
|
| 66 |
+
return self._fitted
|
| 67 |
+
|
| 68 |
+
def collect_sample(self, yaw_rad, pitch_rad):
|
| 69 |
+
if self._current_idx >= len(self._points):
|
| 70 |
+
return
|
| 71 |
+
pt = self._points[self._current_idx]
|
| 72 |
+
pt.yaws.append(float(yaw_rad))
|
| 73 |
+
pt.pitches.append(float(pitch_rad))
|
| 74 |
+
|
| 75 |
+
def advance(self):
|
| 76 |
+
self._current_idx += 1
|
| 77 |
+
return self._current_idx < len(self._targets)
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def _poly_features(yaw, pitch):
|
| 81 |
+
# [yaw^2, pitch^2, yaw*pitch, yaw, pitch, 1]
|
| 82 |
+
return np.array([yaw**2, pitch**2, yaw * pitch, yaw, pitch, 1.0],
|
| 83 |
+
dtype=np.float64)
|
| 84 |
+
|
| 85 |
+
def fit(self):
|
| 86 |
+
# bias from centre point (index 0)
|
| 87 |
+
center = self._points[0]
|
| 88 |
+
center_yaws = _iqr_filter(center.yaws)
|
| 89 |
+
center_pitches = _iqr_filter(center.pitches)
|
| 90 |
+
if len(center_yaws) < 2 or len(center_pitches) < 2:
|
| 91 |
+
return False
|
| 92 |
+
self._yaw_bias = float(np.median(center_yaws))
|
| 93 |
+
self._pitch_bias = float(np.median(center_pitches))
|
| 94 |
+
|
| 95 |
+
rows_A, rows_B = [], []
|
| 96 |
+
for pt in self._points:
|
| 97 |
+
clean_yaws = _iqr_filter(pt.yaws)
|
| 98 |
+
clean_pitches = _iqr_filter(pt.pitches)
|
| 99 |
+
if len(clean_yaws) < 2 or len(clean_pitches) < 2:
|
| 100 |
+
continue
|
| 101 |
+
med_yaw = float(np.median(clean_yaws)) - self._yaw_bias
|
| 102 |
+
med_pitch = float(np.median(clean_pitches)) - self._pitch_bias
|
| 103 |
+
rows_A.append(self._poly_features(med_yaw, med_pitch))
|
| 104 |
+
rows_B.append([pt.target_x, pt.target_y])
|
| 105 |
+
|
| 106 |
+
if len(rows_A) < 5:
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
A = np.array(rows_A, dtype=np.float64)
|
| 110 |
+
B = np.array(rows_B, dtype=np.float64)
|
| 111 |
+
try:
|
| 112 |
+
W, _, _, _ = np.linalg.lstsq(A, B, rcond=None)
|
| 113 |
+
self._W = W
|
| 114 |
+
self._fitted = True
|
| 115 |
+
return True
|
| 116 |
+
except np.linalg.LinAlgError:
|
| 117 |
+
return False
|
| 118 |
+
|
| 119 |
+
def predict(self, yaw_rad, pitch_rad):
|
| 120 |
+
if not self._fitted or self._W is None:
|
| 121 |
+
return 0.5, 0.5
|
| 122 |
+
feat = self._poly_features(yaw_rad - self._yaw_bias, pitch_rad - self._pitch_bias)
|
| 123 |
+
xy = feat @ self._W
|
| 124 |
+
return float(np.clip(xy[0], 0, 1)), float(np.clip(xy[1], 0, 1))
|
| 125 |
+
|
| 126 |
+
def to_dict(self):
|
| 127 |
+
return {
|
| 128 |
+
"targets": self._targets,
|
| 129 |
+
"fitted": self._fitted,
|
| 130 |
+
"current_index": self._current_idx,
|
| 131 |
+
"W": self._W.tolist() if self._W is not None else None,
|
| 132 |
+
"yaw_bias": self._yaw_bias,
|
| 133 |
+
"pitch_bias": self._pitch_bias,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
@classmethod
|
| 137 |
+
def from_dict(cls, d):
|
| 138 |
+
cal = cls(targets=d.get("targets", DEFAULT_TARGETS))
|
| 139 |
+
cal._fitted = d.get("fitted", False)
|
| 140 |
+
cal._current_idx = d.get("current_index", 0)
|
| 141 |
+
cal._yaw_bias = d.get("yaw_bias", 0.0)
|
| 142 |
+
cal._pitch_bias = d.get("pitch_bias", 0.0)
|
| 143 |
+
w = d.get("W")
|
| 144 |
+
if w is not None:
|
| 145 |
+
cal._W = np.array(w, dtype=np.float64)
|
| 146 |
+
return cal
|
models/gaze_eye_fusion.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fuses calibrated gaze position with eye openness (EAR) for focus detection.
|
| 2 |
+
# Takes L2CS gaze angles + MediaPipe landmarks, outputs screen coords + focus decision.
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from .gaze_calibration import GazeCalibration
|
| 8 |
+
from .eye_scorer import compute_avg_ear
|
| 9 |
+
|
| 10 |
+
_EAR_BLINK = 0.18
|
| 11 |
+
_ON_SCREEN_MARGIN = 0.08
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GazeEyeFusion:
|
| 15 |
+
|
| 16 |
+
def __init__(self, calibration, ear_weight=0.3, gaze_weight=0.7, focus_threshold=0.52):
|
| 17 |
+
if not calibration.is_fitted:
|
| 18 |
+
raise ValueError("Calibration must be fitted first")
|
| 19 |
+
self._cal = calibration
|
| 20 |
+
self._ear_w = ear_weight
|
| 21 |
+
self._gaze_w = gaze_weight
|
| 22 |
+
self._threshold = focus_threshold
|
| 23 |
+
self._smooth_x = 0.5
|
| 24 |
+
self._smooth_y = 0.5
|
| 25 |
+
self._alpha = 0.5
|
| 26 |
+
|
| 27 |
+
def update(self, yaw_rad, pitch_rad, landmarks):
|
| 28 |
+
gx, gy = self._cal.predict(yaw_rad, pitch_rad)
|
| 29 |
+
|
| 30 |
+
# EMA smooth the gaze position
|
| 31 |
+
self._smooth_x += self._alpha * (gx - self._smooth_x)
|
| 32 |
+
self._smooth_y += self._alpha * (gy - self._smooth_y)
|
| 33 |
+
gx, gy = self._smooth_x, self._smooth_y
|
| 34 |
+
|
| 35 |
+
on_screen = (
|
| 36 |
+
-_ON_SCREEN_MARGIN <= gx <= 1.0 + _ON_SCREEN_MARGIN and
|
| 37 |
+
-_ON_SCREEN_MARGIN <= gy <= 1.0 + _ON_SCREEN_MARGIN
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
ear = None
|
| 41 |
+
ear_score = 1.0
|
| 42 |
+
if landmarks is not None:
|
| 43 |
+
ear = compute_avg_ear(landmarks)
|
| 44 |
+
ear_score = 0.0 if ear < _EAR_BLINK else min(ear / 0.30, 1.0)
|
| 45 |
+
|
| 46 |
+
# penalise gaze near screen edges
|
| 47 |
+
gaze_score = 1.0 if on_screen else 0.0
|
| 48 |
+
if on_screen:
|
| 49 |
+
dx = max(0.0, abs(gx - 0.5) - 0.3)
|
| 50 |
+
dy = max(0.0, abs(gy - 0.5) - 0.3)
|
| 51 |
+
gaze_score = max(0.0, 1.0 - math.sqrt(dx**2 + dy**2) * 5.0)
|
| 52 |
+
|
| 53 |
+
score = float(np.clip(self._gaze_w * gaze_score + self._ear_w * ear_score, 0, 1))
|
| 54 |
+
|
| 55 |
+
return {
|
| 56 |
+
"gaze_x": round(float(gx), 4),
|
| 57 |
+
"gaze_y": round(float(gy), 4),
|
| 58 |
+
"on_screen": on_screen,
|
| 59 |
+
"ear": round(ear, 4) if ear is not None else None,
|
| 60 |
+
"focus_score": round(score, 4),
|
| 61 |
+
"focused": score >= self._threshold,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
def reset(self):
|
| 65 |
+
self._smooth_x = 0.5
|
| 66 |
+
self._smooth_y = 0.5
|
requirements.txt
CHANGED
|
@@ -20,3 +20,5 @@ xgboost>=2.0.0
|
|
| 20 |
clearml>=2.0.2
|
| 21 |
pytest>=9.0.0
|
| 22 |
pytest-cov>=5.0.0
|
|
|
|
|
|
|
|
|
| 20 |
clearml>=2.0.2
|
| 21 |
pytest>=9.0.0
|
| 22 |
pytest-cov>=5.0.0
|
| 23 |
+
face_detection @ git+https://github.com/elliottzheng/face-detection
|
| 24 |
+
gdown>=5.0.0
|
src/components/CalibrationOverlay.jsx
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import React, { useState, useEffect, useRef, useCallback } from 'react';
|
| 2 |
+
|
| 3 |
+
const COLLECT_MS = 2000;
|
| 4 |
+
const CENTER_MS = 3000; // centre point gets extra time (bias reference)
|
| 5 |
+
|
| 6 |
+
function CalibrationOverlay({ calibration, videoManager }) {
|
| 7 |
+
const [progress, setProgress] = useState(0);
|
| 8 |
+
const timerRef = useRef(null);
|
| 9 |
+
const startRef = useRef(null);
|
| 10 |
+
const overlayRef = useRef(null);
|
| 11 |
+
|
| 12 |
+
const enterFullscreen = useCallback(() => {
|
| 13 |
+
const el = overlayRef.current;
|
| 14 |
+
if (!el) return;
|
| 15 |
+
const req = el.requestFullscreen || el.webkitRequestFullscreen || el.msRequestFullscreen;
|
| 16 |
+
if (req) req.call(el).catch(() => {});
|
| 17 |
+
}, []);
|
| 18 |
+
|
| 19 |
+
const exitFullscreen = useCallback(() => {
|
| 20 |
+
if (document.fullscreenElement || document.webkitFullscreenElement) {
|
| 21 |
+
const exit = document.exitFullscreen || document.webkitExitFullscreen || document.msExitFullscreen;
|
| 22 |
+
if (exit) exit.call(document).catch(() => {});
|
| 23 |
+
}
|
| 24 |
+
}, []);
|
| 25 |
+
|
| 26 |
+
useEffect(() => {
|
| 27 |
+
if (calibration && calibration.active && !calibration.done) {
|
| 28 |
+
const t = setTimeout(enterFullscreen, 100);
|
| 29 |
+
return () => clearTimeout(t);
|
| 30 |
+
}
|
| 31 |
+
}, [calibration?.active]);
|
| 32 |
+
|
| 33 |
+
useEffect(() => {
|
| 34 |
+
if (!calibration || !calibration.active) exitFullscreen();
|
| 35 |
+
}, [calibration?.active]);
|
| 36 |
+
|
| 37 |
+
useEffect(() => {
|
| 38 |
+
if (!calibration || !calibration.collecting || calibration.done) {
|
| 39 |
+
setProgress(0);
|
| 40 |
+
if (timerRef.current) cancelAnimationFrame(timerRef.current);
|
| 41 |
+
return;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
startRef.current = performance.now();
|
| 45 |
+
const duration = calibration.index === 0 ? CENTER_MS : COLLECT_MS;
|
| 46 |
+
|
| 47 |
+
const tick = () => {
|
| 48 |
+
const pct = Math.min((performance.now() - startRef.current) / duration, 1);
|
| 49 |
+
setProgress(pct);
|
| 50 |
+
if (pct >= 1) {
|
| 51 |
+
if (videoManager) videoManager.nextCalibrationPoint();
|
| 52 |
+
startRef.current = performance.now();
|
| 53 |
+
setProgress(0);
|
| 54 |
+
}
|
| 55 |
+
timerRef.current = requestAnimationFrame(tick);
|
| 56 |
+
};
|
| 57 |
+
timerRef.current = requestAnimationFrame(tick);
|
| 58 |
+
|
| 59 |
+
return () => { if (timerRef.current) cancelAnimationFrame(timerRef.current); };
|
| 60 |
+
}, [calibration?.index, calibration?.collecting, calibration?.done]);
|
| 61 |
+
|
| 62 |
+
const handleCancel = () => {
|
| 63 |
+
if (videoManager) videoManager.cancelCalibration();
|
| 64 |
+
exitFullscreen();
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
if (!calibration || !calibration.active) return null;
|
| 68 |
+
|
| 69 |
+
if (calibration.done) {
|
| 70 |
+
return (
|
| 71 |
+
<div ref={overlayRef} style={overlayStyle}>
|
| 72 |
+
<div style={messageBoxStyle}>
|
| 73 |
+
<h2 style={{ margin: '0 0 10px', color: calibration.success ? '#4ade80' : '#f87171' }}>
|
| 74 |
+
{calibration.success ? 'Calibration Complete' : 'Calibration Failed'}
|
| 75 |
+
</h2>
|
| 76 |
+
<p style={{ color: '#ccc', margin: 0 }}>
|
| 77 |
+
{calibration.success
|
| 78 |
+
? 'Gaze tracking is now active.'
|
| 79 |
+
: 'Not enough samples collected. Try again.'}
|
| 80 |
+
</p>
|
| 81 |
+
</div>
|
| 82 |
+
</div>
|
| 83 |
+
);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
const [tx, ty] = calibration.target || [0.5, 0.5];
|
| 87 |
+
|
| 88 |
+
return (
|
| 89 |
+
<div ref={overlayRef} style={overlayStyle}>
|
| 90 |
+
<div style={{
|
| 91 |
+
position: 'absolute', top: '30px', left: '50%', transform: 'translateX(-50%)',
|
| 92 |
+
color: '#fff', fontSize: '16px', textAlign: 'center',
|
| 93 |
+
textShadow: '0 0 8px rgba(0,0,0,0.8)', pointerEvents: 'none',
|
| 94 |
+
}}>
|
| 95 |
+
<div style={{ fontWeight: 'bold', fontSize: '20px' }}>
|
| 96 |
+
Look at the dot ({calibration.index + 1}/{calibration.numPoints})
|
| 97 |
+
</div>
|
| 98 |
+
<div style={{ fontSize: '14px', color: '#aaa', marginTop: '6px' }}>
|
| 99 |
+
{calibration.index === 0
|
| 100 |
+
? 'Look at the center dot - this sets your baseline'
|
| 101 |
+
: 'Hold your gaze steady on the target'}
|
| 102 |
+
</div>
|
| 103 |
+
</div>
|
| 104 |
+
|
| 105 |
+
<div style={{
|
| 106 |
+
position: 'absolute', left: `${tx * 100}%`, top: `${ty * 100}%`,
|
| 107 |
+
transform: 'translate(-50%, -50%)',
|
| 108 |
+
}}>
|
| 109 |
+
<svg width="60" height="60" style={{ position: 'absolute', left: '-30px', top: '-30px' }}>
|
| 110 |
+
<circle cx="30" cy="30" r="24" fill="none" stroke="rgba(255,255,255,0.15)" strokeWidth="3" />
|
| 111 |
+
<circle cx="30" cy="30" r="24" fill="none" stroke="#4ade80" strokeWidth="3"
|
| 112 |
+
strokeDasharray={`${progress * 150.8} 150.8`} strokeLinecap="round"
|
| 113 |
+
transform="rotate(-90, 30, 30)" />
|
| 114 |
+
</svg>
|
| 115 |
+
<div style={{
|
| 116 |
+
width: '20px', height: '20px', borderRadius: '50%',
|
| 117 |
+
background: 'radial-gradient(circle, #fff 30%, #4ade80 100%)',
|
| 118 |
+
boxShadow: '0 0 20px rgba(74, 222, 128, 0.8)',
|
| 119 |
+
}} />
|
| 120 |
+
</div>
|
| 121 |
+
|
| 122 |
+
<button onClick={handleCancel} style={{
|
| 123 |
+
position: 'absolute', bottom: '40px', left: '50%', transform: 'translateX(-50%)',
|
| 124 |
+
padding: '10px 28px', background: 'rgba(255,255,255,0.1)',
|
| 125 |
+
border: '1px solid rgba(255,255,255,0.3)', color: '#fff',
|
| 126 |
+
borderRadius: '20px', cursor: 'pointer', fontSize: '14px',
|
| 127 |
+
}}>
|
| 128 |
+
Cancel Calibration
|
| 129 |
+
</button>
|
| 130 |
+
</div>
|
| 131 |
+
);
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
const overlayStyle = {
|
| 135 |
+
position: 'fixed', top: 0, left: 0, width: '100vw', height: '100vh',
|
| 136 |
+
background: 'rgba(0, 0, 0, 0.92)', zIndex: 10000,
|
| 137 |
+
display: 'flex', alignItems: 'center', justifyContent: 'center',
|
| 138 |
+
};
|
| 139 |
+
|
| 140 |
+
const messageBoxStyle = {
|
| 141 |
+
textAlign: 'center', padding: '30px 40px',
|
| 142 |
+
background: 'rgba(30, 30, 50, 0.9)', borderRadius: '16px',
|
| 143 |
+
border: '1px solid rgba(255,255,255,0.1)',
|
| 144 |
+
};
|
| 145 |
+
|
| 146 |
+
export default CalibrationOverlay;
|
src/components/FocusPageLocal.jsx
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import React, { useState, useEffect, useRef } from 'react';
|
|
|
|
| 2 |
|
| 3 |
const FLOW_STEPS = {
|
| 4 |
intro: 'intro',
|
|
@@ -48,6 +49,9 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 48 |
const [isStarting, setIsStarting] = useState(false);
|
| 49 |
const [focusState, setFocusState] = useState(FOCUS_STATES.pending);
|
| 50 |
const [cameraError, setCameraError] = useState('');
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
const localVideoRef = useRef(null);
|
| 53 |
const displayCanvasRef = useRef(null);
|
|
@@ -127,6 +131,8 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 127 |
setFocusState(FOCUS_STATES.pending);
|
| 128 |
setCameraReady(false);
|
| 129 |
if (originalOnSessionEnd) originalOnSessionEnd(summary);
|
|
|
|
|
|
|
| 130 |
};
|
| 131 |
|
| 132 |
const statsInterval = setInterval(() => {
|
|
@@ -136,8 +142,10 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 136 |
}, 1000);
|
| 137 |
|
| 138 |
return () => {
|
| 139 |
-
videoManager
|
| 140 |
-
|
|
|
|
|
|
|
| 141 |
clearInterval(statsInterval);
|
| 142 |
};
|
| 143 |
}, [videoManager]);
|
|
@@ -149,6 +157,8 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 149 |
.then((data) => {
|
| 150 |
if (data.available) setAvailableModels(data.available);
|
| 151 |
if (data.current) setCurrentModel(data.current);
|
|
|
|
|
|
|
| 152 |
})
|
| 153 |
.catch((err) => console.error('Failed to fetch models:', err));
|
| 154 |
}, []);
|
|
@@ -204,6 +214,8 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 204 |
const result = await res.json();
|
| 205 |
if (result.updated) {
|
| 206 |
setCurrentModel(modelName);
|
|
|
|
|
|
|
| 207 |
}
|
| 208 |
} catch (err) {
|
| 209 |
console.error('Failed to switch model:', err);
|
|
@@ -225,6 +237,21 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 225 |
console.error('Camera init error:', err);
|
| 226 |
}
|
| 227 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
const handleStart = async () => {
|
| 229 |
try {
|
| 230 |
setIsStarting(true);
|
|
@@ -697,6 +724,65 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 697 |
}}>
|
| 698 |
<span title="Server CPU">CPU: <strong style={{ color: '#8f8' }}>{systemStats.cpu_percent}%</strong></span>
|
| 699 |
<span title="Server memory">RAM: <strong style={{ color: '#8af' }}>{systemStats.memory_percent}%</strong> ({systemStats.memory_used_mb}/{systemStats.memory_total_mb} MB)</span>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
</section>
|
| 701 |
)}
|
| 702 |
|
|
@@ -787,6 +873,58 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
|
|
| 787 |
</section>
|
| 788 |
</>
|
| 789 |
) : null}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 790 |
</main>
|
| 791 |
);
|
| 792 |
}
|
|
|
|
| 1 |
import React, { useState, useEffect, useRef } from 'react';
|
| 2 |
+
import CalibrationOverlay from './CalibrationOverlay';
|
| 3 |
|
| 4 |
const FLOW_STEPS = {
|
| 5 |
intro: 'intro',
|
|
|
|
| 49 |
const [isStarting, setIsStarting] = useState(false);
|
| 50 |
const [focusState, setFocusState] = useState(FOCUS_STATES.pending);
|
| 51 |
const [cameraError, setCameraError] = useState('');
|
| 52 |
+
const [calibration, setCalibration] = useState(null);
|
| 53 |
+
const [l2csBoost, setL2csBoost] = useState(false);
|
| 54 |
+
const [l2csBoostAvailable, setL2csBoostAvailable] = useState(false);
|
| 55 |
|
| 56 |
const localVideoRef = useRef(null);
|
| 57 |
const displayCanvasRef = useRef(null);
|
|
|
|
| 131 |
setFocusState(FOCUS_STATES.pending);
|
| 132 |
setCameraReady(false);
|
| 133 |
if (originalOnSessionEnd) originalOnSessionEnd(summary);
|
| 134 |
+
videoManager.callbacks.onCalibrationUpdate = (cal) => {
|
| 135 |
+
setCalibration(cal && cal.active ? { ...cal } : null);
|
| 136 |
};
|
| 137 |
|
| 138 |
const statsInterval = setInterval(() => {
|
|
|
|
| 142 |
}, 1000);
|
| 143 |
|
| 144 |
return () => {
|
| 145 |
+
if (videoManager) {
|
| 146 |
+
videoManager.callbacks.onStatusUpdate = originalOnStatusUpdate;
|
| 147 |
+
videoManager.callbacks.onCalibrationUpdate = null;
|
| 148 |
+
}
|
| 149 |
clearInterval(statsInterval);
|
| 150 |
};
|
| 151 |
}, [videoManager]);
|
|
|
|
| 157 |
.then((data) => {
|
| 158 |
if (data.available) setAvailableModels(data.available);
|
| 159 |
if (data.current) setCurrentModel(data.current);
|
| 160 |
+
if (data.l2cs_boost !== undefined) setL2csBoost(data.l2cs_boost);
|
| 161 |
+
if (data.l2cs_boost_available !== undefined) setL2csBoostAvailable(data.l2cs_boost_available);
|
| 162 |
})
|
| 163 |
.catch((err) => console.error('Failed to fetch models:', err));
|
| 164 |
}, []);
|
|
|
|
| 214 |
const result = await res.json();
|
| 215 |
if (result.updated) {
|
| 216 |
setCurrentModel(modelName);
|
| 217 |
+
setL2csBoostAvailable(modelName !== 'l2cs' && availableModels.includes('l2cs'));
|
| 218 |
+
if (modelName === 'l2cs') setL2csBoost(false);
|
| 219 |
}
|
| 220 |
} catch (err) {
|
| 221 |
console.error('Failed to switch model:', err);
|
|
|
|
| 237 |
console.error('Camera init error:', err);
|
| 238 |
}
|
| 239 |
};
|
| 240 |
+
|
| 241 |
+
const handleBoostToggle = async () => {
|
| 242 |
+
const next = !l2csBoost;
|
| 243 |
+
try {
|
| 244 |
+
const res = await fetch('/api/settings', {
|
| 245 |
+
method: 'PUT',
|
| 246 |
+
headers: { 'Content-Type': 'application/json' },
|
| 247 |
+
body: JSON.stringify({ l2cs_boost: next })
|
| 248 |
+
});
|
| 249 |
+
if (res.ok) setL2csBoost(next);
|
| 250 |
+
} catch (err) {
|
| 251 |
+
console.error('Failed to toggle L2CS boost:', err);
|
| 252 |
+
}
|
| 253 |
+
};
|
| 254 |
+
|
| 255 |
const handleStart = async () => {
|
| 256 |
try {
|
| 257 |
setIsStarting(true);
|
|
|
|
| 724 |
}}>
|
| 725 |
<span title="Server CPU">CPU: <strong style={{ color: '#8f8' }}>{systemStats.cpu_percent}%</strong></span>
|
| 726 |
<span title="Server memory">RAM: <strong style={{ color: '#8af' }}>{systemStats.memory_percent}%</strong> ({systemStats.memory_used_mb}/{systemStats.memory_total_mb} MB)</span>
|
| 727 |
+
<span style={{ color: '#aaa', fontSize: '13px', marginRight: '4px' }}>Model:</span>
|
| 728 |
+
{availableModels.map(name => (
|
| 729 |
+
<button
|
| 730 |
+
key={name}
|
| 731 |
+
onClick={() => handleModelChange(name)}
|
| 732 |
+
style={{
|
| 733 |
+
padding: '5px 14px',
|
| 734 |
+
borderRadius: '16px',
|
| 735 |
+
border: currentModel === name ? '2px solid #007BFF' : '1px solid #555',
|
| 736 |
+
background: currentModel === name ? '#007BFF' : 'transparent',
|
| 737 |
+
color: currentModel === name ? '#fff' : '#ccc',
|
| 738 |
+
fontSize: '12px',
|
| 739 |
+
fontWeight: currentModel === name ? 'bold' : 'normal',
|
| 740 |
+
cursor: 'pointer',
|
| 741 |
+
textTransform: 'uppercase',
|
| 742 |
+
transition: 'all 0.2s'
|
| 743 |
+
}}
|
| 744 |
+
>
|
| 745 |
+
{name}
|
| 746 |
+
</button>
|
| 747 |
+
))}
|
| 748 |
+
{l2csBoostAvailable && currentModel !== 'l2cs' && (
|
| 749 |
+
<button
|
| 750 |
+
onClick={handleBoostToggle}
|
| 751 |
+
style={{
|
| 752 |
+
padding: '5px 14px',
|
| 753 |
+
borderRadius: '16px',
|
| 754 |
+
border: l2csBoost ? '2px solid #f59e0b' : '1px solid #555',
|
| 755 |
+
background: l2csBoost ? 'rgba(245, 158, 11, 0.15)' : 'transparent',
|
| 756 |
+
color: l2csBoost ? '#f59e0b' : '#888',
|
| 757 |
+
fontSize: '11px',
|
| 758 |
+
fontWeight: l2csBoost ? 'bold' : 'normal',
|
| 759 |
+
cursor: 'pointer',
|
| 760 |
+
transition: 'all 0.2s',
|
| 761 |
+
marginLeft: '4px',
|
| 762 |
+
}}
|
| 763 |
+
>
|
| 764 |
+
{l2csBoost ? 'GAZE ON' : 'GAZE'}
|
| 765 |
+
</button>
|
| 766 |
+
)}
|
| 767 |
+
{(currentModel === 'l2cs' || l2csBoost) && stats && stats.isStreaming && (
|
| 768 |
+
<button
|
| 769 |
+
onClick={() => videoManager && videoManager.startCalibration()}
|
| 770 |
+
style={{
|
| 771 |
+
padding: '5px 14px',
|
| 772 |
+
borderRadius: '16px',
|
| 773 |
+
border: '1px solid #4ade80',
|
| 774 |
+
background: 'transparent',
|
| 775 |
+
color: '#4ade80',
|
| 776 |
+
fontSize: '12px',
|
| 777 |
+
fontWeight: 'bold',
|
| 778 |
+
cursor: 'pointer',
|
| 779 |
+
transition: 'all 0.2s',
|
| 780 |
+
marginLeft: '4px',
|
| 781 |
+
}}
|
| 782 |
+
>
|
| 783 |
+
Calibrate
|
| 784 |
+
</button>
|
| 785 |
+
)}
|
| 786 |
</section>
|
| 787 |
)}
|
| 788 |
|
|
|
|
| 873 |
</section>
|
| 874 |
</>
|
| 875 |
) : null}
|
| 876 |
+
))}
|
| 877 |
+
</div>
|
| 878 |
+
<div id="timeline-line"></div>
|
| 879 |
+
</section>
|
| 880 |
+
|
| 881 |
+
{/* 4. Control Buttons */}
|
| 882 |
+
<section id="control-panel">
|
| 883 |
+
<button id="btn-cam-start" className="action-btn green" onClick={handleStart}>
|
| 884 |
+
Start
|
| 885 |
+
</button>
|
| 886 |
+
|
| 887 |
+
<button id="btn-floating" className="action-btn yellow" onClick={handleFloatingWindow}>
|
| 888 |
+
Floating Window
|
| 889 |
+
</button>
|
| 890 |
+
|
| 891 |
+
<button
|
| 892 |
+
id="btn-preview"
|
| 893 |
+
className="action-btn"
|
| 894 |
+
style={{ backgroundColor: '#6c5ce7' }}
|
| 895 |
+
onClick={handlePreview}
|
| 896 |
+
>
|
| 897 |
+
Preview Result
|
| 898 |
+
</button>
|
| 899 |
+
|
| 900 |
+
<button id="btn-cam-stop" className="action-btn red" onClick={handleStop}>
|
| 901 |
+
Stop
|
| 902 |
+
</button>
|
| 903 |
+
</section>
|
| 904 |
+
|
| 905 |
+
{/* 5. Frame Control */}
|
| 906 |
+
<section id="frame-control">
|
| 907 |
+
<label htmlFor="frame-slider">Frame Rate (FPS)</label>
|
| 908 |
+
<input
|
| 909 |
+
type="range"
|
| 910 |
+
id="frame-slider"
|
| 911 |
+
min="10"
|
| 912 |
+
max="30"
|
| 913 |
+
value={currentFrame}
|
| 914 |
+
onChange={(e) => handleFrameChange(e.target.value)}
|
| 915 |
+
/>
|
| 916 |
+
<input
|
| 917 |
+
type="number"
|
| 918 |
+
id="frame-input"
|
| 919 |
+
min="10"
|
| 920 |
+
max="30"
|
| 921 |
+
value={currentFrame}
|
| 922 |
+
onChange={(e) => handleFrameChange(e.target.value)}
|
| 923 |
+
/>
|
| 924 |
+
</section>
|
| 925 |
+
|
| 926 |
+
{/* Calibration overlay (fixed fullscreen, must be outside overflow:hidden containers) */}
|
| 927 |
+
<CalibrationOverlay calibration={calibration} videoManager={videoManager} />
|
| 928 |
</main>
|
| 929 |
);
|
| 930 |
}
|
src/utils/VideoManagerLocal.js
CHANGED
|
@@ -40,6 +40,17 @@ export class VideoManagerLocal {
|
|
| 40 |
this.lastNotificationTime = null;
|
| 41 |
this.notificationCooldown = 60000;
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
// Performance metrics
|
| 44 |
this.stats = {
|
| 45 |
framesSent: 0,
|
|
@@ -74,8 +85,8 @@ export class VideoManagerLocal {
|
|
| 74 |
|
| 75 |
// Create a smaller capture canvas for faster encoding and transfer.
|
| 76 |
this.canvas = document.createElement('canvas');
|
| 77 |
-
this.canvas.width =
|
| 78 |
-
this.canvas.height =
|
| 79 |
|
| 80 |
console.log('Local camera initialized');
|
| 81 |
return true;
|
|
@@ -247,7 +258,7 @@ export class VideoManagerLocal {
|
|
| 247 |
this.ws.send(blob);
|
| 248 |
this.stats.framesSent++;
|
| 249 |
}
|
| 250 |
-
}, 'image/jpeg', 0.
|
| 251 |
} catch (error) {
|
| 252 |
this._sendingBlob = false;
|
| 253 |
console.error('Capture error:', error);
|
|
@@ -312,6 +323,19 @@ export class VideoManagerLocal {
|
|
| 312 |
ctx.textAlign = 'left';
|
| 313 |
}
|
| 314 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
// Performance stats
|
| 316 |
ctx.fillStyle = 'rgba(0,0,0,0.5)';
|
| 317 |
ctx.fillRect(0, h - 25, w, 25);
|
|
@@ -380,6 +404,9 @@ export class VideoManagerLocal {
|
|
| 380 |
mar: data.mar,
|
| 381 |
sf: data.sf,
|
| 382 |
se: data.se,
|
|
|
|
|
|
|
|
|
|
| 383 |
};
|
| 384 |
this.drawDetectionResult(detectionData);
|
| 385 |
break;
|
|
@@ -397,6 +424,51 @@ export class VideoManagerLocal {
|
|
| 397 |
this.sessionStartTime = null;
|
| 398 |
break;
|
| 399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
case 'error':
|
| 401 |
console.error('Server error:', data.message);
|
| 402 |
break;
|
|
@@ -406,6 +478,28 @@ export class VideoManagerLocal {
|
|
| 406 |
}
|
| 407 |
}
|
| 408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
// Face mesh landmark index groups (matches live_demo.py)
|
| 410 |
static FACE_OVAL = [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109,10];
|
| 411 |
static LEFT_EYE = [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246];
|
|
|
|
| 40 |
this.lastNotificationTime = null;
|
| 41 |
this.notificationCooldown = 60000;
|
| 42 |
|
| 43 |
+
// Calibration state
|
| 44 |
+
this.calibration = {
|
| 45 |
+
active: false,
|
| 46 |
+
collecting: false,
|
| 47 |
+
target: null,
|
| 48 |
+
index: 0,
|
| 49 |
+
numPoints: 0,
|
| 50 |
+
done: false,
|
| 51 |
+
success: false,
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
// Performance metrics
|
| 55 |
this.stats = {
|
| 56 |
framesSent: 0,
|
|
|
|
| 85 |
|
| 86 |
// Create a smaller capture canvas for faster encoding and transfer.
|
| 87 |
this.canvas = document.createElement('canvas');
|
| 88 |
+
this.canvas.width = 640;
|
| 89 |
+
this.canvas.height = 480;
|
| 90 |
|
| 91 |
console.log('Local camera initialized');
|
| 92 |
return true;
|
|
|
|
| 258 |
this.ws.send(blob);
|
| 259 |
this.stats.framesSent++;
|
| 260 |
}
|
| 261 |
+
}, 'image/jpeg', 0.75);
|
| 262 |
} catch (error) {
|
| 263 |
this._sendingBlob = false;
|
| 264 |
console.error('Capture error:', error);
|
|
|
|
| 323 |
ctx.textAlign = 'left';
|
| 324 |
}
|
| 325 |
}
|
| 326 |
+
// Gaze pointer (L2CS + calibration)
|
| 327 |
+
if (data && data.gaze_x !== undefined && data.gaze_y !== undefined) {
|
| 328 |
+
const gx = data.gaze_x * w;
|
| 329 |
+
const gy = data.gaze_y * h;
|
| 330 |
+
ctx.beginPath();
|
| 331 |
+
ctx.arc(gx, gy, 8, 0, 2 * Math.PI);
|
| 332 |
+
ctx.fillStyle = data.on_screen ? 'rgba(0, 200, 255, 0.7)' : 'rgba(255, 80, 80, 0.5)';
|
| 333 |
+
ctx.fill();
|
| 334 |
+
ctx.strokeStyle = '#FFFFFF';
|
| 335 |
+
ctx.lineWidth = 2;
|
| 336 |
+
ctx.stroke();
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
// Performance stats
|
| 340 |
ctx.fillStyle = 'rgba(0,0,0,0.5)';
|
| 341 |
ctx.fillRect(0, h - 25, w, 25);
|
|
|
|
| 404 |
mar: data.mar,
|
| 405 |
sf: data.sf,
|
| 406 |
se: data.se,
|
| 407 |
+
gaze_x: data.gaze_x,
|
| 408 |
+
gaze_y: data.gaze_y,
|
| 409 |
+
on_screen: data.on_screen,
|
| 410 |
};
|
| 411 |
this.drawDetectionResult(detectionData);
|
| 412 |
break;
|
|
|
|
| 424 |
this.sessionStartTime = null;
|
| 425 |
break;
|
| 426 |
|
| 427 |
+
case 'calibration_started':
|
| 428 |
+
this.calibration = {
|
| 429 |
+
active: true,
|
| 430 |
+
collecting: true,
|
| 431 |
+
target: data.target,
|
| 432 |
+
index: data.index,
|
| 433 |
+
numPoints: data.num_points,
|
| 434 |
+
done: false,
|
| 435 |
+
success: false,
|
| 436 |
+
};
|
| 437 |
+
if (this.callbacks.onCalibrationUpdate) {
|
| 438 |
+
this.callbacks.onCalibrationUpdate({ ...this.calibration });
|
| 439 |
+
}
|
| 440 |
+
break;
|
| 441 |
+
|
| 442 |
+
case 'calibration_point':
|
| 443 |
+
this.calibration.target = data.target;
|
| 444 |
+
this.calibration.index = data.index;
|
| 445 |
+
if (this.callbacks.onCalibrationUpdate) {
|
| 446 |
+
this.callbacks.onCalibrationUpdate({ ...this.calibration });
|
| 447 |
+
}
|
| 448 |
+
break;
|
| 449 |
+
|
| 450 |
+
case 'calibration_done':
|
| 451 |
+
this.calibration.collecting = false;
|
| 452 |
+
this.calibration.done = true;
|
| 453 |
+
this.calibration.success = data.success;
|
| 454 |
+
if (this.callbacks.onCalibrationUpdate) {
|
| 455 |
+
this.callbacks.onCalibrationUpdate({ ...this.calibration });
|
| 456 |
+
}
|
| 457 |
+
setTimeout(() => {
|
| 458 |
+
this.calibration.active = false;
|
| 459 |
+
if (this.callbacks.onCalibrationUpdate) {
|
| 460 |
+
this.callbacks.onCalibrationUpdate({ ...this.calibration });
|
| 461 |
+
}
|
| 462 |
+
}, 2000);
|
| 463 |
+
break;
|
| 464 |
+
|
| 465 |
+
case 'calibration_cancelled':
|
| 466 |
+
this.calibration = { active: false, collecting: false, target: null, index: 0, numPoints: 0, done: false, success: false };
|
| 467 |
+
if (this.callbacks.onCalibrationUpdate) {
|
| 468 |
+
this.callbacks.onCalibrationUpdate({ ...this.calibration });
|
| 469 |
+
}
|
| 470 |
+
break;
|
| 471 |
+
|
| 472 |
case 'error':
|
| 473 |
console.error('Server error:', data.message);
|
| 474 |
break;
|
|
|
|
| 478 |
}
|
| 479 |
}
|
| 480 |
|
| 481 |
+
startCalibration() {
|
| 482 |
+
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
|
| 483 |
+
this.ws.send(JSON.stringify({ type: 'calibration_start' }));
|
| 484 |
+
}
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
nextCalibrationPoint() {
|
| 488 |
+
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
|
| 489 |
+
this.ws.send(JSON.stringify({ type: 'calibration_next' }));
|
| 490 |
+
}
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
cancelCalibration() {
|
| 494 |
+
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
|
| 495 |
+
this.ws.send(JSON.stringify({ type: 'calibration_cancel' }));
|
| 496 |
+
}
|
| 497 |
+
this.calibration = { active: false, collecting: false, target: null, index: 0, numPoints: 0, done: false, success: false };
|
| 498 |
+
if (this.callbacks.onCalibrationUpdate) {
|
| 499 |
+
this.callbacks.onCalibrationUpdate({ ...this.calibration });
|
| 500 |
+
}
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
// Face mesh landmark index groups (matches live_demo.py)
|
| 504 |
static FACE_OVAL = [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109,10];
|
| 505 |
static LEFT_EYE = [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246];
|
ui/pipeline.py
CHANGED
|
@@ -5,6 +5,7 @@ import glob
|
|
| 5 |
import json
|
| 6 |
import math
|
| 7 |
import os
|
|
|
|
| 8 |
import sys
|
| 9 |
|
| 10 |
import numpy as np
|
|
@@ -54,8 +55,12 @@ def _clip_features(vec):
|
|
| 54 |
|
| 55 |
|
| 56 |
class _OutputSmoother:
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
self._grace = grace_frames
|
| 60 |
self._score = 0.5
|
| 61 |
self._no_face = 0
|
|
@@ -64,14 +69,15 @@ class _OutputSmoother:
|
|
| 64 |
self._score = 0.5
|
| 65 |
self._no_face = 0
|
| 66 |
|
| 67 |
-
def update(self, raw_score
|
| 68 |
if face_detected:
|
| 69 |
self._no_face = 0
|
| 70 |
-
|
|
|
|
| 71 |
else:
|
| 72 |
self._no_face += 1
|
| 73 |
if self._no_face > self._grace:
|
| 74 |
-
self._score *= 0.
|
| 75 |
return self._score
|
| 76 |
|
| 77 |
|
|
@@ -645,3 +651,141 @@ class XGBoostPipeline:
|
|
| 645 |
|
| 646 |
def __exit__(self, *args):
|
| 647 |
self.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import json
|
| 6 |
import math
|
| 7 |
import os
|
| 8 |
+
import pathlib
|
| 9 |
import sys
|
| 10 |
|
| 11 |
import numpy as np
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
class _OutputSmoother:
|
| 58 |
+
# Asymmetric EMA: rises fast (recognise focus), falls slower (avoid flicker).
|
| 59 |
+
# Grace period holds score steady for a few frames when face is lost.
|
| 60 |
+
|
| 61 |
+
def __init__(self, alpha_up=0.55, alpha_down=0.45, grace_frames=10):
|
| 62 |
+
self._alpha_up = alpha_up
|
| 63 |
+
self._alpha_down = alpha_down
|
| 64 |
self._grace = grace_frames
|
| 65 |
self._score = 0.5
|
| 66 |
self._no_face = 0
|
|
|
|
| 69 |
self._score = 0.5
|
| 70 |
self._no_face = 0
|
| 71 |
|
| 72 |
+
def update(self, raw_score, face_detected):
|
| 73 |
if face_detected:
|
| 74 |
self._no_face = 0
|
| 75 |
+
alpha = self._alpha_up if raw_score > self._score else self._alpha_down
|
| 76 |
+
self._score += alpha * (raw_score - self._score)
|
| 77 |
else:
|
| 78 |
self._no_face += 1
|
| 79 |
if self._no_face > self._grace:
|
| 80 |
+
self._score *= 0.80
|
| 81 |
return self._score
|
| 82 |
|
| 83 |
|
|
|
|
| 651 |
|
| 652 |
def __exit__(self, *args):
|
| 653 |
self.close()
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
def _resolve_l2cs_weights():
|
| 657 |
+
for p in [
|
| 658 |
+
os.path.join(_PROJECT_ROOT, "models", "L2CS-Net", "models", "L2CSNet_gaze360.pkl"),
|
| 659 |
+
os.path.join(_PROJECT_ROOT, "models", "L2CSNet_gaze360.pkl"),
|
| 660 |
+
os.path.join(_PROJECT_ROOT, "checkpoints", "L2CSNet_gaze360.pkl"),
|
| 661 |
+
]:
|
| 662 |
+
if os.path.isfile(p):
|
| 663 |
+
return p
|
| 664 |
+
return None
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def is_l2cs_weights_available():
|
| 668 |
+
return _resolve_l2cs_weights() is not None
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
class L2CSPipeline:
|
| 672 |
+
# Uses in-tree l2cs.Pipeline (RetinaFace + ResNet50) for gaze estimation
|
| 673 |
+
# and MediaPipe for head pose, EAR, MAR, and roll de-rotation.
|
| 674 |
+
|
| 675 |
+
YAW_THRESHOLD = 22.0
|
| 676 |
+
PITCH_THRESHOLD = 20.0
|
| 677 |
+
|
| 678 |
+
def __init__(self, weights_path=None, arch="ResNet50", device="cpu",
|
| 679 |
+
threshold=0.52, detector=None):
|
| 680 |
+
resolved = weights_path or _resolve_l2cs_weights()
|
| 681 |
+
if resolved is None or not os.path.isfile(resolved):
|
| 682 |
+
raise FileNotFoundError(
|
| 683 |
+
"L2CS weights not found. Place L2CSNet_gaze360.pkl in "
|
| 684 |
+
"models/L2CS-Net/models/ or checkpoints/"
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
# add in-tree L2CS-Net to import path
|
| 688 |
+
l2cs_root = os.path.join(_PROJECT_ROOT, "models", "L2CS-Net")
|
| 689 |
+
if l2cs_root not in sys.path:
|
| 690 |
+
sys.path.insert(0, l2cs_root)
|
| 691 |
+
from l2cs import Pipeline as _L2CSPipeline
|
| 692 |
+
|
| 693 |
+
import torch
|
| 694 |
+
# bypass upstream select_device bug by constructing torch.device directly
|
| 695 |
+
self._pipeline = _L2CSPipeline(
|
| 696 |
+
weights=pathlib.Path(resolved), arch=arch, device=torch.device(device),
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
self._detector = detector or FaceMeshDetector()
|
| 700 |
+
self._owns_detector = detector is None
|
| 701 |
+
self._head_pose = HeadPoseEstimator()
|
| 702 |
+
self.head_pose = self._head_pose
|
| 703 |
+
self._eye_scorer = EyeBehaviourScorer()
|
| 704 |
+
self._threshold = threshold
|
| 705 |
+
self._smoother = _OutputSmoother()
|
| 706 |
+
|
| 707 |
+
print(
|
| 708 |
+
f"[L2CS] Loaded {resolved} | arch={arch} device={device} "
|
| 709 |
+
f"yaw_thresh={self.YAW_THRESHOLD} pitch_thresh={self.PITCH_THRESHOLD} "
|
| 710 |
+
f"threshold={threshold}"
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
@staticmethod
|
| 714 |
+
def _derotate_gaze(pitch_rad, yaw_rad, roll_deg):
|
| 715 |
+
# remove head roll so tilted-but-looking-at-screen reads as (0,0)
|
| 716 |
+
roll_rad = -math.radians(roll_deg)
|
| 717 |
+
cos_r, sin_r = math.cos(roll_rad), math.sin(roll_rad)
|
| 718 |
+
return (yaw_rad * sin_r + pitch_rad * cos_r,
|
| 719 |
+
yaw_rad * cos_r - pitch_rad * sin_r)
|
| 720 |
+
|
| 721 |
+
def process_frame(self, bgr_frame):
|
| 722 |
+
landmarks = self._detector.process(bgr_frame)
|
| 723 |
+
h, w = bgr_frame.shape[:2]
|
| 724 |
+
|
| 725 |
+
out = {
|
| 726 |
+
"landmarks": landmarks, "is_focused": False, "raw_score": 0.0,
|
| 727 |
+
"s_face": 0.0, "s_eye": 0.0, "gaze_pitch": None, "gaze_yaw": None,
|
| 728 |
+
"yaw": None, "pitch": None, "roll": None, "mar": None, "is_yawning": False,
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
# MediaPipe: head pose, eye/mouth scores
|
| 732 |
+
roll_deg = 0.0
|
| 733 |
+
if landmarks is not None:
|
| 734 |
+
angles = self._head_pose.estimate(landmarks, w, h)
|
| 735 |
+
if angles is not None:
|
| 736 |
+
out["yaw"], out["pitch"], out["roll"] = angles
|
| 737 |
+
roll_deg = angles[2]
|
| 738 |
+
out["s_face"] = self._head_pose.score(landmarks, w, h)
|
| 739 |
+
out["s_eye"] = self._eye_scorer.score(landmarks)
|
| 740 |
+
out["mar"] = compute_mar(landmarks)
|
| 741 |
+
out["is_yawning"] = out["mar"] > MAR_YAWN_THRESHOLD
|
| 742 |
+
|
| 743 |
+
# L2CS gaze (uses its own RetinaFace detector internally)
|
| 744 |
+
results = self._pipeline.step(bgr_frame)
|
| 745 |
+
|
| 746 |
+
if results is None or results.pitch.shape[0] == 0:
|
| 747 |
+
smoothed = self._smoother.update(0.0, landmarks is not None)
|
| 748 |
+
out["raw_score"] = smoothed
|
| 749 |
+
out["is_focused"] = smoothed >= self._threshold
|
| 750 |
+
return out
|
| 751 |
+
|
| 752 |
+
pitch_rad = float(results.pitch[0])
|
| 753 |
+
yaw_rad = float(results.yaw[0])
|
| 754 |
+
|
| 755 |
+
pitch_rad, yaw_rad = self._derotate_gaze(pitch_rad, yaw_rad, roll_deg)
|
| 756 |
+
out["gaze_pitch"] = pitch_rad
|
| 757 |
+
out["gaze_yaw"] = yaw_rad
|
| 758 |
+
|
| 759 |
+
yaw_deg = abs(math.degrees(yaw_rad))
|
| 760 |
+
pitch_deg = abs(math.degrees(pitch_rad))
|
| 761 |
+
|
| 762 |
+
# fall back to L2CS angles if MediaPipe didn't produce head pose
|
| 763 |
+
out["yaw"] = out.get("yaw") or math.degrees(yaw_rad)
|
| 764 |
+
out["pitch"] = out.get("pitch") or math.degrees(pitch_rad)
|
| 765 |
+
|
| 766 |
+
# cosine scoring: 1.0 at centre, 0.0 at threshold
|
| 767 |
+
yaw_t = min(yaw_deg / self.YAW_THRESHOLD, 1.0)
|
| 768 |
+
pitch_t = min(pitch_deg / self.PITCH_THRESHOLD, 1.0)
|
| 769 |
+
yaw_score = 0.5 * (1.0 + math.cos(math.pi * yaw_t))
|
| 770 |
+
pitch_score = 0.5 * (1.0 + math.cos(math.pi * pitch_t))
|
| 771 |
+
gaze_score = 0.55 * yaw_score + 0.45 * pitch_score
|
| 772 |
+
|
| 773 |
+
if out["is_yawning"]:
|
| 774 |
+
gaze_score = 0.0
|
| 775 |
+
|
| 776 |
+
out["raw_score"] = self._smoother.update(float(gaze_score), True)
|
| 777 |
+
out["is_focused"] = out["raw_score"] >= self._threshold
|
| 778 |
+
return out
|
| 779 |
+
|
| 780 |
+
def reset_session(self):
|
| 781 |
+
self._smoother.reset()
|
| 782 |
+
|
| 783 |
+
def close(self):
|
| 784 |
+
if self._owns_detector:
|
| 785 |
+
self._detector.close()
|
| 786 |
+
|
| 787 |
+
def __enter__(self):
|
| 788 |
+
return self
|
| 789 |
+
|
| 790 |
+
def __exit__(self, *args):
|
| 791 |
+
self.close()
|