Spaces:
Running
Running
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Any | |
| import base64 | |
| import cv2 | |
| import numpy as np | |
| import aiosqlite | |
| import json | |
| from datetime import datetime, timedelta | |
| import math | |
| import os | |
| from pathlib import Path | |
| from typing import Callable | |
| import asyncio | |
| import concurrent.futures | |
| import threading | |
| from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack | |
| from av import VideoFrame | |
| from mediapipe.tasks.python.vision import FaceLandmarksConnections | |
| from ui.pipeline import ( | |
| FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline, | |
| L2CSPipeline, is_l2cs_weights_available, | |
| ) | |
| from models.face_mesh import FaceMeshDetector | |
| # ================ FACE MESH DRAWING (server-side, for WebRTC) ================ | |
| _FONT = cv2.FONT_HERSHEY_SIMPLEX | |
| _CYAN = (255, 255, 0) | |
| _GREEN = (0, 255, 0) | |
| _MAGENTA = (255, 0, 255) | |
| _ORANGE = (0, 165, 255) | |
| _RED = (0, 0, 255) | |
| _WHITE = (255, 255, 255) | |
| _LIGHT_GREEN = (144, 238, 144) | |
| _TESSELATION_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION] | |
| _CONTOUR_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS] | |
| _LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46] | |
| _RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276] | |
| _NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2] | |
| _LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61] | |
| _LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78] | |
| _LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145] | |
| _RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380] | |
| def _lm_px(lm, idx, w, h): | |
| return (int(lm[idx, 0] * w), int(lm[idx, 1] * h)) | |
| def _draw_polyline(frame, lm, indices, w, h, color, thickness): | |
| for i in range(len(indices) - 1): | |
| cv2.line(frame, _lm_px(lm, indices[i], w, h), _lm_px(lm, indices[i + 1], w, h), color, thickness, cv2.LINE_AA) | |
| def _draw_face_mesh(frame, lm, w, h): | |
| """Draw tessellation, contours, eyebrows, nose, lips, eyes, irises, gaze lines.""" | |
| # Tessellation (gray triangular grid, semi-transparent) | |
| overlay = frame.copy() | |
| for s, e in _TESSELATION_CONNS: | |
| cv2.line(overlay, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), (200, 200, 200), 1, cv2.LINE_AA) | |
| cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame) | |
| # Contours | |
| for s, e in _CONTOUR_CONNS: | |
| cv2.line(frame, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), _CYAN, 1, cv2.LINE_AA) | |
| # Eyebrows | |
| _draw_polyline(frame, lm, _LEFT_EYEBROW, w, h, _LIGHT_GREEN, 2) | |
| _draw_polyline(frame, lm, _RIGHT_EYEBROW, w, h, _LIGHT_GREEN, 2) | |
| # Nose | |
| _draw_polyline(frame, lm, _NOSE_BRIDGE, w, h, _ORANGE, 1) | |
| # Lips | |
| _draw_polyline(frame, lm, _LIPS_OUTER, w, h, _MAGENTA, 1) | |
| _draw_polyline(frame, lm, _LIPS_INNER, w, h, (200, 0, 200), 1) | |
| # Eyes | |
| left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32) | |
| cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA) | |
| right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32) | |
| cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA) | |
| # EAR key points | |
| for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]: | |
| for idx in indices: | |
| cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA) | |
| # Irises + gaze lines | |
| for iris_idx, eye_inner, eye_outer in [ | |
| (FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33), | |
| (FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263), | |
| ]: | |
| iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32) | |
| center = iris_pts[0] | |
| if len(iris_pts) >= 5: | |
| radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)] | |
| radius = max(int(np.mean(radii)), 2) | |
| cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA) | |
| cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA) | |
| eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w) | |
| eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h) | |
| dx, dy = center[0] - eye_cx, center[1] - eye_cy | |
| cv2.line(frame, tuple(center), (int(center[0] + dx * 3), int(center[1] + dy * 3)), _RED, 1, cv2.LINE_AA) | |
| def _draw_hud(frame, result, model_name): | |
| """Draw status bar and detail overlay matching live_demo.py.""" | |
| h, w = frame.shape[:2] | |
| is_focused = result["is_focused"] | |
| status = "FOCUSED" if is_focused else "NOT FOCUSED" | |
| color = _GREEN if is_focused else _RED | |
| # Top bar | |
| cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1) | |
| cv2.putText(frame, status, (10, 28), _FONT, 0.8, color, 2, cv2.LINE_AA) | |
| cv2.putText(frame, model_name.upper(), (w - 150, 28), _FONT, 0.45, _WHITE, 1, cv2.LINE_AA) | |
| # Detail line | |
| conf = result.get("mlp_prob", result.get("raw_score", 0.0)) | |
| mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else "" | |
| sf = result.get("s_face", 0) | |
| se = result.get("s_eye", 0) | |
| detail = f"conf:{conf:.2f} S_face:{sf:.2f} S_eye:{se:.2f}{mar_s}" | |
| cv2.putText(frame, detail, (10, 48), _FONT, 0.4, _WHITE, 1, cv2.LINE_AA) | |
| # Head pose (top right) | |
| if result.get("yaw") is not None: | |
| cv2.putText(frame, f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}", | |
| (w - 280, 48), _FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA) | |
| # Yawn indicator | |
| if result.get("is_yawning"): | |
| cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA) | |
| # Landmark indices used for face mesh drawing on client (union of all groups). | |
| # Sending only these instead of all 478 saves ~60% of the landmarks payload. | |
| _MESH_INDICES = sorted(set( | |
| [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] # face oval | |
| + [33,7,163,144,145,153,154,155,133,173,157,158,159,160,161,246] # left eye | |
| + [362,382,381,380,374,373,390,249,263,466,388,387,386,385,384,398] # right eye | |
| + [468,469,470,471,472, 473,474,475,476,477] # irises | |
| + [70,63,105,66,107,55,65,52,53,46] # left eyebrow | |
| + [300,293,334,296,336,285,295,282,283,276] # right eyebrow | |
| + [6,197,195,5,4,1,19,94,2] # nose bridge | |
| + [61,146,91,181,84,17,314,405,321,375,291,409,270,269,267,0,37,39,40,185] # lips outer | |
| + [78,95,88,178,87,14,317,402,318,324,308,415,310,311,312,13,82,81,80,191] # lips inner | |
| + [33,160,158,133,153,145] # left EAR key points | |
| + [362,385,387,263,373,380] # right EAR key points | |
| )) | |
| # Build a lookup: original_index -> position in sparse array, so client can reconstruct. | |
| _MESH_INDEX_SET = set(_MESH_INDICES) | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Focus Guard API") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables | |
| db_path = "focus_guard.db" | |
| pcs = set() | |
| _cached_model_name = "mlp" # in-memory cache, updated via /api/settings | |
| _l2cs_boost_enabled = False # when True, L2CS runs alongside the base model | |
| async def _wait_for_ice_gathering(pc: RTCPeerConnection): | |
| if pc.iceGatheringState == "complete": | |
| return | |
| done = asyncio.Event() | |
| def _on_state_change(): | |
| if pc.iceGatheringState == "complete": | |
| done.set() | |
| await done.wait() | |
| # ================ DATABASE MODELS ================ | |
| async def init_database(): | |
| """Initialize SQLite database with required tables""" | |
| async with aiosqlite.connect(db_path) as db: | |
| # FocusSessions table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS focus_sessions ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| start_time TIMESTAMP NOT NULL, | |
| end_time TIMESTAMP, | |
| duration_seconds INTEGER DEFAULT 0, | |
| focus_score REAL DEFAULT 0.0, | |
| total_frames INTEGER DEFAULT 0, | |
| focused_frames INTEGER DEFAULT 0, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| # FocusEvents table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS focus_events ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id INTEGER NOT NULL, | |
| timestamp TIMESTAMP NOT NULL, | |
| is_focused BOOLEAN NOT NULL, | |
| confidence REAL NOT NULL, | |
| detection_data TEXT, | |
| FOREIGN KEY (session_id) REFERENCES focus_sessions (id) | |
| ) | |
| """) | |
| # UserSettings table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS user_settings ( | |
| id INTEGER PRIMARY KEY CHECK (id = 1), | |
| sensitivity INTEGER DEFAULT 6, | |
| notification_enabled BOOLEAN DEFAULT 1, | |
| notification_threshold INTEGER DEFAULT 30, | |
| frame_rate INTEGER DEFAULT 30, | |
| model_name TEXT DEFAULT 'mlp' | |
| ) | |
| """) | |
| # Insert default settings if not exists | |
| await db.execute(""" | |
| INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name) | |
| VALUES (1, 6, 1, 30, 30, 'mlp') | |
| """) | |
| await db.commit() | |
| # ================ PYDANTIC MODELS ================ | |
| class SessionCreate(BaseModel): | |
| pass | |
| class SessionEnd(BaseModel): | |
| session_id: int | |
| class SettingsUpdate(BaseModel): | |
| sensitivity: Optional[int] = None | |
| notification_enabled: Optional[bool] = None | |
| notification_threshold: Optional[int] = None | |
| frame_rate: Optional[int] = None | |
| model_name: Optional[str] = None | |
| l2cs_boost: Optional[bool] = None | |
| class VideoTransformTrack(VideoStreamTrack): | |
| def __init__(self, track, session_id: int, get_channel: Callable[[], Any]): | |
| super().__init__() | |
| self.track = track | |
| self.session_id = session_id | |
| self.get_channel = get_channel | |
| self.last_inference_time = 0 | |
| self.min_inference_interval = 1 / 60 | |
| self.last_frame = None | |
| async def recv(self): | |
| frame = await self.track.recv() | |
| img = frame.to_ndarray(format="bgr24") | |
| if img is None: | |
| return frame | |
| # Normalize size for inference/drawing | |
| img = cv2.resize(img, (640, 480)) | |
| now = datetime.now().timestamp() | |
| do_infer = (now - self.last_inference_time) >= self.min_inference_interval | |
| if do_infer: | |
| self.last_inference_time = now | |
| model_name = _cached_model_name | |
| if model_name == "l2cs" and pipelines.get("l2cs") is None: | |
| _ensure_l2cs() | |
| if model_name not in pipelines or pipelines.get(model_name) is None: | |
| model_name = 'mlp' | |
| active_pipeline = pipelines.get(model_name) | |
| if active_pipeline is not None: | |
| loop = asyncio.get_event_loop() | |
| out = await loop.run_in_executor( | |
| _inference_executor, | |
| _process_frame_safe, | |
| active_pipeline, | |
| img, | |
| model_name, | |
| ) | |
| is_focused = out["is_focused"] | |
| confidence = out.get("mlp_prob", out.get("raw_score", 0.0)) | |
| metadata = {"s_face": out.get("s_face", 0.0), "s_eye": out.get("s_eye", 0.0), "mar": out.get("mar", 0.0), "model": model_name} | |
| # Draw face mesh + HUD on the video frame | |
| h_f, w_f = img.shape[:2] | |
| lm = out.get("landmarks") | |
| if lm is not None: | |
| _draw_face_mesh(img, lm, w_f, h_f) | |
| _draw_hud(img, out, model_name) | |
| else: | |
| is_focused = False | |
| confidence = 0.0 | |
| metadata = {"model": model_name} | |
| cv2.rectangle(img, (0, 0), (img.shape[1], 55), (0, 0, 0), -1) | |
| cv2.putText(img, "NO MODEL", (10, 28), _FONT, 0.8, _RED, 2, cv2.LINE_AA) | |
| if self.session_id: | |
| await store_focus_event(self.session_id, is_focused, confidence, metadata) | |
| channel = self.get_channel() | |
| if channel and channel.readyState == "open": | |
| try: | |
| channel.send(json.dumps({"type": "detection", "focused": is_focused, "confidence": round(confidence, 3), "detections": detections})) | |
| except Exception: | |
| pass | |
| self.last_frame = img | |
| elif self.last_frame is not None: | |
| img = self.last_frame | |
| new_frame = VideoFrame.from_ndarray(img, format="bgr24") | |
| new_frame.pts = frame.pts | |
| new_frame.time_base = frame.time_base | |
| return new_frame | |
| # ================ DATABASE OPERATIONS ================ | |
| async def create_session(): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute( | |
| "INSERT INTO focus_sessions (start_time) VALUES (?)", | |
| (datetime.now().isoformat(),) | |
| ) | |
| await db.commit() | |
| return cursor.lastrowid | |
| async def end_session(session_id: int): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute( | |
| "SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?", | |
| (session_id,) | |
| ) | |
| row = await cursor.fetchone() | |
| if not row: | |
| return None | |
| start_time_str, total_frames, focused_frames = row | |
| start_time = datetime.fromisoformat(start_time_str) | |
| end_time = datetime.now() | |
| duration = (end_time - start_time).total_seconds() | |
| focus_score = focused_frames / total_frames if total_frames > 0 else 0.0 | |
| await db.execute(""" | |
| UPDATE focus_sessions | |
| SET end_time = ?, duration_seconds = ?, focus_score = ? | |
| WHERE id = ? | |
| """, (end_time.isoformat(), int(duration), focus_score, session_id)) | |
| await db.commit() | |
| return { | |
| 'session_id': session_id, | |
| 'start_time': start_time_str, | |
| 'end_time': end_time.isoformat(), | |
| 'duration_seconds': int(duration), | |
| 'focus_score': round(focus_score, 3), | |
| 'total_frames': total_frames, | |
| 'focused_frames': focused_frames | |
| } | |
| async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict): | |
| async with aiosqlite.connect(db_path) as db: | |
| await db.execute(""" | |
| INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata))) | |
| await db.execute(""" | |
| UPDATE focus_sessions | |
| SET total_frames = total_frames + 1, | |
| focused_frames = focused_frames + ? | |
| WHERE id = ? | |
| """, (1 if is_focused else 0, session_id)) | |
| await db.commit() | |
| class _EventBuffer: | |
| """Buffer focus events in memory and flush to DB in batches to avoid per-frame DB writes.""" | |
| def __init__(self, flush_interval: float = 2.0): | |
| self._buf: list = [] | |
| self._lock = asyncio.Lock() | |
| self._flush_interval = flush_interval | |
| self._task: asyncio.Task | None = None | |
| self._total_frames = 0 | |
| self._focused_frames = 0 | |
| def start(self): | |
| if self._task is None: | |
| self._task = asyncio.create_task(self._flush_loop()) | |
| async def stop(self): | |
| if self._task: | |
| self._task.cancel() | |
| try: | |
| await self._task | |
| except asyncio.CancelledError: | |
| pass | |
| self._task = None | |
| await self._flush() | |
| def add(self, session_id: int, is_focused: bool, confidence: float, metadata: dict): | |
| self._buf.append((session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata))) | |
| self._total_frames += 1 | |
| if is_focused: | |
| self._focused_frames += 1 | |
| async def _flush_loop(self): | |
| while True: | |
| await asyncio.sleep(self._flush_interval) | |
| await self._flush() | |
| async def _flush(self): | |
| async with self._lock: | |
| if not self._buf: | |
| return | |
| batch = self._buf[:] | |
| total = self._total_frames | |
| focused = self._focused_frames | |
| self._buf.clear() | |
| self._total_frames = 0 | |
| self._focused_frames = 0 | |
| if not batch: | |
| return | |
| session_id = batch[0][0] | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| await db.executemany(""" | |
| INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, batch) | |
| await db.execute(""" | |
| UPDATE focus_sessions | |
| SET total_frames = total_frames + ?, | |
| focused_frames = focused_frames + ? | |
| WHERE id = ? | |
| """, (total, focused, session_id)) | |
| await db.commit() | |
| except Exception as e: | |
| print(f"[DB] Flush error: {e}") | |
| # ================ STARTUP/SHUTDOWN ================ | |
| pipelines = { | |
| "geometric": None, | |
| "mlp": None, | |
| "hybrid": None, | |
| "xgboost": None, | |
| "l2cs": None, | |
| } | |
| # Thread pool for CPU-bound inference so the event loop stays responsive. | |
| _inference_executor = concurrent.futures.ThreadPoolExecutor( | |
| max_workers=4, | |
| thread_name_prefix="inference", | |
| ) | |
| # One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when | |
| # multiple frames are processed in parallel by the thread pool. | |
| _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost", "l2cs")} | |
| _l2cs_load_lock = threading.Lock() | |
| _l2cs_error: str | None = None | |
| def _ensure_l2cs(): | |
| # lazy-load L2CS on first use, double-checked locking | |
| global _l2cs_error | |
| if pipelines["l2cs"] is not None: | |
| return True | |
| with _l2cs_load_lock: | |
| if pipelines["l2cs"] is not None: | |
| return True | |
| if not is_l2cs_weights_available(): | |
| _l2cs_error = "Weights not found" | |
| return False | |
| try: | |
| pipelines["l2cs"] = L2CSPipeline() | |
| _l2cs_error = None | |
| print("[OK] L2CSPipeline lazy-loaded") | |
| return True | |
| except Exception as e: | |
| _l2cs_error = str(e) | |
| print(f"[ERR] L2CS lazy-load failed: {e}") | |
| return False | |
| def _process_frame_safe(pipeline, frame, model_name): | |
| with _pipeline_locks[model_name]: | |
| return pipeline.process_frame(frame) | |
| _BOOST_BASE_W = 0.35 | |
| _BOOST_L2CS_W = 0.65 | |
| _BOOST_VETO = 0.38 # L2CS below this -> forced not-focused | |
| def _process_frame_with_l2cs_boost(base_pipeline, frame, base_model_name): | |
| # run base model | |
| with _pipeline_locks[base_model_name]: | |
| base_out = base_pipeline.process_frame(frame) | |
| l2cs_pipe = pipelines.get("l2cs") | |
| if l2cs_pipe is None: | |
| base_out["boost_active"] = False | |
| return base_out | |
| # run L2CS | |
| with _pipeline_locks["l2cs"]: | |
| l2cs_out = l2cs_pipe.process_frame(frame) | |
| base_score = base_out.get("mlp_prob", base_out.get("raw_score", 0.0)) | |
| l2cs_score = l2cs_out.get("raw_score", 0.0) | |
| # veto: gaze clearly off-screen overrides base model | |
| if l2cs_score < _BOOST_VETO: | |
| fused_score = l2cs_score * 0.8 | |
| is_focused = False | |
| else: | |
| fused_score = _BOOST_BASE_W * base_score + _BOOST_L2CS_W * l2cs_score | |
| is_focused = fused_score >= 0.52 | |
| base_out["raw_score"] = fused_score | |
| base_out["is_focused"] = is_focused | |
| base_out["boost_active"] = True | |
| base_out["base_score"] = round(base_score, 3) | |
| base_out["l2cs_score"] = round(l2cs_score, 3) | |
| if l2cs_out.get("gaze_yaw") is not None: | |
| base_out["gaze_yaw"] = l2cs_out["gaze_yaw"] | |
| base_out["gaze_pitch"] = l2cs_out["gaze_pitch"] | |
| return base_out | |
| async def startup_event(): | |
| global pipelines, _cached_model_name | |
| print(" Starting Focus Guard API...") | |
| await init_database() | |
| # Load cached model name from DB | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| if row: | |
| _cached_model_name = row[0] | |
| print("[OK] Database initialized") | |
| try: | |
| pipelines["geometric"] = FaceMeshPipeline() | |
| print("[OK] FaceMeshPipeline (geometric) loaded") | |
| except Exception as e: | |
| print(f"[WARN] FaceMeshPipeline unavailable: {e}") | |
| try: | |
| pipelines["mlp"] = MLPPipeline() | |
| print("[OK] MLPPipeline loaded") | |
| except Exception as e: | |
| print(f"[ERR] Failed to load MLPPipeline: {e}") | |
| try: | |
| pipelines["hybrid"] = HybridFocusPipeline() | |
| print("[OK] HybridFocusPipeline loaded") | |
| except Exception as e: | |
| print(f"[WARN] HybridFocusPipeline unavailable: {e}") | |
| try: | |
| pipelines["xgboost"] = XGBoostPipeline() | |
| print("[OK] XGBoostPipeline loaded") | |
| except Exception as e: | |
| print(f"[ERR] Failed to load XGBoostPipeline: {e}") | |
| if is_l2cs_weights_available(): | |
| print("[OK] L2CS weights found — pipeline will be lazy-loaded on first use") | |
| else: | |
| print("[WARN] L2CS weights not found — l2cs model unavailable") | |
| async def shutdown_event(): | |
| _inference_executor.shutdown(wait=False) | |
| print(" Shutting down Focus Guard API...") | |
| # ================ WEBRTC SIGNALING ================ | |
| async def webrtc_offer(offer: dict): | |
| try: | |
| print(f"Received WebRTC offer") | |
| pc = RTCPeerConnection() | |
| pcs.add(pc) | |
| session_id = await create_session() | |
| print(f"Created session: {session_id}") | |
| channel_ref = {"channel": None} | |
| def on_datachannel(channel): | |
| print(f"Data channel opened") | |
| channel_ref["channel"] = channel | |
| def on_track(track): | |
| print(f"Received track: {track.kind}") | |
| if track.kind == "video": | |
| local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"]) | |
| pc.addTrack(local_track) | |
| print(f"Video track added") | |
| async def on_ended(): | |
| print(f"Track ended") | |
| async def on_connectionstatechange(): | |
| print(f"Connection state changed: {pc.connectionState}") | |
| if pc.connectionState in ("failed", "closed", "disconnected"): | |
| try: | |
| await end_session(session_id) | |
| except Exception as e: | |
| print(f"⚠Error ending session: {e}") | |
| pcs.discard(pc) | |
| await pc.close() | |
| await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"])) | |
| print(f"Remote description set") | |
| answer = await pc.createAnswer() | |
| await pc.setLocalDescription(answer) | |
| print(f"Answer created") | |
| await _wait_for_ice_gathering(pc) | |
| print(f"ICE gathering complete") | |
| return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id} | |
| except Exception as e: | |
| print(f"WebRTC offer error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}") | |
| # ================ WEBSOCKET ================ | |
| async def websocket_endpoint(websocket: WebSocket): | |
| from models.gaze_calibration import GazeCalibration | |
| from models.gaze_eye_fusion import GazeEyeFusion | |
| await websocket.accept() | |
| session_id = None | |
| frame_count = 0 | |
| running = True | |
| event_buffer = _EventBuffer(flush_interval=2.0) | |
| # Calibration state (per-connection) | |
| _cal: dict = {"cal": None, "collecting": False, "fusion": None} | |
| # Latest frame slot — only the most recent frame is kept, older ones are dropped. | |
| _slot = {"frame": None} | |
| _frame_ready = asyncio.Event() | |
| async def _receive_loop(): | |
| """Receive messages as fast as possible. Binary = frame, text = control.""" | |
| nonlocal session_id, running | |
| try: | |
| while running: | |
| msg = await websocket.receive() | |
| msg_type = msg.get("type", "") | |
| if msg_type == "websocket.disconnect": | |
| running = False | |
| _frame_ready.set() | |
| return | |
| # Binary message → JPEG frame (fast path, no base64) | |
| raw_bytes = msg.get("bytes") | |
| if raw_bytes is not None and len(raw_bytes) > 0: | |
| _slot["frame"] = raw_bytes | |
| _frame_ready.set() | |
| continue | |
| # Text message → JSON control command (or legacy base64 frame) | |
| text = msg.get("text") | |
| if not text: | |
| continue | |
| data = json.loads(text) | |
| if data["type"] == "frame": | |
| _slot["frame"] = base64.b64decode(data["image"]) | |
| _frame_ready.set() | |
| elif data["type"] == "start_session": | |
| session_id = await create_session() | |
| event_buffer.start() | |
| for p in pipelines.values(): | |
| if p is not None and hasattr(p, "reset_session"): | |
| p.reset_session() | |
| await websocket.send_json({"type": "session_started", "session_id": session_id}) | |
| elif data["type"] == "end_session": | |
| if session_id: | |
| await event_buffer.stop() | |
| summary = await end_session(session_id) | |
| if summary: | |
| await websocket.send_json({"type": "session_ended", "summary": summary}) | |
| session_id = None | |
| # ---- Calibration commands ---- | |
| elif data["type"] == "calibration_start": | |
| loop = asyncio.get_event_loop() | |
| await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| _cal["cal"] = GazeCalibration() | |
| _cal["collecting"] = True | |
| _cal["fusion"] = None | |
| cal = _cal["cal"] | |
| await websocket.send_json({ | |
| "type": "calibration_started", | |
| "num_points": cal.num_points, | |
| "target": list(cal.current_target), | |
| "index": cal.current_index, | |
| }) | |
| elif data["type"] == "calibration_next": | |
| cal = _cal.get("cal") | |
| if cal is not None: | |
| more = cal.advance() | |
| if more: | |
| await websocket.send_json({ | |
| "type": "calibration_point", | |
| "target": list(cal.current_target), | |
| "index": cal.current_index, | |
| }) | |
| else: | |
| _cal["collecting"] = False | |
| ok = cal.fit() | |
| if ok: | |
| _cal["fusion"] = GazeEyeFusion(cal) | |
| await websocket.send_json({"type": "calibration_done", "success": True}) | |
| else: | |
| await websocket.send_json({"type": "calibration_done", "success": False, "error": "Not enough samples"}) | |
| elif data["type"] == "calibration_cancel": | |
| _cal["cal"] = None | |
| _cal["collecting"] = False | |
| _cal["fusion"] = None | |
| await websocket.send_json({"type": "calibration_cancelled"}) | |
| except WebSocketDisconnect: | |
| running = False | |
| _frame_ready.set() | |
| except Exception as e: | |
| print(f"[WS] receive error: {e}") | |
| running = False | |
| _frame_ready.set() | |
| async def _process_loop(): | |
| """Process only the latest frame, dropping stale ones.""" | |
| nonlocal frame_count, running | |
| loop = asyncio.get_event_loop() | |
| while running: | |
| await _frame_ready.wait() | |
| _frame_ready.clear() | |
| if not running: | |
| return | |
| raw = _slot["frame"] | |
| _slot["frame"] = None | |
| if raw is None: | |
| continue | |
| try: | |
| nparr = np.frombuffer(raw, np.uint8) | |
| frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if frame is None: | |
| continue | |
| frame = cv2.resize(frame, (640, 480)) | |
| # During calibration collection, always use L2CS | |
| collecting = _cal.get("collecting", False) | |
| if collecting: | |
| if pipelines.get("l2cs") is None: | |
| await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| use_model = "l2cs" if pipelines.get("l2cs") is not None else _cached_model_name | |
| else: | |
| use_model = _cached_model_name | |
| model_name = use_model | |
| if model_name == "l2cs" and pipelines.get("l2cs") is None: | |
| await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| if model_name not in pipelines or pipelines.get(model_name) is None: | |
| model_name = "mlp" | |
| active_pipeline = pipelines.get(model_name) | |
| # L2CS boost: run L2CS alongside base model | |
| use_boost = ( | |
| _l2cs_boost_enabled | |
| and model_name != "l2cs" | |
| and pipelines.get("l2cs") is not None | |
| and not collecting | |
| ) | |
| landmarks_list = None | |
| out = None | |
| if active_pipeline is not None: | |
| if use_boost: | |
| out = await loop.run_in_executor( | |
| _inference_executor, | |
| _process_frame_with_l2cs_boost, | |
| active_pipeline, | |
| frame, | |
| model_name, | |
| ) | |
| else: | |
| out = await loop.run_in_executor( | |
| _inference_executor, | |
| _process_frame_safe, | |
| active_pipeline, | |
| frame, | |
| model_name, | |
| ) | |
| is_focused = out["is_focused"] | |
| confidence = out.get("mlp_prob", out.get("raw_score", 0.0)) | |
| lm = out.get("landmarks") | |
| if lm is not None: | |
| landmarks_list = [ | |
| [round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)] | |
| for i in range(lm.shape[0]) | |
| ] | |
| # Calibration sample collection (L2CS gaze angles) | |
| if collecting and _cal.get("cal") is not None: | |
| pipe_yaw = out.get("gaze_yaw") | |
| pipe_pitch = out.get("gaze_pitch") | |
| if pipe_yaw is not None and pipe_pitch is not None: | |
| _cal["cal"].collect_sample(pipe_yaw, pipe_pitch) | |
| # Gaze fusion (when L2CS active + calibration fitted) | |
| fusion = _cal.get("fusion") | |
| if ( | |
| fusion is not None | |
| and model_name == "l2cs" | |
| and out.get("gaze_yaw") is not None | |
| ): | |
| fuse = fusion.update( | |
| out["gaze_yaw"], out["gaze_pitch"], lm | |
| ) | |
| is_focused = fuse["focused"] | |
| confidence = fuse["focus_score"] | |
| if session_id: | |
| metadata = { | |
| "s_face": out.get("s_face", 0.0), | |
| "s_eye": out.get("s_eye", 0.0), | |
| "mar": out.get("mar", 0.0), | |
| "model": model_name, | |
| } | |
| event_buffer.add(session_id, is_focused, confidence, metadata) | |
| else: | |
| is_focused = False | |
| confidence = 0.0 | |
| resp = { | |
| "type": "detection", | |
| "focused": is_focused, | |
| "confidence": round(confidence, 3), | |
| "model": model_name, | |
| "fc": frame_count, | |
| } | |
| if out is not None: | |
| if out.get("yaw") is not None: | |
| resp["yaw"] = round(out["yaw"], 1) | |
| resp["pitch"] = round(out["pitch"], 1) | |
| resp["roll"] = round(out["roll"], 1) | |
| if out.get("mar") is not None: | |
| resp["mar"] = round(out["mar"], 3) | |
| resp["sf"] = round(out.get("s_face", 0), 3) | |
| resp["se"] = round(out.get("s_eye", 0), 3) | |
| # Gaze fusion fields (L2CS standalone or boost mode) | |
| fusion = _cal.get("fusion") | |
| has_gaze = out.get("gaze_yaw") is not None | |
| if fusion is not None and has_gaze and (model_name == "l2cs" or use_boost): | |
| fuse = fusion.update(out["gaze_yaw"], out["gaze_pitch"], out.get("landmarks")) | |
| resp["gaze_x"] = fuse["gaze_x"] | |
| resp["gaze_y"] = fuse["gaze_y"] | |
| resp["on_screen"] = fuse["on_screen"] | |
| if model_name == "l2cs": | |
| resp["focused"] = fuse["focused"] | |
| resp["confidence"] = round(fuse["focus_score"], 3) | |
| if out.get("boost_active"): | |
| resp["boost"] = True | |
| resp["base_score"] = out.get("base_score", 0) | |
| resp["l2cs_score"] = out.get("l2cs_score", 0) | |
| if landmarks_list is not None: | |
| resp["lm"] = landmarks_list | |
| await websocket.send_json(resp) | |
| frame_count += 1 | |
| except Exception as e: | |
| print(f"[WS] process error: {e}") | |
| try: | |
| await asyncio.gather(_receive_loop(), _process_loop()) | |
| except Exception: | |
| pass | |
| finally: | |
| running = False | |
| if session_id: | |
| await event_buffer.stop() | |
| await end_session(session_id) | |
| # ================ API ENDPOINTS ================ | |
| async def api_start_session(): | |
| session_id = await create_session() | |
| return {"session_id": session_id} | |
| async def api_end_session(data: SessionEnd): | |
| summary = await end_session(data.session_id) | |
| if not summary: raise HTTPException(status_code=404, detail="Session not found") | |
| return summary | |
| async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| # NEW: If importing/exporting all, remove limit if special flag or high limit | |
| # For simplicity: if limit is -1, return all | |
| limit_clause = "LIMIT ? OFFSET ?" | |
| params = [] | |
| base_query = "SELECT * FROM focus_sessions" | |
| where_clause = "" | |
| if filter == "today": | |
| date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "week": | |
| date_filter = datetime.now() - timedelta(days=7) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "month": | |
| date_filter = datetime.now() - timedelta(days=30) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "all": | |
| # Just ensure we only get completed sessions or all sessions | |
| where_clause = " WHERE end_time IS NOT NULL" | |
| query = f"{base_query}{where_clause} ORDER BY start_time DESC" | |
| # Handle Limit for Exports | |
| if limit == -1: | |
| # No limit clause for export | |
| pass | |
| else: | |
| query += f" {limit_clause}" | |
| params.extend([limit, offset]) | |
| cursor = await db.execute(query, tuple(params)) | |
| rows = await cursor.fetchall() | |
| return [dict(row) for row in rows] | |
| # --- NEW: Import Endpoint --- | |
| async def import_sessions(sessions: List[dict]): | |
| count = 0 | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| for session in sessions: | |
| # Use .get() to handle potential missing fields from older versions or edits | |
| await db.execute(""" | |
| INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at) | |
| VALUES (?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| session.get('start_time'), | |
| session.get('end_time'), | |
| session.get('duration_seconds', 0), | |
| session.get('focus_score', 0.0), | |
| session.get('total_frames', 0), | |
| session.get('focused_frames', 0), | |
| session.get('created_at', session.get('start_time')) | |
| )) | |
| count += 1 | |
| await db.commit() | |
| return {"status": "success", "count": count} | |
| except Exception as e: | |
| print(f"Import Error: {e}") | |
| return {"status": "error", "message": str(e)} | |
| # --- NEW: Clear History Endpoint --- | |
| async def clear_history(): | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| # Delete events first (foreign key good practice) | |
| await db.execute("DELETE FROM focus_events") | |
| await db.execute("DELETE FROM focus_sessions") | |
| await db.commit() | |
| return {"status": "success", "message": "History cleared"} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_session(session_id: int): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,)) | |
| row = await cursor.fetchone() | |
| if not row: raise HTTPException(status_code=404, detail="Session not found") | |
| session = dict(row) | |
| cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,)) | |
| events = [dict(r) for r in await cursor.fetchall()] | |
| session['events'] = events | |
| return session | |
| async def get_settings(): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| result = dict(row) if row else {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'mlp'} | |
| result['l2cs_boost'] = _l2cs_boost_enabled | |
| return result | |
| async def update_settings(settings: SettingsUpdate): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1") | |
| exists = await cursor.fetchone() | |
| if not exists: | |
| await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)") | |
| await db.commit() | |
| updates = [] | |
| params = [] | |
| if settings.sensitivity is not None: | |
| updates.append("sensitivity = ?") | |
| params.append(max(1, min(10, settings.sensitivity))) | |
| if settings.notification_enabled is not None: | |
| updates.append("notification_enabled = ?") | |
| params.append(settings.notification_enabled) | |
| if settings.notification_threshold is not None: | |
| updates.append("notification_threshold = ?") | |
| params.append(max(5, min(300, settings.notification_threshold))) | |
| if settings.frame_rate is not None: | |
| updates.append("frame_rate = ?") | |
| params.append(max(5, min(60, settings.frame_rate))) | |
| if settings.model_name is not None and settings.model_name in pipelines: | |
| if settings.model_name == "l2cs": | |
| loop = asyncio.get_event_loop() | |
| loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| if not loaded: | |
| raise HTTPException(status_code=400, detail=f"L2CS model unavailable: {_l2cs_error}") | |
| elif pipelines[settings.model_name] is None: | |
| raise HTTPException(status_code=400, detail=f"Model '{settings.model_name}' not loaded") | |
| updates.append("model_name = ?") | |
| params.append(settings.model_name) | |
| global _cached_model_name | |
| _cached_model_name = settings.model_name | |
| if settings.l2cs_boost is not None: | |
| global _l2cs_boost_enabled | |
| if settings.l2cs_boost: | |
| loop = asyncio.get_event_loop() | |
| loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| if not loaded: | |
| raise HTTPException(status_code=400, detail=f"L2CS boost unavailable: {_l2cs_error}") | |
| _l2cs_boost_enabled = settings.l2cs_boost | |
| if updates: | |
| query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1" | |
| await db.execute(query, params) | |
| await db.commit() | |
| return {"status": "success", "updated": len(updates) > 0} | |
| async def get_stats_summary(): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| total_sessions = (await cursor.fetchone())[0] | |
| cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| total_focus_time = (await cursor.fetchone())[0] or 0 | |
| cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| avg_focus_score = (await cursor.fetchone())[0] or 0.0 | |
| cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC") | |
| dates = [row[0] for row in await cursor.fetchall()] | |
| streak_days = 0 | |
| if dates: | |
| current_date = datetime.now().date() | |
| for i, date_str in enumerate(dates): | |
| session_date = datetime.fromisoformat(date_str).date() | |
| expected_date = current_date - timedelta(days=i) | |
| if session_date == expected_date: streak_days += 1 | |
| else: break | |
| return { | |
| 'total_sessions': total_sessions, | |
| 'total_focus_time': int(total_focus_time), | |
| 'avg_focus_score': round(avg_focus_score, 3), | |
| 'streak_days': streak_days | |
| } | |
| async def get_available_models(): | |
| """Return model names, statuses, and which is currently active.""" | |
| statuses = {} | |
| errors = {} | |
| available = [] | |
| for name, p in pipelines.items(): | |
| if name == "l2cs": | |
| if p is not None: | |
| statuses[name] = "ready" | |
| available.append(name) | |
| elif is_l2cs_weights_available(): | |
| statuses[name] = "lazy" | |
| available.append(name) | |
| elif _l2cs_error: | |
| statuses[name] = "error" | |
| errors[name] = _l2cs_error | |
| else: | |
| statuses[name] = "unavailable" | |
| elif p is not None: | |
| statuses[name] = "ready" | |
| available.append(name) | |
| else: | |
| statuses[name] = "unavailable" | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| current = row[0] if row else "mlp" | |
| if current not in available and available: | |
| current = available[0] | |
| l2cs_boost_available = ( | |
| statuses.get("l2cs") in ("ready", "lazy") and current != "l2cs" | |
| ) | |
| return { | |
| "available": available, | |
| "current": current, | |
| "statuses": statuses, | |
| "errors": errors, | |
| "l2cs_boost": _l2cs_boost_enabled, | |
| "l2cs_boost_available": l2cs_boost_available, | |
| } | |
| async def l2cs_status(): | |
| """L2CS-specific status: weights available, loaded, and calibration info.""" | |
| loaded = pipelines.get("l2cs") is not None | |
| return { | |
| "weights_available": is_l2cs_weights_available(), | |
| "loaded": loaded, | |
| "error": _l2cs_error, | |
| } | |
| async def get_mesh_topology(): | |
| """Return tessellation edge pairs for client-side face mesh drawing (cached by client).""" | |
| return {"tessellation": _TESSELATION_CONNS} | |
| async def health_check(): | |
| available = [name for name, p in pipelines.items() if p is not None] | |
| return {"status": "healthy", "models_loaded": available, "database": os.path.exists(db_path)} | |
| # ================ STATIC FILES (SPA SUPPORT) ================ | |
| # Resolve static dir from this file so it works regardless of cwd | |
| _STATIC_DIR = Path(__file__).resolve().parent / "static" | |
| _ASSETS_DIR = _STATIC_DIR / "assets" | |
| # 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all | |
| if _ASSETS_DIR.is_dir(): | |
| app.mount("/assets", StaticFiles(directory=str(_ASSETS_DIR)), name="assets") | |
| # 2. Catch-all for SPA: serve index.html for app routes, never for /assets (would break JS MIME type) | |
| async def serve_react_app(full_path: str, request: Request): | |
| if full_path.startswith("api") or full_path.startswith("ws"): | |
| raise HTTPException(status_code=404, detail="Not Found") | |
| # Don't serve HTML for asset paths; let them 404 so we don't break module script loading | |
| if full_path.startswith("assets") or full_path.startswith("assets/"): | |
| raise HTTPException(status_code=404, detail="Not Found") | |
| index_path = _STATIC_DIR / "index.html" | |
| if index_path.is_file(): | |
| return FileResponse(str(index_path)) | |
| return {"message": "React app not found. Please run 'npm run build' and copy dist to static."} | |