Abdelrahman Almatrooshi commited on
Commit
bb2a2db
·
1 Parent(s): 6d9eb2d

Integrate L2CS-Net gaze estimation

Browse files

- Add L2CS-Net in-tree (models/L2CS-Net/) with Gaze360 weights via Git LFS
- L2CSPipeline: ResNet50 gaze + MediaPipe head pose, roll de-rotation, cosine scoring
- 9-point polynomial gaze calibration with bias correction and IQR outlier filtering
- Gaze-eye fusion: calibrated screen coords + EAR for focus detection
- L2CS Boost mode: runs gaze alongside any base model (35/65 weight, veto at 0.38)
- Calibration UI: fullscreen overlay, auto-advance, progress ring
- Frontend: GAZE toggle, Calibrate button, gaze pointer dot on canvas
- Bumped capture resolution to 640x480 @ JPEG 0.75
- Dockerfile: added git, CPU-only torch for HF Space deployment

Dockerfile CHANGED
@@ -7,7 +7,14 @@ ENV PYTHONUNBUFFERED=1
7
 
8
  WORKDIR /app
9
 
10
- RUN apt-get update && apt-get install -y --no-install-recommends libglib2.0-0 libsm6 libxrender1 libxext6 libxcb1 libgl1 libgomp1 ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev libavdevice-dev libopus-dev libvpx-dev libsrtp2-dev build-essential nodejs npm && rm -rf /var/lib/apt/lists/*
 
 
 
 
 
 
 
11
 
12
  COPY requirements.txt ./
13
  RUN pip install --no-cache-dir -r requirements.txt
 
7
 
8
  WORKDIR /app
9
 
10
+ RUN apt-get update && apt-get install -y --no-install-recommends \
11
+ libglib2.0-0 libsm6 libxrender1 libxext6 libxcb1 libgl1 libgomp1 \
12
+ ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswscale-dev \
13
+ libavdevice-dev libopus-dev libvpx-dev libsrtp2-dev \
14
+ build-essential nodejs npm git \
15
+ && rm -rf /var/lib/apt/lists/*
16
+
17
+ RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
18
 
19
  COPY requirements.txt ./
20
  RUN pip install --no-cache-dir -r requirements.txt
README.md CHANGED
@@ -1,6 +1,6 @@
1
  # FocusGuard
2
 
3
- Webcam-based focus detection: MediaPipe face mesh 17 features (EAR, gaze, head pose, PERCLOS, etc.) MLP or XGBoost for focused/unfocused. React + FastAPI app with WebSocket video.
4
 
5
  ## Project layout
6
 
@@ -9,10 +9,18 @@ Webcam-based focus detection: MediaPipe face mesh → 17 features (EAR, gaze, he
9
  ├── data_preparation/ loaders, split, scale
10
  ├── notebooks/ MLP/XGB training + LOPO
11
  ├── models/ face_mesh, head_pose, eye_scorer, train scripts
 
 
 
12
  ├── checkpoints/ mlp_best.pt, xgboost_*_best.json, scalers
13
  ├── evaluation/ logs, plots, justify_thresholds
14
  ├── ui/ pipeline.py, live_demo.py
15
  ├── src/ React frontend
 
 
 
 
 
16
  ├── static/ built frontend (after npm run build)
17
  ├── main.py, app.py FastAPI backend
18
  ├── requirements.txt
@@ -70,19 +78,50 @@ python -m models.xgboost.train
70
 
71
  9 participants, 144,793 samples, 10 features, binary labels. Collect with `python -m models.collect_features --name <name>`. Data lives in `data/collected_<name>/`.
72
 
 
 
 
 
 
 
 
 
 
 
73
  ## Model numbers (15% test split)
74
 
75
  | Model | Accuracy | F1 | ROC-AUC |
76
  |-------|----------|-----|---------|
77
  | XGBoost (600 trees, depth 8) | 95.87% | 0.959 | 0.991 |
78
- | MLP (6432) | 92.92% | 0.929 | 0.971 |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  ## Pipeline
81
 
82
  1. Face mesh (MediaPipe 478 pts)
83
- 2. Head pose yaw, pitch, roll, scores, gaze offset
84
- 3. Eye scorer EAR, gaze ratio, MAR
85
- 4. Temporal PERCLOS, blink rate, yawn
86
- 5. 10-d vector MLP or XGBoost focused / unfocused
87
 
88
- **Stack:** FastAPI, aiosqlite, React/Vite, PyTorch, XGBoost, MediaPipe, OpenCV.
 
1
  # FocusGuard
2
 
3
+ Webcam-based focus detection: MediaPipe face mesh -> 17 features (EAR, gaze, head pose, PERCLOS, etc.) -> MLP or XGBoost for focused/unfocused. React + FastAPI app with WebSocket video.
4
 
5
  ## Project layout
6
 
 
9
  ├── data_preparation/ loaders, split, scale
10
  ├── notebooks/ MLP/XGB training + LOPO
11
  ├── models/ face_mesh, head_pose, eye_scorer, train scripts
12
+ │ ├── gaze_calibration.py 9-point polynomial gaze calibration
13
+ │ ├── gaze_eye_fusion.py Fuses calibrated gaze with eye openness
14
+ │ └── L2CS-Net/ In-tree L2CS-Net repo with Gaze360 weights
15
  ├── checkpoints/ mlp_best.pt, xgboost_*_best.json, scalers
16
  ├── evaluation/ logs, plots, justify_thresholds
17
  ├── ui/ pipeline.py, live_demo.py
18
  ├── src/ React frontend
19
+ │ ├── components/
20
+ │ │ ├── FocusPageLocal.jsx Main focus page (camera, controls, model selector)
21
+ │ │ └── CalibrationOverlay.jsx Fullscreen calibration UI
22
+ │ └── utils/
23
+ │ └── VideoManagerLocal.js WebSocket client, frame capture, canvas rendering
24
  ├── static/ built frontend (after npm run build)
25
  ├── main.py, app.py FastAPI backend
26
  ├── requirements.txt
 
78
 
79
  9 participants, 144,793 samples, 10 features, binary labels. Collect with `python -m models.collect_features --name <name>`. Data lives in `data/collected_<name>/`.
80
 
81
+ ## Models
82
+
83
+ | Model | What it uses | Best for |
84
+ |-------|-------------|----------|
85
+ | **Geometric** | Head pose angles + eye aspect ratio (EAR) | Fast, no ML needed |
86
+ | **XGBoost** | Trained classifier on head/eye features (600 trees, depth 8) | Balanced accuracy/speed |
87
+ | **MLP** | Neural network on same features (64->32) | Higher accuracy |
88
+ | **Hybrid** | Weighted MLP + Geometric ensemble | Best head-pose accuracy |
89
+ | **L2CS** | Deep gaze estimation (ResNet50, Gaze360 weights) | Detects eye-only gaze shifts |
90
+
91
  ## Model numbers (15% test split)
92
 
93
  | Model | Accuracy | F1 | ROC-AUC |
94
  |-------|----------|-----|---------|
95
  | XGBoost (600 trees, depth 8) | 95.87% | 0.959 | 0.991 |
96
+ | MLP (64->32) | 92.92% | 0.929 | 0.971 |
97
+
98
+ ## L2CS Gaze Tracking
99
+
100
+ L2CS-Net predicts where your eyes are looking, not just where your head is pointed. This catches the scenario where your head faces the screen but your eyes wander.
101
+
102
+ ### Standalone mode
103
+ Select **L2CS** as the model - it handles everything.
104
+
105
+ ### Boost mode
106
+ Select any other model, then click the **GAZE** toggle. L2CS runs alongside the base model:
107
+ - Base model handles head pose and eye openness (35% weight)
108
+ - L2CS handles gaze direction (65% weight)
109
+ - If L2CS detects gaze is clearly off-screen, it **vetoes** the base model regardless of score
110
+
111
+ ### Calibration
112
+ After enabling L2CS or Gaze Boost, click **Calibrate** while a session is running:
113
+ 1. A fullscreen overlay shows 9 target dots (3x3 grid)
114
+ 2. Look at each dot as the progress ring fills
115
+ 3. The first dot (centre) sets your baseline gaze offset
116
+ 4. After all 9 points, a polynomial model maps your gaze angles to screen coordinates
117
+ 5. A cyan tracking dot appears on the video showing where you're looking
118
 
119
  ## Pipeline
120
 
121
  1. Face mesh (MediaPipe 478 pts)
122
+ 2. Head pose -> yaw, pitch, roll, scores, gaze offset
123
+ 3. Eye scorer -> EAR, gaze ratio, MAR
124
+ 4. Temporal -> PERCLOS, blink rate, yawn
125
+ 5. 10-d vector -> MLP or XGBoost -> focused / unfocused
126
 
127
+ **Stack:** FastAPI, aiosqlite, React/Vite, PyTorch, XGBoost, MediaPipe, OpenCV, L2CS-Net.
checkpoints/L2CSNet_gaze360.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a7f3480d868dd48261e1d59f915b0ef0bb33ea12ea00938fb2168f212080665
3
+ size 95849977
download_l2cs_weights.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Downloads L2CS-Net Gaze360 weights into checkpoints/
3
+
4
+ import os
5
+ import sys
6
+
7
+ CHECKPOINTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints")
8
+ DEST = os.path.join(CHECKPOINTS_DIR, "L2CSNet_gaze360.pkl")
9
+ GDRIVE_ID = "1dL2Jokb19_SBSHAhKHOxJsmYs5-GoyLo"
10
+
11
+
12
+ def main():
13
+ if os.path.isfile(DEST):
14
+ print(f"[OK] Weights already at {DEST}")
15
+ return
16
+
17
+ try:
18
+ import gdown
19
+ except ImportError:
20
+ print("gdown not installed. Run: pip install gdown")
21
+ sys.exit(1)
22
+
23
+ os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
24
+ print(f"Downloading L2CS-Net weights to {DEST} ...")
25
+ gdown.download(f"https://drive.google.com/uc?id={GDRIVE_ID}", DEST, quiet=False)
26
+
27
+ if os.path.isfile(DEST):
28
+ print(f"[OK] Downloaded ({os.path.getsize(DEST) / 1024 / 1024:.1f} MB)")
29
+ else:
30
+ print("[ERR] Download failed. Manual download:")
31
+ print(" https://drive.google.com/drive/folders/17p6ORr-JQJcw-eYtG2WGNiuS_qVKwdWd")
32
+ print(f" Place L2CSNet_gaze360.pkl in {CHECKPOINTS_DIR}/")
33
+ sys.exit(1)
34
+
35
+
36
+ if __name__ == "__main__":
37
+ main()
main.py CHANGED
@@ -25,7 +25,10 @@ from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack
25
  from av import VideoFrame
26
 
27
  from mediapipe.tasks.python.vision import FaceLandmarksConnections
28
- from ui.pipeline import FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline
 
 
 
29
  from models.face_mesh import FaceMeshDetector
30
 
31
  # ================ FACE MESH DRAWING (server-side, for WebRTC) ================
@@ -212,17 +215,7 @@ app.add_middleware(
212
  db_path = "focus_guard.db"
213
  pcs = set()
214
  _cached_model_name = "mlp"
215
- pipelines = {
216
- "geometric": None,
217
- "mlp": None,
218
- "hybrid": None,
219
- "xgboost": None,
220
- }
221
- _inference_executor = concurrent.futures.ThreadPoolExecutor(
222
- max_workers=4,
223
- thread_name_prefix="inference",
224
- )
225
- _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost")}
226
 
227
  async def _wait_for_ice_gathering(pc: RTCPeerConnection):
228
  if pc.iceGatheringState == "complete":
@@ -302,6 +295,7 @@ class SettingsUpdate(BaseModel):
302
  notification_threshold: Optional[int] = None
303
  frame_rate: Optional[int] = None
304
  model_name: Optional[str] = None
 
305
 
306
  class VideoTransformTrack(VideoStreamTrack):
307
  def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
@@ -329,6 +323,8 @@ class VideoTransformTrack(VideoStreamTrack):
329
  self.last_inference_time = now
330
 
331
  model_name = _cached_model_name
 
 
332
  if model_name not in pipelines or pipelines.get(model_name) is None:
333
  model_name = 'mlp'
334
  active_pipeline = pipelines.get(model_name)
@@ -513,10 +509,56 @@ class _EventBuffer:
513
  except Exception as e:
514
  print(f"[DB] Flush error: {e}")
515
 
516
- def _process_frame_safe(pipeline, frame, model_name: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  with _pipeline_locks[model_name]:
518
  return pipeline.process_frame(frame)
519
 
 
520
  def _first_available_pipeline_name(preferred: str | None = None) -> str | None:
521
  if preferred and preferred in pipelines and pipelines.get(preferred) is not None:
522
  return preferred
@@ -525,6 +567,96 @@ def _first_available_pipeline_name(preferred: str | None = None) -> str | None:
525
  return name
526
  return None
527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
  # ================ WEBRTC SIGNALING ================
529
 
530
  @app.post("/api/webrtc/offer")
@@ -590,14 +722,19 @@ async def webrtc_offer(offer: dict):
590
 
591
  @app.websocket("/ws/video")
592
  async def websocket_endpoint(websocket: WebSocket):
 
 
 
593
  await websocket.accept()
594
  session_id = None
595
  frame_count = 0
596
  running = True
597
  event_buffer = _EventBuffer(flush_interval=2.0)
598
 
599
- # Latest frame slot: keep only the newest frame and drop stale ones.
600
- # Using a dict so nested functions can mutate without nonlocal issues.
 
 
601
  _slot = {"frame": None}
602
  _frame_ready = asyncio.Event()
603
 
@@ -628,7 +765,6 @@ async def websocket_endpoint(websocket: WebSocket):
628
  data = json.loads(text)
629
 
630
  if data["type"] == "frame":
631
- # Legacy base64 path (fallback)
632
  _slot["frame"] = base64.b64decode(data["image"])
633
  _frame_ready.set()
634
 
@@ -647,6 +783,47 @@ async def websocket_endpoint(websocket: WebSocket):
647
  if summary:
648
  await websocket.send_json({"type": "session_ended", "summary": summary})
649
  session_id = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
  except WebSocketDisconnect:
651
  running = False
652
  _frame_ready.set()
@@ -665,7 +842,6 @@ async def websocket_endpoint(websocket: WebSocket):
665
  if not running:
666
  return
667
 
668
- # Grab latest frame and clear slot
669
  raw = _slot["frame"]
670
  _slot["frame"] = None
671
  if raw is None:
@@ -678,36 +854,87 @@ async def websocket_endpoint(websocket: WebSocket):
678
  continue
679
  frame = cv2.resize(frame, (640, 480))
680
 
681
- model_name = _first_available_pipeline_name(_cached_model_name)
682
- active_pipeline = pipelines.get(model_name) if model_name is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
 
684
  landmarks_list = None
 
685
  if active_pipeline is not None:
686
- out = await loop.run_in_executor(
687
- _inference_executor,
688
- _process_frame_safe,
689
- active_pipeline,
690
- frame,
691
- model_name,
692
- )
 
 
 
 
 
 
 
 
 
693
  is_focused = out["is_focused"]
694
  confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
695
 
696
  lm = out.get("landmarks")
697
  if lm is not None:
698
- # Send all 478 landmarks as flat array for tessellation drawing
699
  landmarks_list = [
700
  [round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
701
  for i in range(lm.shape[0])
702
  ]
703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
704
  if session_id:
705
- event_buffer.add(session_id, is_focused, confidence, {
706
  "s_face": out.get("s_face", 0.0),
707
  "s_eye": out.get("s_eye", 0.0),
708
  "mar": out.get("mar", 0.0),
709
  "model": model_name,
710
- })
 
711
  else:
712
  is_focused = False
713
  confidence = 0.0
@@ -721,8 +948,7 @@ async def websocket_endpoint(websocket: WebSocket):
721
  "fc": frame_count,
722
  "frame_count": frame_count,
723
  }
724
- if active_pipeline is not None:
725
- # Send detailed metrics for HUD
726
  if out.get("yaw") is not None:
727
  resp["yaw"] = round(out["yaw"], 1)
728
  resp["pitch"] = round(out["pitch"], 1)
@@ -731,6 +957,24 @@ async def websocket_endpoint(websocket: WebSocket):
731
  resp["mar"] = round(out["mar"], 3)
732
  resp["sf"] = round(out.get("s_face", 0), 3)
733
  resp["se"] = round(out.get("s_eye", 0), 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
  if landmarks_list is not None:
735
  resp["lm"] = landmarks_list
736
  await websocket.send_json(resp)
@@ -863,8 +1107,9 @@ async def get_settings():
863
  db.row_factory = aiosqlite.Row
864
  cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
865
  row = await cursor.fetchone()
866
- if row: return dict(row)
867
- else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
 
868
 
869
  @app.put("/api/settings")
870
  async def update_settings(settings: SettingsUpdate):
@@ -889,12 +1134,28 @@ async def update_settings(settings: SettingsUpdate):
889
  if settings.frame_rate is not None:
890
  updates.append("frame_rate = ?")
891
  params.append(max(5, min(60, settings.frame_rate)))
892
- if settings.model_name is not None and settings.model_name in pipelines and pipelines[settings.model_name] is not None:
 
 
 
 
 
 
 
893
  updates.append("model_name = ?")
894
  params.append(settings.model_name)
895
  global _cached_model_name
896
  _cached_model_name = settings.model_name
897
 
 
 
 
 
 
 
 
 
 
898
  if updates:
899
  query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
900
  await db.execute(query, params)
@@ -946,15 +1207,55 @@ async def get_stats_summary():
946
 
947
  @app.get("/api/models")
948
  async def get_available_models():
949
- """Return list of loaded model names and which is currently active."""
950
- available = [name for name, p in pipelines.items() if p is not None]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
951
  async with aiosqlite.connect(db_path) as db:
952
  cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
953
  row = await cursor.fetchone()
954
  current = row[0] if row else "mlp"
955
  if current not in available and available:
956
  current = available[0]
957
- return {"available": available, "current": current}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
958
 
959
  @app.get("/api/mesh-topology")
960
  async def get_mesh_topology():
 
25
  from av import VideoFrame
26
 
27
  from mediapipe.tasks.python.vision import FaceLandmarksConnections
28
+ from ui.pipeline import (
29
+ FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline,
30
+ L2CSPipeline, is_l2cs_weights_available,
31
+ )
32
  from models.face_mesh import FaceMeshDetector
33
 
34
  # ================ FACE MESH DRAWING (server-side, for WebRTC) ================
 
215
  db_path = "focus_guard.db"
216
  pcs = set()
217
  _cached_model_name = "mlp"
218
+ _l2cs_boost_enabled = False
 
 
 
 
 
 
 
 
 
 
219
 
220
  async def _wait_for_ice_gathering(pc: RTCPeerConnection):
221
  if pc.iceGatheringState == "complete":
 
295
  notification_threshold: Optional[int] = None
296
  frame_rate: Optional[int] = None
297
  model_name: Optional[str] = None
298
+ l2cs_boost: Optional[bool] = None
299
 
300
  class VideoTransformTrack(VideoStreamTrack):
301
  def __init__(self, track, session_id: int, get_channel: Callable[[], Any]):
 
323
  self.last_inference_time = now
324
 
325
  model_name = _cached_model_name
326
+ if model_name == "l2cs" and pipelines.get("l2cs") is None:
327
+ _ensure_l2cs()
328
  if model_name not in pipelines or pipelines.get(model_name) is None:
329
  model_name = 'mlp'
330
  active_pipeline = pipelines.get(model_name)
 
509
  except Exception as e:
510
  print(f"[DB] Flush error: {e}")
511
 
512
+ # ================ STARTUP/SHUTDOWN ================
513
+
514
+ pipelines = {
515
+ "geometric": None,
516
+ "mlp": None,
517
+ "hybrid": None,
518
+ "xgboost": None,
519
+ "l2cs": None,
520
+ }
521
+
522
+ # Thread pool for CPU-bound inference so the event loop stays responsive.
523
+ _inference_executor = concurrent.futures.ThreadPoolExecutor(
524
+ max_workers=4,
525
+ thread_name_prefix="inference",
526
+ )
527
+ # One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when
528
+ # multiple frames are processed in parallel by the thread pool.
529
+ _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost", "l2cs")}
530
+
531
+ _l2cs_load_lock = threading.Lock()
532
+ _l2cs_error: str | None = None
533
+
534
+
535
+ def _ensure_l2cs():
536
+ # lazy-load L2CS on first use, double-checked locking
537
+ global _l2cs_error
538
+ if pipelines["l2cs"] is not None:
539
+ return True
540
+ with _l2cs_load_lock:
541
+ if pipelines["l2cs"] is not None:
542
+ return True
543
+ if not is_l2cs_weights_available():
544
+ _l2cs_error = "Weights not found"
545
+ return False
546
+ try:
547
+ pipelines["l2cs"] = L2CSPipeline()
548
+ _l2cs_error = None
549
+ print("[OK] L2CSPipeline lazy-loaded")
550
+ return True
551
+ except Exception as e:
552
+ _l2cs_error = str(e)
553
+ print(f"[ERR] L2CS lazy-load failed: {e}")
554
+ return False
555
+
556
+
557
+ def _process_frame_safe(pipeline, frame, model_name):
558
  with _pipeline_locks[model_name]:
559
  return pipeline.process_frame(frame)
560
 
561
+
562
  def _first_available_pipeline_name(preferred: str | None = None) -> str | None:
563
  if preferred and preferred in pipelines and pipelines.get(preferred) is not None:
564
  return preferred
 
567
  return name
568
  return None
569
 
570
+
571
+ _BOOST_BASE_W = 0.35
572
+ _BOOST_L2CS_W = 0.65
573
+ _BOOST_VETO = 0.38 # L2CS below this -> forced not-focused
574
+
575
+
576
+ def _process_frame_with_l2cs_boost(base_pipeline, frame, base_model_name):
577
+ # run base model
578
+ with _pipeline_locks[base_model_name]:
579
+ base_out = base_pipeline.process_frame(frame)
580
+
581
+ l2cs_pipe = pipelines.get("l2cs")
582
+ if l2cs_pipe is None:
583
+ base_out["boost_active"] = False
584
+ return base_out
585
+
586
+ # run L2CS
587
+ with _pipeline_locks["l2cs"]:
588
+ l2cs_out = l2cs_pipe.process_frame(frame)
589
+
590
+ base_score = base_out.get("mlp_prob", base_out.get("raw_score", 0.0))
591
+ l2cs_score = l2cs_out.get("raw_score", 0.0)
592
+
593
+ # veto: gaze clearly off-screen overrides base model
594
+ if l2cs_score < _BOOST_VETO:
595
+ fused_score = l2cs_score * 0.8
596
+ is_focused = False
597
+ else:
598
+ fused_score = _BOOST_BASE_W * base_score + _BOOST_L2CS_W * l2cs_score
599
+ is_focused = fused_score >= 0.52
600
+
601
+ base_out["raw_score"] = fused_score
602
+ base_out["is_focused"] = is_focused
603
+ base_out["boost_active"] = True
604
+ base_out["base_score"] = round(base_score, 3)
605
+ base_out["l2cs_score"] = round(l2cs_score, 3)
606
+
607
+ if l2cs_out.get("gaze_yaw") is not None:
608
+ base_out["gaze_yaw"] = l2cs_out["gaze_yaw"]
609
+ base_out["gaze_pitch"] = l2cs_out["gaze_pitch"]
610
+
611
+ return base_out
612
+
613
+ @app.on_event("startup")
614
+ async def startup_event():
615
+ global pipelines, _cached_model_name
616
+ print(" Starting Focus Guard API...")
617
+ await init_database()
618
+ # Load cached model name from DB
619
+ async with aiosqlite.connect(db_path) as db:
620
+ cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
621
+ row = await cursor.fetchone()
622
+ if row:
623
+ _cached_model_name = row[0]
624
+ print("[OK] Database initialized")
625
+
626
+ try:
627
+ pipelines["geometric"] = FaceMeshPipeline()
628
+ print("[OK] FaceMeshPipeline (geometric) loaded")
629
+ except Exception as e:
630
+ print(f"[WARN] FaceMeshPipeline unavailable: {e}")
631
+
632
+ try:
633
+ pipelines["mlp"] = MLPPipeline()
634
+ print("[OK] MLPPipeline loaded")
635
+ except Exception as e:
636
+ print(f"[ERR] Failed to load MLPPipeline: {e}")
637
+
638
+ try:
639
+ pipelines["hybrid"] = HybridFocusPipeline()
640
+ print("[OK] HybridFocusPipeline loaded")
641
+ except Exception as e:
642
+ print(f"[WARN] HybridFocusPipeline unavailable: {e}")
643
+
644
+ try:
645
+ pipelines["xgboost"] = XGBoostPipeline()
646
+ print("[OK] XGBoostPipeline loaded")
647
+ except Exception as e:
648
+ print(f"[ERR] Failed to load XGBoostPipeline: {e}")
649
+
650
+ if is_l2cs_weights_available():
651
+ print("[OK] L2CS weights found — pipeline will be lazy-loaded on first use")
652
+ else:
653
+ print("[WARN] L2CS weights not found — l2cs model unavailable")
654
+
655
+ @app.on_event("shutdown")
656
+ async def shutdown_event():
657
+ _inference_executor.shutdown(wait=False)
658
+ print(" Shutting down Focus Guard API...")
659
+
660
  # ================ WEBRTC SIGNALING ================
661
 
662
  @app.post("/api/webrtc/offer")
 
722
 
723
  @app.websocket("/ws/video")
724
  async def websocket_endpoint(websocket: WebSocket):
725
+ from models.gaze_calibration import GazeCalibration
726
+ from models.gaze_eye_fusion import GazeEyeFusion
727
+
728
  await websocket.accept()
729
  session_id = None
730
  frame_count = 0
731
  running = True
732
  event_buffer = _EventBuffer(flush_interval=2.0)
733
 
734
+ # Calibration state (per-connection)
735
+ _cal: dict = {"cal": None, "collecting": False, "fusion": None}
736
+
737
+ # Latest frame slot — only the most recent frame is kept, older ones are dropped.
738
  _slot = {"frame": None}
739
  _frame_ready = asyncio.Event()
740
 
 
765
  data = json.loads(text)
766
 
767
  if data["type"] == "frame":
 
768
  _slot["frame"] = base64.b64decode(data["image"])
769
  _frame_ready.set()
770
 
 
783
  if summary:
784
  await websocket.send_json({"type": "session_ended", "summary": summary})
785
  session_id = None
786
+
787
+ # ---- Calibration commands ----
788
+ elif data["type"] == "calibration_start":
789
+ loop = asyncio.get_event_loop()
790
+ await loop.run_in_executor(_inference_executor, _ensure_l2cs)
791
+ _cal["cal"] = GazeCalibration()
792
+ _cal["collecting"] = True
793
+ _cal["fusion"] = None
794
+ cal = _cal["cal"]
795
+ await websocket.send_json({
796
+ "type": "calibration_started",
797
+ "num_points": cal.num_points,
798
+ "target": list(cal.current_target),
799
+ "index": cal.current_index,
800
+ })
801
+
802
+ elif data["type"] == "calibration_next":
803
+ cal = _cal.get("cal")
804
+ if cal is not None:
805
+ more = cal.advance()
806
+ if more:
807
+ await websocket.send_json({
808
+ "type": "calibration_point",
809
+ "target": list(cal.current_target),
810
+ "index": cal.current_index,
811
+ })
812
+ else:
813
+ _cal["collecting"] = False
814
+ ok = cal.fit()
815
+ if ok:
816
+ _cal["fusion"] = GazeEyeFusion(cal)
817
+ await websocket.send_json({"type": "calibration_done", "success": True})
818
+ else:
819
+ await websocket.send_json({"type": "calibration_done", "success": False, "error": "Not enough samples"})
820
+
821
+ elif data["type"] == "calibration_cancel":
822
+ _cal["cal"] = None
823
+ _cal["collecting"] = False
824
+ _cal["fusion"] = None
825
+ await websocket.send_json({"type": "calibration_cancelled"})
826
+
827
  except WebSocketDisconnect:
828
  running = False
829
  _frame_ready.set()
 
842
  if not running:
843
  return
844
 
 
845
  raw = _slot["frame"]
846
  _slot["frame"] = None
847
  if raw is None:
 
854
  continue
855
  frame = cv2.resize(frame, (640, 480))
856
 
857
+ # During calibration collection, always use L2CS
858
+ collecting = _cal.get("collecting", False)
859
+ if collecting:
860
+ if pipelines.get("l2cs") is None:
861
+ await loop.run_in_executor(_inference_executor, _ensure_l2cs)
862
+ use_model = "l2cs" if pipelines.get("l2cs") is not None else _cached_model_name
863
+ else:
864
+ use_model = _cached_model_name
865
+
866
+ model_name = use_model
867
+ if model_name == "l2cs" and pipelines.get("l2cs") is None:
868
+ await loop.run_in_executor(_inference_executor, _ensure_l2cs)
869
+ if model_name not in pipelines or pipelines.get(model_name) is None:
870
+ model_name = "mlp"
871
+ active_pipeline = pipelines.get(model_name)
872
+
873
+ # L2CS boost: run L2CS alongside base model
874
+ use_boost = (
875
+ _l2cs_boost_enabled
876
+ and model_name != "l2cs"
877
+ and pipelines.get("l2cs") is not None
878
+ and not collecting
879
+ )
880
 
881
  landmarks_list = None
882
+ out = None
883
  if active_pipeline is not None:
884
+ if use_boost:
885
+ out = await loop.run_in_executor(
886
+ _inference_executor,
887
+ _process_frame_with_l2cs_boost,
888
+ active_pipeline,
889
+ frame,
890
+ model_name,
891
+ )
892
+ else:
893
+ out = await loop.run_in_executor(
894
+ _inference_executor,
895
+ _process_frame_safe,
896
+ active_pipeline,
897
+ frame,
898
+ model_name,
899
+ )
900
  is_focused = out["is_focused"]
901
  confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
902
 
903
  lm = out.get("landmarks")
904
  if lm is not None:
 
905
  landmarks_list = [
906
  [round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)]
907
  for i in range(lm.shape[0])
908
  ]
909
 
910
+ # Calibration sample collection (L2CS gaze angles)
911
+ if collecting and _cal.get("cal") is not None:
912
+ pipe_yaw = out.get("gaze_yaw")
913
+ pipe_pitch = out.get("gaze_pitch")
914
+ if pipe_yaw is not None and pipe_pitch is not None:
915
+ _cal["cal"].collect_sample(pipe_yaw, pipe_pitch)
916
+
917
+ # Gaze fusion (when L2CS active + calibration fitted)
918
+ fusion = _cal.get("fusion")
919
+ if (
920
+ fusion is not None
921
+ and model_name == "l2cs"
922
+ and out.get("gaze_yaw") is not None
923
+ ):
924
+ fuse = fusion.update(
925
+ out["gaze_yaw"], out["gaze_pitch"], lm
926
+ )
927
+ is_focused = fuse["focused"]
928
+ confidence = fuse["focus_score"]
929
+
930
  if session_id:
931
+ metadata = {
932
  "s_face": out.get("s_face", 0.0),
933
  "s_eye": out.get("s_eye", 0.0),
934
  "mar": out.get("mar", 0.0),
935
  "model": model_name,
936
+ }
937
+ event_buffer.add(session_id, is_focused, confidence, metadata)
938
  else:
939
  is_focused = False
940
  confidence = 0.0
 
948
  "fc": frame_count,
949
  "frame_count": frame_count,
950
  }
951
+ if out is not None:
 
952
  if out.get("yaw") is not None:
953
  resp["yaw"] = round(out["yaw"], 1)
954
  resp["pitch"] = round(out["pitch"], 1)
 
957
  resp["mar"] = round(out["mar"], 3)
958
  resp["sf"] = round(out.get("s_face", 0), 3)
959
  resp["se"] = round(out.get("s_eye", 0), 3)
960
+
961
+ # Gaze fusion fields (L2CS standalone or boost mode)
962
+ fusion = _cal.get("fusion")
963
+ has_gaze = out.get("gaze_yaw") is not None
964
+ if fusion is not None and has_gaze and (model_name == "l2cs" or use_boost):
965
+ fuse = fusion.update(out["gaze_yaw"], out["gaze_pitch"], out.get("landmarks"))
966
+ resp["gaze_x"] = fuse["gaze_x"]
967
+ resp["gaze_y"] = fuse["gaze_y"]
968
+ resp["on_screen"] = fuse["on_screen"]
969
+ if model_name == "l2cs":
970
+ resp["focused"] = fuse["focused"]
971
+ resp["confidence"] = round(fuse["focus_score"], 3)
972
+
973
+ if out.get("boost_active"):
974
+ resp["boost"] = True
975
+ resp["base_score"] = out.get("base_score", 0)
976
+ resp["l2cs_score"] = out.get("l2cs_score", 0)
977
+
978
  if landmarks_list is not None:
979
  resp["lm"] = landmarks_list
980
  await websocket.send_json(resp)
 
1107
  db.row_factory = aiosqlite.Row
1108
  cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1")
1109
  row = await cursor.fetchone()
1110
+ result = dict(row) if row else {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'}
1111
+ result['l2cs_boost'] = _l2cs_boost_enabled
1112
+ return result
1113
 
1114
  @app.put("/api/settings")
1115
  async def update_settings(settings: SettingsUpdate):
 
1134
  if settings.frame_rate is not None:
1135
  updates.append("frame_rate = ?")
1136
  params.append(max(5, min(60, settings.frame_rate)))
1137
+ if settings.model_name is not None and settings.model_name in pipelines:
1138
+ if settings.model_name == "l2cs":
1139
+ loop = asyncio.get_event_loop()
1140
+ loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs)
1141
+ if not loaded:
1142
+ raise HTTPException(status_code=400, detail=f"L2CS model unavailable: {_l2cs_error}")
1143
+ elif pipelines[settings.model_name] is None:
1144
+ raise HTTPException(status_code=400, detail=f"Model '{settings.model_name}' not loaded")
1145
  updates.append("model_name = ?")
1146
  params.append(settings.model_name)
1147
  global _cached_model_name
1148
  _cached_model_name = settings.model_name
1149
 
1150
+ if settings.l2cs_boost is not None:
1151
+ global _l2cs_boost_enabled
1152
+ if settings.l2cs_boost:
1153
+ loop = asyncio.get_event_loop()
1154
+ loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs)
1155
+ if not loaded:
1156
+ raise HTTPException(status_code=400, detail=f"L2CS boost unavailable: {_l2cs_error}")
1157
+ _l2cs_boost_enabled = settings.l2cs_boost
1158
+
1159
  if updates:
1160
  query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1"
1161
  await db.execute(query, params)
 
1207
 
1208
  @app.get("/api/models")
1209
  async def get_available_models():
1210
+ """Return model names, statuses, and which is currently active."""
1211
+ statuses = {}
1212
+ errors = {}
1213
+ available = []
1214
+ for name, p in pipelines.items():
1215
+ if name == "l2cs":
1216
+ if p is not None:
1217
+ statuses[name] = "ready"
1218
+ available.append(name)
1219
+ elif is_l2cs_weights_available():
1220
+ statuses[name] = "lazy"
1221
+ available.append(name)
1222
+ elif _l2cs_error:
1223
+ statuses[name] = "error"
1224
+ errors[name] = _l2cs_error
1225
+ else:
1226
+ statuses[name] = "unavailable"
1227
+ elif p is not None:
1228
+ statuses[name] = "ready"
1229
+ available.append(name)
1230
+ else:
1231
+ statuses[name] = "unavailable"
1232
  async with aiosqlite.connect(db_path) as db:
1233
  cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
1234
  row = await cursor.fetchone()
1235
  current = row[0] if row else "mlp"
1236
  if current not in available and available:
1237
  current = available[0]
1238
+ l2cs_boost_available = (
1239
+ statuses.get("l2cs") in ("ready", "lazy") and current != "l2cs"
1240
+ )
1241
+ return {
1242
+ "available": available,
1243
+ "current": current,
1244
+ "statuses": statuses,
1245
+ "errors": errors,
1246
+ "l2cs_boost": _l2cs_boost_enabled,
1247
+ "l2cs_boost_available": l2cs_boost_available,
1248
+ }
1249
+
1250
+ @app.get("/api/l2cs/status")
1251
+ async def l2cs_status():
1252
+ """L2CS-specific status: weights available, loaded, and calibration info."""
1253
+ loaded = pipelines.get("l2cs") is not None
1254
+ return {
1255
+ "weights_available": is_l2cs_weights_available(),
1256
+ "loaded": loaded,
1257
+ "error": _l2cs_error,
1258
+ }
1259
 
1260
  @app.get("/api/mesh-topology")
1261
  async def get_mesh_topology():
models/L2CS-Net/.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore the test data - sensitive
2
+ datasets/
3
+ evaluation/
4
+ output/
5
+
6
+ # Ignore debugging configurations
7
+ /.vscode
8
+
9
+ # Byte-compiled / optimized / DLL files
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+
14
+ # C extensions
15
+ *.so
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ pip-wheel-metadata/
32
+ share/python-wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+ MANIFEST
37
+
38
+ # PyInstaller
39
+ # Usually these files are written by a python script from a template
40
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
+ *.manifest
42
+ *.spec
43
+
44
+ # Installer logs
45
+ pip-log.txt
46
+ pip-delete-this-directory.txt
47
+
48
+ # Unit test / coverage reports
49
+ htmlcov/
50
+ .tox/
51
+ .nox/
52
+ .coverage
53
+ .coverage.*
54
+ .cache
55
+ nosetests.xml
56
+ coverage.xml
57
+ *.cover
58
+ *.py,cover
59
+ .hypothesis/
60
+ .pytest_cache/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+ db.sqlite3-journal
71
+
72
+ # Flask stuff:
73
+ instance/
74
+ .webassets-cache
75
+
76
+ # Scrapy stuff:
77
+ .scrapy
78
+
79
+ # Sphinx documentation
80
+ docs/_build/
81
+
82
+ # PyBuilder
83
+ target/
84
+
85
+ # Jupyter Notebook
86
+ .ipynb_checkpoints
87
+
88
+ # IPython
89
+ profile_default/
90
+ ipython_config.py
91
+
92
+ # pyenv
93
+ .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
103
+ __pypackages__/
104
+
105
+ # Celery stuff
106
+ celerybeat-schedule
107
+ celerybeat.pid
108
+
109
+ # SageMath parsed files
110
+ *.sage.py
111
+
112
+ # Environments
113
+ .env
114
+ .venv
115
+ env/
116
+ venv/
117
+ ENV/
118
+ env.bak/
119
+ venv.bak/
120
+
121
+ # Spyder project settings
122
+ .spyderproject
123
+ .spyproject
124
+
125
+ # Rope project settings
126
+ .ropeproject
127
+
128
+ # mkdocs documentation
129
+ /site
130
+
131
+ # mypy
132
+ .mypy_cache/
133
+ .dmypy.json
134
+ dmypy.json
135
+
136
+ # Pyre type checker
137
+ .pyre/
138
+
139
+ # Ignore other files
140
+ my.secrets
models/L2CS-Net/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Ahmed Abdelrahman
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
models/L2CS-Net/README.md ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ <p align="center">
5
+ <img src="https://github.com/Ahmednull/Storage/blob/main/gaze.gif" alt="animated" />
6
+ </p>
7
+
8
+
9
+ ___
10
+
11
+ # L2CS-Net
12
+
13
+ The official PyTorch implementation of L2CS-Net for gaze estimation and tracking.
14
+
15
+ ## Installation
16
+ <img src="https://img.shields.io/badge/python%20-%2314354C.svg?&style=for-the-badge&logo=python&logoColor=white"/> <img src="https://img.shields.io/badge/PyTorch%20-%23EE4C2C.svg?&style=for-the-badge&logo=PyTorch&logoColor=white" />
17
+
18
+ Install package with the following:
19
+
20
+ ```
21
+ pip install git+https://github.com/Ahmednull/L2CS-Net.git@main
22
+ ```
23
+
24
+ Or, you can git clone the repo and install with the following:
25
+
26
+ ```
27
+ pip install [-e] .
28
+ ```
29
+
30
+ Now you should be able to import the package with the following command:
31
+
32
+ ```
33
+ $ python
34
+ >>> import l2cs
35
+ ```
36
+
37
+ ## Usage
38
+
39
+ Detect face and predict gaze from webcam
40
+
41
+ ```python
42
+ from l2cs import Pipeline, render
43
+ import cv2
44
+
45
+ gaze_pipeline = Pipeline(
46
+ weights=CWD / 'models' / 'L2CSNet_gaze360.pkl',
47
+ arch='ResNet50',
48
+ device=torch.device('cpu') # or 'gpu'
49
+ )
50
+
51
+ cap = cv2.VideoCapture(cam)
52
+ _, frame = cap.read()
53
+
54
+ # Process frame and visualize
55
+ results = gaze_pipeline.step(frame)
56
+ frame = render(frame, results)
57
+ ```
58
+
59
+ ## Demo
60
+ * Download the pre-trained models from [here](https://drive.google.com/drive/folders/17p6ORr-JQJcw-eYtG2WGNiuS_qVKwdWd?usp=sharing) and Store it to *models/*.
61
+ * Run:
62
+ ```
63
+ python demo.py \
64
+ --snapshot models/L2CSNet_gaze360.pkl \
65
+ --gpu 0 \
66
+ --cam 0 \
67
+ ```
68
+ This means the demo will run using *L2CSNet_gaze360.pkl* pretrained model
69
+
70
+ ## Community Contributions
71
+
72
+ - [Gaze Detection and Eye Tracking: A How-To Guide](https://blog.roboflow.com/gaze-direction-position/): Use L2CS-Net through a HTTP interface with the open source Roboflow Inference project.
73
+
74
+ ## MPIIGaze
75
+ We provide the code for train and test MPIIGaze dataset with leave-one-person-out evaluation.
76
+
77
+ ### Prepare datasets
78
+ * Download **MPIIFaceGaze dataset** from [here](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/its-written-all-over-your-face-full-face-appearance-based-gaze-estimation).
79
+ * Apply data preprocessing from [here](http://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/).
80
+ * Store the dataset to *datasets/MPIIFaceGaze*.
81
+
82
+ ### Train
83
+ ```
84
+ python train.py \
85
+ --dataset mpiigaze \
86
+ --snapshot output/snapshots \
87
+ --gpu 0 \
88
+ --num_epochs 50 \
89
+ --batch_size 16 \
90
+ --lr 0.00001 \
91
+ --alpha 1 \
92
+
93
+ ```
94
+ This means the code will perform leave-one-person-out training automatically and store the models to *output/snapshots*.
95
+
96
+ ### Test
97
+ ```
98
+ python test.py \
99
+ --dataset mpiigaze \
100
+ --snapshot output/snapshots/snapshot_folder \
101
+ --evalpath evaluation/L2CS-mpiigaze \
102
+ --gpu 0 \
103
+ ```
104
+ This means the code will perform leave-one-person-out testing automatically and store the results to *evaluation/L2CS-mpiigaze*.
105
+
106
+ To get the average leave-one-person-out accuracy use:
107
+ ```
108
+ python leave_one_out_eval.py \
109
+ --evalpath evaluation/L2CS-mpiigaze \
110
+ --respath evaluation/L2CS-mpiigaze \
111
+ ```
112
+ This means the code will take the evaluation path and outputs the leave-one-out gaze accuracy to the *evaluation/L2CS-mpiigaze*.
113
+
114
+ ## Gaze360
115
+ We provide the code for train and test Gaze360 dataset with train-val-test evaluation.
116
+
117
+ ### Prepare datasets
118
+ * Download **Gaze360 dataset** from [here](http://gaze360.csail.mit.edu/download.php).
119
+
120
+ * Apply data preprocessing from [here](http://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/).
121
+
122
+ * Store the dataset to *datasets/Gaze360*.
123
+
124
+
125
+ ### Train
126
+ ```
127
+ python train.py \
128
+ --dataset gaze360 \
129
+ --snapshot output/snapshots \
130
+ --gpu 0 \
131
+ --num_epochs 50 \
132
+ --batch_size 16 \
133
+ --lr 0.00001 \
134
+ --alpha 1 \
135
+
136
+ ```
137
+ This means the code will perform training and store the models to *output/snapshots*.
138
+
139
+ ### Test
140
+ ```
141
+ python test.py \
142
+ --dataset gaze360 \
143
+ --snapshot output/snapshots/snapshot_folder \
144
+ --evalpath evaluation/L2CS-gaze360 \
145
+ --gpu 0 \
146
+ ```
147
+ This means the code will perform testing on snapshot_folder and store the results to *evaluation/L2CS-gaze360*.
148
+
models/L2CS-Net/demo.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pathlib
3
+ import numpy as np
4
+ import cv2
5
+ import time
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.autograd import Variable
10
+ from torchvision import transforms
11
+ import torch.backends.cudnn as cudnn
12
+ import torchvision
13
+
14
+ from PIL import Image
15
+ from PIL import Image, ImageOps
16
+
17
+ from face_detection import RetinaFace
18
+
19
+ from l2cs import select_device, draw_gaze, getArch, Pipeline, render
20
+
21
+ CWD = pathlib.Path.cwd()
22
+
23
+ def parse_args():
24
+ """Parse input arguments."""
25
+ parser = argparse.ArgumentParser(
26
+ description='Gaze evalution using model pretrained with L2CS-Net on Gaze360.')
27
+ parser.add_argument(
28
+ '--device',dest='device', help='Device to run model: cpu or gpu:0',
29
+ default="cpu", type=str)
30
+ parser.add_argument(
31
+ '--snapshot',dest='snapshot', help='Path of model snapshot.',
32
+ default='output/snapshots/L2CS-gaze360-_loader-180-4/_epoch_55.pkl', type=str)
33
+ parser.add_argument(
34
+ '--cam',dest='cam_id', help='Camera device id to use [0]',
35
+ default=0, type=int)
36
+ parser.add_argument(
37
+ '--arch',dest='arch',help='Network architecture, can be: ResNet18, ResNet34, ResNet50, ResNet101, ResNet152',
38
+ default='ResNet50', type=str)
39
+
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+ if __name__ == '__main__':
44
+ args = parse_args()
45
+
46
+ cudnn.enabled = True
47
+ arch=args.arch
48
+ cam = args.cam_id
49
+ # snapshot_path = args.snapshot
50
+
51
+ gaze_pipeline = Pipeline(
52
+ weights=CWD / 'models' / 'L2CSNet_gaze360.pkl',
53
+ arch='ResNet50',
54
+ device = select_device(args.device, batch_size=1)
55
+ )
56
+
57
+ cap = cv2.VideoCapture(cam)
58
+
59
+ # Check if the webcam is opened correctly
60
+ if not cap.isOpened():
61
+ raise IOError("Cannot open webcam")
62
+
63
+ with torch.no_grad():
64
+ while True:
65
+
66
+ # Get frame
67
+ success, frame = cap.read()
68
+ start_fps = time.time()
69
+
70
+ if not success:
71
+ print("Failed to obtain frame")
72
+ time.sleep(0.1)
73
+
74
+ # Process frame
75
+ results = gaze_pipeline.step(frame)
76
+
77
+ # Visualize output
78
+ frame = render(frame, results)
79
+
80
+ myFPS = 1.0 / (time.time() - start_fps)
81
+ cv2.putText(frame, 'FPS: {:.1f}'.format(myFPS), (10, 20),cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (0, 255, 0), 1, cv2.LINE_AA)
82
+
83
+ cv2.imshow("Demo",frame)
84
+ if cv2.waitKey(1) & 0xFF == ord('q'):
85
+ break
86
+ success,frame = cap.read()
87
+
models/L2CS-Net/l2cs/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import select_device, natural_keys, gazeto3d, angular, getArch
2
+ from .vis import draw_gaze, render
3
+ from .model import L2CS
4
+ from .pipeline import Pipeline
5
+ from .datasets import Gaze360, Mpiigaze
6
+
7
+ __all__ = [
8
+ # Classes
9
+ 'L2CS',
10
+ 'Pipeline',
11
+ 'Gaze360',
12
+ 'Mpiigaze',
13
+ # Utils
14
+ 'render',
15
+ 'select_device',
16
+ 'draw_gaze',
17
+ 'natural_keys',
18
+ 'gazeto3d',
19
+ 'angular',
20
+ 'getArch'
21
+ ]
models/L2CS-Net/l2cs/datasets.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+
5
+
6
+ import torch
7
+ from torch.utils.data.dataset import Dataset
8
+ from torchvision import transforms
9
+ from PIL import Image, ImageFilter
10
+
11
+
12
+ class Gaze360(Dataset):
13
+ def __init__(self, path, root, transform, angle, binwidth, train=True):
14
+ self.transform = transform
15
+ self.root = root
16
+ self.orig_list_len = 0
17
+ self.angle = angle
18
+ if train==False:
19
+ angle=90
20
+ self.binwidth=binwidth
21
+ self.lines = []
22
+ if isinstance(path, list):
23
+ for i in path:
24
+ with open(i) as f:
25
+ print("here")
26
+ line = f.readlines()
27
+ line.pop(0)
28
+ self.lines.extend(line)
29
+ else:
30
+ with open(path) as f:
31
+ lines = f.readlines()
32
+ lines.pop(0)
33
+ self.orig_list_len = len(lines)
34
+ for line in lines:
35
+ gaze2d = line.strip().split(" ")[5]
36
+ label = np.array(gaze2d.split(",")).astype("float")
37
+ if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle:
38
+ self.lines.append(line)
39
+
40
+
41
+ print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines), angle))
42
+
43
+ def __len__(self):
44
+ return len(self.lines)
45
+
46
+ def __getitem__(self, idx):
47
+ line = self.lines[idx]
48
+ line = line.strip().split(" ")
49
+
50
+ face = line[0]
51
+ lefteye = line[1]
52
+ righteye = line[2]
53
+ name = line[3]
54
+ gaze2d = line[5]
55
+ label = np.array(gaze2d.split(",")).astype("float")
56
+ label = torch.from_numpy(label).type(torch.FloatTensor)
57
+
58
+ pitch = label[0]* 180 / np.pi
59
+ yaw = label[1]* 180 / np.pi
60
+
61
+ img = Image.open(os.path.join(self.root, face))
62
+
63
+ # fimg = cv2.imread(os.path.join(self.root, face))
64
+ # fimg = cv2.resize(fimg, (448, 448))/255.0
65
+ # fimg = fimg.transpose(2, 0, 1)
66
+ # img=torch.from_numpy(fimg).type(torch.FloatTensor)
67
+
68
+ if self.transform:
69
+ img = self.transform(img)
70
+
71
+ # Bin values
72
+ bins = np.array(range(-1*self.angle, self.angle, self.binwidth))
73
+ binned_pose = np.digitize([pitch, yaw], bins) - 1
74
+
75
+ labels = binned_pose
76
+ cont_labels = torch.FloatTensor([pitch, yaw])
77
+
78
+
79
+ return img, labels, cont_labels, name
80
+
81
+ class Mpiigaze(Dataset):
82
+ def __init__(self, pathorg, root, transform, train, angle,fold=0):
83
+ self.transform = transform
84
+ self.root = root
85
+ self.orig_list_len = 0
86
+ self.lines = []
87
+ path=pathorg.copy()
88
+ if train==True:
89
+ path.pop(fold)
90
+ else:
91
+ path=path[fold]
92
+ if isinstance(path, list):
93
+ for i in path:
94
+ with open(i) as f:
95
+ lines = f.readlines()
96
+ lines.pop(0)
97
+ self.orig_list_len += len(lines)
98
+ for line in lines:
99
+ gaze2d = line.strip().split(" ")[7]
100
+ label = np.array(gaze2d.split(",")).astype("float")
101
+ if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle:
102
+ self.lines.append(line)
103
+ else:
104
+ with open(path) as f:
105
+ lines = f.readlines()
106
+ lines.pop(0)
107
+ self.orig_list_len += len(lines)
108
+ for line in lines:
109
+ gaze2d = line.strip().split(" ")[7]
110
+ label = np.array(gaze2d.split(",")).astype("float")
111
+ if abs((label[0]*180/np.pi)) <= 42 and abs((label[1]*180/np.pi)) <= 42:
112
+ self.lines.append(line)
113
+
114
+ print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines),angle))
115
+
116
+ def __len__(self):
117
+ return len(self.lines)
118
+
119
+ def __getitem__(self, idx):
120
+ line = self.lines[idx]
121
+ line = line.strip().split(" ")
122
+
123
+ name = line[3]
124
+ gaze2d = line[7]
125
+ head2d = line[8]
126
+ lefteye = line[1]
127
+ righteye = line[2]
128
+ face = line[0]
129
+
130
+ label = np.array(gaze2d.split(",")).astype("float")
131
+ label = torch.from_numpy(label).type(torch.FloatTensor)
132
+
133
+
134
+ pitch = label[0]* 180 / np.pi
135
+ yaw = label[1]* 180 / np.pi
136
+
137
+ img = Image.open(os.path.join(self.root, face))
138
+
139
+ # fimg = cv2.imread(os.path.join(self.root, face))
140
+ # fimg = cv2.resize(fimg, (448, 448))/255.0
141
+ # fimg = fimg.transpose(2, 0, 1)
142
+ # img=torch.from_numpy(fimg).type(torch.FloatTensor)
143
+
144
+ if self.transform:
145
+ img = self.transform(img)
146
+
147
+ # Bin values
148
+ bins = np.array(range(-42, 42,3))
149
+ binned_pose = np.digitize([pitch, yaw], bins) - 1
150
+
151
+ labels = binned_pose
152
+ cont_labels = torch.FloatTensor([pitch, yaw])
153
+
154
+
155
+ return img, labels, cont_labels, name
156
+
157
+
models/L2CS-Net/l2cs/model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+ import math
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class L2CS(nn.Module):
9
+ def __init__(self, block, layers, num_bins):
10
+ self.inplanes = 64
11
+ super(L2CS, self).__init__()
12
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
13
+ self.bn1 = nn.BatchNorm2d(64)
14
+ self.relu = nn.ReLU(inplace=True)
15
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
16
+ self.layer1 = self._make_layer(block, 64, layers[0])
17
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
18
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
19
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
20
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
21
+
22
+ self.fc_yaw_gaze = nn.Linear(512 * block.expansion, num_bins)
23
+ self.fc_pitch_gaze = nn.Linear(512 * block.expansion, num_bins)
24
+
25
+ # Vestigial layer from previous experiments
26
+ self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)
27
+
28
+ for m in self.modules():
29
+ if isinstance(m, nn.Conv2d):
30
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
31
+ m.weight.data.normal_(0, math.sqrt(2. / n))
32
+ elif isinstance(m, nn.BatchNorm2d):
33
+ m.weight.data.fill_(1)
34
+ m.bias.data.zero_()
35
+
36
+ def _make_layer(self, block, planes, blocks, stride=1):
37
+ downsample = None
38
+ if stride != 1 or self.inplanes != planes * block.expansion:
39
+ downsample = nn.Sequential(
40
+ nn.Conv2d(self.inplanes, planes * block.expansion,
41
+ kernel_size=1, stride=stride, bias=False),
42
+ nn.BatchNorm2d(planes * block.expansion),
43
+ )
44
+
45
+ layers = []
46
+ layers.append(block(self.inplanes, planes, stride, downsample))
47
+ self.inplanes = planes * block.expansion
48
+ for i in range(1, blocks):
49
+ layers.append(block(self.inplanes, planes))
50
+
51
+ return nn.Sequential(*layers)
52
+
53
+ def forward(self, x):
54
+ x = self.conv1(x)
55
+ x = self.bn1(x)
56
+ x = self.relu(x)
57
+ x = self.maxpool(x)
58
+
59
+ x = self.layer1(x)
60
+ x = self.layer2(x)
61
+ x = self.layer3(x)
62
+ x = self.layer4(x)
63
+ x = self.avgpool(x)
64
+ x = x.view(x.size(0), -1)
65
+
66
+
67
+ # gaze
68
+ pre_yaw_gaze = self.fc_yaw_gaze(x)
69
+ pre_pitch_gaze = self.fc_pitch_gaze(x)
70
+ return pre_yaw_gaze, pre_pitch_gaze
71
+
72
+
73
+
models/L2CS-Net/l2cs/pipeline.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from typing import Union
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from dataclasses import dataclass
9
+ from face_detection import RetinaFace
10
+
11
+ from .utils import prep_input_numpy, getArch
12
+ from .results import GazeResultContainer
13
+
14
+
15
+ class Pipeline:
16
+
17
+ def __init__(
18
+ self,
19
+ weights: pathlib.Path,
20
+ arch: str,
21
+ device: str = 'cpu',
22
+ include_detector:bool = True,
23
+ confidence_threshold:float = 0.5
24
+ ):
25
+
26
+ # Save input parameters
27
+ self.weights = weights
28
+ self.include_detector = include_detector
29
+ self.device = device
30
+ self.confidence_threshold = confidence_threshold
31
+
32
+ # Create L2CS model
33
+ self.model = getArch(arch, 90)
34
+ self.model.load_state_dict(torch.load(self.weights, map_location=device))
35
+ self.model.to(self.device)
36
+ self.model.eval()
37
+
38
+ # Create RetinaFace if requested
39
+ if self.include_detector:
40
+
41
+ if device.type == 'cpu':
42
+ self.detector = RetinaFace()
43
+ else:
44
+ self.detector = RetinaFace(gpu_id=device.index)
45
+
46
+ self.softmax = nn.Softmax(dim=1)
47
+ self.idx_tensor = [idx for idx in range(90)]
48
+ self.idx_tensor = torch.FloatTensor(self.idx_tensor).to(self.device)
49
+
50
+ def step(self, frame: np.ndarray) -> GazeResultContainer:
51
+
52
+ # Creating containers
53
+ face_imgs = []
54
+ bboxes = []
55
+ landmarks = []
56
+ scores = []
57
+
58
+ if self.include_detector:
59
+ faces = self.detector(frame)
60
+
61
+ if faces is not None:
62
+ for box, landmark, score in faces:
63
+
64
+ # Apply threshold
65
+ if score < self.confidence_threshold:
66
+ continue
67
+
68
+ # Extract safe min and max of x,y
69
+ x_min=int(box[0])
70
+ if x_min < 0:
71
+ x_min = 0
72
+ y_min=int(box[1])
73
+ if y_min < 0:
74
+ y_min = 0
75
+ x_max=int(box[2])
76
+ y_max=int(box[3])
77
+
78
+ # Crop image
79
+ img = frame[y_min:y_max, x_min:x_max]
80
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
81
+ img = cv2.resize(img, (224, 224))
82
+ face_imgs.append(img)
83
+
84
+ # Save data
85
+ bboxes.append(box)
86
+ landmarks.append(landmark)
87
+ scores.append(score)
88
+
89
+ # Predict gaze
90
+ pitch, yaw = self.predict_gaze(np.stack(face_imgs))
91
+
92
+ else:
93
+
94
+ pitch = np.empty((0,1))
95
+ yaw = np.empty((0,1))
96
+
97
+ else:
98
+ pitch, yaw = self.predict_gaze(frame)
99
+
100
+ # Save data
101
+ results = GazeResultContainer(
102
+ pitch=pitch,
103
+ yaw=yaw,
104
+ bboxes=np.stack(bboxes),
105
+ landmarks=np.stack(landmarks),
106
+ scores=np.stack(scores)
107
+ )
108
+
109
+ return results
110
+
111
+ def predict_gaze(self, frame: Union[np.ndarray, torch.Tensor]):
112
+
113
+ # Prepare input
114
+ if isinstance(frame, np.ndarray):
115
+ img = prep_input_numpy(frame, self.device)
116
+ elif isinstance(frame, torch.Tensor):
117
+ img = frame
118
+ else:
119
+ raise RuntimeError("Invalid dtype for input")
120
+
121
+ # Predict
122
+ gaze_pitch, gaze_yaw = self.model(img)
123
+ pitch_predicted = self.softmax(gaze_pitch)
124
+ yaw_predicted = self.softmax(gaze_yaw)
125
+
126
+ # Get continuous predictions in degrees.
127
+ pitch_predicted = torch.sum(pitch_predicted.data * self.idx_tensor, dim=1) * 4 - 180
128
+ yaw_predicted = torch.sum(yaw_predicted.data * self.idx_tensor, dim=1) * 4 - 180
129
+
130
+ pitch_predicted= pitch_predicted.cpu().detach().numpy()* np.pi/180.0
131
+ yaw_predicted= yaw_predicted.cpu().detach().numpy()* np.pi/180.0
132
+
133
+ return pitch_predicted, yaw_predicted
models/L2CS-Net/l2cs/results.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import numpy as np
3
+
4
+ @dataclass
5
+ class GazeResultContainer:
6
+
7
+ pitch: np.ndarray
8
+ yaw: np.ndarray
9
+ bboxes: np.ndarray
10
+ landmarks: np.ndarray
11
+ scores: np.ndarray
models/L2CS-Net/l2cs/utils.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import math
4
+ from math import cos, sin
5
+ from pathlib import Path
6
+ import subprocess
7
+ import re
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import scipy.io as sio
13
+ import cv2
14
+ import torchvision
15
+ from torchvision import transforms
16
+
17
+ from .model import L2CS
18
+
19
+ transformations = transforms.Compose([
20
+ transforms.ToPILImage(),
21
+ transforms.Resize(448),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(
24
+ mean=[0.485, 0.456, 0.406],
25
+ std=[0.229, 0.224, 0.225]
26
+ )
27
+ ])
28
+
29
+ def atoi(text):
30
+ return int(text) if text.isdigit() else text
31
+
32
+ def natural_keys(text):
33
+ '''
34
+ alist.sort(key=natural_keys) sorts in human order
35
+ http://nedbatchelder.com/blog/200712/human_sorting.html
36
+ (See Toothy's implementation in the comments)
37
+ '''
38
+ return [ atoi(c) for c in re.split(r'(\d+)', text) ]
39
+
40
+ def prep_input_numpy(img:np.ndarray, device:str):
41
+ """Preparing a Numpy Array as input to L2CS-Net."""
42
+
43
+ if len(img.shape) == 4:
44
+ imgs = []
45
+ for im in img:
46
+ imgs.append(transformations(im))
47
+ img = torch.stack(imgs)
48
+ else:
49
+ img = transformations(img)
50
+
51
+ img = img.to(device)
52
+
53
+ if len(img.shape) == 3:
54
+ img = img.unsqueeze(0)
55
+
56
+ return img
57
+
58
+ def gazeto3d(gaze):
59
+ gaze_gt = np.zeros([3])
60
+ gaze_gt[0] = -np.cos(gaze[1]) * np.sin(gaze[0])
61
+ gaze_gt[1] = -np.sin(gaze[1])
62
+ gaze_gt[2] = -np.cos(gaze[1]) * np.cos(gaze[0])
63
+ return gaze_gt
64
+
65
+ def angular(gaze, label):
66
+ total = np.sum(gaze * label)
67
+ return np.arccos(min(total/(np.linalg.norm(gaze)* np.linalg.norm(label)), 0.9999999))*180/np.pi
68
+
69
+ def select_device(device='', batch_size=None):
70
+ # device = 'cpu' or '0' or '0,1,2,3'
71
+ s = f'YOLOv3 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
72
+ cpu = device.lower() == 'cpu'
73
+ if cpu:
74
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
75
+ elif device: # non-cpu device requested
76
+ os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
77
+ # assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability
78
+
79
+ cuda = not cpu and torch.cuda.is_available()
80
+ if cuda:
81
+ devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7
82
+ n = len(devices) # device count
83
+ if n > 1 and batch_size: # check batch_size is divisible by device_count
84
+ assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
85
+ space = ' ' * len(s)
86
+ for i, d in enumerate(devices):
87
+ p = torch.cuda.get_device_properties(i)
88
+ s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB
89
+ else:
90
+ s += 'CPU\n'
91
+
92
+ return torch.device('cuda:0' if cuda else 'cpu')
93
+
94
+ def spherical2cartesial(x):
95
+
96
+ output = torch.zeros(x.size(0),3)
97
+ output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0])
98
+ output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0])
99
+ output[:,1] = torch.sin(x[:,1])
100
+
101
+ return output
102
+
103
+ def compute_angular_error(input,target):
104
+
105
+ input = spherical2cartesial(input)
106
+ target = spherical2cartesial(target)
107
+
108
+ input = input.view(-1,3,1)
109
+ target = target.view(-1,1,3)
110
+ output_dot = torch.bmm(target,input)
111
+ output_dot = output_dot.view(-1)
112
+ output_dot = torch.acos(output_dot)
113
+ output_dot = output_dot.data
114
+ output_dot = 180*torch.mean(output_dot)/math.pi
115
+ return output_dot
116
+
117
+ def softmax_temperature(tensor, temperature):
118
+ result = torch.exp(tensor / temperature)
119
+ result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result))
120
+ return result
121
+
122
+ def git_describe(path=Path(__file__).parent): # path must be a directory
123
+ # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
124
+ s = f'git -C {path} describe --tags --long --always'
125
+ try:
126
+ return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1]
127
+ except subprocess.CalledProcessError as e:
128
+ return '' # not a git repository
129
+
130
+ def getArch(arch,bins):
131
+ # Base network structure
132
+ if arch == 'ResNet18':
133
+ model = L2CS( torchvision.models.resnet.BasicBlock,[2, 2, 2, 2], bins)
134
+ elif arch == 'ResNet34':
135
+ model = L2CS( torchvision.models.resnet.BasicBlock,[3, 4, 6, 3], bins)
136
+ elif arch == 'ResNet101':
137
+ model = L2CS( torchvision.models.resnet.Bottleneck,[3, 4, 23, 3], bins)
138
+ elif arch == 'ResNet152':
139
+ model = L2CS( torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
140
+ else:
141
+ if arch != 'ResNet50':
142
+ print('Invalid value for architecture is passed! '
143
+ 'The default value of ResNet50 will be used instead!')
144
+ model = L2CS( torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
145
+ return model
models/L2CS-Net/l2cs/vis.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from .results import GazeResultContainer
4
+
5
+ def draw_gaze(a,b,c,d,image_in, pitchyaw, thickness=2, color=(255, 255, 0),sclae=2.0):
6
+ """Draw gaze angle on given image with a given eye positions."""
7
+ image_out = image_in
8
+ (h, w) = image_in.shape[:2]
9
+ length = c
10
+ pos = (int(a+c / 2.0), int(b+d / 2.0))
11
+ if len(image_out.shape) == 2 or image_out.shape[2] == 1:
12
+ image_out = cv2.cvtColor(image_out, cv2.COLOR_GRAY2BGR)
13
+ dx = -length * np.sin(pitchyaw[0]) * np.cos(pitchyaw[1])
14
+ dy = -length * np.sin(pitchyaw[1])
15
+ cv2.arrowedLine(image_out, tuple(np.round(pos).astype(np.int32)),
16
+ tuple(np.round([pos[0] + dx, pos[1] + dy]).astype(int)), color,
17
+ thickness, cv2.LINE_AA, tipLength=0.18)
18
+ return image_out
19
+
20
+ def draw_bbox(frame: np.ndarray, bbox: np.ndarray):
21
+
22
+ x_min=int(bbox[0])
23
+ if x_min < 0:
24
+ x_min = 0
25
+ y_min=int(bbox[1])
26
+ if y_min < 0:
27
+ y_min = 0
28
+ x_max=int(bbox[2])
29
+ y_max=int(bbox[3])
30
+
31
+ cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0,255,0), 1)
32
+
33
+ return frame
34
+
35
+ def render(frame: np.ndarray, results: GazeResultContainer):
36
+
37
+ # Draw bounding boxes
38
+ for bbox in results.bboxes:
39
+ frame = draw_bbox(frame, bbox)
40
+
41
+ # Draw Gaze
42
+ for i in range(results.pitch.shape[0]):
43
+
44
+ bbox = results.bboxes[i]
45
+ pitch = results.pitch[i]
46
+ yaw = results.yaw[i]
47
+
48
+ # Extract safe min and max of x,y
49
+ x_min=int(bbox[0])
50
+ if x_min < 0:
51
+ x_min = 0
52
+ y_min=int(bbox[1])
53
+ if y_min < 0:
54
+ y_min = 0
55
+ x_max=int(bbox[2])
56
+ y_max=int(bbox[3])
57
+
58
+ # Compute sizes
59
+ bbox_width = x_max - x_min
60
+ bbox_height = y_max - y_min
61
+
62
+ draw_gaze(x_min,y_min,bbox_width, bbox_height,frame,(pitch,yaw),color=(0,0,255))
63
+
64
+ return frame
models/L2CS-Net/leave_one_out_eval.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+
4
+
5
+
6
+ def parse_args():
7
+ """Parse input arguments."""
8
+ parser = argparse.ArgumentParser(
9
+ description='gaze estimation using binned loss function.')
10
+ parser.add_argument(
11
+ '--evalpath', dest='evalpath', help='path for evaluating gaze test.',
12
+ default="evaluation\L2CS-gaze360-_standard-10", type=str)
13
+ parser.add_argument(
14
+ '--respath', dest='respath', help='path for saving result.',
15
+ default="evaluation\L2CS-gaze360-_standard-10", type=str)
16
+
17
+ if __name__ == '__main__':
18
+
19
+ args = parse_args()
20
+ evalpath =args.evalpath
21
+ respath=args.respath
22
+ if not os.path.exist(respath):
23
+ os.makedirs(respath)
24
+ with open(os.path.join(respath,"avg.log"), 'w') as outfile:
25
+ outfile.write("Average equal\n")
26
+
27
+ min=10.0
28
+ dirlist = os.listdir(evalpath)
29
+ dirlist.sort()
30
+ l=0.0
31
+ for j in range(50):
32
+ j=20
33
+ avg=0.0
34
+ h=j+3
35
+ for i in dirlist:
36
+ with open(evalpath+"/"+i+"/mpiigaze_binned.log") as myfile:
37
+
38
+ x=list(myfile)[h]
39
+ str1 = ""
40
+
41
+ # traverse in the string
42
+ for ele in x:
43
+ str1 += ele
44
+ split_string = str1.split("MAE:",1)[1]
45
+ avg+=float(split_string)
46
+
47
+ avg=avg/15.0
48
+ if avg<min:
49
+ min=avg
50
+ l=j+1
51
+ outfile.write("epoch"+str(j+1)+"= "+str(avg)+"\n")
52
+
53
+ outfile.write("min angular error equal= "+str(min)+"at epoch= "+str(l)+"\n")
54
+ print(min)
models/L2CS-Net/models/L2CSNet_gaze360.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a7f3480d868dd48261e1d59f915b0ef0bb33ea12ea00938fb2168f212080665
3
+ size 95849977
models/L2CS-Net/models/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Path to pre-trained models
models/L2CS-Net/pyproject.toml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "l2cs"
3
+ version = "0.0.1"
4
+ description = "The official PyTorch implementation of L2CS-Net for gaze estimation and tracking"
5
+ authors = [
6
+ {name = "Ahmed Abderlrahman"},
7
+ {name = "Thorsten Hempel"}
8
+ ]
9
+ license = {file = "LICENSE.txt"}
10
+ readme = "README.md"
11
+ requires-python = ">3.6"
12
+
13
+ keywords = ["gaze", "estimation", "eye-tracking", "deep-learning", "pytorch"]
14
+
15
+ classifiers = [
16
+ "Programming Language :: Python :: 3"
17
+ ]
18
+
19
+ dependencies = [
20
+ 'matplotlib>=3.3.4',
21
+ 'numpy>=1.19.5',
22
+ 'opencv-python>=4.5.5',
23
+ 'pandas>=1.1.5',
24
+ 'Pillow>=8.4.0',
25
+ 'scipy>=1.5.4',
26
+ 'torch>=1.10.1',
27
+ 'torchvision>=0.11.2',
28
+ 'face_detection@git+https://github.com/elliottzheng/face-detection'
29
+ ]
30
+
31
+ [project.urls]
32
+ homepath = "https://github.com/Ahmednull/L2CS-Net"
33
+ repository = "https://github.com/Ahmednull/L2CS-Net"
34
+
35
+ [build-system]
36
+ requires = ["setuptools", "wheel"]
37
+ build-backend = "setuptools.build_meta"
38
+
39
+ # https://setuptools.pypa.io/en/stable/userguide/datafiles.html
40
+ [tool.setuptools]
41
+ include-package-data = true
42
+
43
+ [tool.setuptools.packages.find]
44
+ where = ["."]
models/L2CS-Net/test.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, argparse
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.autograd import Variable
7
+ from torch.utils.data import DataLoader
8
+ from torchvision import transforms
9
+ import torch.backends.cudnn as cudnn
10
+ import torchvision
11
+
12
+ from l2cs import select_device, natural_keys, gazeto3d, angular, getArch, L2CS, Gaze360, Mpiigaze
13
+
14
+
15
+ def parse_args():
16
+ """Parse input arguments."""
17
+ parser = argparse.ArgumentParser(
18
+ description='Gaze estimation using L2CSNet .')
19
+ # Gaze360
20
+ parser.add_argument(
21
+ '--gaze360image_dir', dest='gaze360image_dir', help='Directory path for gaze images.',
22
+ default='datasets/Gaze360/Image', type=str)
23
+ parser.add_argument(
24
+ '--gaze360label_dir', dest='gaze360label_dir', help='Directory path for gaze labels.',
25
+ default='datasets/Gaze360/Label/test.label', type=str)
26
+ # mpiigaze
27
+ parser.add_argument(
28
+ '--gazeMpiimage_dir', dest='gazeMpiimage_dir', help='Directory path for gaze images.',
29
+ default='datasets/MPIIFaceGaze/Image', type=str)
30
+ parser.add_argument(
31
+ '--gazeMpiilabel_dir', dest='gazeMpiilabel_dir', help='Directory path for gaze labels.',
32
+ default='datasets/MPIIFaceGaze/Label', type=str)
33
+ # Important args -------------------------------------------------------------------------------------------------------
34
+ # ----------------------------------------------------------------------------------------------------------------------
35
+ parser.add_argument(
36
+ '--dataset', dest='dataset', help='gaze360, mpiigaze',
37
+ default= "gaze360", type=str)
38
+ parser.add_argument(
39
+ '--snapshot', dest='snapshot', help='Path to the folder contains models.',
40
+ default='output/snapshots/L2CS-gaze360-_loader-180-4-lr', type=str)
41
+ parser.add_argument(
42
+ '--evalpath', dest='evalpath', help='path for the output evaluating gaze test.',
43
+ default="evaluation/L2CS-gaze360-_loader-180-4-lr", type=str)
44
+ parser.add_argument(
45
+ '--gpu',dest='gpu_id', help='GPU device id to use [0]',
46
+ default="0", type=str)
47
+ parser.add_argument(
48
+ '--batch_size', dest='batch_size', help='Batch size.',
49
+ default=100, type=int)
50
+ parser.add_argument(
51
+ '--arch', dest='arch', help='Network architecture, can be: ResNet18, ResNet34, [ResNet50], ''ResNet101, ResNet152, Squeezenet_1_0, Squeezenet_1_1, MobileNetV2',
52
+ default='ResNet50', type=str)
53
+ # ---------------------------------------------------------------------------------------------------------------------
54
+ # Important args ------------------------------------------------------------------------------------------------------
55
+ args = parser.parse_args()
56
+ return args
57
+
58
+
59
+ def getArch(arch,bins):
60
+ # Base network structure
61
+ if arch == 'ResNet18':
62
+ model = L2CS( torchvision.models.resnet.BasicBlock,[2, 2, 2, 2], bins)
63
+ elif arch == 'ResNet34':
64
+ model = L2CS( torchvision.models.resnet.BasicBlock,[3, 4, 6, 3], bins)
65
+ elif arch == 'ResNet101':
66
+ model = L2CS( torchvision.models.resnet.Bottleneck,[3, 4, 23, 3], bins)
67
+ elif arch == 'ResNet152':
68
+ model = L2CS( torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
69
+ else:
70
+ if arch != 'ResNet50':
71
+ print('Invalid value for architecture is passed! '
72
+ 'The default value of ResNet50 will be used instead!')
73
+ model = L2CS( torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
74
+ return model
75
+
76
+ if __name__ == '__main__':
77
+ args = parse_args()
78
+ cudnn.enabled = True
79
+ gpu = select_device(args.gpu_id, batch_size=args.batch_size)
80
+ batch_size=args.batch_size
81
+ arch=args.arch
82
+ data_set=args.dataset
83
+ evalpath =args.evalpath
84
+ snapshot_path = args.snapshot
85
+ bins=args.bins
86
+ angle=args.angle
87
+ bin_width=args.bin_width
88
+
89
+ transformations = transforms.Compose([
90
+ transforms.Resize(448),
91
+ transforms.ToTensor(),
92
+ transforms.Normalize(
93
+ mean=[0.485, 0.456, 0.406],
94
+ std=[0.229, 0.224, 0.225]
95
+ )
96
+ ])
97
+
98
+
99
+
100
+ if data_set=="gaze360":
101
+
102
+ gaze_dataset=Gaze360(args.gaze360label_dir,args.gaze360image_dir, transformations, 180, 4, train=False)
103
+ test_loader = torch.utils.data.DataLoader(
104
+ dataset=gaze_dataset,
105
+ batch_size=batch_size,
106
+ shuffle=False,
107
+ num_workers=4,
108
+ pin_memory=True)
109
+
110
+
111
+
112
+ if not os.path.exists(evalpath):
113
+ os.makedirs(evalpath)
114
+
115
+
116
+ # list all epochs for testing
117
+ folder = os.listdir(snapshot_path)
118
+ folder.sort(key=natural_keys)
119
+ softmax = nn.Softmax(dim=1)
120
+ with open(os.path.join(evalpath,data_set+".log"), 'w') as outfile:
121
+ configuration = f"\ntest configuration = gpu_id={gpu}, batch_size={batch_size}, model_arch={arch}\nStart testing dataset={data_set}----------------------------------------\n"
122
+ print(configuration)
123
+ outfile.write(configuration)
124
+ epoch_list=[]
125
+ avg_yaw=[]
126
+ avg_pitch=[]
127
+ avg_MAE=[]
128
+ for epochs in folder:
129
+ # Base network structure
130
+ model=getArch(arch, 90)
131
+ saved_state_dict = torch.load(os.path.join(snapshot_path, epochs))
132
+ model.load_state_dict(saved_state_dict)
133
+ model.cuda(gpu)
134
+ model.eval()
135
+ total = 0
136
+ idx_tensor = [idx for idx in range(90)]
137
+ idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
138
+ avg_error = .0
139
+
140
+
141
+ with torch.no_grad():
142
+ for j, (images, labels, cont_labels, name) in enumerate(test_loader):
143
+ images = Variable(images).cuda(gpu)
144
+ total += cont_labels.size(0)
145
+
146
+ label_pitch = cont_labels[:,0].float()*np.pi/180
147
+ label_yaw = cont_labels[:,1].float()*np.pi/180
148
+
149
+
150
+ gaze_pitch, gaze_yaw = model(images)
151
+
152
+ # Binned predictions
153
+ _, pitch_bpred = torch.max(gaze_pitch.data, 1)
154
+ _, yaw_bpred = torch.max(gaze_yaw.data, 1)
155
+
156
+
157
+ # Continuous predictions
158
+ pitch_predicted = softmax(gaze_pitch)
159
+ yaw_predicted = softmax(gaze_yaw)
160
+
161
+ # mapping from binned (0 to 28) to angels (-180 to 180)
162
+ pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu() * 4 - 180
163
+ yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu() * 4 - 180
164
+
165
+ pitch_predicted = pitch_predicted*np.pi/180
166
+ yaw_predicted = yaw_predicted*np.pi/180
167
+
168
+ for p,y,pl,yl in zip(pitch_predicted,yaw_predicted,label_pitch,label_yaw):
169
+ avg_error += angular(gazeto3d([p,y]), gazeto3d([pl,yl]))
170
+
171
+
172
+
173
+ x = ''.join(filter(lambda i: i.isdigit(), epochs))
174
+ epoch_list.append(x)
175
+ avg_MAE.append(avg_error/total)
176
+ loger = f"[{epochs}---{args.dataset}] Total Num:{total},MAE:{avg_error/total}\n"
177
+ outfile.write(loger)
178
+ print(loger)
179
+
180
+ fig = plt.figure(figsize=(14, 8))
181
+ plt.xlabel('epoch')
182
+ plt.ylabel('avg')
183
+ plt.title('Gaze angular error')
184
+ plt.legend()
185
+ plt.plot(epoch_list, avg_MAE, color='k', label='mae')
186
+ fig.savefig(os.path.join(evalpath,data_set+".png"), format='png')
187
+ plt.show()
188
+
189
+
190
+
191
+ elif data_set=="mpiigaze":
192
+ model_used=getArch(arch, bins)
193
+
194
+ for fold in range(15):
195
+ folder = os.listdir(args.gazeMpiilabel_dir)
196
+ folder.sort()
197
+ testlabelpathombined = [os.path.join(args.gazeMpiilabel_dir, j) for j in folder]
198
+ gaze_dataset=Mpiigaze(testlabelpathombined,args.gazeMpiimage_dir, transformations, False, angle, fold)
199
+
200
+ test_loader = torch.utils.data.DataLoader(
201
+ dataset=gaze_dataset,
202
+ batch_size=batch_size,
203
+ shuffle=True,
204
+ num_workers=4,
205
+ pin_memory=True)
206
+
207
+
208
+ if not os.path.exists(os.path.join(evalpath, f"fold"+str(fold))):
209
+ os.makedirs(os.path.join(evalpath, f"fold"+str(fold)))
210
+
211
+ # list all epochs for testing
212
+ folder = os.listdir(os.path.join(snapshot_path,"fold"+str(fold)))
213
+ folder.sort(key=natural_keys)
214
+
215
+ softmax = nn.Softmax(dim=1)
216
+ with open(os.path.join(evalpath, os.path.join("fold"+str(fold), data_set+".log")), 'w') as outfile:
217
+ configuration = f"\ntest configuration equal gpu_id={gpu}, batch_size={batch_size}, model_arch={arch}\nStart testing dataset={data_set}, fold={fold}---------------------------------------\n"
218
+ print(configuration)
219
+ outfile.write(configuration)
220
+ epoch_list=[]
221
+ avg_MAE=[]
222
+ for epochs in folder:
223
+ model=model_used
224
+ saved_state_dict = torch.load(os.path.join(snapshot_path+"/fold"+str(fold),epochs))
225
+ model= nn.DataParallel(model,device_ids=[0])
226
+ model.load_state_dict(saved_state_dict)
227
+ model.cuda(gpu)
228
+ model.eval()
229
+ total = 0
230
+ idx_tensor = [idx for idx in range(28)]
231
+ idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
232
+ avg_error = .0
233
+ with torch.no_grad():
234
+ for j, (images, labels, cont_labels, name) in enumerate(test_loader):
235
+ images = Variable(images).cuda(gpu)
236
+ total += cont_labels.size(0)
237
+
238
+ label_pitch = cont_labels[:,0].float()*np.pi/180
239
+ label_yaw = cont_labels[:,1].float()*np.pi/180
240
+
241
+
242
+ gaze_pitch, gaze_yaw = model(images)
243
+
244
+ # Binned predictions
245
+ _, pitch_bpred = torch.max(gaze_pitch.data, 1)
246
+ _, yaw_bpred = torch.max(gaze_yaw.data, 1)
247
+
248
+
249
+ # Continuous predictions
250
+ pitch_predicted = softmax(gaze_pitch)
251
+ yaw_predicted = softmax(gaze_yaw)
252
+
253
+ # mapping from binned (0 to 28) to angels (-42 to 42)
254
+ pitch_predicted = \
255
+ torch.sum(pitch_predicted * idx_tensor, 1).cpu() * 3 - 42
256
+ yaw_predicted = \
257
+ torch.sum(yaw_predicted * idx_tensor, 1).cpu() * 3 - 42
258
+
259
+
260
+ pitch_predicted = pitch_predicted*np.pi/180
261
+ yaw_predicted = yaw_predicted*np.pi/180
262
+
263
+ for p,y,pl,yl in zip(pitch_predicted, yaw_predicted, label_pitch, label_yaw):
264
+ avg_error += angular(gazeto3d([p,y]), gazeto3d([pl,yl]))
265
+
266
+
267
+ x = ''.join(filter(lambda i: i.isdigit(), epochs))
268
+ epoch_list.append(x)
269
+ avg_MAE.append(avg_error/ total)
270
+ loger = f"[{epochs}---{args.dataset}] Total Num:{total},MAE:{avg_error/total} \n"
271
+ outfile.write(loger)
272
+ print(loger)
273
+
274
+ fig = plt.figure(figsize=(14, 8))
275
+ plt.xlabel('epoch')
276
+ plt.ylabel('avg')
277
+ plt.title('Gaze angular error')
278
+ plt.legend()
279
+ plt.plot(epoch_list, avg_MAE, color='k', label='mae')
280
+ fig.savefig(os.path.join(evalpath, os.path.join("fold"+str(fold), data_set+".png")), format='png')
281
+ # plt.show()
282
+
283
+
284
+
models/L2CS-Net/train.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import time
4
+
5
+ import torch.utils.model_zoo as model_zoo
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.autograd import Variable
9
+ from torch.utils.data import DataLoader
10
+ from torchvision import transforms
11
+ import torch.backends.cudnn as cudnn
12
+ import torchvision
13
+
14
+ from l2cs import L2CS, select_device, Gaze360, Mpiigaze
15
+
16
+
17
+ def parse_args():
18
+ """Parse input arguments."""
19
+ parser = argparse.ArgumentParser(description='Gaze estimation using L2CSNet.')
20
+ # Gaze360
21
+ parser.add_argument(
22
+ '--gaze360image_dir', dest='gaze360image_dir', help='Directory path for gaze images.',
23
+ default='datasets/Gaze360/Image', type=str)
24
+ parser.add_argument(
25
+ '--gaze360label_dir', dest='gaze360label_dir', help='Directory path for gaze labels.',
26
+ default='datasets/Gaze360/Label/train.label', type=str)
27
+ # mpiigaze
28
+ parser.add_argument(
29
+ '--gazeMpiimage_dir', dest='gazeMpiimage_dir', help='Directory path for gaze images.',
30
+ default='datasets/MPIIFaceGaze/Image', type=str)
31
+ parser.add_argument(
32
+ '--gazeMpiilabel_dir', dest='gazeMpiilabel_dir', help='Directory path for gaze labels.',
33
+ default='datasets/MPIIFaceGaze/Label', type=str)
34
+
35
+ # Important args -------------------------------------------------------------------------------------------------------
36
+ # ----------------------------------------------------------------------------------------------------------------------
37
+ parser.add_argument(
38
+ '--dataset', dest='dataset', help='mpiigaze, rtgene, gaze360, ethgaze',
39
+ default= "gaze360", type=str)
40
+ parser.add_argument(
41
+ '--output', dest='output', help='Path of output models.',
42
+ default='output/snapshots/', type=str)
43
+ parser.add_argument(
44
+ '--snapshot', dest='snapshot', help='Path of model snapshot.',
45
+ default='', type=str)
46
+ parser.add_argument(
47
+ '--gpu', dest='gpu_id', help='GPU device id to use [0] or multiple 0,1,2,3',
48
+ default='0', type=str)
49
+ parser.add_argument(
50
+ '--num_epochs', dest='num_epochs', help='Maximum number of training epochs.',
51
+ default=60, type=int)
52
+ parser.add_argument(
53
+ '--batch_size', dest='batch_size', help='Batch size.',
54
+ default=1, type=int)
55
+ parser.add_argument(
56
+ '--arch', dest='arch', help='Network architecture, can be: ResNet18, ResNet34, [ResNet50], ''ResNet101, ResNet152, Squeezenet_1_0, Squeezenet_1_1, MobileNetV2',
57
+ default='ResNet50', type=str)
58
+ parser.add_argument(
59
+ '--alpha', dest='alpha', help='Regression loss coefficient.',
60
+ default=1, type=float)
61
+ parser.add_argument(
62
+ '--lr', dest='lr', help='Base learning rate.',
63
+ default=0.00001, type=float)
64
+ # ---------------------------------------------------------------------------------------------------------------------
65
+ # Important args ------------------------------------------------------------------------------------------------------
66
+ args = parser.parse_args()
67
+ return args
68
+
69
+ def get_ignored_params(model):
70
+ # Generator function that yields ignored params.
71
+ b = [model.conv1, model.bn1, model.fc_finetune]
72
+ for i in range(len(b)):
73
+ for module_name, module in b[i].named_modules():
74
+ if 'bn' in module_name:
75
+ module.eval()
76
+ for name, param in module.named_parameters():
77
+ yield param
78
+
79
+ def get_non_ignored_params(model):
80
+ # Generator function that yields params that will be optimized.
81
+ b = [model.layer1, model.layer2, model.layer3, model.layer4]
82
+ for i in range(len(b)):
83
+ for module_name, module in b[i].named_modules():
84
+ if 'bn' in module_name:
85
+ module.eval()
86
+ for name, param in module.named_parameters():
87
+ yield param
88
+
89
+ def get_fc_params(model):
90
+ # Generator function that yields fc layer params.
91
+ b = [model.fc_yaw_gaze, model.fc_pitch_gaze]
92
+ for i in range(len(b)):
93
+ for module_name, module in b[i].named_modules():
94
+ for name, param in module.named_parameters():
95
+ yield param
96
+
97
+ def load_filtered_state_dict(model, snapshot):
98
+ # By user apaszke from discuss.pytorch.org
99
+ model_dict = model.state_dict()
100
+ snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
101
+ model_dict.update(snapshot)
102
+ model.load_state_dict(model_dict)
103
+
104
+
105
+ def getArch_weights(arch, bins):
106
+ if arch == 'ResNet18':
107
+ model = L2CS(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], bins)
108
+ pre_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
109
+ elif arch == 'ResNet34':
110
+ model = L2CS(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], bins)
111
+ pre_url = 'https://download.pytorch.org/models/resnet34-333f7ec4.pth'
112
+ elif arch == 'ResNet101':
113
+ model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], bins)
114
+ pre_url = 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'
115
+ elif arch == 'ResNet152':
116
+ model = L2CS(torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
117
+ pre_url = 'https://download.pytorch.org/models/resnet152-b121ed2d.pth'
118
+ else:
119
+ if arch != 'ResNet50':
120
+ print('Invalid value for architecture is passed! '
121
+ 'The default value of ResNet50 will be used instead!')
122
+ model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
123
+ pre_url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
124
+
125
+ return model, pre_url
126
+
127
+ if __name__ == '__main__':
128
+ args = parse_args()
129
+ cudnn.enabled = True
130
+ num_epochs = args.num_epochs
131
+ batch_size = args.batch_size
132
+ gpu = select_device(args.gpu_id, batch_size=args.batch_size)
133
+ data_set=args.dataset
134
+ alpha = args.alpha
135
+ output=args.output
136
+
137
+
138
+ transformations = transforms.Compose([
139
+ transforms.Resize(448),
140
+ transforms.ToTensor(),
141
+ transforms.Normalize(
142
+ mean=[0.485, 0.456, 0.406],
143
+ std=[0.229, 0.224, 0.225]
144
+ )
145
+ ])
146
+
147
+
148
+
149
+ if data_set=="gaze360":
150
+ model, pre_url = getArch_weights(args.arch, 90)
151
+ if args.snapshot == '':
152
+ load_filtered_state_dict(model, model_zoo.load_url(pre_url))
153
+ else:
154
+ saved_state_dict = torch.load(args.snapshot)
155
+ model.load_state_dict(saved_state_dict)
156
+
157
+
158
+ model.cuda(gpu)
159
+ dataset=Gaze360(args.gaze360label_dir, args.gaze360image_dir, transformations, 180, 4)
160
+ print('Loading data.')
161
+ train_loader_gaze = DataLoader(
162
+ dataset=dataset,
163
+ batch_size=int(batch_size),
164
+ shuffle=True,
165
+ num_workers=0,
166
+ pin_memory=True)
167
+ torch.backends.cudnn.benchmark = True
168
+
169
+ summary_name = '{}_{}'.format('L2CS-gaze360-', int(time.time()))
170
+ output=os.path.join(output, summary_name)
171
+ if not os.path.exists(output):
172
+ os.makedirs(output)
173
+
174
+
175
+ criterion = nn.CrossEntropyLoss().cuda(gpu)
176
+ reg_criterion = nn.MSELoss().cuda(gpu)
177
+ softmax = nn.Softmax(dim=1).cuda(gpu)
178
+ idx_tensor = [idx for idx in range(90)]
179
+ idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
180
+
181
+
182
+ # Optimizer gaze
183
+ optimizer_gaze = torch.optim.Adam([
184
+ {'params': get_ignored_params(model), 'lr': 0},
185
+ {'params': get_non_ignored_params(model), 'lr': args.lr},
186
+ {'params': get_fc_params(model), 'lr': args.lr}
187
+ ], args.lr)
188
+
189
+
190
+ configuration = f"\ntrain configuration, gpu_id={args.gpu_id}, batch_size={batch_size}, model_arch={args.arch}\nStart testing dataset={data_set}, loader={len(train_loader_gaze)}------------------------- \n"
191
+ print(configuration)
192
+ for epoch in range(num_epochs):
193
+ sum_loss_pitch_gaze = sum_loss_yaw_gaze = iter_gaze = 0
194
+
195
+
196
+ for i, (images_gaze, labels_gaze, cont_labels_gaze,name) in enumerate(train_loader_gaze):
197
+ images_gaze = Variable(images_gaze).cuda(gpu)
198
+
199
+ # Binned labels
200
+ label_pitch_gaze = Variable(labels_gaze[:, 0]).cuda(gpu)
201
+ label_yaw_gaze = Variable(labels_gaze[:, 1]).cuda(gpu)
202
+
203
+ # Continuous labels
204
+ label_pitch_cont_gaze = Variable(cont_labels_gaze[:, 0]).cuda(gpu)
205
+ label_yaw_cont_gaze = Variable(cont_labels_gaze[:, 1]).cuda(gpu)
206
+
207
+ pitch, yaw = model(images_gaze)
208
+
209
+ # Cross entropy loss
210
+ loss_pitch_gaze = criterion(pitch, label_pitch_gaze)
211
+ loss_yaw_gaze = criterion(yaw, label_yaw_gaze)
212
+
213
+ # MSE loss
214
+ pitch_predicted = softmax(pitch)
215
+ yaw_predicted = softmax(yaw)
216
+
217
+ pitch_predicted = \
218
+ torch.sum(pitch_predicted * idx_tensor, 1) * 4 - 180
219
+ yaw_predicted = \
220
+ torch.sum(yaw_predicted * idx_tensor, 1) * 4 - 180
221
+
222
+ loss_reg_pitch = reg_criterion(
223
+ pitch_predicted, label_pitch_cont_gaze)
224
+ loss_reg_yaw = reg_criterion(
225
+ yaw_predicted, label_yaw_cont_gaze)
226
+
227
+ # Total loss
228
+ loss_pitch_gaze += alpha * loss_reg_pitch
229
+ loss_yaw_gaze += alpha * loss_reg_yaw
230
+
231
+ sum_loss_pitch_gaze += loss_pitch_gaze
232
+ sum_loss_yaw_gaze += loss_yaw_gaze
233
+
234
+ loss_seq = [loss_pitch_gaze, loss_yaw_gaze]
235
+ grad_seq = [torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))]
236
+ optimizer_gaze.zero_grad(set_to_none=True)
237
+ torch.autograd.backward(loss_seq, grad_seq)
238
+ optimizer_gaze.step()
239
+ # scheduler.step()
240
+
241
+ iter_gaze += 1
242
+
243
+ if (i+1) % 100 == 0:
244
+ print('Epoch [%d/%d], Iter [%d/%d] Losses: '
245
+ 'Gaze Yaw %.4f,Gaze Pitch %.4f' % (
246
+ epoch+1,
247
+ num_epochs,
248
+ i+1,
249
+ len(dataset)//batch_size,
250
+ sum_loss_pitch_gaze/iter_gaze,
251
+ sum_loss_yaw_gaze/iter_gaze
252
+ )
253
+ )
254
+
255
+
256
+ if epoch % 1 == 0 and epoch < num_epochs:
257
+ print('Taking snapshot...',
258
+ torch.save(model.state_dict(),
259
+ output +'/'+
260
+ '_epoch_' + str(epoch+1) + '.pkl')
261
+ )
262
+
263
+
264
+
265
+ elif data_set=="mpiigaze":
266
+ folder = os.listdir(args.gazeMpiilabel_dir)
267
+ folder.sort()
268
+ testlabelpathombined = [os.path.join(args.gazeMpiilabel_dir, j) for j in folder]
269
+ for fold in range(15):
270
+ model, pre_url = getArch_weights(args.arch, 28)
271
+ load_filtered_state_dict(model, model_zoo.load_url(pre_url))
272
+ model = nn.DataParallel(model)
273
+ model.to(gpu)
274
+ print('Loading data.')
275
+ dataset=Mpiigaze(testlabelpathombined,args.gazeMpiimage_dir, transformations, True, fold)
276
+ train_loader_gaze = DataLoader(
277
+ dataset=dataset,
278
+ batch_size=int(batch_size),
279
+ shuffle=True,
280
+ num_workers=4,
281
+ pin_memory=True)
282
+ torch.backends.cudnn.benchmark = True
283
+
284
+ summary_name = '{}_{}'.format('L2CS-mpiigaze', int(time.time()))
285
+
286
+
287
+ if not os.path.exists(os.path.join(output+'/{}'.format(summary_name),'fold' + str(fold))):
288
+ os.makedirs(os.path.join(output+'/{}'.format(summary_name),'fold' + str(fold)))
289
+
290
+
291
+ criterion = nn.CrossEntropyLoss().cuda(gpu)
292
+ reg_criterion = nn.MSELoss().cuda(gpu)
293
+ softmax = nn.Softmax(dim=1).cuda(gpu)
294
+ idx_tensor = [idx for idx in range(28)]
295
+ idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
296
+
297
+ # Optimizer gaze
298
+ optimizer_gaze = torch.optim.Adam([
299
+ {'params': get_ignored_params(model, args.arch), 'lr': 0},
300
+ {'params': get_non_ignored_params(model, args.arch), 'lr': args.lr},
301
+ {'params': get_fc_params(model, args.arch), 'lr': args.lr}
302
+ ], args.lr)
303
+
304
+
305
+
306
+ configuration = f"\ntrain configuration, gpu_id={args.gpu_id}, batch_size={batch_size}, model_arch={args.arch}\n Start training dataset={data_set}, loader={len(train_loader_gaze)}, fold={fold}--------------\n"
307
+ print(configuration)
308
+ for epoch in range(num_epochs):
309
+ sum_loss_pitch_gaze = sum_loss_yaw_gaze = iter_gaze = 0
310
+
311
+
312
+ for i, (images_gaze, labels_gaze, cont_labels_gaze,name) in enumerate(train_loader_gaze):
313
+ images_gaze = Variable(images_gaze).cuda(gpu)
314
+
315
+ # Binned labels
316
+ label_pitch_gaze = Variable(labels_gaze[:, 0]).cuda(gpu)
317
+ label_yaw_gaze = Variable(labels_gaze[:, 1]).cuda(gpu)
318
+
319
+ # Continuous labels
320
+ label_pitch_cont_gaze = Variable(cont_labels_gaze[:, 0]).cuda(gpu)
321
+ label_yaw_cont_gaze = Variable(cont_labels_gaze[:, 1]).cuda(gpu)
322
+
323
+ pitch, yaw = model(images_gaze)
324
+
325
+ # Cross entropy loss
326
+ loss_pitch_gaze = criterion(pitch, label_pitch_gaze)
327
+ loss_yaw_gaze = criterion(yaw, label_yaw_gaze)
328
+
329
+ # MSE loss
330
+ pitch_predicted = softmax(pitch)
331
+ yaw_predicted = softmax(yaw)
332
+
333
+ pitch_predicted = \
334
+ torch.sum(pitch_predicted * idx_tensor, 1) * 3 - 42
335
+ yaw_predicted = \
336
+ torch.sum(yaw_predicted * idx_tensor, 1) * 3 - 42
337
+
338
+ loss_reg_pitch = reg_criterion(
339
+ pitch_predicted, label_pitch_cont_gaze)
340
+ loss_reg_yaw = reg_criterion(
341
+ yaw_predicted, label_yaw_cont_gaze)
342
+
343
+ # Total loss
344
+ loss_pitch_gaze += alpha * loss_reg_pitch
345
+ loss_yaw_gaze += alpha * loss_reg_yaw
346
+
347
+ sum_loss_pitch_gaze += loss_pitch_gaze
348
+ sum_loss_yaw_gaze += loss_yaw_gaze
349
+
350
+ loss_seq = [loss_pitch_gaze, loss_yaw_gaze]
351
+ grad_seq = \
352
+ [torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))]
353
+
354
+ optimizer_gaze.zero_grad(set_to_none=True)
355
+ torch.autograd.backward(loss_seq, grad_seq)
356
+ optimizer_gaze.step()
357
+
358
+ iter_gaze += 1
359
+
360
+ if (i+1) % 100 == 0:
361
+ print('Epoch [%d/%d], Iter [%d/%d] Losses: '
362
+ 'Gaze Yaw %.4f,Gaze Pitch %.4f' % (
363
+ epoch+1,
364
+ num_epochs,
365
+ i+1,
366
+ len(dataset)//batch_size,
367
+ sum_loss_pitch_gaze/iter_gaze,
368
+ sum_loss_yaw_gaze/iter_gaze
369
+ )
370
+ )
371
+
372
+
373
+
374
+ # Save models at numbered epochs.
375
+ if epoch % 1 == 0 and epoch < num_epochs:
376
+ print('Taking snapshot...',
377
+ torch.save(model.state_dict(),
378
+ output+'/fold' + str(fold) +'/'+
379
+ '_epoch_' + str(epoch+1) + '.pkl')
380
+ )
381
+
382
+
383
+
384
+
models/gaze_calibration.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 9-point gaze calibration for L2CS-Net
2
+ # Maps raw gaze angles -> normalised screen coords via polynomial least-squares.
3
+ # Centre point is the bias reference (subtracted from all readings).
4
+
5
+ import numpy as np
6
+ from dataclasses import dataclass, field
7
+
8
+ # 3x3 grid, centre first (bias ref), then row by row
9
+ DEFAULT_TARGETS = [
10
+ (0.5, 0.5),
11
+ (0.15, 0.15), (0.50, 0.15), (0.85, 0.15),
12
+ (0.15, 0.50), (0.85, 0.50),
13
+ (0.15, 0.85), (0.50, 0.85), (0.85, 0.85),
14
+ ]
15
+
16
+
17
+ @dataclass
18
+ class _PointSamples:
19
+ target_x: float
20
+ target_y: float
21
+ yaws: list = field(default_factory=list)
22
+ pitches: list = field(default_factory=list)
23
+
24
+
25
+ def _iqr_filter(values):
26
+ if len(values) < 4:
27
+ return values
28
+ arr = np.array(values)
29
+ q1, q3 = np.percentile(arr, [25, 75])
30
+ iqr = q3 - q1
31
+ lo, hi = q1 - 1.5 * iqr, q3 + 1.5 * iqr
32
+ return arr[(arr >= lo) & (arr <= hi)].tolist()
33
+
34
+
35
+ class GazeCalibration:
36
+
37
+ def __init__(self, targets=None):
38
+ self._targets = targets or list(DEFAULT_TARGETS)
39
+ self._points = [_PointSamples(tx, ty) for tx, ty in self._targets]
40
+ self._current_idx = 0
41
+ self._fitted = False
42
+ self._W = None # (6, 2) polynomial weights
43
+ self._yaw_bias = 0.0
44
+ self._pitch_bias = 0.0
45
+
46
+ @property
47
+ def num_points(self):
48
+ return len(self._targets)
49
+
50
+ @property
51
+ def current_index(self):
52
+ return self._current_idx
53
+
54
+ @property
55
+ def current_target(self):
56
+ if self._current_idx < len(self._targets):
57
+ return self._targets[self._current_idx]
58
+ return self._targets[-1]
59
+
60
+ @property
61
+ def is_complete(self):
62
+ return self._current_idx >= len(self._targets)
63
+
64
+ @property
65
+ def is_fitted(self):
66
+ return self._fitted
67
+
68
+ def collect_sample(self, yaw_rad, pitch_rad):
69
+ if self._current_idx >= len(self._points):
70
+ return
71
+ pt = self._points[self._current_idx]
72
+ pt.yaws.append(float(yaw_rad))
73
+ pt.pitches.append(float(pitch_rad))
74
+
75
+ def advance(self):
76
+ self._current_idx += 1
77
+ return self._current_idx < len(self._targets)
78
+
79
+ @staticmethod
80
+ def _poly_features(yaw, pitch):
81
+ # [yaw^2, pitch^2, yaw*pitch, yaw, pitch, 1]
82
+ return np.array([yaw**2, pitch**2, yaw * pitch, yaw, pitch, 1.0],
83
+ dtype=np.float64)
84
+
85
+ def fit(self):
86
+ # bias from centre point (index 0)
87
+ center = self._points[0]
88
+ center_yaws = _iqr_filter(center.yaws)
89
+ center_pitches = _iqr_filter(center.pitches)
90
+ if len(center_yaws) < 2 or len(center_pitches) < 2:
91
+ return False
92
+ self._yaw_bias = float(np.median(center_yaws))
93
+ self._pitch_bias = float(np.median(center_pitches))
94
+
95
+ rows_A, rows_B = [], []
96
+ for pt in self._points:
97
+ clean_yaws = _iqr_filter(pt.yaws)
98
+ clean_pitches = _iqr_filter(pt.pitches)
99
+ if len(clean_yaws) < 2 or len(clean_pitches) < 2:
100
+ continue
101
+ med_yaw = float(np.median(clean_yaws)) - self._yaw_bias
102
+ med_pitch = float(np.median(clean_pitches)) - self._pitch_bias
103
+ rows_A.append(self._poly_features(med_yaw, med_pitch))
104
+ rows_B.append([pt.target_x, pt.target_y])
105
+
106
+ if len(rows_A) < 5:
107
+ return False
108
+
109
+ A = np.array(rows_A, dtype=np.float64)
110
+ B = np.array(rows_B, dtype=np.float64)
111
+ try:
112
+ W, _, _, _ = np.linalg.lstsq(A, B, rcond=None)
113
+ self._W = W
114
+ self._fitted = True
115
+ return True
116
+ except np.linalg.LinAlgError:
117
+ return False
118
+
119
+ def predict(self, yaw_rad, pitch_rad):
120
+ if not self._fitted or self._W is None:
121
+ return 0.5, 0.5
122
+ feat = self._poly_features(yaw_rad - self._yaw_bias, pitch_rad - self._pitch_bias)
123
+ xy = feat @ self._W
124
+ return float(np.clip(xy[0], 0, 1)), float(np.clip(xy[1], 0, 1))
125
+
126
+ def to_dict(self):
127
+ return {
128
+ "targets": self._targets,
129
+ "fitted": self._fitted,
130
+ "current_index": self._current_idx,
131
+ "W": self._W.tolist() if self._W is not None else None,
132
+ "yaw_bias": self._yaw_bias,
133
+ "pitch_bias": self._pitch_bias,
134
+ }
135
+
136
+ @classmethod
137
+ def from_dict(cls, d):
138
+ cal = cls(targets=d.get("targets", DEFAULT_TARGETS))
139
+ cal._fitted = d.get("fitted", False)
140
+ cal._current_idx = d.get("current_index", 0)
141
+ cal._yaw_bias = d.get("yaw_bias", 0.0)
142
+ cal._pitch_bias = d.get("pitch_bias", 0.0)
143
+ w = d.get("W")
144
+ if w is not None:
145
+ cal._W = np.array(w, dtype=np.float64)
146
+ return cal
models/gaze_eye_fusion.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fuses calibrated gaze position with eye openness (EAR) for focus detection.
2
+ # Takes L2CS gaze angles + MediaPipe landmarks, outputs screen coords + focus decision.
3
+
4
+ import math
5
+ import numpy as np
6
+
7
+ from .gaze_calibration import GazeCalibration
8
+ from .eye_scorer import compute_avg_ear
9
+
10
+ _EAR_BLINK = 0.18
11
+ _ON_SCREEN_MARGIN = 0.08
12
+
13
+
14
+ class GazeEyeFusion:
15
+
16
+ def __init__(self, calibration, ear_weight=0.3, gaze_weight=0.7, focus_threshold=0.52):
17
+ if not calibration.is_fitted:
18
+ raise ValueError("Calibration must be fitted first")
19
+ self._cal = calibration
20
+ self._ear_w = ear_weight
21
+ self._gaze_w = gaze_weight
22
+ self._threshold = focus_threshold
23
+ self._smooth_x = 0.5
24
+ self._smooth_y = 0.5
25
+ self._alpha = 0.5
26
+
27
+ def update(self, yaw_rad, pitch_rad, landmarks):
28
+ gx, gy = self._cal.predict(yaw_rad, pitch_rad)
29
+
30
+ # EMA smooth the gaze position
31
+ self._smooth_x += self._alpha * (gx - self._smooth_x)
32
+ self._smooth_y += self._alpha * (gy - self._smooth_y)
33
+ gx, gy = self._smooth_x, self._smooth_y
34
+
35
+ on_screen = (
36
+ -_ON_SCREEN_MARGIN <= gx <= 1.0 + _ON_SCREEN_MARGIN and
37
+ -_ON_SCREEN_MARGIN <= gy <= 1.0 + _ON_SCREEN_MARGIN
38
+ )
39
+
40
+ ear = None
41
+ ear_score = 1.0
42
+ if landmarks is not None:
43
+ ear = compute_avg_ear(landmarks)
44
+ ear_score = 0.0 if ear < _EAR_BLINK else min(ear / 0.30, 1.0)
45
+
46
+ # penalise gaze near screen edges
47
+ gaze_score = 1.0 if on_screen else 0.0
48
+ if on_screen:
49
+ dx = max(0.0, abs(gx - 0.5) - 0.3)
50
+ dy = max(0.0, abs(gy - 0.5) - 0.3)
51
+ gaze_score = max(0.0, 1.0 - math.sqrt(dx**2 + dy**2) * 5.0)
52
+
53
+ score = float(np.clip(self._gaze_w * gaze_score + self._ear_w * ear_score, 0, 1))
54
+
55
+ return {
56
+ "gaze_x": round(float(gx), 4),
57
+ "gaze_y": round(float(gy), 4),
58
+ "on_screen": on_screen,
59
+ "ear": round(ear, 4) if ear is not None else None,
60
+ "focus_score": round(score, 4),
61
+ "focused": score >= self._threshold,
62
+ }
63
+
64
+ def reset(self):
65
+ self._smooth_x = 0.5
66
+ self._smooth_y = 0.5
requirements.txt CHANGED
@@ -20,3 +20,5 @@ xgboost>=2.0.0
20
  clearml>=2.0.2
21
  pytest>=9.0.0
22
  pytest-cov>=5.0.0
 
 
 
20
  clearml>=2.0.2
21
  pytest>=9.0.0
22
  pytest-cov>=5.0.0
23
+ face_detection @ git+https://github.com/elliottzheng/face-detection
24
+ gdown>=5.0.0
src/components/CalibrationOverlay.jsx ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState, useEffect, useRef, useCallback } from 'react';
2
+
3
+ const COLLECT_MS = 2000;
4
+ const CENTER_MS = 3000; // centre point gets extra time (bias reference)
5
+
6
+ function CalibrationOverlay({ calibration, videoManager }) {
7
+ const [progress, setProgress] = useState(0);
8
+ const timerRef = useRef(null);
9
+ const startRef = useRef(null);
10
+ const overlayRef = useRef(null);
11
+
12
+ const enterFullscreen = useCallback(() => {
13
+ const el = overlayRef.current;
14
+ if (!el) return;
15
+ const req = el.requestFullscreen || el.webkitRequestFullscreen || el.msRequestFullscreen;
16
+ if (req) req.call(el).catch(() => {});
17
+ }, []);
18
+
19
+ const exitFullscreen = useCallback(() => {
20
+ if (document.fullscreenElement || document.webkitFullscreenElement) {
21
+ const exit = document.exitFullscreen || document.webkitExitFullscreen || document.msExitFullscreen;
22
+ if (exit) exit.call(document).catch(() => {});
23
+ }
24
+ }, []);
25
+
26
+ useEffect(() => {
27
+ if (calibration && calibration.active && !calibration.done) {
28
+ const t = setTimeout(enterFullscreen, 100);
29
+ return () => clearTimeout(t);
30
+ }
31
+ }, [calibration?.active]);
32
+
33
+ useEffect(() => {
34
+ if (!calibration || !calibration.active) exitFullscreen();
35
+ }, [calibration?.active]);
36
+
37
+ useEffect(() => {
38
+ if (!calibration || !calibration.collecting || calibration.done) {
39
+ setProgress(0);
40
+ if (timerRef.current) cancelAnimationFrame(timerRef.current);
41
+ return;
42
+ }
43
+
44
+ startRef.current = performance.now();
45
+ const duration = calibration.index === 0 ? CENTER_MS : COLLECT_MS;
46
+
47
+ const tick = () => {
48
+ const pct = Math.min((performance.now() - startRef.current) / duration, 1);
49
+ setProgress(pct);
50
+ if (pct >= 1) {
51
+ if (videoManager) videoManager.nextCalibrationPoint();
52
+ startRef.current = performance.now();
53
+ setProgress(0);
54
+ }
55
+ timerRef.current = requestAnimationFrame(tick);
56
+ };
57
+ timerRef.current = requestAnimationFrame(tick);
58
+
59
+ return () => { if (timerRef.current) cancelAnimationFrame(timerRef.current); };
60
+ }, [calibration?.index, calibration?.collecting, calibration?.done]);
61
+
62
+ const handleCancel = () => {
63
+ if (videoManager) videoManager.cancelCalibration();
64
+ exitFullscreen();
65
+ };
66
+
67
+ if (!calibration || !calibration.active) return null;
68
+
69
+ if (calibration.done) {
70
+ return (
71
+ <div ref={overlayRef} style={overlayStyle}>
72
+ <div style={messageBoxStyle}>
73
+ <h2 style={{ margin: '0 0 10px', color: calibration.success ? '#4ade80' : '#f87171' }}>
74
+ {calibration.success ? 'Calibration Complete' : 'Calibration Failed'}
75
+ </h2>
76
+ <p style={{ color: '#ccc', margin: 0 }}>
77
+ {calibration.success
78
+ ? 'Gaze tracking is now active.'
79
+ : 'Not enough samples collected. Try again.'}
80
+ </p>
81
+ </div>
82
+ </div>
83
+ );
84
+ }
85
+
86
+ const [tx, ty] = calibration.target || [0.5, 0.5];
87
+
88
+ return (
89
+ <div ref={overlayRef} style={overlayStyle}>
90
+ <div style={{
91
+ position: 'absolute', top: '30px', left: '50%', transform: 'translateX(-50%)',
92
+ color: '#fff', fontSize: '16px', textAlign: 'center',
93
+ textShadow: '0 0 8px rgba(0,0,0,0.8)', pointerEvents: 'none',
94
+ }}>
95
+ <div style={{ fontWeight: 'bold', fontSize: '20px' }}>
96
+ Look at the dot ({calibration.index + 1}/{calibration.numPoints})
97
+ </div>
98
+ <div style={{ fontSize: '14px', color: '#aaa', marginTop: '6px' }}>
99
+ {calibration.index === 0
100
+ ? 'Look at the center dot - this sets your baseline'
101
+ : 'Hold your gaze steady on the target'}
102
+ </div>
103
+ </div>
104
+
105
+ <div style={{
106
+ position: 'absolute', left: `${tx * 100}%`, top: `${ty * 100}%`,
107
+ transform: 'translate(-50%, -50%)',
108
+ }}>
109
+ <svg width="60" height="60" style={{ position: 'absolute', left: '-30px', top: '-30px' }}>
110
+ <circle cx="30" cy="30" r="24" fill="none" stroke="rgba(255,255,255,0.15)" strokeWidth="3" />
111
+ <circle cx="30" cy="30" r="24" fill="none" stroke="#4ade80" strokeWidth="3"
112
+ strokeDasharray={`${progress * 150.8} 150.8`} strokeLinecap="round"
113
+ transform="rotate(-90, 30, 30)" />
114
+ </svg>
115
+ <div style={{
116
+ width: '20px', height: '20px', borderRadius: '50%',
117
+ background: 'radial-gradient(circle, #fff 30%, #4ade80 100%)',
118
+ boxShadow: '0 0 20px rgba(74, 222, 128, 0.8)',
119
+ }} />
120
+ </div>
121
+
122
+ <button onClick={handleCancel} style={{
123
+ position: 'absolute', bottom: '40px', left: '50%', transform: 'translateX(-50%)',
124
+ padding: '10px 28px', background: 'rgba(255,255,255,0.1)',
125
+ border: '1px solid rgba(255,255,255,0.3)', color: '#fff',
126
+ borderRadius: '20px', cursor: 'pointer', fontSize: '14px',
127
+ }}>
128
+ Cancel Calibration
129
+ </button>
130
+ </div>
131
+ );
132
+ }
133
+
134
+ const overlayStyle = {
135
+ position: 'fixed', top: 0, left: 0, width: '100vw', height: '100vh',
136
+ background: 'rgba(0, 0, 0, 0.92)', zIndex: 10000,
137
+ display: 'flex', alignItems: 'center', justifyContent: 'center',
138
+ };
139
+
140
+ const messageBoxStyle = {
141
+ textAlign: 'center', padding: '30px 40px',
142
+ background: 'rgba(30, 30, 50, 0.9)', borderRadius: '16px',
143
+ border: '1px solid rgba(255,255,255,0.1)',
144
+ };
145
+
146
+ export default CalibrationOverlay;
src/components/FocusPageLocal.jsx CHANGED
@@ -1,4 +1,5 @@
1
  import React, { useState, useEffect, useRef } from 'react';
 
2
 
3
  const FLOW_STEPS = {
4
  intro: 'intro',
@@ -48,6 +49,9 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
48
  const [isStarting, setIsStarting] = useState(false);
49
  const [focusState, setFocusState] = useState(FOCUS_STATES.pending);
50
  const [cameraError, setCameraError] = useState('');
 
 
 
51
 
52
  const localVideoRef = useRef(null);
53
  const displayCanvasRef = useRef(null);
@@ -127,6 +131,8 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
127
  setFocusState(FOCUS_STATES.pending);
128
  setCameraReady(false);
129
  if (originalOnSessionEnd) originalOnSessionEnd(summary);
 
 
130
  };
131
 
132
  const statsInterval = setInterval(() => {
@@ -136,8 +142,10 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
136
  }, 1000);
137
 
138
  return () => {
139
- videoManager.callbacks.onStatusUpdate = originalOnStatusUpdate;
140
- videoManager.callbacks.onSessionEnd = originalOnSessionEnd;
 
 
141
  clearInterval(statsInterval);
142
  };
143
  }, [videoManager]);
@@ -149,6 +157,8 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
149
  .then((data) => {
150
  if (data.available) setAvailableModels(data.available);
151
  if (data.current) setCurrentModel(data.current);
 
 
152
  })
153
  .catch((err) => console.error('Failed to fetch models:', err));
154
  }, []);
@@ -204,6 +214,8 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
204
  const result = await res.json();
205
  if (result.updated) {
206
  setCurrentModel(modelName);
 
 
207
  }
208
  } catch (err) {
209
  console.error('Failed to switch model:', err);
@@ -225,6 +237,21 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
225
  console.error('Camera init error:', err);
226
  }
227
  };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  const handleStart = async () => {
229
  try {
230
  setIsStarting(true);
@@ -697,6 +724,65 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
697
  }}>
698
  <span title="Server CPU">CPU: <strong style={{ color: '#8f8' }}>{systemStats.cpu_percent}%</strong></span>
699
  <span title="Server memory">RAM: <strong style={{ color: '#8af' }}>{systemStats.memory_percent}%</strong> ({systemStats.memory_used_mb}/{systemStats.memory_total_mb} MB)</span>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
  </section>
701
  )}
702
 
@@ -787,6 +873,58 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
787
  </section>
788
  </>
789
  ) : null}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
790
  </main>
791
  );
792
  }
 
1
  import React, { useState, useEffect, useRef } from 'react';
2
+ import CalibrationOverlay from './CalibrationOverlay';
3
 
4
  const FLOW_STEPS = {
5
  intro: 'intro',
 
49
  const [isStarting, setIsStarting] = useState(false);
50
  const [focusState, setFocusState] = useState(FOCUS_STATES.pending);
51
  const [cameraError, setCameraError] = useState('');
52
+ const [calibration, setCalibration] = useState(null);
53
+ const [l2csBoost, setL2csBoost] = useState(false);
54
+ const [l2csBoostAvailable, setL2csBoostAvailable] = useState(false);
55
 
56
  const localVideoRef = useRef(null);
57
  const displayCanvasRef = useRef(null);
 
131
  setFocusState(FOCUS_STATES.pending);
132
  setCameraReady(false);
133
  if (originalOnSessionEnd) originalOnSessionEnd(summary);
134
+ videoManager.callbacks.onCalibrationUpdate = (cal) => {
135
+ setCalibration(cal && cal.active ? { ...cal } : null);
136
  };
137
 
138
  const statsInterval = setInterval(() => {
 
142
  }, 1000);
143
 
144
  return () => {
145
+ if (videoManager) {
146
+ videoManager.callbacks.onStatusUpdate = originalOnStatusUpdate;
147
+ videoManager.callbacks.onCalibrationUpdate = null;
148
+ }
149
  clearInterval(statsInterval);
150
  };
151
  }, [videoManager]);
 
157
  .then((data) => {
158
  if (data.available) setAvailableModels(data.available);
159
  if (data.current) setCurrentModel(data.current);
160
+ if (data.l2cs_boost !== undefined) setL2csBoost(data.l2cs_boost);
161
+ if (data.l2cs_boost_available !== undefined) setL2csBoostAvailable(data.l2cs_boost_available);
162
  })
163
  .catch((err) => console.error('Failed to fetch models:', err));
164
  }, []);
 
214
  const result = await res.json();
215
  if (result.updated) {
216
  setCurrentModel(modelName);
217
+ setL2csBoostAvailable(modelName !== 'l2cs' && availableModels.includes('l2cs'));
218
+ if (modelName === 'l2cs') setL2csBoost(false);
219
  }
220
  } catch (err) {
221
  console.error('Failed to switch model:', err);
 
237
  console.error('Camera init error:', err);
238
  }
239
  };
240
+
241
+ const handleBoostToggle = async () => {
242
+ const next = !l2csBoost;
243
+ try {
244
+ const res = await fetch('/api/settings', {
245
+ method: 'PUT',
246
+ headers: { 'Content-Type': 'application/json' },
247
+ body: JSON.stringify({ l2cs_boost: next })
248
+ });
249
+ if (res.ok) setL2csBoost(next);
250
+ } catch (err) {
251
+ console.error('Failed to toggle L2CS boost:', err);
252
+ }
253
+ };
254
+
255
  const handleStart = async () => {
256
  try {
257
  setIsStarting(true);
 
724
  }}>
725
  <span title="Server CPU">CPU: <strong style={{ color: '#8f8' }}>{systemStats.cpu_percent}%</strong></span>
726
  <span title="Server memory">RAM: <strong style={{ color: '#8af' }}>{systemStats.memory_percent}%</strong> ({systemStats.memory_used_mb}/{systemStats.memory_total_mb} MB)</span>
727
+ <span style={{ color: '#aaa', fontSize: '13px', marginRight: '4px' }}>Model:</span>
728
+ {availableModels.map(name => (
729
+ <button
730
+ key={name}
731
+ onClick={() => handleModelChange(name)}
732
+ style={{
733
+ padding: '5px 14px',
734
+ borderRadius: '16px',
735
+ border: currentModel === name ? '2px solid #007BFF' : '1px solid #555',
736
+ background: currentModel === name ? '#007BFF' : 'transparent',
737
+ color: currentModel === name ? '#fff' : '#ccc',
738
+ fontSize: '12px',
739
+ fontWeight: currentModel === name ? 'bold' : 'normal',
740
+ cursor: 'pointer',
741
+ textTransform: 'uppercase',
742
+ transition: 'all 0.2s'
743
+ }}
744
+ >
745
+ {name}
746
+ </button>
747
+ ))}
748
+ {l2csBoostAvailable && currentModel !== 'l2cs' && (
749
+ <button
750
+ onClick={handleBoostToggle}
751
+ style={{
752
+ padding: '5px 14px',
753
+ borderRadius: '16px',
754
+ border: l2csBoost ? '2px solid #f59e0b' : '1px solid #555',
755
+ background: l2csBoost ? 'rgba(245, 158, 11, 0.15)' : 'transparent',
756
+ color: l2csBoost ? '#f59e0b' : '#888',
757
+ fontSize: '11px',
758
+ fontWeight: l2csBoost ? 'bold' : 'normal',
759
+ cursor: 'pointer',
760
+ transition: 'all 0.2s',
761
+ marginLeft: '4px',
762
+ }}
763
+ >
764
+ {l2csBoost ? 'GAZE ON' : 'GAZE'}
765
+ </button>
766
+ )}
767
+ {(currentModel === 'l2cs' || l2csBoost) && stats && stats.isStreaming && (
768
+ <button
769
+ onClick={() => videoManager && videoManager.startCalibration()}
770
+ style={{
771
+ padding: '5px 14px',
772
+ borderRadius: '16px',
773
+ border: '1px solid #4ade80',
774
+ background: 'transparent',
775
+ color: '#4ade80',
776
+ fontSize: '12px',
777
+ fontWeight: 'bold',
778
+ cursor: 'pointer',
779
+ transition: 'all 0.2s',
780
+ marginLeft: '4px',
781
+ }}
782
+ >
783
+ Calibrate
784
+ </button>
785
+ )}
786
  </section>
787
  )}
788
 
 
873
  </section>
874
  </>
875
  ) : null}
876
+ ))}
877
+ </div>
878
+ <div id="timeline-line"></div>
879
+ </section>
880
+
881
+ {/* 4. Control Buttons */}
882
+ <section id="control-panel">
883
+ <button id="btn-cam-start" className="action-btn green" onClick={handleStart}>
884
+ Start
885
+ </button>
886
+
887
+ <button id="btn-floating" className="action-btn yellow" onClick={handleFloatingWindow}>
888
+ Floating Window
889
+ </button>
890
+
891
+ <button
892
+ id="btn-preview"
893
+ className="action-btn"
894
+ style={{ backgroundColor: '#6c5ce7' }}
895
+ onClick={handlePreview}
896
+ >
897
+ Preview Result
898
+ </button>
899
+
900
+ <button id="btn-cam-stop" className="action-btn red" onClick={handleStop}>
901
+ Stop
902
+ </button>
903
+ </section>
904
+
905
+ {/* 5. Frame Control */}
906
+ <section id="frame-control">
907
+ <label htmlFor="frame-slider">Frame Rate (FPS)</label>
908
+ <input
909
+ type="range"
910
+ id="frame-slider"
911
+ min="10"
912
+ max="30"
913
+ value={currentFrame}
914
+ onChange={(e) => handleFrameChange(e.target.value)}
915
+ />
916
+ <input
917
+ type="number"
918
+ id="frame-input"
919
+ min="10"
920
+ max="30"
921
+ value={currentFrame}
922
+ onChange={(e) => handleFrameChange(e.target.value)}
923
+ />
924
+ </section>
925
+
926
+ {/* Calibration overlay (fixed fullscreen, must be outside overflow:hidden containers) */}
927
+ <CalibrationOverlay calibration={calibration} videoManager={videoManager} />
928
  </main>
