Abdelrahman Almatrooshi commited on
Commit
87209fb
·
1 Parent(s): 5627c54

Replace merged files with working IntegrationTest versions

Browse files

The merge created broken UI with duplicate sections and missing
closing braces. Replaced all L2CS-related files with the tested
versions from the original IntegrationTest Space deployment.

Files changed (4) hide show
  1. main.py +20 -177
  2. src/components/FocusPageLocal.jsx +157 -456
  3. src/utils/VideoManagerLocal.js +60 -310
  4. ui/pipeline.py +150 -157
main.py CHANGED
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.responses import FileResponse
@@ -16,7 +14,6 @@ import math
16
  import os
17
  from pathlib import Path
18
  from typing import Callable
19
- from contextlib import asynccontextmanager
20
  import asyncio
21
  import concurrent.futures
22
  import threading
@@ -136,38 +133,6 @@ def _draw_hud(frame, result, model_name):
136
  if result.get("is_yawning"):
137
  cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA)
138
 
139
-
140
- def _draw_gaze_arrows(frame, result, lm, w, h):
141
- """Draw eyes, irises, and iris-based gaze lines matching live_demo.py."""
142
- if lm is None:
143
- return
144
- # Eye contours
145
- left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32)
146
- cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA)
147
- right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32)
148
- cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA)
149
- # EAR key points (yellow dots)
150
- for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]:
151
- for idx in indices:
152
- cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA)
153
- # Irises + gaze direction lines
154
- for iris_idx, eye_inner, eye_outer in [
155
- (FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33),
156
- (FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263),
157
- ]:
158
- iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32)
159
- center = iris_pts[0]
160
- if len(iris_pts) >= 5:
161
- radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)]
162
- radius = max(int(np.mean(radii)), 2)
163
- cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA)
164
- cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA)
165
- eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w)
166
- eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h)
167
- dx, dy = center[0] - eye_cx, center[1] - eye_cy
168
- cv2.line(frame, tuple(center), (int(center[0] + dx * 3), int(center[1] + dy * 3)), _RED, 1, cv2.LINE_AA)
169
-
170
-
171
  # Landmark indices used for face mesh drawing on client (union of all groups).
172
  # Sending only these instead of all 478 saves ~60% of the landmarks payload.