929
  );
930
  }
src/utils/VideoManagerLocal.js CHANGED
@@ -40,6 +40,17 @@ export class VideoManagerLocal {
40
  this.lastNotificationTime = null;
41
  this.notificationCooldown = 60000;
42
 
 
 
 
 
 
 
 
 
 
 
 
43
  // Performance metrics
44
  this.stats = {
45
  framesSent: 0,
@@ -74,8 +85,8 @@ export class VideoManagerLocal {
74
 
75
  // Create a smaller capture canvas for faster encoding and transfer.
76
  this.canvas = document.createElement('canvas');
77
- this.canvas.width = 320;
78
- this.canvas.height = 240;
79
 
80
  console.log('Local camera initialized');
81
  return true;
@@ -247,7 +258,7 @@ export class VideoManagerLocal {
247
  this.ws.send(blob);
248
  this.stats.framesSent++;
249
  }
250
- }, 'image/jpeg', 0.5);
251
  } catch (error) {
252
  this._sendingBlob = false;
253
  console.error('Capture error:', error);
@@ -312,6 +323,19 @@ export class VideoManagerLocal {
312
  ctx.textAlign = 'left';
313
  }
314
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  // Performance stats
316
  ctx.fillStyle = 'rgba(0,0,0,0.5)';
317
  ctx.fillRect(0, h - 25, w, 25);
@@ -380,6 +404,9 @@ export class VideoManagerLocal {
380
  mar: data.mar,
381
  sf: data.sf,
382
  se: data.se,
 
 
 
383
  };
384
  this.drawDetectionResult(detectionData);
385
  break;
@@ -397,6 +424,51 @@ export class VideoManagerLocal {
397
  this.sessionStartTime = null;
398
  break;
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  case 'error':
401
  console.error('Server error:', data.message);
402
  break;
@@ -406,6 +478,28 @@ export class VideoManagerLocal {
406
  }
407
  }
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  // Face mesh landmark index groups (matches live_demo.py)
410
  static FACE_OVAL = [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109,10];
411
  static LEFT_EYE = [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246];
 
40
  this.lastNotificationTime = null;
41
  this.notificationCooldown = 60000;
42
 
43
+ // Calibration state
44
+ this.calibration = {
45
+ active: false,
46
+ collecting: false,
47
+ target: null,
48
+ index: 0,
49
+ numPoints: 0,
50
+ done: false,
51
+ success: false,
52
+ };
53
+
54
  // Performance metrics
55
  this.stats = {
56
  framesSent: 0,
 
85
 
86
  // Create a smaller capture canvas for faster encoding and transfer.
87
  this.canvas = document.createElement('canvas');
88
+ this.canvas.width = 640;
89
+ this.canvas.height = 480;
90
 
91
  console.log('Local camera initialized');
92
  return true;
 
258
  this.ws.send(blob);
259
  this.stats.framesSent++;
260
  }
261
+ }, 'image/jpeg', 0.75);
262
  } catch (error) {
263
  this._sendingBlob = false;
264
  console.error('Capture error:', error);
 
323
  ctx.textAlign = 'left';
324
  }
325
  }
326
+ // Gaze pointer (L2CS + calibration)
327
+ if (data && data.gaze_x !== undefined && data.gaze_y !== undefined) {
328
+ const gx = data.gaze_x * w;
329
+ const gy = data.gaze_y * h;
330
+ ctx.beginPath();
331
+ ctx.arc(gx, gy, 8, 0, 2 * Math.PI);
332
+ ctx.fillStyle = data.on_screen ? 'rgba(0, 200, 255, 0.7)' : 'rgba(255, 80, 80, 0.5)';
333
+ ctx.fill();
334
+ ctx.strokeStyle = '#FFFFFF';
335
+ ctx.lineWidth = 2;
336
+ ctx.stroke();
337
+ }
338
+
339
  // Performance stats
340
  ctx.fillStyle = 'rgba(0,0,0,0.5)';
341
  ctx.fillRect(0, h - 25, w, 25);
 
404
  mar: data.mar,
405
  sf: data.sf,
406
  se: data.se,
407
+ gaze_x: data.gaze_x,
408
+ gaze_y: data.gaze_y,
409
+ on_screen: data.on_screen,
410
  };
411
  this.drawDetectionResult(detectionData);
412
  break;
 
424
  this.sessionStartTime = null;
425
  break;
426
 
427
+ case 'calibration_started':
428
+ this.calibration = {
429
+ active: true,
430
+ collecting: true,
431
+ target: data.target,
432
+ index: data.index,
433
+ numPoints: data.num_points,
434
+ done: false,
435
+ success: false,
436
+ };
437
+ if (this.callbacks.onCalibrationUpdate) {
438
+ this.callbacks.onCalibrationUpdate({ ...this.calibration });
439
+ }
440
+ break;
441
+
442
+ case 'calibration_point':
443
+ this.calibration.target = data.target;
444
+ this.calibration.index = data.index;
445
+ if (this.callbacks.onCalibrationUpdate) {
446
+ this.callbacks.onCalibrationUpdate({ ...this.calibration });
447
+ }
448
+ break;
449
+
450
+ case 'calibration_done':
451
+ this.calibration.collecting = false;
452
+ this.calibration.done = true;
453
+ this.calibration.success = data.success;
454
+ if (this.callbacks.onCalibrationUpdate) {
455
+ this.callbacks.onCalibrationUpdate({ ...this.calibration });
456
+ }
457
+ setTimeout(() => {
458
+ this.calibration.active = false;
459
+ if (this.callbacks.onCalibrationUpdate) {
460
+ this.callbacks.onCalibrationUpdate({ ...this.calibration });
461
+ }
462
+ }, 2000);
463
+ break;
464
+
465
+ case 'calibration_cancelled':
466
+ this.calibration = { active: false, collecting: false, target: null, index: 0, numPoints: 0, done: false, success: false };
467
+ if (this.callbacks.onCalibrationUpdate) {
468
+ this.callbacks.onCalibrationUpdate({ ...this.calibration });
469
+ }
470
+ break;
471
+
472
  case 'error':