173
  _MESH_INDICES = sorted(set(
@@ -186,57 +151,8 @@ _MESH_INDICES = sorted(set(
186
  # Build a lookup: original_index -> position in sparse array, so client can reconstruct.
187
  _MESH_INDEX_SET = set(_MESH_INDICES)
188
 
189
- @asynccontextmanager
190
- async def lifespan(app):
191
- global _cached_model_name
192
- print(" Starting Focus Guard API...")
193
- await init_database()
194
- async with aiosqlite.connect(db_path) as db:
195
- cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1")
196
- row = await cursor.fetchone()
197
- if row:
198
- _cached_model_name = row[0]
199
- print("[OK] Database initialized")
200
- try:
201
- pipelines["geometric"] = FaceMeshPipeline()
202
- print("[OK] FaceMeshPipeline (geometric) loaded")
203
- except Exception as e:
204
- print(f"[WARN] FaceMeshPipeline unavailable: {e}")
205
- try:
206
- pipelines["mlp"] = MLPPipeline()
207
- print("[OK] MLPPipeline loaded")
208
- except Exception as e:
209
- print(f"[ERR] Failed to load MLPPipeline: {e}")
210
- try:
211
- pipelines["hybrid"] = HybridFocusPipeline()
212
- print("[OK] HybridFocusPipeline loaded")
213
- except Exception as e:
214
- print(f"[WARN] HybridFocusPipeline unavailable: {e}")
215
- try:
216
- pipelines["xgboost"] = XGBoostPipeline()
217
- print("[OK] XGBoostPipeline loaded")
218
- except Exception as e:
219
- print(f"[ERR] Failed to load XGBoostPipeline: {e}")
220
- if is_l2cs_weights_available():
221
- print("[OK] L2CS weights found — pipeline will be lazy-loaded on first use")
222
- else:
223
- print("[WARN] L2CS weights not found — l2cs model unavailable")
224
- resolved_model = _first_available_pipeline_name(_cached_model_name)
225
- if resolved_model is not None and resolved_model != _cached_model_name:
226
- _cached_model_name = resolved_model
227
- async with aiosqlite.connect(db_path) as db:
228
- await db.execute(
229
- "UPDATE user_settings SET model_name = ? WHERE id = 1",
230
- (_cached_model_name,),
231
- )
232
- await db.commit()
233
- if resolved_model is not None:
234
- print(f"[OK] Active model set to {resolved_model}")
235
- yield
236
- _inference_executor.shutdown(wait=False)
237
- print(" Shutting down Focus Guard API...")
238
-
239
- app = FastAPI(title="Focus Guard API", lifespan=lifespan)
240
 
241
  # Add CORS middleware
242
  app.add_middleware(
@@ -250,8 +166,8 @@ app.add_middleware(
250
  # Global variables
251
  db_path = "focus_guard.db"
252
  pcs = set()
253
- _cached_model_name = "mlp"
254
- _l2cs_boost_enabled = False
255
 
256
  async def _wait_for_ice_gathering(pc: RTCPeerConnection):
257
  if pc.iceGatheringState == "complete":
@@ -357,37 +273,23 @@ class VideoTransformTrack(VideoStreamTrack):
357
 
358
  if do_infer:
359
  self.last_inference_time = now
360
-
361
  model_name = _cached_model_name
362
  if model_name == "l2cs" and pipelines.get("l2cs") is None:
363
  _ensure_l2cs()
364
  if model_name not in pipelines or pipelines.get(model_name) is None:
365
- model_name = "mlp"
366
  active_pipeline = pipelines.get(model_name)
367
 
368
  if active_pipeline is not None:
369
  loop = asyncio.get_event_loop()
370
- use_boost = (
371
- _l2cs_boost_enabled
372
- and model_name != "l2cs"
373
- and pipelines.get("l2cs") is not None
 
 
374
  )
375
- if use_boost:
376
- out = await loop.run_in_executor(
377
- _inference_executor,
378
- _process_frame_with_l2cs_boost,
379
- active_pipeline,
380
- img,
381
- model_name,
382
- )
383
- else:
384
- out = await loop.run_in_executor(
385
- _inference_executor,
386
- _process_frame_safe,
387
- active_pipeline,
388
- img,
389
- model_name,
390
- )
391
  is_focused = out["is_focused"]
392
  confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
393
  metadata = {"s_face": out.get("s_face", 0.0), "s_eye": out.get("s_eye", 0.0), "mar": out.get("mar", 0.0), "model": model_name}
@@ -395,10 +297,8 @@ class VideoTransformTrack(VideoStreamTrack):
395
  # Draw face mesh + HUD on the video frame
396
  h_f, w_f = img.shape[:2]
397
  lm = out.get("landmarks")
398
- if lm is not None and model_name != "l2cs":
399
  _draw_face_mesh(img, lm, w_f, h_f)
400
- if model_name == "l2cs" and lm is not None:
401
- _draw_gaze_arrows(img, out, lm, w_f, h_f)
402
  _draw_hud(img, out, model_name)
403
  else:
404
  is_focused = False
@@ -413,13 +313,7 @@ class VideoTransformTrack(VideoStreamTrack):
413
  channel = self.get_channel()
414
  if channel and channel.readyState == "open":
415
  try:
416
- channel.send(json.dumps({
417
- "type": "detection",
418
- "focused": is_focused,
419
- "confidence": round(confidence, 3),
420
- "detections": [],
421
- "model": model_name,
422
- }))
423
  except Exception:
424
  pass
425
 
@@ -611,15 +505,6 @@ def _process_frame_safe(pipeline, frame, model_name):
611
  return pipeline.process_frame(frame)
612
 
613
 
614
- def _first_available_pipeline_name(preferred: str | None = None) -> str | None:
615
- if preferred and preferred in pipelines and pipelines.get(preferred) is not None:
616
- return preferred
617
- for name, pipeline in pipelines.items():
618
- if pipeline is not None:
619
- return name
620
- return None
621
-
622
-
623
  _BOOST_BASE_W = 0.35
624
  _BOOST_L2CS_W = 0.65
625
  _BOOST_VETO = 0.38 # L2CS below this -> forced not-focused
@@ -995,10 +880,8 @@ async def websocket_endpoint(websocket: WebSocket):
995
  "type": "detection",
996
  "focused": is_focused,
997
  "confidence": round(confidence, 3),
998
- "detections": [],
999
  "model": model_name,
1000
  "fc": frame_count,
1001
- "frame_count": frame_count,
1002
  }
1003
  if out is not None:
1004
  if out.get("yaw") is not None:
@@ -1029,14 +912,7 @@ async def websocket_endpoint(websocket: WebSocket):
1029
 
1030
  if landmarks_list is not None:
1031
  resp["lm"] = landmarks_list
1032
- try:
1033
- await websocket.send_json(resp)
1034
- except Exception as send_err:
1035
- # Connection can close between loop ticks; end cleanly.
1036
- if "Unexpected ASGI message 'websocket.send'" in str(send_err):
1037
- running = False
1038
- return
1039
- raise
1040
  frame_count += 1
1041
  except Exception as e:
1042
  print(f"[WS] process error: {e}")
@@ -1172,7 +1048,6 @@ async def get_settings():
1172
 
1173
  @app.put("/api/settings")
1174
  async def update_settings(settings: SettingsUpdate):
1175
- global _cached_model_name, _l2cs_boost_enabled
1176
  async with aiosqlite.connect(db_path) as db:
1177
  cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
1178
  exists = await cursor.fetchone()
@@ -1222,22 +1097,6 @@ async def update_settings(settings: SettingsUpdate):
1222
  await db.commit()
1223
  return {"status": "success", "updated": len(updates) > 0}
1224
 
1225
- @app.get("/api/stats/system")
1226
- async def get_system_stats():
1227
- """Return server CPU and memory usage for UI display."""
1228
- try:
1229
- import psutil
1230
- cpu = psutil.cpu_percent(interval=0.1)
1231
- mem = psutil.virtual_memory()
1232
- return {
1233
- "cpu_percent": round(cpu, 1),
1234
- "memory_percent": round(mem.percent, 1),
1235
- "memory_used_mb": round(mem.used / (1024 * 1024), 0),
1236
- "memory_total_mb": round(mem.total / (1024 * 1024), 0),
1237
- }
1238
- except ImportError:
1239
- return {"cpu_percent": None, "memory_percent": None, "memory_used_mb": None, "memory_total_mb": None}
1240
-
1241
  @app.get("/api/stats/summary")
1242
  async def get_stats_summary():
1243
  async with aiosqlite.connect(db_path) as db:
@@ -1265,14 +1124,6 @@ async def get_stats_summary():
1265
  'streak_days': streak_days
1266
  }
1267
 
1268
- @app.get("/api/l2cs/status")
1269
- async def get_l2cs_status():
1270
- return {
1271
- "weights_available": is_l2cs_weights_available(),
1272
- "loaded": pipelines.get("l2cs") is not None,
1273
- "error": _l2cs_error,
1274
- }
1275
-
1276
  @app.get("/api/models")
1277
  async def get_available_models():
1278
  """Return model names, statuses, and which is currently active."""
@@ -1337,13 +1188,9 @@ async def health_check():
1337
 
1338
  # ================ STATIC FILES (SPA SUPPORT) ================
1339
 
1340
- # Resolve frontend dir from this file so it works regardless of cwd.
1341
- # Prefer a built `dist/` app when present, otherwise fall back to `static/`.
1342
- _BASE_DIR = Path(__file__).resolve().parent
1343
- _DIST_DIR = _BASE_DIR / "dist"
1344
- _STATIC_DIR = _BASE_DIR / "static"
1345
- _FRONTEND_DIR = _DIST_DIR if (_DIST_DIR / "index.html").is_file() else _STATIC_DIR
1346
- _ASSETS_DIR = _FRONTEND_DIR / "assets"
1347
 
1348
  # 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all
1349
  if _ASSETS_DIR.is_dir():
@@ -1358,11 +1205,7 @@ async def serve_react_app(full_path: str, request: Request):
1358
  if full_path.startswith("assets") or full_path.startswith("assets/"):
1359
  raise HTTPException(status_code=404, detail="Not Found")
1360
 
1361
- file_path = _FRONTEND_DIR / full_path
1362
- if full_path and file_path.is_file():
1363
- return FileResponse(str(file_path))
1364
-
1365
- index_path = _FRONTEND_DIR / "index.html"
1366
  if index_path.is_file():
1367
  return FileResponse(str(index_path))
1368
- return {"message": "React app not found. Please run 'npm run build' and copy dist to static if needed."}
 
 
 
1
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
 
14
  import os
15
  from pathlib import Path
16
  from typing import Callable
 
17
  import asyncio
18
  import concurrent.futures
19
  import threading
 
133
  if result.get("is_yawning"):
134
  cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA)
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  # Landmark indices used for face mesh drawing on client (union of all groups).
137
  # Sending only these instead of all 478 saves ~60% of the landmarks payload.
138
  _MESH_INDICES = sorted(set(
 
151
  # Build a lookup: original_index -> position in sparse array, so client can reconstruct.
152
  _MESH_INDEX_SET = set(_MESH_INDICES)
153
 
154
+ # Initialize FastAPI app
155
+ app = FastAPI(title="Focus Guard API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # Add CORS middleware
158
  app.add_middleware(
 
166
  # Global variables
167
  db_path = "focus_guard.db"
168
  pcs = set()
169
+ _cached_model_name = "mlp" # in-memory cache, updated via /api/settings
170
+ _l2cs_boost_enabled = False # when True, L2CS runs alongside the base model
171
 
172
  async def _wait_for_ice_gathering(pc: RTCPeerConnection):
173
  if pc.iceGatheringState == "complete":
 
273
 
274
  if do_infer:
275
  self.last_inference_time = now
276
+
277
  model_name = _cached_model_name
278
  if model_name == "l2cs" and pipelines.get("l2cs") is None:
279
  _ensure_l2cs()
280
  if model_name not in pipelines or pipelines.get(model_name) is None:
281
+ model_name = 'mlp'
282
  active_pipeline = pipelines.get(model_name)
283
 
284
  if active_pipeline is not None:
285
  loop = asyncio.get_event_loop()
286
+ out = await loop.run_in_executor(
287
+ _inference_executor,
288
+ _process_frame_safe,
289
+ active_pipeline,
290
+ img,
291
+ model_name,
292
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  is_focused = out["is_focused"]
294
  confidence = out.get("mlp_prob", out.get("raw_score", 0.0))
295
  metadata = {"s_face": out.get("s_face", 0.0), "s_eye": out.get("s_eye", 0.0), "mar": out.get("mar", 0.0), "model": model_name}
 
297
  # Draw face mesh + HUD on the video frame
298
  h_f, w_f = img.shape[:2]
299
  lm = out.get("landmarks")
300
+ if lm is not None:
301
  _draw_face_mesh(img, lm, w_f, h_f)
 
 
302
  _draw_hud(img, out, model_name)
303
  else:
304
  is_focused = False
 
313
  channel = self.get_channel()
314
  if channel and channel.readyState == "open":
315
  try:
316
+ channel.send(json.dumps({"type": "detection", "focused": is_focused, "confidence": round(confidence, 3), "detections": detections}))
 
 
 
 
 
 
317
  except Exception:
318
  pass
319
 
 
505
  return pipeline.process_frame(frame)
506
 
507
 
 
 
 
 
 
 
 
 
 
508
  _BOOST_BASE_W = 0.35
509
  _BOOST_L2CS_W = 0.65
510
  _BOOST_VETO = 0.38 # L2CS below this -> forced not-focused
 
880
  "type": "detection",
881
  "focused": is_focused,
882
  "confidence": round(confidence, 3),
 
883
  "model": model_name,
884
  "fc": frame_count,
 
885
  }
886
  if out is not None:
887
  if out.get("yaw") is not None:
 
912
 
913
  if landmarks_list is not None:
914
  resp["lm"] = landmarks_list
915
+ await websocket.send_json(resp)
 
 
 
 
 
 
 
916
  frame_count += 1
917
  except Exception as e:
918
  print(f"[WS] process error: {e}")
 
1048
 
1049
  @app.put("/api/settings")
1050
  async def update_settings(settings: SettingsUpdate):
 
1051
  async with aiosqlite.connect(db_path) as db:
1052
  cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1")
1053
  exists = await cursor.fetchone()
 
1097
  await db.commit()
1098
  return {"status": "success", "updated": len(updates) > 0}
1099
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1100
  @app.get("/api/stats/summary")
1101
  async def get_stats_summary():
1102
  async with aiosqlite.connect(db_path) as db:
 
1124
  'streak_days': streak_days
1125
  }
1126
 
 
 
 
 
 
 
 
 
1127
  @app.get("/api/models")
1128
  async def get_available_models():
1129
  """Return model names, statuses, and which is currently active."""
 
1188
 
1189
  # ================ STATIC FILES (SPA SUPPORT) ================
1190
 
1191
+ # Resolve static dir from this file so it works regardless of cwd
1192
+ _STATIC_DIR = Path(__file__).resolve().parent / "static"
1193
+ _ASSETS_DIR = _STATIC_DIR / "assets"
 
 
 
 
1194
 
1195
  # 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all
1196
  if _ASSETS_DIR.is_dir():
 
1205
  if full_path.startswith("assets") or full_path.startswith("assets/"):
1206
  raise HTTPException(status_code=404, detail="Not Found")
1207
 
1208
+ index_path = _STATIC_DIR / "index.html"
 
 
 
 
1209
  if index_path.is_file():
1210
  return FileResponse(str(index_path))
1211
+ return {"message": "React app not found. Please run 'npm run build' and copy dist to static."}
src/components/FocusPageLocal.jsx CHANGED
@@ -1,138 +1,42 @@
1
  import React, { useState, useEffect, useRef } from 'react';
2
  import CalibrationOverlay from './CalibrationOverlay';
3
 
4
- const FLOW_STEPS = {
5
- intro: 'intro',
6
- permission: 'permission',
7
- ready: 'ready'
8
- };
9
-
10
- const FOCUS_STATES = {
11
- pending: 'pending',
12
- focused: 'focused',
13
- notFocused: 'not-focused'
14
- };
15
-
16
- function HelloIcon() {
17
- return (
18
- <svg width="96" height="96" viewBox="0 0 96 96" aria-hidden="true">
19
- <circle cx="48" cy="48" r="40" fill="#007BFF" />
20
- <path d="M30 38c0-4 2.7-7 6-7s6 3 6 7" fill="none" stroke="#fff" strokeWidth="6" strokeLinecap="round" />
21
- <path d="M54 38c0-4 2.7-7 6-7s6 3 6 7" fill="none" stroke="#fff" strokeWidth="6" strokeLinecap="round" />
22
- <path d="M30 52c3 11 10 17 18 17s15-6 18-17" fill="none" stroke="#fff" strokeWidth="6" strokeLinecap="round" />
23
- </svg>
24
- );
25
- }
26
-
27
- function CameraIcon() {
28
- return (
29
- <svg width="110" height="110" viewBox="0 0 110 110" aria-hidden="true">
30
- <rect x="30" y="36" width="50" height="34" rx="5" fill="none" stroke="#007BFF" strokeWidth="6" />
31
- <path d="M24 72h62c0 9-7 16-16 16H40c-9 0-16-7-16-16Z" fill="none" stroke="#007BFF" strokeWidth="6" />
32
- <path d="M55 28v8" stroke="#007BFF" strokeWidth="6" strokeLinecap="round" />
33
- <circle cx="55" cy="36" r="14" fill="none" stroke="#007BFF" strokeWidth="6" />
34
- <circle cx="55" cy="36" r="4" fill="#007BFF" />
35
- <path d="M46 83h18" stroke="#007BFF" strokeWidth="6" strokeLinecap="round" />
36
- </svg>
37
- );
38
- }
39
-
40
- function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActive, role }) {
41
  const [currentFrame, setCurrentFrame] = useState(15);
42
  const [timelineEvents, setTimelineEvents] = useState([]);
43
  const [stats, setStats] = useState(null);
44
- const [systemStats, setSystemStats] = useState(null);
45
  const [availableModels, setAvailableModels] = useState([]);
46
  const [currentModel, setCurrentModel] = useState('mlp');
47
- const [flowStep, setFlowStep] = useState(FLOW_STEPS.intro);
48
- const [cameraReady, setCameraReady] = useState(false);
49
- const [isStarting, setIsStarting] = useState(false);
50
- const [focusState, setFocusState] = useState(FOCUS_STATES.pending);
51
- const [cameraError, setCameraError] = useState('');
52
  const [calibration, setCalibration] = useState(null);
53
  const [l2csBoost, setL2csBoost] = useState(false);
54
  const [l2csBoostAvailable, setL2csBoostAvailable] = useState(false);
55
 
56
  const localVideoRef = useRef(null);
57
  const displayCanvasRef = useRef(null);
58
- const pipVideoRef = useRef(null);
59
  const pipStreamRef = useRef(null);
60
- const previewFrameRef = useRef(null);
61
 
 
62
  const formatDuration = (seconds) => {
63
- if (seconds === 0) return '0s';
64
  const mins = Math.floor(seconds / 60);
65
  const secs = Math.floor(seconds % 60);
66
  return `${mins}m ${secs}s`;
67
  };
68
-
69
- const stopPreviewLoop = () => {
70
- if (previewFrameRef.current) {
71
- cancelAnimationFrame(previewFrameRef.current);
72
- previewFrameRef.current = null;
73
- }
74
- };
75
-
76
- const startPreviewLoop = () => {
77
- stopPreviewLoop();
78
- const renderPreview = () => {
79
- const canvas = displayCanvasRef.current;
80
- const video = localVideoRef.current;
81
-
82
- if (!canvas || !video || !cameraReady || videoManager?.isStreaming) {
83
- previewFrameRef.current = null;
84
- return;
85
- }
86
-
87
- if (video.readyState >= 2) {
88
- const ctx = canvas.getContext('2d');
89
- ctx.drawImage(video, 0, 0, canvas.width, canvas.height);
90
- }
91
-
92
- previewFrameRef.current = requestAnimationFrame(renderPreview);
93
- };
94
-
95
- previewFrameRef.current = requestAnimationFrame(renderPreview);
96
- };
97
-
98
- const getErrorMessage = (err) => {
99
- if (err?.name === 'NotAllowedError') {
100
- return 'Camera permission denied. Please allow camera access.';
101
- }
102
- if (err?.name === 'NotFoundError') {
103
- return 'No camera found. Please connect a camera.';
104
- }
105
- if (err?.name === 'NotReadableError') {
106
- return 'Camera is already in use by another application.';
107
- }
108
- if (err?.target?.url) {
109
- return `WebSocket connection failed: ${err.target.url}. Check that the backend server is running.`;
110
- }
111
- return err?.message || 'Failed to start focus session.';
112
- };
113
 
114
  useEffect(() => {
115
  if (!videoManager) return;
116
 
117
  const originalOnStatusUpdate = videoManager.callbacks.onStatusUpdate;
118
- const originalOnSessionEnd = videoManager.callbacks.onSessionEnd;
119
-
120
  videoManager.callbacks.onStatusUpdate = (isFocused) => {
121
- setTimelineEvents((prev) => {
122
  const newEvents = [...prev, { isFocused, timestamp: Date.now() }];
123
  if (newEvents.length > 60) newEvents.shift();
124
  return newEvents;
125
  });
126
- setFocusState(isFocused ? FOCUS_STATES.focused : FOCUS_STATES.notFocused);
127
  if (originalOnStatusUpdate) originalOnStatusUpdate(isFocused);
128
  };
129
 
130
- videoManager.callbacks.onSessionEnd = (summary) => {
131
- setFocusState(FOCUS_STATES.pending);
132
- setCameraReady(false);
133
- if (originalOnSessionEnd) originalOnSessionEnd(summary);
134
- };
135
-
136
  videoManager.callbacks.onCalibrationUpdate = (cal) => {
137
  setCalibration(cal && cal.active ? { ...cal } : null);
138
  };
@@ -155,55 +59,14 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
155
  // Fetch available models on mount
156
  useEffect(() => {
157
  fetch('/api/models')
158
- .then((res) => res.json())
159
- .then((data) => {
160
  if (data.available) setAvailableModels(data.available);
161
  if (data.current) setCurrentModel(data.current);
162
  if (data.l2cs_boost !== undefined) setL2csBoost(data.l2cs_boost);
163
  if (data.l2cs_boost_available !== undefined) setL2csBoostAvailable(data.l2cs_boost_available);
164
  })
165
- .catch((err) => console.error('Failed to fetch models:', err));
166
- }, []);
167
-
168
- useEffect(() => {
169
- if (flowStep === FLOW_STEPS.ready && cameraReady && !videoManager?.isStreaming) {
170
- startPreviewLoop();
171
- return;
172
- }
173
- stopPreviewLoop();
174
- }, [cameraReady, flowStep, videoManager?.isStreaming]);
175
-
176
- useEffect(() => {
177
- if (!isActive) {
178
- stopPreviewLoop();
179
- }
180
- }, [isActive]);
181
-
182
- useEffect(() => {
183
- return () => {
184
- stopPreviewLoop();
185
- if (pipVideoRef.current) {
186
- pipVideoRef.current.pause();
187
- pipVideoRef.current.srcObject = null;
188
- }
189
- if (pipStreamRef.current) {
190
- pipStreamRef.current.getTracks().forEach((t) => t.stop());
191
- pipStreamRef.current = null;
192
- }
193
- };
194
- }, []);
195
-
196
- // Poll server CPU/memory for UI
197
- useEffect(() => {
198
- const fetchSystem = () => {
199
- fetch('/api/stats/system')
200
- .then(res => res.json())
201
- .then(data => setSystemStats(data))
202
- .catch(() => setSystemStats(null));
203
- };
204
- fetchSystem();
205
- const interval = setInterval(fetchSystem, 3000);
206
- return () => clearInterval(interval);
207
  }, []);
208
 
209
  const handleModelChange = async (modelName) => {
@@ -224,22 +87,6 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
224
  }
225
  };
226
 
227
- const handleEnableCamera = async () => {
228
- if (!videoManager) return;
229
-
230
- try {
231
- setCameraError('');
232
- await videoManager.initCamera(localVideoRef.current, displayCanvasRef.current);
233
- setCameraReady(true);
234
- setFlowStep(FLOW_STEPS.ready);
235
- setFocusState(FOCUS_STATES.pending);
236
- } catch (err) {
237
- const errorMessage = getErrorMessage(err);
238
- setCameraError(errorMessage);
239
- console.error('Camera init error:', err);
240
- }
241
- };
242
-
243
  const handleBoostToggle = async () => {
244
  const next = !l2csBoost;
245
  try {
@@ -256,33 +103,39 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
256
 
257
  const handleStart = async () => {
258
  try {
259
- setIsStarting(true);
260
- setSessionResult(null);
261
- setTimelineEvents([]);
262
- setFocusState(FOCUS_STATES.pending);
263
- setCameraError('');
264
 
265
- if (!cameraReady) {
266
  await videoManager.initCamera(localVideoRef.current, displayCanvasRef.current);
267
- setCameraReady(true);
268
- setFlowStep(FLOW_STEPS.ready);
269
- }
270
 
271
- await videoManager.startStreaming();
 
 
 
272
  } catch (err) {
273
- const errorMessage = getErrorMessage(err);
274
- setCameraError(errorMessage);
275
- setFocusState(FOCUS_STATES.pending);
276
  console.error('Start error:', err);
277
- alert(`Failed to start: ${errorMessage}\n\nCheck browser console for details.`);
278
- } finally {
279
- setIsStarting(false);
 
 
 
 
 
 
 
 
 
 
280
  }
281
  };
282
 
283
  const handleStop = async () => {
284
  if (videoManager) {
285
- await videoManager.stopStreaming();
286
  }
287
  try {
288
  if (document.pictureInPictureElement === pipVideoRef.current) {
@@ -294,17 +147,14 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
294
  pipVideoRef.current.srcObject = null;
295
  }
296
  if (pipStreamRef.current) {
297
- pipStreamRef.current.getTracks().forEach((t) => t.stop());
298
  pipStreamRef.current = null;
299
  }
300
- stopPreviewLoop();
301
- setFocusState(FOCUS_STATES.pending);
302
- setCameraReady(false);
303
  };
304
 
305
  const handlePiP = async () => {
306
  try {
307
- //
308
  if (!videoManager || !videoManager.isStreaming) {
309
  alert('Please start the video first.');
310
  return;
@@ -315,20 +165,20 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
315
  return;
316
  }
317
 
318
- //
319
  if (document.pictureInPictureElement === pipVideoRef.current) {
320
  await document.exitPictureInPicture();
321
  console.log('PiP exited');
322
  return;
323
  }
324
 
325
- //
326
  if (!document.pictureInPictureEnabled) {
327
  alert('Picture-in-Picture is not supported in this browser.');
328
  return;
329
  }
330
 
331
- //
332
  const pipVideo = pipVideoRef.current;
333
  if (!pipVideo) {
334
  alert('PiP video element not ready.');
@@ -337,7 +187,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
337
 
338
  const isSafariPiP = typeof pipVideo.webkitSetPresentationMode === 'function';
339
 
340
- //
341
  let stream = pipStreamRef.current;
342
  if (!stream) {
343
  const capture = displayCanvasRef.current.captureStream;
@@ -355,7 +205,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
355
  pipStreamRef.current = stream;
356
  }
357
 
358
- //
359
  if (!stream || stream.getTracks().length === 0) {
360
  alert('Failed to capture video stream from canvas.');
361
  return;
@@ -363,7 +213,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
363
 
364
  pipVideo.srcObject = stream;
365
 
366
- //
367
  if (pipVideo.readyState < 2) {
368
  await new Promise((resolve) => {
369
  const onReady = () => {
@@ -373,23 +223,25 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
373
  };
374
  pipVideo.addEventListener('loadeddata', onReady);
375
  pipVideo.addEventListener('canplay', onReady);
376
- //
377
  setTimeout(resolve, 600);
378
  });
379
  }
380
 
381
  try {
382
  await pipVideo.play();
383
- } catch (_) {}
 
 
384
 
385
- //
386
  if (isSafariPiP) {
387
  try {
388
  pipVideo.webkitSetPresentationMode('picture-in-picture');
389
  console.log('PiP activated (Safari)');
390
  return;
391
  } catch (e) {
392
- //
393
  const cameraStream = localVideoRef.current?.srcObject;
394
  if (cameraStream && cameraStream !== pipVideo.srcObject) {
395
  pipVideo.srcObject = cameraStream;
@@ -404,7 +256,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
404
  }
405
  }
406
 
407
- //
408
  if (typeof pipVideo.requestPictureInPicture === 'function') {
409
  await pipVideo.requestPictureInPicture();
410
  console.log('PiP activated');
@@ -414,7 +266,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
414
 
415
  } catch (err) {
416
  console.error('PiP error:', err);
417
- alert(`Failed to enter Picture-in-Picture: ${err.message}`);
418
  }
419
  };
420
 
@@ -423,7 +275,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
423
  };
424
 
425
  const handleFrameChange = (val) => {
426
- const rate = parseInt(val, 10);
427
  setCurrentFrame(rate);
428
  if (videoManager) {
429
  videoManager.setFrameRate(rate);
@@ -436,7 +288,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
436
  return;
437
  }
438
 
439
- //
440
  const currentStats = videoManager.getStats();
441
 
442
  if (!currentStats.sessionId) {
@@ -444,15 +296,15 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
444
  return;
445
  }
446
 
447
- //
448
  const sessionDuration = Math.floor((Date.now() - (videoManager.sessionStartTime || Date.now())) / 1000);
449
 
450
- //
451
  const focusScore = currentStats.framesProcessed > 0
452
  ? (currentStats.framesProcessed * (currentStats.currentStatus ? 1 : 0)) / currentStats.framesProcessed
453
  : 0;
454
 
455
- //
456
  setSessionResult({
457
  duration_seconds: sessionDuration,
458
  focus_score: focusScore,
@@ -476,142 +328,24 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
476
  pointerEvents: 'none'
477
  };
478
 
479
- const focusStateLabel = {
480
- [FOCUS_STATES.pending]: 'Pending',
481
- [FOCUS_STATES.focused]: 'Focused',
482
- [FOCUS_STATES.notFocused]: 'Not Focused'
483
- }[focusState];
484
-
485
- const introHighlights = [
486
- {
487
- title: 'Live focus tracking',
488
- text: 'Head pose, gaze, and eye openness are read continuously during the session.'
489
- },
490
- {
491
- title: 'Quick setup',
492
- text: 'Front-facing light and a stable camera angle give the cleanest preview.'
493
- },
494
- {
495
- title: 'Private by default',
496
- text: 'Only session metadata is stored, not the raw camera footage.'
497
- }
498
- ];
499
-
500
- const permissionSteps = [
501
- {
502
- title: 'Allow browser access',
503
- text: 'Approve the camera prompt so the preview can appear immediately.'
504
- },
505
- {
506
- title: 'Check your framing',
507
- text: 'Keep your face visible and centered for more stable landmark detection.'
508
- },
509
- {
510
- title: 'Start when ready',
511
- text: 'After the preview appears, use the page controls to begin or stop.'
512
- }
513
- ];
514
-
515
- const renderIntroCard = () => {
516
- if (flowStep === FLOW_STEPS.intro) {
517
- return (
518
- <div className="focus-flow-overlay">
519
- <div className="focus-flow-card">
520
- <div className="focus-flow-header">
521
- <div>
522
- <div className="focus-flow-eyebrow">Focus Session</div>
523
- <h2>Before you begin</h2>
524
- </div>
525
- <div className="focus-flow-icon">
526
- <HelloIcon />
527
- </div>
528
- </div>
529
-
530
- <p className="focus-flow-lead">
531
- The focus page uses your live camera preview to estimate attention in real time.
532
- Review the setup notes below, then continue to camera access.
533
- </p>
534
-
535
- <div className="focus-flow-grid">
536
- {introHighlights.map((item) => (
537
- <article key={item.title} className="focus-flow-panel">
538
- <h3>{item.title}</h3>
539
- <p>{item.text}</p>
540
- </article>
541
- ))}
542
- </div>
543
-
544
- <div className="focus-flow-footer">
545
- <div className="focus-flow-note">
546
- You can still change frame rate and available model options after the preview loads.
547
- </div>
548
- <button className="focus-flow-button" onClick={() => setFlowStep(FLOW_STEPS.permission)}>
549
- Continue
550
- </button>
551
- </div>
552
- </div>
553
- </div>
554
- );
555
- }
556
-
557
- if (flowStep === FLOW_STEPS.permission && !cameraReady) {
558
- return (
559
- <div className="focus-flow-overlay">
560
- <div className="focus-flow-card">
561
- <div className="focus-flow-header">
562
- <div>
563
- <div className="focus-flow-eyebrow">Camera Setup</div>
564
- <h2>Enable camera access</h2>
565
- </div>
566
- <div className="focus-flow-icon">
567
- <CameraIcon />
568
- </div>
569
- </div>
570
-
571
- <p className="focus-flow-lead">
572
- Once access is granted, your preview appears here and the rest of the Focus page
573
- behaves like the other dashboard screens.
574
- </p>
575
-
576
- <div className="focus-flow-steps">
577
- {permissionSteps.map((item, index) => (
578
- <div key={item.title} className="focus-flow-step">
579
- <div className="focus-flow-step-number">{index + 1}</div>
580
- <div className="focus-flow-step-copy">
581
- <h3>{item.title}</h3>
582
- <p>{item.text}</p>
583
- </div>
584
- </div>
585
- ))}
586
- </div>
587
-
588
- {cameraError ? <div className="focus-inline-error">{cameraError}</div> : null}
589
-
590
- <div className="focus-flow-footer">
591
- <button
592
- type="button"
593
- className="focus-flow-secondary"
594
- onClick={() => setFlowStep(FLOW_STEPS.intro)}
595
- >
596
- Back
597
- </button>
598
- <button className="focus-flow-button" onClick={handleEnableCamera}>
599
- Enable Camera
600
- </button>
601
- </div>
602
- </div>
603
- </div>
604
- );
605
- }
606
-
607
- return null;
608
- };
609
 
610
  return (
611
  <main id="page-b" className="page" style={pageStyle}>
612
- {renderIntroCard()}
613
-
614
- <section id="display-area" className="focus-display-shell">
615
  <video
616
  ref={pipVideoRef}
617
  muted
@@ -625,7 +359,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
625
  pointerEvents: 'none'
626
  }}
627
  />
628
- {/* local video (hidden, for capture) */}
629
  <video
630
  ref={localVideoRef}
631
  muted
@@ -634,7 +368,7 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
634
  style={{ display: 'none' }}
635
  />
636
 
637
- {/* processed video (canvas) */}
638
  <canvas
639
  ref={displayCanvasRef}
640
  width={640}
@@ -643,25 +377,11 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
643
  width: '100%',
644
  height: '100%',
645
  objectFit: 'contain',
646
- backgroundColor: '#101010'
647
  }}
648
  />
649
 
650
- {flowStep === FLOW_STEPS.ready ? (
651
- <>
652
- <div className={`focus-state-pill ${focusState}`}>
653
- <span className="focus-state-dot" />
654
- {focusStateLabel}
655
- </div>
656
- {!cameraReady && !videoManager?.isStreaming ? (
657
- <div className="focus-idle-overlay">
658
- <p>Camera is paused.</p>
659
- <span>Use Start to enable the camera and begin detection.</span>
660
- </div>
661
- ) : null}
662
- </>
663
- ) : null}
664
-
665
  {sessionResult && (
666
  <div className="session-result-overlay">
667
  <h3>Session Complete!</h3>
@@ -691,41 +411,42 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
691
  </div>
692
  )}
693
 
694
- {role === 'admin' && stats && stats.isStreaming ? (
695
- <div className="focus-debug-panel">
 
 
 
 
 
 
 
 
 
 
 
696
  <div>Session: {stats.sessionId}</div>
697
  <div>Sent: {stats.framesSent}</div>
698
  <div>Processed: {stats.framesProcessed}</div>
699
  <div>Latency: {stats.avgLatency.toFixed(0)}ms</div>
700
  <div>Status: {stats.currentStatus ? 'Focused' : 'Not Focused'}</div>
701
  <div>Confidence: {(stats.lastConfidence * 100).toFixed(1)}%</div>
702
- {systemStats && systemStats.cpu_percent != null && (
703
- <div style={{ marginTop: '6px', borderTop: '1px solid #444', paddingTop: '4px' }}>
704
- <div>CPU: {systemStats.cpu_percent}%</div>
705
- <div>RAM: {systemStats.memory_percent}% ({systemStats.memory_used_mb}/{systemStats.memory_total_mb} MB)</div>
706
- </div>
707
- )}
708
  </div>
709
- ) : null}
710
  </section>
711
 
712
- {/* Server CPU / Memory (always visible) */}
713
- {systemStats && (systemStats.cpu_percent != null || systemStats.memory_percent != null) && (
714
  <section style={{
715
  display: 'flex',
716
  alignItems: 'center',
717
  justifyContent: 'center',
718
- gap: '16px',
719
- padding: '6px 12px',
720
- background: 'rgba(0,0,0,0.3)',
721
  borderRadius: '8px',
722
- margin: '6px auto',
723
- maxWidth: '400px',
724
- fontSize: '13px',
725
- color: '#aaa'
726
  }}>
727
- <span title="Server CPU">CPU: <strong style={{ color: '#8f8' }}>{systemStats.cpu_percent}%</strong></span>
728
- <span title="Server memory">RAM: <strong style={{ color: '#8af' }}>{systemStats.memory_percent}%</strong> ({systemStats.memory_used_mb}/{systemStats.memory_total_mb} MB)</span>
729
  <span style={{ color: '#aaa', fontSize: '13px', marginRight: '4px' }}>Model:</span>
730
  {availableModels.map(name => (
731
  <button
@@ -788,93 +509,73 @@ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActiv
788
  </section>
789
  )}
790
 
791
- {flowStep === FLOW_STEPS.ready ? (
792
- <>
793
- {availableModels.length > 0 ? (
794
- <section className="focus-model-strip">
795
- <span className="focus-model-label">Model:</span>
796
- {availableModels.map((name) => (
797
- <button
798
- key={name}
799
- onClick={() => handleModelChange(name)}
800
- className={`focus-model-button ${currentModel === name ? 'active' : ''}`}
801
- >
802
- {name}
803
- </button>
804
- ))}
805
- </section>
806
- ) : null}
807
-
808
- <section id="timeline-area">
809
- <div className="timeline-label">Timeline</div>
810
- <div id="timeline-visuals">
811
- {timelineEvents.map((event, index) => (
812
- <div
813
- key={index}
814
- className="timeline-block"
815
- style={{
816
- backgroundColor: event.isFocused ? '#00FF00' : '#FF0000',
817
- width: '10px',
818
- height: '20px',
819
- display: 'inline-block',
820
- marginRight: '2px',
821
- borderRadius: '2px'
822
- }}
823
- title={event.isFocused ? 'Focused' : 'Distracted'}
824
- />
825
- ))}
826
- </div>
827
- <div id="timeline-line" />
828
- </section>
829
-
830
- <section id="control-panel">
831
- <button id="btn-cam-start" className="action-btn green" onClick={handleStart} disabled={isStarting}>
832
- {isStarting ? 'Starting...' : 'Start'}
833
- </button>
834
-
835
- <button id="btn-floating" className="action-btn yellow" onClick={handlePiP}>
836
- Floating Window
837
- </button>
838
- <button
839
- id="btn-preview"
840
- className="action-btn"
841
- style={{ backgroundColor: '#ff7a52' }}
842
- onClick={handlePreview}
843
- >
844
- Preview Result
845
- </button>
846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
847
 
848
- <button id="btn-cam-stop" className="action-btn red" onClick={handleStop}>
849
- Stop
850
- </button>
851
- </section>
852
-
853
- {cameraError ? (
854
- <div className="focus-inline-error focus-inline-error-standalone">{cameraError}</div>
855
- ) : null}
856
-
857
- <section id="frame-control">
858
- <label htmlFor="frame-slider">Frame Rate (FPS)</label>
859
- <input
860
- type="range"
861
- id="frame-slider"
862
- min="10"
863
- max="30"
864
- value={currentFrame}
865
- onChange={(e) => handleFrameChange(e.target.value)}
866
- />
867
- <input
868
- type="number"
869
- id="frame-input"
870
- min="10"
871
- max="30"
872
- value={currentFrame}
873
- onChange={(e) => handleFrameChange(e.target.value)}
874
- />
875
- </section>
876
- </>
877
- ) : null}
878
 
879
  {/* Calibration overlay (fixed fullscreen, must be outside overflow:hidden containers) */}
880
  <CalibrationOverlay calibration={calibration} videoManager={videoManager} />
 
1
  import React, { useState, useEffect, useRef } from 'react';
2
  import CalibrationOverlay from './CalibrationOverlay';
3
 
4
+ function FocusPageLocal({ videoManager, sessionResult, setSessionResult, isActive }) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  const [currentFrame, setCurrentFrame] = useState(15);
6
  const [timelineEvents, setTimelineEvents] = useState([]);
7
  const [stats, setStats] = useState(null);
 
8
  const [availableModels, setAvailableModels] = useState([]);
9
  const [currentModel, setCurrentModel] = useState('mlp');
 
 
 
 
 
10
  const [calibration, setCalibration] = useState(null);
11
  const [l2csBoost, setL2csBoost] = useState(false);
12
  const [l2csBoostAvailable, setL2csBoostAvailable] = useState(false);
13
 
14
  const localVideoRef = useRef(null);
15
  const displayCanvasRef = useRef(null);
16
+ const pipVideoRef = useRef(null); // 用于 PiP 的隐藏 video 元素
17
  const pipStreamRef = useRef(null);
 
18
 
19
+ // 辅助函数:格式化时间
20
  const formatDuration = (seconds) => {
21
+ if (seconds === 0) return "0s";
22
  const mins = Math.floor(seconds / 60);
23
  const secs = Math.floor(seconds % 60);
24
  return `${mins}m ${secs}s`;
25
  };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  useEffect(() => {
28
  if (!videoManager) return;
29
 
30
  const originalOnStatusUpdate = videoManager.callbacks.onStatusUpdate;
 
 
31
  videoManager.callbacks.onStatusUpdate = (isFocused) => {
32
+ setTimelineEvents(prev => {
33
  const newEvents = [...prev, { isFocused, timestamp: Date.now() }];
34
  if (newEvents.length > 60) newEvents.shift();
35
  return newEvents;
36
  });
 
37
  if (originalOnStatusUpdate) originalOnStatusUpdate(isFocused);
38
  };
39
 
 
 
 
 
 
 
40
  videoManager.callbacks.onCalibrationUpdate = (cal) => {
41
  setCalibration(cal && cal.active ? { ...cal } : null);
42
  };
 
59
  // Fetch available models on mount
60
  useEffect(() => {
61
  fetch('/api/models')
62
+ .then(res => res.json())
63
+ .then(data => {
64
  if (data.available) setAvailableModels(data.available);
65
  if (data.current) setCurrentModel(data.current);
66
  if (data.l2cs_boost !== undefined) setL2csBoost(data.l2cs_boost);
67
  if (data.l2cs_boost_available !== undefined) setL2csBoostAvailable(data.l2cs_boost_available);
68
  })
69
+ .catch(err => console.error('Failed to fetch models:', err));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  }, []);
71
 
72
  const handleModelChange = async (modelName) => {
 
87
  }
88
  };
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  const handleBoostToggle = async () => {
91
  const next = !l2csBoost;
92
  try {
 
103
 
104
  const handleStart = async () => {
105
  try {
106
+ if (videoManager) {
107
+ setSessionResult(null);
108
+ setTimelineEvents([]);
 
 
109
 
110
+ console.log('Initializing local camera...');
111
  await videoManager.initCamera(localVideoRef.current, displayCanvasRef.current);
112
+ console.log('Camera initialized');
 
 
113
 
114
+ console.log('Starting local streaming...');
115
+ await videoManager.startStreaming();
116
+ console.log('Streaming started successfully');
117
+ }
118
  } catch (err) {
 
 
 
119
  console.error('Start error:', err);
120
+ let errorMessage = "Failed to start: ";
121
+
122
+ if (err.name === 'NotAllowedError') {
123
+ errorMessage += "Camera permission denied. Please allow camera access.";
124
+ } else if (err.name === 'NotFoundError') {
125
+ errorMessage += "No camera found. Please connect a camera.";
126
+ } else if (err.name === 'NotReadableError') {
127
+ errorMessage += "Camera is already in use by another application.";
128
+ } else {
129
+ errorMessage += err.message || "Unknown error occurred.";
130
+ }
131
+
132
+ alert(errorMessage + "\n\nCheck browser console for details.");
133
  }
134
  };
135
 
136
  const handleStop = async () => {
137
  if (videoManager) {
138
+ videoManager.stopStreaming();
139
  }
140
  try {
141
  if (document.pictureInPictureElement === pipVideoRef.current) {
 
147
  pipVideoRef.current.srcObject = null;
148
  }
149
  if (pipStreamRef.current) {
150
+ pipStreamRef.current.getTracks().forEach(t => t.stop());
151
  pipStreamRef.current = null;
152
  }
 
 
 
153
  };
154
 
155
  const handlePiP = async () => {
156
  try {
157
+ // 检查是否有视频管理器和是否在运行
158
  if (!videoManager || !videoManager.isStreaming) {
159
  alert('Please start the video first.');
160
  return;
 
165
  return;
166
  }
167
 
168
+ // 如果已经在 PiP 模式,且是本视频,退出
169
  if (document.pictureInPictureElement === pipVideoRef.current) {
170
  await document.exitPictureInPicture();
171
  console.log('PiP exited');
172
  return;
173
  }
174
 
175
+ // 检查浏览器支持
176
  if (!document.pictureInPictureEnabled) {
177
  alert('Picture-in-Picture is not supported in this browser.');
178
  return;
179
  }
180
 
181
+ // 创建或获取 PiP video 元素
182
  const pipVideo = pipVideoRef.current;
183
  if (!pipVideo) {
184
  alert('PiP video element not ready.');
 
187
 
188
  const isSafariPiP = typeof pipVideo.webkitSetPresentationMode === 'function';
189
 
190
+ // 优先用画布流(带检测框),失败再回退到摄像头流
191
  let stream = pipStreamRef.current;
192
  if (!stream) {
193
  const capture = displayCanvasRef.current.captureStream;
 
205
  pipStreamRef.current = stream;
206
  }
207
 
208
+ // 确保流有轨道
209
  if (!stream || stream.getTracks().length === 0) {
210
  alert('Failed to capture video stream from canvas.');
211
  return;
 
213
 
214
  pipVideo.srcObject = stream;
215
 
216
+ // 播放视频(Safari 可能不会触发 onloadedmetadata)
217
  if (pipVideo.readyState < 2) {
218
  await new Promise((resolve) => {
219
  const onReady = () => {
 
223
  };
224
  pipVideo.addEventListener('loadeddata', onReady);
225
  pipVideo.addEventListener('canplay', onReady);
226
+ // 兜底:短延迟后继续尝试
227
  setTimeout(resolve, 600);
228
  });
229
  }
230
 
231
  try {
232
  await pipVideo.play();
233
+ } catch (_) {
234
+ // Safari 可能拒绝自动播放,但仍可进入 PiP
235
+ }
236
 
237
+ // Safari 支持(优先)
238
  if (isSafariPiP) {
239
  try {
240
  pipVideo.webkitSetPresentationMode('picture-in-picture');
241
  console.log('PiP activated (Safari)');
242
  return;
243
  } catch (e) {
244
+ // 如果画布流失败,回退到摄像头流再试一次
245
  const cameraStream = localVideoRef.current?.srcObject;
246
  if (cameraStream && cameraStream !== pipVideo.srcObject) {
247
  pipVideo.srcObject = cameraStream;
 
256
  }
257
  }
258
 
259
+ // 标准 API
260
  if (typeof pipVideo.requestPictureInPicture === 'function') {
261
  await pipVideo.requestPictureInPicture();
262
  console.log('PiP activated');
 
266
 
267
  } catch (err) {
268
  console.error('PiP error:', err);
269
+ alert('Failed to enter Picture-in-Picture: ' + err.message);
270
  }
271
  };
272
 
 
275
  };
276
 
277
  const handleFrameChange = (val) => {
278
+ const rate = parseInt(val);
279
  setCurrentFrame(rate);
280
  if (videoManager) {
281
  videoManager.setFrameRate(rate);
 
288
  return;
289
  }
290
 
291
+ // 获取当前统计数据
292
  const currentStats = videoManager.getStats();
293
 
294
  if (!currentStats.sessionId) {
 
296
  return;
297
  }
298
 
299
+ // 计算当前持续时间(从 session 开始到现在)
300
  const sessionDuration = Math.floor((Date.now() - (videoManager.sessionStartTime || Date.now())) / 1000);
301
 
302
+ // 计算当前专注分数
303
  const focusScore = currentStats.framesProcessed > 0
304
  ? (currentStats.framesProcessed * (currentStats.currentStatus ? 1 : 0)) / currentStats.framesProcessed
305
  : 0;
306
 
307
+ // 显示当前实时数据
308
  setSessionResult({
309
  duration_seconds: sessionDuration,
310
  focus_score: focusScore,
 
328
  pointerEvents: 'none'
329
  };
330
 
331
+ useEffect(() => {
332
+ return () => {
333
+ if (pipVideoRef.current) {
334
+ pipVideoRef.current.pause();
335
+ pipVideoRef.current.srcObject = null;
336
+ }
337
+ if (pipStreamRef.current) {
338
+ pipStreamRef.current.getTracks().forEach(t => t.stop());
339
+ pipStreamRef.current = null;
340
+ }
341
+ };
342
+ }, []);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
  return (
345
  <main id="page-b" className="page" style={pageStyle}>
346
+ {/* 1. Camera / Display Area */}
347
+ <section id="display-area" style={{ position: 'relative', overflow: 'hidden' }}>
348
+ {/* 用于 PiP 的隐藏 video 元素(保持在 DOM 以提高兼容性) */}
349
  <video
350
  ref={pipVideoRef}
351
  muted
 
359
  pointerEvents: 'none'
360
  }}
361
  />
362
+ {/* 本地视频流(隐藏,仅用于截图) */}
363
  <video
364
  ref={localVideoRef}
365
  muted
 
368
  style={{ display: 'none' }}
369
  />
370
 
371
+ {/* 显示处理后的视频(使用 Canvas) */}
372
  <canvas
373
  ref={displayCanvasRef}
374
  width={640}
 
377
  width: '100%',
378
  height: '100%',
379
  objectFit: 'contain',
380
+ backgroundColor: '#000'
381
  }}
382
  />
383
 
384
+ {/* 结果覆盖层 */}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  {sessionResult && (
386
  <div className="session-result-overlay">
387
  <h3>Session Complete!</h3>
 
411
  </div>
412
  )}
413
 
414
+ {/* 性能统计显示(开发模式) */}
415
+ {stats && stats.isStreaming && (
416
+ <div style={{
417
+ position: 'absolute',
418
+ top: '10px',
419
+ right: '10px',
420
+ background: 'rgba(0,0,0,0.7)',
421
+ color: 'white',
422
+ padding: '10px',
423
+ borderRadius: '5px',
424
+ fontSize: '12px',
425
+ fontFamily: 'monospace'
426
+ }}>
427
  <div>Session: {stats.sessionId}</div>
428
  <div>Sent: {stats.framesSent}</div>
429
  <div>Processed: {stats.framesProcessed}</div>
430
  <div>Latency: {stats.avgLatency.toFixed(0)}ms</div>
431
  <div>Status: {stats.currentStatus ? 'Focused' : 'Not Focused'}</div>
432
  <div>Confidence: {(stats.lastConfidence * 100).toFixed(1)}%</div>
 
 
 
 
 
 
433
  </div>
434
+ )}
435
  </section>
436
 
437
+ {/* 2. Model Selector */}
438
+ {availableModels.length > 0 && (
439
  <section style={{
440
  display: 'flex',
441
  alignItems: 'center',
442
  justifyContent: 'center',
443
+ gap: '8px',
444
+ padding: '8px 16px',
445
+ background: '#1a1a2e',
446
  borderRadius: '8px',
447
+ margin: '8px auto',
448
+ maxWidth: '600px'
 
 
449
  }}>
 
 
450
  <span style={{ color: '#aaa', fontSize: '13px', marginRight: '4px' }}>Model:</span>
451
  {availableModels.map(name => (
452
  <button
 
509
  </section>
510
  )}
511
 
512
+ {/* 3. Timeline Area */}
513
+ <section id="timeline-area">
514
+ <div className="timeline-label">Timeline</div>
515
+ <div id="timeline-visuals">
516
+ {timelineEvents.map((event, index) => (
517
+ <div
518
+ key={index}
519
+ className="timeline-block"
520
+ style={{
521
+ backgroundColor: event.isFocused ? '#00FF00' : '#FF0000',
522
+ width: '10px',
523
+ height: '20px',
524
+ display: 'inline-block',
525
+ marginRight: '2px',
526
+ borderRadius: '2px'
527
+ }}
528
+ title={event.isFocused ? 'Focused' : 'Distracted'}
529
+ />
530
+ ))}
531
+ </div>
532
+ <div id="timeline-line"></div>
533
+ </section>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
+ {/* 4. Control Buttons */}
536
+ <section id="control-panel">
537
+ <button id="btn-cam-start" className="action-btn green" onClick={handleStart}>
538
+ Start
539
+ </button>
540
+
541
+ <button id="btn-floating" className="action-btn yellow" onClick={handleFloatingWindow}>
542
+ Floating Window
543
+ </button>
544
+
545
+ <button
546
+ id="btn-preview"
547
+ className="action-btn"
548
+ style={{ backgroundColor: '#6c5ce7' }}
549
+ onClick={handlePreview}
550
+ >
551
+ Preview Result
552
+ </button>
553
+
554
+ <button id="btn-cam-stop" className="action-btn red" onClick={handleStop}>
555
+ Stop
556
+ </button>
557
+ </section>
558
 
559
+ {/* 5. Frame Control */}
560
+ <section id="frame-control">
561
+ <label htmlFor="frame-slider">Frame Rate (FPS)</label>
562
+ <input
563
+ type="range"
564
+ id="frame-slider"
565
+ min="10"
566
+ max="30"
567
+ value={currentFrame}
568
+ onChange={(e) => handleFrameChange(e.target.value)}
569
+ />
570
+ <input
571
+ type="number"
572
+ id="frame-input"
573
+ min="10"
574
+ max="30"
575
+ value={currentFrame}
576
+ onChange={(e) => handleFrameChange(e.target.value)}
577
+ />
578
+ </section>
 
 
 
 
 
 
 
 
 
 
579
 
580
  {/* Calibration overlay (fixed fullscreen, must be outside overflow:hidden containers) */}
581
  <CalibrationOverlay calibration={calibration} videoManager={videoManager} />
src/utils/VideoManagerLocal.js CHANGED
@@ -1,12 +1,12 @@
1
  // src/utils/VideoManagerLocal.js
2
- // Local video processing implementation using WebSocket + Canvas, without WebRTC.
3
 
4
  export class VideoManagerLocal {
5
  constructor(callbacks) {
6
  this.callbacks = callbacks || {};
7
 
8
- this.localVideoElement = null; // Local camera preview element.
9
- this.displayVideoElement = null; // Processed output display element.
10
  this.canvas = null;
11
  this.stream = null;
12
  this.ws = null;
@@ -14,16 +14,15 @@ export class VideoManagerLocal {
14
  this.isStreaming = false;
15
  this.sessionId = null;
16
  this.sessionStartTime = null;
17
- this.frameRate = 15; // Lower FPS reduces transfer and processing load.
18
  this.captureInterval = null;
19
- this.reconnectTimeout = null;
20
 
21
- // Status smoothing
22
  this.currentStatus = false;
23
  this.statusBuffer = [];
24
  this.bufferSize = 3;
25
 
26
- // Detection data
27
  this.latestDetectionData = null;
28
  this.lastConfidence = 0;
29
 
@@ -33,7 +32,7 @@ export class VideoManagerLocal {
33
  // Continuous render loop
34
  this._animFrameId = null;
35
 
36
- // Notification state
37
  this.notificationEnabled = true;
38
  this.notificationThreshold = 30;
39
  this.unfocusedStartTime = null;
@@ -51,27 +50,16 @@ export class VideoManagerLocal {
51
  success: false,
52
  };
53
 
54
- // Performance metrics
55
  this.stats = {
56
  framesSent: 0,
57
  framesProcessed: 0,
58
  avgLatency: 0,
59
  lastLatencies: []
60
  };
61
-
62
- // Calibration state (9-point gaze calibration)
63
- this.calibrationState = {
64
- active: false,
65
- collecting: false,
66
- done: false,
67
- success: false,
68
- target: [0.5, 0.5],
69
- index: 0,
70
- numPoints: 9
71
- };
72
  }
73
 
74
- // Initialize the camera
75
  async initCamera(localVideoRef, displayCanvasRef) {
76
  try {
77
  console.log('Initializing local camera...');
@@ -88,13 +76,13 @@ export class VideoManagerLocal {
88
  this.localVideoElement = localVideoRef;
89
  this.displayCanvas = displayCanvasRef;
90
 
91
- // Show the local camera stream
92
  if (this.localVideoElement) {
93
  this.localVideoElement.srcObject = this.stream;
94
  this.localVideoElement.play();
95
  }
96
 
97
- // Capture at 640x480 for L2CS / gaze (matches HF commit 2eba0cc).
98
  this.canvas = document.createElement('canvas');
99
  this.canvas.width = 640;
100
  this.canvas.height = 480;
@@ -107,7 +95,7 @@ export class VideoManagerLocal {
107
  }
108
  }
109
 
110
- // Start streaming
111
  async startStreaming() {
112
  if (!this.stream) {
113
  throw new Error('Camera not initialized');
@@ -121,64 +109,35 @@ export class VideoManagerLocal {
121
  console.log('Starting WebSocket streaming...');
122
  this.isStreaming = true;
123
 
124
- try {
125
- // Fetch tessellation topology (once)
126
- if (!this._tessellation) {
127
- try {
128
- const res = await fetch('/api/mesh-topology');
129
- const data = await res.json();
130
- this._tessellation = data.tessellation; // [[start, end], ...]
131
- } catch (e) {
132
- console.warn('Failed to fetch mesh topology:', e);
133
- }
134
  }
 
135
 
136
- // Request notification permission
137
- await this.requestNotificationPermission();
138
- await this.loadNotificationSettings();
139
-
140
- // Open the WebSocket connection
141
- await this.connectWebSocket();
142
-
143
- // Start sending captured frames on a timer
144
- this.startCapture();
145
-
146
- // Start continuous render loop for smooth video
147
- this._lastDetection = null;
148
- this._startRenderLoop();
149
-
150
- console.log('Streaming started');
151
- } catch (error) {
152
- this.isStreaming = false;
153
- this._stopRenderLoop();
154
- this._lastDetection = null;
155
 
156
- if (this.captureInterval) {
157
- clearInterval(this.captureInterval);
158
- this.captureInterval = null;
159
- }
160
 
161
- if (this.reconnectTimeout) {
162
- clearTimeout(this.reconnectTimeout);
163
- this.reconnectTimeout = null;
164
- }
165
 
166
- if (this.ws) {
167
- this.ws.onopen = null;
168
- this.ws.onmessage = null;
169
- this.ws.onerror = null;
170
- this.ws.onclose = null;
171
- try {
172
- this.ws.close();
173
- } catch (_) {}
174
- this.ws = null;
175
- }
176
 
177
- throw error instanceof Error ? error : new Error('Failed to start video streaming.');
178
- }
179
  }
180
 
181
- // Connect the WebSocket
182
  async connectWebSocket() {
183
  return new Promise((resolve, reject) => {
184
  const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
@@ -186,28 +145,17 @@ export class VideoManagerLocal {
186
 
187
  console.log('Connecting to WebSocket:', wsUrl);
188
 
189
- const socket = new WebSocket(wsUrl);
190
- this.ws = socket;
191
-
192
- let settled = false;
193
- let opened = false;
194
- const rejectWithMessage = (message) => {
195
- if (settled) return;
196
- settled = true;
197
- reject(new Error(message));
198
- };
199
 
200
- socket.onopen = () => {
201
- opened = true;
202
- settled = true;
203
  console.log('WebSocket connected');
204
 
205
- // Send the start-session control message
206
- socket.send(JSON.stringify({ type: 'start_session' }));
207
  resolve();
208
  };
209
 
210
- socket.onmessage = (event) => {
211
  try {
212
  const data = JSON.parse(event.data);
213
  this.handleServerMessage(data);
@@ -216,40 +164,22 @@ export class VideoManagerLocal {
216
  }
217
  };
218
 
219
- socket.onerror = () => {
220
- console.error('WebSocket error:', { url: wsUrl, readyState: socket.readyState });
221
- rejectWithMessage(`Failed to connect to ${wsUrl}. Check that the backend server is running and reachable.`);
222
  };
223
 
224
- socket.onclose = (event) => {
225
- console.log('WebSocket disconnected', event.code, event.reason);
226
- if (this.ws === socket) {
227
- this.ws = null;
228
- }
229
-
230
- if (!opened) {
231
- rejectWithMessage(`WebSocket closed before connection was established (${event.code || 'no code'}). Check that the backend server is running on the expected port.`);
232
- return;
233
- }
234
-
235
  if (this.isStreaming) {
236
  console.log('Attempting to reconnect...');
237
- if (this.reconnectTimeout) {
238
- clearTimeout(this.reconnectTimeout);
239
- }
240
- this.reconnectTimeout = setTimeout(() => {
241
- this.reconnectTimeout = null;
242
- if (!this.isStreaming) return;
243
- this.connectWebSocket().catch((error) => {
244
- console.error('Reconnect failed:', error);
245
- });
246
- }, 2000);
247
  }
248
  };
249
  });
250
  }
251
 
252
- // Capture and send frames (binary blobs for speed)
253
  startCapture() {
254
  const interval = 1000 / this.frameRate;
255
  this._sendingBlob = false; // prevent overlapping toBlob calls
@@ -294,8 +224,7 @@ export class VideoManagerLocal {
294
  // Overlay last known detection results
295
  const data = this._lastDetection;
296
  if (data) {
297
- const isL2cs = data.model === 'l2cs';
298
- if (data.landmarks && !isL2cs) {
299
  this.drawFaceMesh(ctx, data.landmarks, w, h);
300
  }
301
  // Top HUD bar (matching live_demo.py)
@@ -334,82 +263,6 @@ export class VideoManagerLocal {
334
  ctx.fillText(`yaw:${data.yaw > 0 ? '+' : ''}${data.yaw.toFixed(0)} pitch:${data.pitch > 0 ? '+' : ''}${data.pitch.toFixed(0)} roll:${data.roll > 0 ? '+' : ''}${data.roll.toFixed(0)}`, w - 10, 48);
335
  ctx.textAlign = 'left';
336
  }
337
-
338
- // Gaze pointer removed from camera — shown in mini-map only.
339
-
340
- // Eye gaze (L2CS): iris-based arrows matching live_demo.py
341
- if (isL2cs && data.landmarks) {
342
- const lm = data.landmarks;
343
- const getPt = (idx) => {
344
- if (!lm) return null;
345
- if (Array.isArray(lm)) return lm[idx] || null;
346
- return lm[String(idx)] || null;
347
- };
348
-
349
- // Draw eye contours (green)
350
- this._drawPolyline(ctx, lm, VideoManagerLocal.LEFT_EYE, w, h, '#00FF00', 2, true);
351
- this._drawPolyline(ctx, lm, VideoManagerLocal.RIGHT_EYE, w, h, '#00FF00', 2, true);
352
-
353
- // EAR key points (yellow)
354
- for (const earIndices of [VideoManagerLocal.LEFT_EAR_POINTS, VideoManagerLocal.RIGHT_EAR_POINTS]) {
355
- for (const idx of earIndices) {
356
- const pt = getPt(idx);
357
- if (!pt) continue;
358
- ctx.beginPath();
359
- ctx.arc(pt[0] * w, pt[1] * h, 3, 0, 2 * Math.PI);
360
- ctx.fillStyle = '#FFFF00';
361
- ctx.fill();
362
- }
363
- }
364
-
365
- // Irises + gaze lines (matching live_demo.py)
366
- const irisSets = [
367
- { iris: VideoManagerLocal.LEFT_IRIS, inner: 133, outer: 33 },
368
- { iris: VideoManagerLocal.RIGHT_IRIS, inner: 362, outer: 263 },
369
- ];
370
- for (const { iris, inner, outer } of irisSets) {
371
- const centerPt = getPt(iris[0]);
372
- if (!centerPt) continue;
373
- const cx = centerPt[0] * w, cy = centerPt[1] * h;
374
-
375
- // Iris circle (magenta)
376
- let radiusSum = 0, count = 0;
377
- for (let i = 1; i < iris.length; i++) {
378
- const pt = getPt(iris[i]);
379
- if (!pt) continue;
380
- radiusSum += Math.hypot(pt[0] * w - cx, pt[1] * h - cy);
381
- count++;
382
- }
383
- const radius = Math.max(count > 0 ? radiusSum / count : 3, 2);
384
- ctx.beginPath();
385
- ctx.arc(cx, cy, radius, 0, 2 * Math.PI);
386
- ctx.strokeStyle = '#FF00FF';
387
- ctx.lineWidth = 2;
388
- ctx.stroke();
389
-
390
- // Iris center dot (white)
391
- ctx.beginPath();
392
- ctx.arc(cx, cy, 2, 0, 2 * Math.PI);
393
- ctx.fillStyle = '#FFFFFF';
394
- ctx.fill();
395
-
396
- // Gaze direction line (red) — from iris center, 3x displacement
397
- const innerPt = getPt(inner);
398
- const outerPt = getPt(outer);
399
- if (innerPt && outerPt) {
400
- const eyeCx = (innerPt[0] + outerPt[0]) / 2.0 * w;
401
- const eyeCy = (innerPt[1] + outerPt[1]) / 2.0 * h;
402
- const dx = cx - eyeCx;
403
- const dy = cy - eyeCy;
404
- ctx.beginPath();
405
- ctx.moveTo(cx, cy);
406
- ctx.lineTo(cx + dx * 3, cy + dy * 3);
407
- ctx.strokeStyle = '#FF0000';
408
- ctx.lineWidth = 1;
409
- ctx.stroke();
410
- }
411
- }
412
- }
413
  }
414
  // Gaze pointer (L2CS + calibration)
415
  if (data && data.gaze_x !== undefined && data.gaze_y !== undefined) {
@@ -443,7 +296,7 @@ export class VideoManagerLocal {
443
  }
444
  }
445
 
446
- // Handle messages from the server
447
  handleServerMessage(data) {
448
  switch (data.type) {
449
  case 'session_started':
@@ -497,70 +350,6 @@ export class VideoManagerLocal {
497
  on_screen: data.on_screen,
498
  };
499
  this.drawDetectionResult(detectionData);
500
-
501
- // Emit gaze data for mini-map
502
- if (this.callbacks.onGazeData) {
503
- this.callbacks.onGazeData({
504
- gaze_x: data.gaze_x != null ? data.gaze_x : null,
505
- gaze_y: data.gaze_y != null ? data.gaze_y : null,
506
- on_screen: data.on_screen != null ? data.on_screen : null,
507
- });
508
- }
509
- break;
510
-
511
- case 'calibration_started':
512
- this.calibrationState = {
513
- active: true,
514
- collecting: true,
515
- done: false,
516
- success: false,
517
- target: data.target || [0.5, 0.5],
518
- index: data.index ?? 0,
519
- numPoints: data.num_points ?? 9,
520
- };
521
- if (this.callbacks.onCalibrationUpdate) {
522
- this.callbacks.onCalibrationUpdate(this.calibrationState);
523
- }
524
- break;
525
-
526
- case 'calibration_point':
527
- this.calibrationState = {
528
- ...this.calibrationState,
529
- target: data.target || [0.5, 0.5],
530
- index: data.index ?? this.calibrationState.index,
531
- };
532
- if (this.callbacks.onCalibrationUpdate) {
533
- this.callbacks.onCalibrationUpdate(this.calibrationState);
534
- }
535
- break;
536
-
537
- case 'calibration_done':
538
- this.calibrationState = {
539
- ...this.calibrationState,
540
- active: true,
541
- collecting: false,
542
- done: true,
543
- success: data.success === true,
544
- error: data.error || null,
545
- };
546
- if (this.callbacks.onCalibrationUpdate) {
547
- this.callbacks.onCalibrationUpdate(this.calibrationState);
548
- }
549
- break;
550
-
551
- case 'calibration_cancelled':
552
- this.calibrationState = {
553
- active: false,
554
- collecting: false,
555
- done: false,
556
- success: false,
557
- target: [0.5, 0.5],
558
- index: 0,
559
- numPoints: 9,
560
- };
561
- if (this.callbacks.onCalibrationUpdate) {
562
- this.callbacks.onCalibrationUpdate(this.calibrationState);
563
- }
564
  break;
565
 
566
  case 'session_ended':
@@ -891,26 +680,21 @@ export class VideoManagerLocal {
891
 
892
  this.isStreaming = false;
893
 
894
- if (this.reconnectTimeout) {
895
- clearTimeout(this.reconnectTimeout);
896
- this.reconnectTimeout = null;
897
- }
898
-
899
- // Stop the render loop
900
  this._stopRenderLoop();
901
  this._lastDetection = null;
902
 
903
- // Stop frame capture
904
  if (this.captureInterval) {
905
  clearInterval(this.captureInterval);
906
  this.captureInterval = null;
907
  }
908
 
909
- // Send the end-session request and wait for the response
910
  if (this.ws && this.ws.readyState === WebSocket.OPEN && this.sessionId) {
911
  const sessionId = this.sessionId;
912
 
913
- // Wait for the session_ended message
914
  const waitForSessionEnd = new Promise((resolve) => {
915
  const originalHandler = this.ws.onmessage;
916
  const timeout = setTimeout(() => {
@@ -928,7 +712,7 @@ export class VideoManagerLocal {
928
  this.ws.onmessage = originalHandler;
929
  resolve();
930
  } else {
931
- // Continue handling non-terminal messages
932
  this.handleServerMessage(data);
933
  }
934
  } catch (e) {
@@ -943,37 +727,37 @@ export class VideoManagerLocal {
943
  session_id: sessionId
944
  }));
945
 
946
- // Wait for the response or a timeout
947
  await waitForSessionEnd;
948
  }
949
 
950
- // Delay socket shutdown briefly so pending messages can flush
951
  await new Promise(resolve => setTimeout(resolve, 200));
952
 
953
- // Close the WebSocket
954
  if (this.ws) {
955
  this.ws.close();
956
  this.ws = null;
957
  }
958
 
959
- // Stop the camera
960
  if (this.stream) {
961
  this.stream.getTracks().forEach(track => track.stop());
962
  this.stream = null;
963
  }
964
 
965
- // Clear the video element
966
  if (this.localVideoElement) {
967
  this.localVideoElement.srcObject = null;
968
  }
969
 
970
- // Clear the canvas
971
  if (this.displayCanvas) {
972
  const ctx = this.displayCanvas.getContext('2d');
973
  ctx.clearRect(0, 0, this.displayCanvas.width, this.displayCanvas.height);
974
  }
975
 
976
- // Reset transient state
977
  this.unfocusedStartTime = null;
978
  this.lastNotificationTime = null;
979
 
@@ -985,47 +769,13 @@ export class VideoManagerLocal {
985
  this.frameRate = Math.max(10, Math.min(30, rate));
986
  console.log(`Frame rate set to ${this.frameRate} FPS`);
987
 
988
- // Restart capture if streaming is already active
989
  if (this.isStreaming && this.captureInterval) {
990
  clearInterval(this.captureInterval);
991
  this.startCapture();
992
  }
993
  }
994
 
995
- startCalibration() {
996
- if (!this.ws || this.ws.readyState !== WebSocket.OPEN) return;
997
- this.ws.send(JSON.stringify({ type: 'calibration_start' }));
998
- }
999
-
1000
- nextCalibrationPoint() {
1001
- if (!this.ws || this.ws.readyState !== WebSocket.OPEN) return;
1002
- this.ws.send(JSON.stringify({ type: 'calibration_next' }));
1003
- }
1004
-
1005
- cancelCalibration() {
1006
- if (!this.ws || this.ws.readyState !== WebSocket.OPEN) return;
1007
- this.ws.send(JSON.stringify({ type: 'calibration_cancel' }));
1008
- }
1009
-
1010
- getCalibrationState() {
1011
- return this.calibrationState;
1012
- }
1013
-
1014
- dismissCalibrationDone() {
1015
- this.calibrationState = {
1016
- active: false,
1017
- collecting: false,
1018
- done: false,
1019
- success: false,
1020
- target: [0.5, 0.5],
1021
- index: 0,
1022
- numPoints: 9,
1023
- };
1024
- if (this.callbacks.onCalibrationUpdate) {
1025
- this.callbacks.onCalibrationUpdate(this.calibrationState);
1026
- }
1027
- }
1028
-
1029
  getStats() {
1030
  return {
1031
  ...this.stats,
 
1
  // src/utils/VideoManagerLocal.js
2
+ // 本地视频处理版本 - 使用 WebSocket + Canvas,不依赖 WebRTC
3
 
4
  export class VideoManagerLocal {
5
  constructor(callbacks) {
6
  this.callbacks = callbacks || {};
7
 
8
+ this.localVideoElement = null; // 显示本地摄像头
9
+ this.displayVideoElement = null; // 显示处理后的视频
10
  this.canvas = null;
11
  this.stream = null;
12
  this.ws = null;
 
14
  this.isStreaming = false;
15
  this.sessionId = null;
16
  this.sessionStartTime = null;
17
+ this.frameRate = 15; // 降低帧率以减少网络负载
18
  this.captureInterval = null;
 
19
 
20
+ // 状态平滑处理
21
  this.currentStatus = false;
22
  this.statusBuffer = [];
23
  this.bufferSize = 3;
24
 
25
+ // 检测数据
26
  this.latestDetectionData = null;
27
  this.lastConfidence = 0;
28
 
 
32
  // Continuous render loop
33
  this._animFrameId = null;
34
 
35
+ // 通知系统
36
  this.notificationEnabled = true;
37
  this.notificationThreshold = 30;
38
  this.unfocusedStartTime = null;
 
50
  success: false,
51
  };
52
 
53
+ // 性能统计
54
  this.stats = {
55
  framesSent: 0,
56
  framesProcessed: 0,
57
  avgLatency: 0,
58
  lastLatencies: []
59
  };
 
 
 
 
 
 
 
 
 
 
 
60
  }
61
 
62
+ // 初始化摄像头
63
  async initCamera(localVideoRef, displayCanvasRef) {
64
  try {
65
  console.log('Initializing local camera...');
 
76
  this.localVideoElement = localVideoRef;
77
  this.displayCanvas = displayCanvasRef;
78
 
79
+ // 显示本地视频流
80
  if (this.localVideoElement) {
81
  this.localVideoElement.srcObject = this.stream;
82
  this.localVideoElement.play();
83
  }
84
 
85
+ // 创建用于截图的 canvas (smaller for faster encode + transfer)
86
  this.canvas = document.createElement('canvas');
87
  this.canvas.width = 640;
88
  this.canvas.height = 480;
 
95
  }
96
  }
97
 
98
+ // 开始流式处理
99
  async startStreaming() {
100
  if (!this.stream) {
101
  throw new Error('Camera not initialized');
 
109
  console.log('Starting WebSocket streaming...');
110
  this.isStreaming = true;
111
 
112
+ // Fetch tessellation topology (once)
113
+ if (!this._tessellation) {
114
+ try {
115
+ const res = await fetch('/api/mesh-topology');
116
+ const data = await res.json();
117
+ this._tessellation = data.tessellation; // [[start, end], ...]
118
+ } catch (e) {
119
+ console.warn('Failed to fetch mesh topology:', e);
 
 
120
  }
121
+ }
122
 
123
+ // 请求通知权限
124
+ await this.requestNotificationPermission();
125
+ await this.loadNotificationSettings();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ // 建立 WebSocket 连接
128
+ await this.connectWebSocket();
 
 
129
 
130
+ // 开始定期截图并发送
131
+ this.startCapture();
 
 
132
 
133
+ // Start continuous render loop for smooth video
134
+ this._lastDetection = null;
135
+ this._startRenderLoop();
 
 
 
 
 
 
 
136
 
137
+ console.log('Streaming started');
 
138
  }
139
 
140
+ // 建立 WebSocket 连接
141
  async connectWebSocket() {
142
  return new Promise((resolve, reject) => {
143
  const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
 
145
 
146
  console.log('Connecting to WebSocket:', wsUrl);
147
 
148
+ this.ws = new WebSocket(wsUrl);
 
 
 
 
 
 
 
 
 
149
 
150
+ this.ws.onopen = () => {
 
 
151
  console.log('WebSocket connected');
152
 
153
+ // 发送开始会话请求
154
+ this.ws.send(JSON.stringify({ type: 'start_session' }));
155
  resolve();
156
  };
157
 
158
+ this.ws.onmessage = (event) => {
159
  try {
160
  const data = JSON.parse(event.data);
161
  this.handleServerMessage(data);
 
164
  }
165
  };
166
 
167
+ this.ws.onerror = (error) => {
168
+ console.error('WebSocket error:', error);
169
+ reject(error);
170
  };
171
 
172
+ this.ws.onclose = () => {
173
+ console.log('WebSocket disconnected');
 
 
 
 
 
 
 
 
 
174
  if (this.isStreaming) {
175
  console.log('Attempting to reconnect...');
176
+ setTimeout(() => this.connectWebSocket(), 2000);
 
 
 
 
 
 
 
 
 
177
  }
178
  };
179
  });
180
  }
181
 
182
+ // 开始截图并发送 (binary blobs for speed)
183
  startCapture() {
184
  const interval = 1000 / this.frameRate;
185
  this._sendingBlob = false; // prevent overlapping toBlob calls
 
224
  // Overlay last known detection results
225
  const data = this._lastDetection;
226
  if (data) {
227
+ if (data.landmarks) {
 
228
  this.drawFaceMesh(ctx, data.landmarks, w, h);
229
  }
230
  // Top HUD bar (matching live_demo.py)
 
263
  ctx.fillText(`yaw:${data.yaw > 0 ? '+' : ''}${data.yaw.toFixed(0)} pitch:${data.pitch > 0 ? '+' : ''}${data.pitch.toFixed(0)} roll:${data.roll > 0 ? '+' : ''}${data.roll.toFixed(0)}`, w - 10, 48);
264
  ctx.textAlign = 'left';
265
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  }
267
  // Gaze pointer (L2CS + calibration)
268
  if (data && data.gaze_x !== undefined && data.gaze_y !== undefined) {
 
296
  }
297
  }
298
 
299
+ // 处理服务器消息
300
  handleServerMessage(data) {
301
  switch (data.type) {
302
  case 'session_started':
 
350
  on_screen: data.on_screen,
351
  };
352
  this.drawDetectionResult(detectionData);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  break;
354
 
355
  case 'session_ended':
 
680
 
681
  this.isStreaming = false;
682
 
683
+ // Stop render loop
 
 
 
 
 
684
  this._stopRenderLoop();
685
  this._lastDetection = null;
686
 
687
+ // 停止截图
688
  if (this.captureInterval) {
689
  clearInterval(this.captureInterval);
690
  this.captureInterval = null;
691
  }
692
 
693
+ // 发送结束会话请求并等待响应
694
  if (this.ws && this.ws.readyState === WebSocket.OPEN && this.sessionId) {
695
  const sessionId = this.sessionId;
696
 
697
+ // 等待 session_ended 消息
698
  const waitForSessionEnd = new Promise((resolve) => {
699
  const originalHandler = this.ws.onmessage;
700
  const timeout = setTimeout(() => {
 
712
  this.ws.onmessage = originalHandler;
713
  resolve();
714
  } else {
715
+ // 仍然处理其他消息
716
  this.handleServerMessage(data);
717
  }
718
  } catch (e) {
 
727
  session_id: sessionId
728
  }));
729
 
730
+ // 等待响应或超时
731
  await waitForSessionEnd;
732
  }
733
 
734
+ // 延迟关闭 WebSocket 确保消息发送完成
735
  await new Promise(resolve => setTimeout(resolve, 200));
736
 
737
+ // 关闭 WebSocket
738
  if (this.ws) {
739
  this.ws.close();
740
  this.ws = null;
741
  }
742
 
743
+ // 停止摄像头
744
  if (this.stream) {
745
  this.stream.getTracks().forEach(track => track.stop());
746
  this.stream = null;
747
  }
748
 
749
+ // 清空视频
750
  if (this.localVideoElement) {
751
  this.localVideoElement.srcObject = null;
752
  }
753
 
754
+ // 清空 canvas
755
  if (this.displayCanvas) {
756
  const ctx = this.displayCanvas.getContext('2d');
757
  ctx.clearRect(0, 0, this.displayCanvas.width, this.displayCanvas.height);
758
  }
759
 
760
+ // 清理状态
761
  this.unfocusedStartTime = null;
762
  this.lastNotificationTime = null;
763
 
 
769
  this.frameRate = Math.max(10, Math.min(30, rate));
770
  console.log(`Frame rate set to ${this.frameRate} FPS`);
771
 
772
+ // 重启截图(如果正在运行)
773
  if (this.isStreaming && this.captureInterval) {
774
  clearInterval(this.captureInterval);
775
  this.startCapture();
776
  }
777
  }
778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
  getStats() {
780
  return {
781
  ...this.stats,
ui/pipeline.py CHANGED
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
  import collections
4
  import glob
5
  import json
@@ -10,26 +8,23 @@ import sys
10
 
11
  import numpy as np
12
  import joblib
13
- import torch
14
- import torch.nn as nn
15
 
16
  _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17
  if _PROJECT_ROOT not in sys.path:
18
  sys.path.insert(0, _PROJECT_ROOT)
19
 
20
- from data_preparation.prepare_dataset import SELECTED_FEATURES
21
  from models.face_mesh import FaceMeshDetector
22
  from models.head_pose import HeadPoseEstimator
23
  from models.eye_scorer import EyeBehaviourScorer, compute_mar, MAR_YAWN_THRESHOLD
 
 
24
  from models.collect_features import FEATURE_NAMES, TemporalTracker, extract_features
25
 
26
- # Same 10 features used for MLP training (prepare_dataset) and inference
27
- MLP_FEATURE_NAMES = SELECTED_FEATURES["face_orientation"]
28
-
29
  _FEAT_IDX = {name: i for i, name in enumerate(FEATURE_NAMES)}
30
 
31
 
32
  def _clip_features(vec):
 
33
  out = vec.copy()
34
  _i = _FEAT_IDX
35
 
@@ -82,21 +77,19 @@ class _OutputSmoother:
82
 
83
 
84
  DEFAULT_HYBRID_CONFIG = {
85
- "use_xgb": False,
86
- "w_mlp": 0.3,
87
- "w_xgb": 0.0,
88
- "w_geo": 0.7,
89
- "threshold": 0.35,
90
  "use_yawn_veto": True,
91
- "geo_face_weight": 0.7,
92
- "geo_eye_weight": 0.3,
93
  "mar_yawn_threshold": float(MAR_YAWN_THRESHOLD),
94
- "combiner": None,
95
- "combiner_path": None,
96
  }
97
 
98
 
99
  class _RuntimeFeatureEngine:
 
 
100
  _MAG_FEATURES = ["pitch", "yaw", "head_deviation", "gaze_offset", "v_gaze", "h_gaze"]
101
  _VEL_FEATURES = ["pitch", "yaw", "h_gaze", "v_gaze", "head_deviation", "gaze_offset"]
102
  _VAR_FEATURES = ["h_gaze", "v_gaze", "pitch"]
@@ -182,9 +175,12 @@ class FaceMeshPipeline:
182
  def __init__(
183
  self,
184
  max_angle: float = 22.0,
185
- alpha: float = 0.7,
186
- beta: float = 0.3,
187
  threshold: float = 0.55,
 
 
 
188
  detector=None,
189
  ):
190
  self.detector = detector or FaceMeshDetector()
@@ -194,6 +190,16 @@ class FaceMeshPipeline:
194
  self.alpha = alpha
195
  self.beta = beta
196
  self.threshold = threshold
 
 
 
 
 
 
 
 
 
 
197
  self._smoother = _OutputSmoother()
198
 
199
  def process_frame(self, bgr_frame: np.ndarray) -> dict:
@@ -225,7 +231,17 @@ class FaceMeshPipeline:
225
  if angles is not None:
226
  out["yaw"], out["pitch"], out["roll"] = angles
227
  out["s_face"] = self.head_pose.score(landmarks, w, h)
228
- out["s_eye"] = self.eye_scorer.score(landmarks)
 
 
 
 
 
 
 
 
 
 
229
  out["mar"] = compute_mar(landmarks)
230
  out["is_yawning"] = out["mar"] > MAR_YAWN_THRESHOLD
231
 
@@ -237,6 +253,10 @@ class FaceMeshPipeline:
237
 
238
  return out
239
 
 
 
 
 
240
  def reset_session(self):
241
  self._smoother.reset()
242
 
@@ -251,45 +271,23 @@ class FaceMeshPipeline:
251
  self.close()
252
 
253
 
254
- # PyTorch MLP matching models/mlp/train.py BaseModel (10 -> 64 -> 32 -> 2)
255
- class _FocusMLP(nn.Module):
256
- def __init__(self, num_features: int, num_classes: int = 2):
257
- super().__init__()
258
- self.network = nn.Sequential(
259
- nn.Linear(num_features, 64),
260
- nn.ReLU(),
261
- nn.Linear(64, 32),
262
- nn.ReLU(),
263
- nn.Linear(32, num_classes),
264
- )
265
-
266
- def forward(self, x):
267
- return self.network(x)
268
-
269
-
270
- def _mlp_artifacts_available(model_dir: str) -> bool:
271
- pt_path = os.path.join(model_dir, "mlp_best.pt")
272
- scaler_path = os.path.join(model_dir, "scaler_mlp.joblib")
273
- return os.path.isfile(pt_path) and os.path.isfile(scaler_path)
274
-
275
-
276
- def _load_mlp_artifacts(model_dir: str):
277
- """Load PyTorch MLP + scaler from checkpoints. Returns (model, scaler, feature_names)."""
278
- pt_path = os.path.join(model_dir, "mlp_best.pt")
279
- scaler_path = os.path.join(model_dir, "scaler_mlp.joblib")
280
- if not os.path.isfile(pt_path):
281
- raise FileNotFoundError(f"No MLP checkpoint at {pt_path}")
282
- if not os.path.isfile(scaler_path):
283
- raise FileNotFoundError(f"No scaler at {scaler_path}")
284
-
285
- num_features = len(MLP_FEATURE_NAMES)
286
- num_classes = 2
287
- model = _FocusMLP(num_features, num_classes)
288
- model.load_state_dict(torch.load(pt_path, map_location="cpu", weights_only=True))
289
- model.eval()
290
-
291
- scaler = joblib.load(scaler_path)
292
- return model, scaler, list(MLP_FEATURE_NAMES)
293
 
294
 
295
  def _load_hybrid_config(model_dir: str, config_path: str | None = None):
@@ -306,41 +304,43 @@ def _load_hybrid_config(model_dir: str, config_path: str | None = None):
306
  if key in file_cfg:
307
  cfg[key] = file_cfg[key]
308
 
309
- cfg["use_xgb"] = bool(cfg.get("use_xgb", False))
310
- cfg["w_mlp"] = float(cfg.get("w_mlp", 0.3))
311
- cfg["w_xgb"] = float(cfg.get("w_xgb", 0.0))
312
  cfg["w_geo"] = float(cfg["w_geo"])
313
- if cfg["use_xgb"]:
314
- weight_sum = cfg["w_xgb"] + cfg["w_geo"]
315
- if weight_sum <= 0:
316
- raise ValueError("[HYBRID] Invalid config: w_xgb + w_geo must be > 0")
317
- cfg["w_xgb"] /= weight_sum
318
- cfg["w_geo"] /= weight_sum
319
- else:
320
- weight_sum = cfg["w_mlp"] + cfg["w_geo"]
321
- if weight_sum <= 0:
322
- raise ValueError("[HYBRID] Invalid config: w_mlp + w_geo must be > 0")
323
- cfg["w_mlp"] /= weight_sum
324
- cfg["w_geo"] /= weight_sum
325
  cfg["threshold"] = float(cfg["threshold"])
326
  cfg["use_yawn_veto"] = bool(cfg["use_yawn_veto"])
327
  cfg["geo_face_weight"] = float(cfg["geo_face_weight"])
328
  cfg["geo_eye_weight"] = float(cfg["geo_eye_weight"])
329
  cfg["mar_yawn_threshold"] = float(cfg["mar_yawn_threshold"])
330
- cfg["combiner"] = cfg.get("combiner") or None
331
- cfg["combiner_path"] = cfg.get("combiner_path") or None
332
 
333
  print(f"[HYBRID] Loaded config: {resolved}")
334
  return cfg, resolved
335
 
336
 
337
  class MLPPipeline:
338
- def __init__(self, model_dir=None, detector=None, threshold=0.23):
339
  if model_dir is None:
340
- model_dir = os.path.join(_PROJECT_ROOT, "checkpoints")
341
-
342
- self._mlp, self._scaler, self._feature_names = _load_mlp_artifacts(model_dir)
343
- self._indices = [FEATURE_NAMES.index(n) for n in self._feature_names]
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  self._detector = detector or FaceMeshDetector()
346
  self._owns_detector = detector is None
@@ -350,7 +350,7 @@ class MLPPipeline:
350
  self._temporal = TemporalTracker()
351
  self._smoother = _OutputSmoother()
352
  self._threshold = threshold
353
- print(f"[MLP] Loaded PyTorch MLP from {model_dir} | {len(self._feature_names)} features | threshold={threshold}")
354
 
355
  def process_frame(self, bgr_frame):
356
  landmarks = self._detector.process(bgr_frame)
@@ -382,13 +382,13 @@ class MLPPipeline:
382
  out["s_eye"] = float(vec[_FEAT_IDX["s_eye"]])
383
  out["mar"] = float(vec[_FEAT_IDX["mar"]])
384
 
385
- X = vec[self._indices].reshape(1, -1).astype(np.float32)
386
- X_sc = self._scaler.transform(X) if self._scaler is not None else X
387
- with torch.no_grad():
388
- x_t = torch.from_numpy(X_sc).float()
389
- logits = self._mlp(x_t)
390
- probs = torch.softmax(logits, dim=1)
391
- mlp_prob = float(probs[0, 1])
392
  out["mlp_prob"] = float(np.clip(mlp_prob, 0.0, 1.0))
393
  out["raw_score"] = self._smoother.update(out["mlp_prob"], True)
394
  out["is_focused"] = out["raw_score"] >= self._threshold
@@ -409,66 +409,62 @@ class MLPPipeline:
409
  self.close()
410
 
411
 
412
- def _resolve_xgb_path():
413
- return os.path.join(_PROJECT_ROOT, "checkpoints", "xgboost_face_orientation_best.json")
414
-
415
-
416
  class HybridFocusPipeline:
417
  def __init__(
418
  self,
419
  model_dir=None,
420
  config_path: str | None = None,
 
 
 
421
  max_angle: float = 22.0,
422
  detector=None,
423
  ):
424
  if model_dir is None:
425
  model_dir = os.path.join(_PROJECT_ROOT, "checkpoints")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  self._cfg, self._cfg_path = _load_hybrid_config(model_dir=model_dir, config_path=config_path)
427
- self._use_xgb = self._cfg["use_xgb"]
428
 
429
  self._detector = detector or FaceMeshDetector()
430
  self._owns_detector = detector is None
431
  self._head_pose = HeadPoseEstimator(max_angle=max_angle)
432
  self._eye_scorer = EyeBehaviourScorer()
433
  self._temporal = TemporalTracker()
 
 
 
 
 
 
 
 
 
 
434
  self.head_pose = self._head_pose
435
  self._smoother = _OutputSmoother()
436
 
437
- self._combiner = None
438
- combiner_path = self._cfg.get("combiner_path")
439
- if combiner_path and self._cfg.get("combiner") == "logistic":
440
- resolved_combiner = combiner_path if os.path.isabs(combiner_path) else os.path.join(model_dir, combiner_path)
441
- if not os.path.isfile(resolved_combiner):
442
- resolved_combiner = os.path.join(_PROJECT_ROOT, combiner_path)
443
- if os.path.isfile(resolved_combiner):
444
- blob = joblib.load(resolved_combiner)
445
- self._combiner = blob.get("combiner")
446
- if self._combiner is None:
447
- self._combiner = blob
448
- print(f"[HYBRID] LR combiner loaded from {resolved_combiner}")
449
- else:
450
- print(f"[HYBRID] combiner_path not found: {resolved_combiner}, using heuristic weights")
451
- if self._use_xgb:
452
- from xgboost import XGBClassifier
453
- xgb_path = _resolve_xgb_path()
454
- if not os.path.isfile(xgb_path):
455
- raise FileNotFoundError(f"No XGBoost checkpoint at {xgb_path}")
456
- self._xgb_model = XGBClassifier()
457
- self._xgb_model.load_model(xgb_path)
458
- self._xgb_indices = [FEATURE_NAMES.index(n) for n in XGBoostPipeline.SELECTED]
459
- self._mlp = None
460
- self._scaler = None
461
- self._indices = None
462
- self._feature_names = list(XGBoostPipeline.SELECTED)
463
- mode = "LR combiner" if self._combiner else f"w_xgb={self._cfg['w_xgb']:.2f}, w_geo={self._cfg['w_geo']:.2f}"
464
- print(f"[HYBRID] XGBoost+geo | {xgb_path} | {mode}, threshold={self._cfg['threshold']:.2f}")
465
- else:
466
- self._mlp, self._scaler, self._feature_names = _load_mlp_artifacts(model_dir)
467
- self._indices = [FEATURE_NAMES.index(n) for n in self._feature_names]
468
- self._xgb_model = None
469
- self._xgb_indices = None
470
- mode = "LR combiner" if self._combiner else f"w_mlp={self._cfg['w_mlp']:.2f}, w_geo={self._cfg['w_geo']:.2f}"
471
- print(f"[HYBRID] MLP+geo | {len(self._feature_names)} features | {mode}, threshold={self._cfg['threshold']:.2f}")
472
 
473
  @property
474
  def config(self) -> dict:
@@ -506,8 +502,15 @@ class HybridFocusPipeline:
506
  out["yaw"], out["pitch"], out["roll"] = angles
507
 
508
  out["s_face"] = self._head_pose.score(landmarks, w, h)
509
- out["s_eye"] = self._eye_scorer.score(landmarks)
510
- s_eye_geo = out["s_eye"]
 
 
 
 
 
 
 
511
 
512
  geo_score = (
513
  self._cfg["geo_face_weight"] * out["s_face"] +
@@ -529,32 +532,16 @@ class HybridFocusPipeline:
529
  }
530
  vec = extract_features(landmarks, w, h, self._head_pose, self._eye_scorer, self._temporal, _pre=pre)
531
  vec = _clip_features(vec)
532
-
533
- if self._use_xgb:
534
- X = vec[self._xgb_indices].reshape(1, -1).astype(np.float32)
535
- prob = self._xgb_model.predict_proba(X)[0]
536
- model_prob = float(np.clip(prob[1], 0.0, 1.0))
537
- out["mlp_prob"] = model_prob
538
- if self._combiner is not None:
539
- meta = np.array([[model_prob, out["geo_score"]]], dtype=np.float32)
540
- focus_score = float(self._combiner.predict_proba(meta)[0, 1])
541
- else:
542
- focus_score = self._cfg["w_xgb"] * model_prob + self._cfg["w_geo"] * out["geo_score"]
543
  else:
544
- X = vec[self._indices].reshape(1, -1).astype(np.float32)
545
- X_sc = self._scaler.transform(X) if self._scaler is not None else X
546
- with torch.no_grad():
547
- x_t = torch.from_numpy(X_sc).float()
548
- logits = self._mlp(x_t)
549
- probs = torch.softmax(logits, dim=1)
550
- mlp_prob = float(probs[0, 1])
551
- out["mlp_prob"] = float(np.clip(mlp_prob, 0.0, 1.0))
552
- if self._combiner is not None:
553
- meta = np.array([[out["mlp_prob"], out["geo_score"]]], dtype=np.float32)
554
- focus_score = float(self._combiner.predict_proba(meta)[0, 1])
555
- else:
556
- focus_score = self._cfg["w_mlp"] * out["mlp_prob"] + self._cfg["w_geo"] * out["geo_score"]
557
 
 
558
  out["focus_score"] = self._smoother.update(float(np.clip(focus_score, 0.0, 1.0)), True)
559
  out["raw_score"] = out["focus_score"]
560
  out["is_focused"] = out["focus_score"] >= self._cfg["threshold"]
@@ -576,16 +563,22 @@ class HybridFocusPipeline:
576
 
577
 
578
  class XGBoostPipeline:
 
 
 
579
  SELECTED = [
580
  'head_deviation', 's_face', 's_eye', 'h_gaze', 'pitch',
581
  'ear_left', 'ear_avg', 'ear_right', 'gaze_offset', 'perclos',
582
  ]
583
 
584
- def __init__(self, model_path=None, threshold=0.38):
585
  from xgboost import XGBClassifier
586
 
587
  if model_path is None:
588
- model_path = os.path.join(_PROJECT_ROOT, "checkpoints", "xgboost_face_orientation_best.json")
 
 
 
589
  if not os.path.isfile(model_path):
590
  raise FileNotFoundError(f"No XGBoost checkpoint at {model_path}")
591
 
 
 
 
1
  import collections
2
  import glob
3
  import json
 
8
 
9
  import numpy as np
10
  import joblib
 
 
11
 
12
  _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
  if _PROJECT_ROOT not in sys.path:
14
  sys.path.insert(0, _PROJECT_ROOT)
15
 
 
16
  from models.face_mesh import FaceMeshDetector
17
  from models.head_pose import HeadPoseEstimator
18
  from models.eye_scorer import EyeBehaviourScorer, compute_mar, MAR_YAWN_THRESHOLD
19
+ from models.eye_crop import extract_eye_crops
20
+ from models.eye_classifier import load_eye_classifier, GeometricOnlyClassifier
21
  from models.collect_features import FEATURE_NAMES, TemporalTracker, extract_features
22
 
 
 
 
23
  _FEAT_IDX = {name: i for i, name in enumerate(FEATURE_NAMES)}
24
 
25
 
26
  def _clip_features(vec):
27
+ """Clip raw features to the same ranges used during training."""
28
  out = vec.copy()
29
  _i = _FEAT_IDX
30
 
 
77
 
78
 
79
  DEFAULT_HYBRID_CONFIG = {
80
+ "w_mlp": 0.7,
81
+ "w_geo": 0.3,
82
+ "threshold": 0.55,
 
 
83
  "use_yawn_veto": True,
84
+ "geo_face_weight": 0.4,
85
+ "geo_eye_weight": 0.6,
86
  "mar_yawn_threshold": float(MAR_YAWN_THRESHOLD),
 
 
87
  }
88
 
89
 
90
  class _RuntimeFeatureEngine:
91
+ """Runtime feature engineering (magnitudes, velocities, variances) with EMA baselines."""
92
+
93
  _MAG_FEATURES = ["pitch", "yaw", "head_deviation", "gaze_offset", "v_gaze", "h_gaze"]
94
  _VEL_FEATURES = ["pitch", "yaw", "h_gaze", "v_gaze", "head_deviation", "gaze_offset"]
95
  _VAR_FEATURES = ["h_gaze", "v_gaze", "pitch"]
 
175
  def __init__(
176
  self,
177
  max_angle: float = 22.0,
178
+ alpha: float = 0.4,
179
+ beta: float = 0.6,
180
  threshold: float = 0.55,
181
+ eye_model_path: str | None = None,
182
+ eye_backend: str = "yolo",
183
+ eye_blend: float = 0.5,
184
  detector=None,
185
  ):
186
  self.detector = detector or FaceMeshDetector()
 
190
  self.alpha = alpha
191
  self.beta = beta
192
  self.threshold = threshold
193
+ self.eye_blend = eye_blend
194
+
195
+ self.eye_classifier = load_eye_classifier(
196
+ path=eye_model_path if eye_model_path and os.path.exists(eye_model_path) else None,
197
+ backend=eye_backend,
198
+ device="cpu",
199
+ )
200
+ self._has_eye_model = not isinstance(self.eye_classifier, GeometricOnlyClassifier)
201
+ if self._has_eye_model:
202
+ print(f"[PIPELINE] Eye model: {self.eye_classifier.name}")
203
  self._smoother = _OutputSmoother()
204
 
205
  def process_frame(self, bgr_frame: np.ndarray) -> dict:
 
231
  if angles is not None:
232
  out["yaw"], out["pitch"], out["roll"] = angles
233
  out["s_face"] = self.head_pose.score(landmarks, w, h)
234
+
235
+ s_eye_geo = self.eye_scorer.score(landmarks)
236
+ if self._has_eye_model:
237
+ left_crop, right_crop, left_bbox, right_bbox = extract_eye_crops(bgr_frame, landmarks)
238
+ out["left_bbox"] = left_bbox
239
+ out["right_bbox"] = right_bbox
240
+ s_eye_model = self.eye_classifier.predict_score([left_crop, right_crop])
241
+ out["s_eye"] = (1.0 - self.eye_blend) * s_eye_geo + self.eye_blend * s_eye_model
242
+ else:
243
+ out["s_eye"] = s_eye_geo
244
+
245
  out["mar"] = compute_mar(landmarks)
246
  out["is_yawning"] = out["mar"] > MAR_YAWN_THRESHOLD
247
 
 
253
 
254
  return out
255
 
256
+ @property
257
+ def has_eye_model(self) -> bool:
258
+ return self._has_eye_model
259
+
260
  def reset_session(self):
261
  self._smoother.reset()
262
 
 
271
  self.close()
272
 
273
 
274
+ def _latest_model_artifacts(model_dir):
275
+ model_files = sorted(glob.glob(os.path.join(model_dir, "model_*.joblib")))
276
+ if not model_files:
277
+ model_files = sorted(glob.glob(os.path.join(model_dir, "mlp_*.joblib")))
278
+ if not model_files:
279
+ return None, None, None
280
+ basename = os.path.basename(model_files[-1])
281
+ tag = ""
282
+ for prefix in ("model_", "mlp_"):
283
+ if basename.startswith(prefix):
284
+ tag = basename[len(prefix) :].replace(".joblib", "")
285
+ break
286
+ scaler_path = os.path.join(model_dir, f"scaler_{tag}.joblib")
287
+ meta_path = os.path.join(model_dir, f"meta_{tag}.npz")
288
+ if not os.path.isfile(scaler_path) or not os.path.isfile(meta_path):
289
+ return None, None, None
290
+ return model_files[-1], scaler_path, meta_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
 
293
  def _load_hybrid_config(model_dir: str, config_path: str | None = None):
 
304
  if key in file_cfg:
305
  cfg[key] = file_cfg[key]
306
 
307
+ cfg["w_mlp"] = float(cfg["w_mlp"])
 
 
308
  cfg["w_geo"] = float(cfg["w_geo"])
309
+ weight_sum = cfg["w_mlp"] + cfg["w_geo"]
310
+ if weight_sum <= 0:
311
+ raise ValueError("[HYBRID] Invalid config: w_mlp + w_geo must be > 0")
312
+ cfg["w_mlp"] /= weight_sum
313
+ cfg["w_geo"] /= weight_sum
 
 
 
 
 
 
 
314
  cfg["threshold"] = float(cfg["threshold"])
315
  cfg["use_yawn_veto"] = bool(cfg["use_yawn_veto"])
316
  cfg["geo_face_weight"] = float(cfg["geo_face_weight"])
317
  cfg["geo_eye_weight"] = float(cfg["geo_eye_weight"])
318
  cfg["mar_yawn_threshold"] = float(cfg["mar_yawn_threshold"])
 
 
319
 
320
  print(f"[HYBRID] Loaded config: {resolved}")
321
  return cfg, resolved
322
 
323
 
324
  class MLPPipeline:
325
+ def __init__(self, model_dir=None, detector=None, threshold=0.5):
326
  if model_dir is None:
327
+ # Check primary location
328
+ model_dir = os.path.join(_PROJECT_ROOT, "MLP", "models")
329
+ if not os.path.exists(model_dir):
330
+ model_dir = os.path.join(_PROJECT_ROOT, "checkpoints")
331
+
332
+ mlp_path, scaler_path, meta_path = _latest_model_artifacts(model_dir)
333
+ if mlp_path is None:
334
+ raise FileNotFoundError(f"No MLP artifacts in {model_dir}")
335
+ self._mlp = joblib.load(mlp_path)
336
+ self._scaler = joblib.load(scaler_path)
337
+ meta = np.load(meta_path, allow_pickle=True)
338
+ self._feature_names = list(meta["feature_names"])
339
+
340
+ norm_feats = list(meta["norm_features"]) if "norm_features" in meta else []
341
+ self._engine = _RuntimeFeatureEngine(FEATURE_NAMES, norm_features=norm_feats)
342
+ ext_names = self._engine.extended_names
343
+ self._indices = [ext_names.index(n) for n in self._feature_names]
344
 
345
  self._detector = detector or FaceMeshDetector()
346
  self._owns_detector = detector is None
 
350
  self._temporal = TemporalTracker()
351
  self._smoother = _OutputSmoother()
352
  self._threshold = threshold
353
+ print(f"[MLP] Loaded {mlp_path} | {len(self._feature_names)} features | threshold={threshold}")
354
 
355
  def process_frame(self, bgr_frame):
356
  landmarks = self._detector.process(bgr_frame)
 
382
  out["s_eye"] = float(vec[_FEAT_IDX["s_eye"]])
383
  out["mar"] = float(vec[_FEAT_IDX["mar"]])
384
 
385
+ ext_vec = self._engine.transform(vec)
386
+ X = ext_vec[self._indices].reshape(1, -1).astype(np.float64)
387
+ X_sc = self._scaler.transform(X)
388
+ if hasattr(self._mlp, "predict_proba"):
389
+ mlp_prob = float(self._mlp.predict_proba(X_sc)[0, 1])
390
+ else:
391
+ mlp_prob = float(self._mlp.predict(X_sc)[0] == 1)
392
  out["mlp_prob"] = float(np.clip(mlp_prob, 0.0, 1.0))
393
  out["raw_score"] = self._smoother.update(out["mlp_prob"], True)
394
  out["is_focused"] = out["raw_score"] >= self._threshold
 
409
  self.close()
410
 
411
 
 
 
 
 
412
  class HybridFocusPipeline:
413
  def __init__(
414
  self,
415
  model_dir=None,
416
  config_path: str | None = None,
417
+ eye_model_path: str | None = None,
418
+ eye_backend: str = "yolo",
419
+ eye_blend: float = 0.5,
420
  max_angle: float = 22.0,
421
  detector=None,
422
  ):
423
  if model_dir is None:
424
  model_dir = os.path.join(_PROJECT_ROOT, "checkpoints")
425
+ mlp_path, scaler_path, meta_path = _latest_model_artifacts(model_dir)
426
+ if mlp_path is None:
427
+ raise FileNotFoundError(f"No MLP artifacts in {model_dir}")
428
+
429
+ self._mlp = joblib.load(mlp_path)
430
+ self._scaler = joblib.load(scaler_path)
431
+ meta = np.load(meta_path, allow_pickle=True)
432
+ self._feature_names = list(meta["feature_names"])
433
+
434
+ norm_feats = list(meta["norm_features"]) if "norm_features" in meta else []
435
+ self._engine = _RuntimeFeatureEngine(FEATURE_NAMES, norm_features=norm_feats)
436
+ ext_names = self._engine.extended_names
437
+ self._indices = [ext_names.index(n) for n in self._feature_names]
438
+
439
  self._cfg, self._cfg_path = _load_hybrid_config(model_dir=model_dir, config_path=config_path)
 
440
 
441
  self._detector = detector or FaceMeshDetector()
442
  self._owns_detector = detector is None
443
  self._head_pose = HeadPoseEstimator(max_angle=max_angle)
444
  self._eye_scorer = EyeBehaviourScorer()
445
  self._temporal = TemporalTracker()
446
+ self._eye_blend = eye_blend
447
+ self.eye_classifier = load_eye_classifier(
448
+ path=eye_model_path if eye_model_path and os.path.exists(eye_model_path) else None,
449
+ backend=eye_backend,
450
+ device="cpu",
451
+ )
452
+ self._has_eye_model = not isinstance(self.eye_classifier, GeometricOnlyClassifier)
453
+ if self._has_eye_model:
454
+ print(f"[HYBRID] Eye model: {self.eye_classifier.name}")
455
+
456
  self.head_pose = self._head_pose
457
  self._smoother = _OutputSmoother()
458
 
459
+ print(
460
+ f"[HYBRID] Loaded {mlp_path} | {len(self._feature_names)} features | "
461
+ f"w_mlp={self._cfg['w_mlp']:.2f}, w_geo={self._cfg['w_geo']:.2f}, "
462
+ f"threshold={self._cfg['threshold']:.2f}"
463
+ )
464
+
465
+ @property
466
+ def has_eye_model(self) -> bool:
467
+ return self._has_eye_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
  @property
470
  def config(self) -> dict:
 
502
  out["yaw"], out["pitch"], out["roll"] = angles
503
 
504
  out["s_face"] = self._head_pose.score(landmarks, w, h)
505
+ s_eye_geo = self._eye_scorer.score(landmarks)
506
+ if self._has_eye_model:
507
+ left_crop, right_crop, left_bbox, right_bbox = extract_eye_crops(bgr_frame, landmarks)
508
+ out["left_bbox"] = left_bbox
509
+ out["right_bbox"] = right_bbox
510
+ s_eye_model = self.eye_classifier.predict_score([left_crop, right_crop])
511
+ out["s_eye"] = (1.0 - self._eye_blend) * s_eye_geo + self._eye_blend * s_eye_model
512
+ else:
513
+ out["s_eye"] = s_eye_geo
514
 
515
  geo_score = (
516
  self._cfg["geo_face_weight"] * out["s_face"] +
 
532
  }
533
  vec = extract_features(landmarks, w, h, self._head_pose, self._eye_scorer, self._temporal, _pre=pre)
534
  vec = _clip_features(vec)
535
+ ext_vec = self._engine.transform(vec)
536
+ X = ext_vec[self._indices].reshape(1, -1).astype(np.float64)
537
+ X_sc = self._scaler.transform(X)
538
+ if hasattr(self._mlp, "predict_proba"):
539
+ mlp_prob = float(self._mlp.predict_proba(X_sc)[0, 1])
 
 
 
 
 
 
540
  else:
541
+ mlp_prob = float(self._mlp.predict(X_sc)[0] == 1)
542
+ out["mlp_prob"] = float(np.clip(mlp_prob, 0.0, 1.0))
 
 
 
 
 
 
 
 
 
 
 
543
 
544
+ focus_score = self._cfg["w_mlp"] * out["mlp_prob"] + self._cfg["w_geo"] * out["geo_score"]
545
  out["focus_score"] = self._smoother.update(float(np.clip(focus_score, 0.0, 1.0)), True)
546
  out["raw_score"] = out["focus_score"]
547
  out["is_focused"] = out["focus_score"] >= self._cfg["threshold"]
 
563
 
564
 
565
  class XGBoostPipeline:
566
+ """Real-time XGBoost inference pipeline using the same feature extraction as MLPPipeline."""
567
+
568
+ # Same 10 features used during training (data_preparation.prepare_dataset.SELECTED_FEATURES)
569
  SELECTED = [
570
  'head_deviation', 's_face', 's_eye', 'h_gaze', 'pitch',
571
  'ear_left', 'ear_avg', 'ear_right', 'gaze_offset', 'perclos',
572
  ]
573
 
574
+ def __init__(self, model_path=None, threshold=0.5):
575
  from xgboost import XGBClassifier
576
 
577
  if model_path is None:
578
+ model_path = os.path.join(_PROJECT_ROOT, "models", "xgboost", "checkpoints", "face_orientation_best.json")
579
+ if not os.path.isfile(model_path):
580
+ # Fallback to legacy path
581
+ model_path = os.path.join(_PROJECT_ROOT, "checkpoints", "xgboost_face_orientation_best.json")
582
  if not os.path.isfile(model_path):
583
  raise FileNotFoundError(f"No XGBoost checkpoint at {model_path}")
584