473
  console.error('Server error:', data.message);
474
  break;
 
478
  }
479
  }
480
 
481
+ startCalibration() {
482
+ if (this.ws && this.ws.readyState === WebSocket.OPEN) {
483
+ this.ws.send(JSON.stringify({ type: 'calibration_start' }));
484
+ }
485
+ }
486
+
487
+ nextCalibrationPoint() {
488
+ if (this.ws && this.ws.readyState === WebSocket.OPEN) {
489
+ this.ws.send(JSON.stringify({ type: 'calibration_next' }));
490
+ }
491
+ }
492
+
493
+ cancelCalibration() {
494
+ if (this.ws && this.ws.readyState === WebSocket.OPEN) {
495
+ this.ws.send(JSON.stringify({ type: 'calibration_cancel' }));
496
+ }
497
+ this.calibration = { active: false, collecting: false, target: null, index: 0, numPoints: 0, done: false, success: false };
498
+ if (this.callbacks.onCalibrationUpdate) {
499
+ this.callbacks.onCalibrationUpdate({ ...this.calibration });
500
+ }
501
+ }
502
+
503
  // Face mesh landmark index groups (matches live_demo.py)
504
  static FACE_OVAL = [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109,10];
505
  static LEFT_EYE = [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246];
ui/pipeline.py CHANGED
@@ -5,6 +5,7 @@ import glob
5
  import json
6
  import math
7
  import os
 
8
  import sys
9
 
10
  import numpy as np
@@ -54,8 +55,12 @@ def _clip_features(vec):
54
 
55
 
56
  class _OutputSmoother:
57
- def __init__(self, alpha: float = 0.3, grace_frames: int = 15):
58
- self._alpha = alpha
 
 
 
 
59
  self._grace = grace_frames
60
  self._score = 0.5
61
  self._no_face = 0
@@ -64,14 +69,15 @@ class _OutputSmoother:
64
  self._score = 0.5
65
  self._no_face = 0
66
 
67
- def update(self, raw_score: float, face_detected: bool) -> float:
68
  if face_detected:
69
  self._no_face = 0
70
- self._score += self._alpha * (raw_score - self._score)
 
71
  else:
72
  self._no_face += 1
73
  if self._no_face > self._grace:
74
- self._score *= 0.85
75
  return self._score
76
 
77
 
@@ -645,3 +651,141 @@ class XGBoostPipeline:
645
 
646
  def __exit__(self, *args):
647
  self.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import json
6
  import math
7
  import os
8
+ import pathlib
9
  import sys
10
 
11
  import numpy as np
 
55
 
56
 
57
  class _OutputSmoother:
58
+ # Asymmetric EMA: rises fast (recognise focus), falls slower (avoid flicker).
59
+ # Grace period holds score steady for a few frames when face is lost.
60
+
61
+ def __init__(self, alpha_up=0.55, alpha_down=0.45, grace_frames=10):
62
+ self._alpha_up = alpha_up
63
+ self._alpha_down = alpha_down
64
  self._grace = grace_frames
65
  self._score = 0.5
66
  self._no_face = 0
 
69
  self._score = 0.5
70
  self._no_face = 0
71
 
72
+ def update(self, raw_score, face_detected):
73
  if face_detected:
74
  self._no_face = 0
75
+ alpha = self._alpha_up if raw_score > self._score else self._alpha_down
76
+ self._score += alpha * (raw_score - self._score)
77
  else:
78
  self._no_face += 1
79
  if self._no_face > self._grace:
80
+ self._score *= 0.80
81
  return self._score
82
 
83
 
 
651
 
652
  def __exit__(self, *args):
653
  self.close()
654
+
655
+
656
+ def _resolve_l2cs_weights():
657
+ for p in [
658
+ os.path.join(_PROJECT_ROOT, "models", "L2CS-Net", "models", "L2CSNet_gaze360.pkl"),
659
+ os.path.join(_PROJECT_ROOT, "models", "L2CSNet_gaze360.pkl"),
660
+ os.path.join(_PROJECT_ROOT, "checkpoints", "L2CSNet_gaze360.pkl"),
661
+ ]:
662
+ if os.path.isfile(p):
663
+ return p
664
+ return None
665
+
666
+
667
+ def is_l2cs_weights_available():
668
+ return _resolve_l2cs_weights() is not None
669
+
670
+
671
+ class L2CSPipeline:
672
+ # Uses in-tree l2cs.Pipeline (RetinaFace + ResNet50) for gaze estimation
673
+ # and MediaPipe for head pose, EAR, MAR, and roll de-rotation.
674
+
675
+ YAW_THRESHOLD = 22.0
676
+ PITCH_THRESHOLD = 20.0
677
+
678
+ def __init__(self, weights_path=None, arch="ResNet50", device="cpu",
679
+ threshold=0.52, detector=None):
680
+ resolved = weights_path or _resolve_l2cs_weights()
681
+ if resolved is None or not os.path.isfile(resolved):
682
+ raise FileNotFoundError(
683
+ "L2CS weights not found. Place L2CSNet_gaze360.pkl in "
684
+ "models/L2CS-Net/models/ or checkpoints/"
685
+ )
686
+
687
+ # add in-tree L2CS-Net to import path
688
+ l2cs_root = os.path.join(_PROJECT_ROOT, "models", "L2CS-Net")
689
+ if l2cs_root not in sys.path:
690
+ sys.path.insert(0, l2cs_root)
691
+ from l2cs import Pipeline as _L2CSPipeline
692
+
693
+ import torch
694
+ # bypass upstream select_device bug by constructing torch.device directly
695
+ self._pipeline = _L2CSPipeline(
696
+ weights=pathlib.Path(resolved), arch=arch, device=torch.device(device),
697
+ )
698
+
699
+ self._detector = detector or FaceMeshDetector()
700
+ self._owns_detector = detector is None
701
+ self._head_pose = HeadPoseEstimator()
702
+ self.head_pose = self._head_pose
703
+ self._eye_scorer = EyeBehaviourScorer()
704
+ self._threshold = threshold
705
+ self._smoother = _OutputSmoother()
706
+
707
+ print(
708
+ f"[L2CS] Loaded {resolved} | arch={arch} device={device} "
709
+ f"yaw_thresh={self.YAW_THRESHOLD} pitch_thresh={self.PITCH_THRESHOLD} "
710
+ f"threshold={threshold}"
711
+ )
712
+
713
+ @staticmethod
714
+ def _derotate_gaze(pitch_rad, yaw_rad, roll_deg):
715
+ # remove head roll so tilted-but-looking-at-screen reads as (0,0)
716
+ roll_rad = -math.radians(roll_deg)
717
+ cos_r, sin_r = math.cos(roll_rad), math.sin(roll_rad)
718
+ return (yaw_rad * sin_r + pitch_rad * cos_r,
719
+ yaw_rad * cos_r - pitch_rad * sin_r)
720
+
721
+ def process_frame(self, bgr_frame):
722
+ landmarks = self._detector.process(bgr_frame)
723
+ h, w = bgr_frame.shape[:2]
724
+
725
+ out = {
726
+ "landmarks": landmarks, "is_focused": False, "raw_score": 0.0,
727
+ "s_face": 0.0, "s_eye": 0.0, "gaze_pitch": None, "gaze_yaw": None,
728
+ "yaw": None, "pitch": None, "roll": None, "mar": None, "is_yawning": False,
729
+ }
730
+
731
+ # MediaPipe: head pose, eye/mouth scores
732
+ roll_deg = 0.0
733
+ if landmarks is not None:
734
+ angles = self._head_pose.estimate(landmarks, w, h)
735
+ if angles is not None:
736
+ out["yaw"], out["pitch"], out["roll"] = angles
737
+ roll_deg = angles[2]
738
+ out["s_face"] = self._head_pose.score(landmarks, w, h)
739
+ out["s_eye"] = self._eye_scorer.score(landmarks)
740
+ out["mar"] = compute_mar(landmarks)
741
+ out["is_yawning"] = out["mar"] > MAR_YAWN_THRESHOLD
742
+
743
+ # L2CS gaze (uses its own RetinaFace detector internally)
744
+ results = self._pipeline.step(bgr_frame)
745
+
746
+ if results is None or results.pitch.shape[0] == 0:
747
+ smoothed = self._smoother.update(0.0, landmarks is not None)
748
+ out["raw_score"] = smoothed
749
+ out["is_focused"] = smoothed >= self._threshold
750
+ return out
751
+
752
+ pitch_rad = float(results.pitch[0])
753
+ yaw_rad = float(results.yaw[0])
754
+
755
+ pitch_rad, yaw_rad = self._derotate_gaze(pitch_rad, yaw_rad, roll_deg)
756
+ out["gaze_pitch"] = pitch_rad
757
+ out["gaze_yaw"] = yaw_rad
758
+
759
+ yaw_deg = abs(math.degrees(yaw_rad))
760
+ pitch_deg = abs(math.degrees(pitch_rad))
761
+
762
+ # fall back to L2CS angles if MediaPipe didn't produce head pose
763
+ out["yaw"] = out.get("yaw") or math.degrees(yaw_rad)
764
+ out["pitch"] = out.get("pitch") or math.degrees(pitch_rad)
765
+
766
+ # cosine scoring: 1.0 at centre, 0.0 at threshold
767
+ yaw_t = min(yaw_deg / self.YAW_THRESHOLD, 1.0)
768
+ pitch_t = min(pitch_deg / self.PITCH_THRESHOLD, 1.0)
769
+ yaw_score = 0.5 * (1.0 + math.cos(math.pi * yaw_t))
770
+ pitch_score = 0.5 * (1.0 + math.cos(math.pi * pitch_t))
771
+ gaze_score = 0.55 * yaw_score + 0.45 * pitch_score
772
+
773
+ if out["is_yawning"]:
774
+ gaze_score = 0.0
775
+
776
+ out["raw_score"] = self._smoother.update(float(gaze_score), True)
777
+ out["is_focused"] = out["raw_score"] >= self._threshold
778
+ return out
779
+
780
+ def reset_session(self):
781
+ self._smoother.reset()
782
+
783
+ def close(self):
784
+ if self._owns_detector:
785
+ self._detector.close()
786
+
787
+ def __enter__(self):
788
+ return self
789
+
790
+ def __exit__(self, *args):
791
+ self.close()