Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| from __future__ import annotations | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Any | |
| import base64 | |
| import cv2 | |
| import numpy as np | |
| import aiosqlite | |
| import json | |
| from datetime import datetime, timedelta | |
| import math | |
| import os | |
| from pathlib import Path | |
| from typing import Callable | |
| from contextlib import asynccontextmanager | |
| import asyncio | |
| import concurrent.futures | |
| import threading | |
| import logging | |
| from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack | |
| logger = logging.getLogger(__name__) | |
| from av import VideoFrame | |
| from mediapipe.tasks.python.vision import FaceLandmarksConnections | |
| from ui.pipeline import ( | |
| FaceMeshPipeline, MLPPipeline, HybridFocusPipeline, XGBoostPipeline, | |
| L2CSPipeline, is_l2cs_weights_available, | |
| ) | |
| from models.face_mesh import FaceMeshDetector | |
| # ================ FACE MESH DRAWING (server-side, for WebRTC) ================ | |
| _FONT = cv2.FONT_HERSHEY_SIMPLEX | |
| _CYAN = (255, 255, 0) | |
| _GREEN = (0, 255, 0) | |
| _MAGENTA = (255, 0, 255) | |
| _ORANGE = (0, 165, 255) | |
| _RED = (0, 0, 255) | |
| _WHITE = (255, 255, 255) | |
| _LIGHT_GREEN = (144, 238, 144) | |
| _TESSELATION_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_TESSELATION] | |
| _CONTOUR_CONNS = [(c.start, c.end) for c in FaceLandmarksConnections.FACE_LANDMARKS_CONTOURS] | |
| _LEFT_EYEBROW = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46] | |
| _RIGHT_EYEBROW = [300, 293, 334, 296, 336, 285, 295, 282, 283, 276] | |
| _NOSE_BRIDGE = [6, 197, 195, 5, 4, 1, 19, 94, 2] | |
| _LIPS_OUTER = [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185, 61] | |
| _LIPS_INNER = [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191, 78] | |
| _LEFT_EAR_POINTS = [33, 160, 158, 133, 153, 145] | |
| _RIGHT_EAR_POINTS = [362, 385, 387, 263, 373, 380] | |
| def _lm_px(lm, idx, w, h): | |
| return (int(lm[idx, 0] * w), int(lm[idx, 1] * h)) | |
| def _draw_polyline(frame, lm, indices, w, h, color, thickness): | |
| for i in range(len(indices) - 1): | |
| cv2.line(frame, _lm_px(lm, indices[i], w, h), _lm_px(lm, indices[i + 1], w, h), color, thickness, cv2.LINE_AA) | |
| def _draw_face_mesh(frame, lm, w, h, draw_gaze_lines=False): | |
| """Draw tessellation, contours, eyebrows, nose, lips, eyes, irises, and optionally gaze lines.""" | |
| # Tessellation (gray triangular grid, semi-transparent) | |
| overlay = frame.copy() | |
| for s, e in _TESSELATION_CONNS: | |
| cv2.line(overlay, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), (200, 200, 200), 1, cv2.LINE_AA) | |
| cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame) | |
| # Contours | |
| for s, e in _CONTOUR_CONNS: | |
| cv2.line(frame, _lm_px(lm, s, w, h), _lm_px(lm, e, w, h), _CYAN, 1, cv2.LINE_AA) | |
| # Eyebrows | |
| _draw_polyline(frame, lm, _LEFT_EYEBROW, w, h, _LIGHT_GREEN, 2) | |
| _draw_polyline(frame, lm, _RIGHT_EYEBROW, w, h, _LIGHT_GREEN, 2) | |
| # Nose | |
| _draw_polyline(frame, lm, _NOSE_BRIDGE, w, h, _ORANGE, 1) | |
| # Lips | |
| _draw_polyline(frame, lm, _LIPS_OUTER, w, h, _MAGENTA, 1) | |
| _draw_polyline(frame, lm, _LIPS_INNER, w, h, (200, 0, 200), 1) | |
| # Eyes | |
| left_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.LEFT_EYE_INDICES], dtype=np.int32) | |
| cv2.polylines(frame, [left_pts], True, _GREEN, 2, cv2.LINE_AA) | |
| right_pts = np.array([_lm_px(lm, i, w, h) for i in FaceMeshDetector.RIGHT_EYE_INDICES], dtype=np.int32) | |
| cv2.polylines(frame, [right_pts], True, _GREEN, 2, cv2.LINE_AA) | |
| # EAR key points | |
| for indices in [_LEFT_EAR_POINTS, _RIGHT_EAR_POINTS]: | |
| for idx in indices: | |
| cv2.circle(frame, _lm_px(lm, idx, w, h), 3, (0, 255, 255), -1, cv2.LINE_AA) | |
| # Irises (always draw) + gaze lines (only when eye gaze is enabled) | |
| for iris_idx, eye_inner, eye_outer in [ | |
| (FaceMeshDetector.LEFT_IRIS_INDICES, 133, 33), | |
| (FaceMeshDetector.RIGHT_IRIS_INDICES, 362, 263), | |
| ]: | |
| iris_pts = np.array([_lm_px(lm, i, w, h) for i in iris_idx], dtype=np.int32) | |
| center = iris_pts[0] | |
| if len(iris_pts) >= 5: | |
| radii = [np.linalg.norm(iris_pts[j] - center) for j in range(1, 5)] | |
| radius = max(int(np.mean(radii)), 2) | |
| cv2.circle(frame, tuple(center), radius, _MAGENTA, 2, cv2.LINE_AA) | |
| cv2.circle(frame, tuple(center), 2, _WHITE, -1, cv2.LINE_AA) | |
| if draw_gaze_lines: | |
| eye_cx = int((lm[eye_inner, 0] + lm[eye_outer, 0]) / 2.0 * w) | |
| eye_cy = int((lm[eye_inner, 1] + lm[eye_outer, 1]) / 2.0 * h) | |
| dx, dy = center[0] - eye_cx, center[1] - eye_cy | |
| cv2.line(frame, tuple(center), (int(center[0] + dx * 3), int(center[1] + dy * 3)), _RED, 1, cv2.LINE_AA) | |
| def _draw_hud(frame, result, model_name): | |
| """Draw status bar and detail overlay matching live_demo.py.""" | |
| h, w = frame.shape[:2] | |
| is_focused = result["is_focused"] | |
| status = "FOCUSED" if is_focused else "NOT FOCUSED" | |
| color = _GREEN if is_focused else _RED | |
| # Top bar | |
| cv2.rectangle(frame, (0, 0), (w, 55), (0, 0, 0), -1) | |
| cv2.putText(frame, status, (10, 28), _FONT, 0.8, color, 2, cv2.LINE_AA) | |
| cv2.putText(frame, model_name.upper(), (w - 150, 28), _FONT, 0.45, _WHITE, 1, cv2.LINE_AA) | |
| # Detail line | |
| conf = result.get("mlp_prob", result.get("raw_score", 0.0)) | |
| mar_s = f" MAR:{result['mar']:.2f}" if result.get("mar") is not None else "" | |
| sf = result.get("s_face", 0) | |
| se = result.get("s_eye", 0) | |
| detail = f"conf:{conf:.2f} S_face:{sf:.2f} S_eye:{se:.2f}{mar_s}" | |
| cv2.putText(frame, detail, (10, 48), _FONT, 0.4, _WHITE, 1, cv2.LINE_AA) | |
| # Head pose (top right) | |
| if result.get("yaw") is not None: | |
| cv2.putText(frame, f"yaw:{result['yaw']:+.0f} pitch:{result['pitch']:+.0f} roll:{result['roll']:+.0f}", | |
| (w - 280, 48), _FONT, 0.4, (180, 180, 180), 1, cv2.LINE_AA) | |
| # Yawn indicator | |
| if result.get("is_yawning"): | |
| cv2.putText(frame, "YAWN", (10, 75), _FONT, 0.7, _ORANGE, 2, cv2.LINE_AA) | |
| # Landmark indices used for face mesh drawing on client (union of all groups). | |
| # Sending only these instead of all 478 saves ~60% of the landmarks payload. | |
| _MESH_INDICES = sorted( | |
| set( | |
| [ | |
| 10, 338, 297, 332, 284, 251, 389, 356, 454, | |
| 323, 361, 288, 397, 365, 379, 378, 400, 377, | |
| 152, 148, 176, 149, 150, 136, 172, 58, 132, | |
| 93, 234, 127, 162, 21, 54, 103, 67, 109, | |
| ] # face oval | |
| + [33, 7, 163, 144, 145, 153, 154, 155, 133, 173, 157, 158, 159, 160, 161, 246] # left eye | |
| + [362, 382, 381, 380, 374, 373, 390, 249, 263, 466, 388, 387, 386, 385, 384, 398] # right eye | |
| + [468, 469, 470, 471, 472, 473, 474, 475, 476, 477] # irises | |
| + [70, 63, 105, 66, 107, 55, 65, 52, 53, 46] # left eyebrow | |
| + [300, 293, 334, 296, 336, 285, 295, 282, 283, 276] # right eyebrow | |
| + [6, 197, 195, 5, 4, 1, 19, 94, 2] # nose bridge | |
| + [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 409, 270, 269, 267, 0, 37, 39, 40, 185] # lips outer | |
| + [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308, 415, 310, 311, 312, 13, 82, 81, 80, 191] # lips inner | |
| + [33, 160, 158, 133, 153, 145] # left EAR key points | |
| + [362, 385, 387, 263, 373, 380] # right EAR key points | |
| ) | |
| ) | |
| # Build a lookup: original_index -> position in sparse array, so client can reconstruct. | |
| _MESH_INDEX_SET = set(_MESH_INDICES) | |
| async def lifespan(app): | |
| global _cached_model_name | |
| print("Starting Focus Guard API") | |
| await init_database() | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| if row: | |
| _cached_model_name = row[0] | |
| print("[OK] Database initialized") | |
| try: | |
| pipelines["geometric"] = FaceMeshPipeline() | |
| print("[OK] FaceMeshPipeline (geometric) loaded") | |
| except Exception as e: | |
| print(f"[WARN] FaceMeshPipeline unavailable: {e}") | |
| try: | |
| pipelines["mlp"] = MLPPipeline() | |
| print("[OK] MLPPipeline loaded") | |
| except Exception as e: | |
| print(f"[ERR] Failed to load MLPPipeline: {e}") | |
| try: | |
| pipelines["hybrid"] = HybridFocusPipeline() | |
| print("[OK] HybridFocusPipeline loaded") | |
| except Exception as e: | |
| print(f"[WARN] HybridFocusPipeline unavailable: {e}") | |
| try: | |
| pipelines["xgboost"] = XGBoostPipeline() | |
| print("[OK] XGBoostPipeline loaded") | |
| except Exception as e: | |
| print(f"[ERR] Failed to load XGBoostPipeline: {e}") | |
| resolved_model = _first_available_pipeline_name(_cached_model_name) | |
| if resolved_model is not None and resolved_model != _cached_model_name: | |
| _cached_model_name = resolved_model | |
| async with aiosqlite.connect(db_path) as db: | |
| await db.execute( | |
| "UPDATE user_settings SET model_name = ? WHERE id = 1", | |
| (_cached_model_name,), | |
| ) | |
| await db.commit() | |
| if resolved_model is not None: | |
| print(f"[OK] Active model set to {resolved_model}") | |
| if is_l2cs_weights_available(): | |
| print("[OK] L2CS weights found (lazy-loaded on first use)") | |
| else: | |
| print("[WARN] L2CS weights not found") | |
| yield | |
| _inference_executor.shutdown(wait=False) | |
| print("Shutting down Focus Guard API") | |
| app = FastAPI(title="Focus Guard API", lifespan=lifespan) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables | |
| db_path = "focus_guard.db" | |
| pcs = set() | |
| _cached_model_name = "mlp" | |
| _l2cs_boost_enabled = False | |
| async def _wait_for_ice_gathering(pc: RTCPeerConnection): | |
| if pc.iceGatheringState == "complete": | |
| return | |
| done = asyncio.Event() | |
| def _on_state_change(): | |
| if pc.iceGatheringState == "complete": | |
| done.set() | |
| await done.wait() | |
| # ================ DATABASE MODELS ================ | |
| async def init_database(): | |
| """Initialize SQLite database with required tables""" | |
| async with aiosqlite.connect(db_path) as db: | |
| # FocusSessions table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS focus_sessions ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| start_time TIMESTAMP NOT NULL, | |
| end_time TIMESTAMP, | |
| duration_seconds INTEGER DEFAULT 0, | |
| focus_score REAL DEFAULT 0.0, | |
| total_frames INTEGER DEFAULT 0, | |
| focused_frames INTEGER DEFAULT 0, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| # FocusEvents table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS focus_events ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id INTEGER NOT NULL, | |
| timestamp TIMESTAMP NOT NULL, | |
| is_focused BOOLEAN NOT NULL, | |
| confidence REAL NOT NULL, | |
| detection_data TEXT, | |
| FOREIGN KEY (session_id) REFERENCES focus_sessions (id) | |
| ) | |
| """) | |
| # UserSettings table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS user_settings ( | |
| id INTEGER PRIMARY KEY CHECK (id = 1), | |
| model_name TEXT DEFAULT 'mlp' | |
| ) | |
| """) | |
| # Insert default settings if not exists | |
| await db.execute(""" | |
| INSERT OR IGNORE INTO user_settings (id, model_name) | |
| VALUES (1, 'mlp') | |
| """) | |
| await db.commit() | |
| # ================ PYDANTIC MODELS ================ | |
| class SessionCreate(BaseModel): | |
| pass | |
| class SessionEnd(BaseModel): | |
| session_id: int | |
| class SettingsUpdate(BaseModel): | |
| model_name: Optional[str] = None | |
| l2cs_boost: Optional[bool] = None | |
| class VideoTransformTrack(VideoStreamTrack): | |
| def __init__(self, track, session_id: int, get_channel: Callable[[], Any]): | |
| super().__init__() | |
| self.track = track | |
| self.session_id = session_id | |
| self.get_channel = get_channel | |
| self.last_inference_time = 0 | |
| self.min_inference_interval = 1 / 60 | |
| self.last_frame = None | |
| async def recv(self): | |
| frame = await self.track.recv() | |
| img = frame.to_ndarray(format="bgr24") | |
| if img is None: | |
| return frame | |
| # Normalize size for inference/drawing | |
| img = cv2.resize(img, (640, 480)) | |
| now = datetime.now().timestamp() | |
| do_infer = (now - self.last_inference_time) >= self.min_inference_interval | |
| if do_infer: | |
| self.last_inference_time = now | |
| model_name = _cached_model_name | |
| if model_name == "l2cs" and pipelines.get("l2cs") is None: | |
| _ensure_l2cs() | |
| if model_name not in pipelines or pipelines.get(model_name) is None: | |
| model_name = 'mlp' | |
| active_pipeline = pipelines.get(model_name) | |
| if active_pipeline is not None: | |
| loop = asyncio.get_event_loop() | |
| out = await loop.run_in_executor( | |
| _inference_executor, | |
| _process_frame_safe, | |
| active_pipeline, | |
| img, | |
| model_name, | |
| ) | |
| is_focused = out["is_focused"] | |
| confidence = out.get("mlp_prob", out.get("raw_score", 0.0)) | |
| metadata = { | |
| "s_face": out.get("s_face", 0.0), | |
| "s_eye": out.get("s_eye", 0.0), | |
| "mar": out.get("mar", 0.0), | |
| "model": model_name, | |
| } | |
| # Draw face mesh + HUD on the video frame | |
| h_f, w_f = img.shape[:2] | |
| lm = out.get("landmarks") | |
| eye_gaze_enabled = _l2cs_boost_enabled or model_name == "l2cs" | |
| if lm is not None: | |
| _draw_face_mesh(img, lm, w_f, h_f, draw_gaze_lines=eye_gaze_enabled) | |
| _draw_hud(img, out, model_name) | |
| else: | |
| is_focused = False | |
| confidence = 0.0 | |
| metadata = {"model": model_name} | |
| cv2.rectangle(img, (0, 0), (img.shape[1], 55), (0, 0, 0), -1) | |
| cv2.putText(img, "NO MODEL", (10, 28), _FONT, 0.8, _RED, 2, cv2.LINE_AA) | |
| if self.session_id: | |
| await store_focus_event(self.session_id, is_focused, confidence, metadata) | |
| channel = self.get_channel() | |
| if channel and channel.readyState == "open": | |
| try: | |
| channel.send(json.dumps({ | |
| "type": "detection", | |
| "focused": is_focused, | |
| "confidence": round(confidence, 3), | |
| "detections": [], | |
| "model": model_name, | |
| })) | |
| except Exception: | |
| pass | |
| self.last_frame = img | |
| elif self.last_frame is not None: | |
| img = self.last_frame | |
| new_frame = VideoFrame.from_ndarray(img, format="bgr24") | |
| new_frame.pts = frame.pts | |
| new_frame.time_base = frame.time_base | |
| return new_frame | |
| # ================ DATABASE OPERATIONS ================ | |
| async def create_session(): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute( | |
| "INSERT INTO focus_sessions (start_time) VALUES (?)", | |
| (datetime.now().isoformat(),) | |
| ) | |
| await db.commit() | |
| return cursor.lastrowid | |
| async def end_session(session_id: int): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute( | |
| "SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?", | |
| (session_id,) | |
| ) | |
| row = await cursor.fetchone() | |
| if not row: | |
| return None | |
| start_time_str, total_frames, focused_frames = row | |
| start_time = datetime.fromisoformat(start_time_str) | |
| end_time = datetime.now() | |
| duration = (end_time - start_time).total_seconds() | |
| focus_score = focused_frames / total_frames if total_frames > 0 else 0.0 | |
| await db.execute(""" | |
| UPDATE focus_sessions | |
| SET end_time = ?, duration_seconds = ?, focus_score = ? | |
| WHERE id = ? | |
| """, (end_time.isoformat(), int(duration), focus_score, session_id)) | |
| await db.commit() | |
| return { | |
| 'session_id': session_id, | |
| 'start_time': start_time_str, | |
| 'end_time': end_time.isoformat(), | |
| 'duration_seconds': int(duration), | |
| 'focus_score': round(focus_score, 3), | |
| 'total_frames': total_frames, | |
| 'focused_frames': focused_frames | |
| } | |
| async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict): | |
| async with aiosqlite.connect(db_path) as db: | |
| await db.execute(""" | |
| INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata))) | |
| await db.execute(""" | |
| UPDATE focus_sessions | |
| SET total_frames = total_frames + 1, | |
| focused_frames = focused_frames + ? | |
| WHERE id = ? | |
| """, (1 if is_focused else 0, session_id)) | |
| await db.commit() | |
| class _EventBuffer: | |
| """Buffer focus events in memory and flush to DB in batches to avoid per-frame DB writes.""" | |
| def __init__(self, flush_interval: float = 2.0): | |
| self._buf: list = [] | |
| self._lock = asyncio.Lock() | |
| self._flush_interval = flush_interval | |
| self._task: asyncio.Task | None = None | |
| self._total_frames = 0 | |
| self._focused_frames = 0 | |
| def start(self): | |
| if self._task is None: | |
| self._task = asyncio.create_task(self._flush_loop()) | |
| async def stop(self): | |
| if self._task: | |
| self._task.cancel() | |
| try: | |
| await self._task | |
| except asyncio.CancelledError: | |
| pass | |
| self._task = None | |
| await self._flush() | |
| def add(self, session_id: int, is_focused: bool, confidence: float, metadata: dict): | |
| self._buf.append((session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata))) | |
| self._total_frames += 1 | |
| if is_focused: | |
| self._focused_frames += 1 | |
| async def _flush_loop(self): | |
| while True: | |
| await asyncio.sleep(self._flush_interval) | |
| await self._flush() | |
| async def _flush(self): | |
| async with self._lock: | |
| if not self._buf: | |
| return | |
| batch = self._buf[:] | |
| total = self._total_frames | |
| focused = self._focused_frames | |
| self._buf.clear() | |
| self._total_frames = 0 | |
| self._focused_frames = 0 | |
| if not batch: | |
| return | |
| session_id = batch[0][0] | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| await db.executemany(""" | |
| INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, batch) | |
| await db.execute(""" | |
| UPDATE focus_sessions | |
| SET total_frames = total_frames + ?, | |
| focused_frames = focused_frames + ? | |
| WHERE id = ? | |
| """, (total, focused, session_id)) | |
| await db.commit() | |
| except Exception as e: | |
| print(f"[DB] Flush error: {e}") | |
| # ================ STARTUP/SHUTDOWN ================ | |
| pipelines = { | |
| "geometric": None, | |
| "mlp": None, | |
| "hybrid": None, | |
| "xgboost": None, | |
| "l2cs": None, | |
| } | |
| # Thread pool for CPU-bound inference so the event loop stays responsive. | |
| _inference_executor = concurrent.futures.ThreadPoolExecutor( | |
| max_workers=4, | |
| thread_name_prefix="inference", | |
| ) | |
| # One lock per pipeline so shared state (TemporalTracker, etc.) is not corrupted when | |
| # multiple frames are processed in parallel by the thread pool. | |
| _pipeline_locks = {name: threading.Lock() for name in ("geometric", "mlp", "hybrid", "xgboost", "l2cs")} | |
| _l2cs_load_lock = threading.Lock() | |
| _l2cs_error: str | None = None | |
| def _ensure_l2cs(): | |
| # lazy-load L2CS on first use, double-checked locking | |
| global _l2cs_error | |
| if pipelines["l2cs"] is not None: | |
| return True | |
| with _l2cs_load_lock: | |
| if pipelines["l2cs"] is not None: | |
| return True | |
| if not is_l2cs_weights_available(): | |
| _l2cs_error = "Weights not found" | |
| return False | |
| try: | |
| pipelines["l2cs"] = L2CSPipeline() | |
| _l2cs_error = None | |
| print("[OK] L2CSPipeline lazy-loaded") | |
| return True | |
| except Exception as e: | |
| _l2cs_error = str(e) | |
| print(f"[ERR] L2CS lazy-load failed: {e}") | |
| return False | |
| def _process_frame_safe(pipeline, frame, model_name): | |
| with _pipeline_locks[model_name]: | |
| return pipeline.process_frame(frame) | |
| def _first_available_pipeline_name(preferred: str | None = None) -> str | None: | |
| if preferred and preferred in pipelines and pipelines.get(preferred) is not None: | |
| return preferred | |
| for name, pipeline in pipelines.items(): | |
| if pipeline is not None: | |
| return name | |
| return None | |
| _BOOST_BASE_W = 0.35 | |
| _BOOST_L2CS_W = 0.65 | |
| _BOOST_VETO = 0.38 # L2CS below this -> forced not-focused | |
| def _process_frame_with_l2cs_boost(base_pipeline, frame, base_model_name): | |
| # run base model | |
| with _pipeline_locks[base_model_name]: | |
| base_out = base_pipeline.process_frame(frame) | |
| l2cs_pipe = pipelines.get("l2cs") | |
| if l2cs_pipe is None: | |
| base_out["boost_active"] = False | |
| return base_out | |
| # run L2CS | |
| with _pipeline_locks["l2cs"]: | |
| l2cs_out = l2cs_pipe.process_frame(frame) | |
| base_score = base_out.get("mlp_prob", base_out.get("raw_score", 0.0)) | |
| l2cs_score = l2cs_out.get("raw_score", 0.0) | |
| # veto: gaze clearly off-screen overrides base model | |
| if l2cs_score < _BOOST_VETO: | |
| fused_score = l2cs_score * 0.8 | |
| is_focused = False | |
| else: | |
| fused_score = _BOOST_BASE_W * base_score + _BOOST_L2CS_W * l2cs_score | |
| is_focused = fused_score >= 0.52 | |
| base_out["raw_score"] = fused_score | |
| base_out["is_focused"] = is_focused | |
| base_out["boost_active"] = True | |
| base_out["base_score"] = round(base_score, 3) | |
| base_out["l2cs_score"] = round(l2cs_score, 3) | |
| if l2cs_out.get("gaze_yaw") is not None: | |
| base_out["gaze_yaw"] = l2cs_out["gaze_yaw"] | |
| base_out["gaze_pitch"] = l2cs_out["gaze_pitch"] | |
| return base_out | |
| # ================ WEBRTC SIGNALING ================ | |
| async def webrtc_offer(offer: dict): | |
| try: | |
| pc = RTCPeerConnection() | |
| pcs.add(pc) | |
| session_id = await create_session() | |
| channel_ref = {"channel": None} | |
| def on_datachannel(channel): | |
| channel_ref["channel"] = channel | |
| def on_track(track): | |
| if track.kind == "video": | |
| local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"]) | |
| pc.addTrack(local_track) | |
| async def on_ended(): | |
| pass | |
| async def on_connectionstatechange(): | |
| if pc.connectionState in ("failed", "closed", "disconnected"): | |
| try: | |
| await end_session(session_id) | |
| except Exception as e: | |
| logger.warning("WebRTC session end failed: %s", e) | |
| pcs.discard(pc) | |
| await pc.close() | |
| await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"])) | |
| answer = await pc.createAnswer() | |
| await pc.setLocalDescription(answer) | |
| await _wait_for_ice_gathering(pc) | |
| return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id} | |
| except Exception as e: | |
| logger.exception("WebRTC offer failed") | |
| raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}") | |
| # ================ WEBSOCKET ================ | |
| async def websocket_endpoint(websocket: WebSocket): | |
| from models.gaze_calibration import GazeCalibration | |
| from models.gaze_eye_fusion import GazeEyeFusion | |
| await websocket.accept() | |
| session_id = None | |
| frame_count = 0 | |
| running = True | |
| event_buffer = _EventBuffer(flush_interval=2.0) | |
| # Calibration state (per-connection) | |
| # verifying: after fit, show a verification target and check gaze accuracy | |
| _cal: dict = {"cal": None, "collecting": False, "fusion": None, | |
| "verifying": False, "verify_target": None, "verify_samples": []} | |
| # Latest frame slot — only the most recent frame is kept, older ones are dropped. | |
| _slot = {"frame": None} | |
| _frame_ready = asyncio.Event() | |
| async def _receive_loop(): | |
| """Receive messages as fast as possible. Binary = frame, text = control.""" | |
| nonlocal session_id, running | |
| try: | |
| while running: | |
| msg = await websocket.receive() | |
| msg_type = msg.get("type", "") | |
| if msg_type == "websocket.disconnect": | |
| running = False | |
| _frame_ready.set() | |
| return | |
| # Binary message → JPEG frame (fast path, no base64) | |
| raw_bytes = msg.get("bytes") | |
| if raw_bytes is not None and len(raw_bytes) > 0: | |
| _slot["frame"] = raw_bytes | |
| _frame_ready.set() | |
| continue | |
| # Text message → JSON control command (or legacy base64 frame) | |
| text = msg.get("text") | |
| if not text: | |
| continue | |
| data = json.loads(text) | |
| if data["type"] == "frame": | |
| _slot["frame"] = base64.b64decode(data["image"]) | |
| _frame_ready.set() | |
| elif data["type"] == "start_session": | |
| session_id = await create_session() | |
| event_buffer.start() | |
| for p in pipelines.values(): | |
| if p is not None and hasattr(p, "reset_session"): | |
| p.reset_session() | |
| await websocket.send_json({"type": "session_started", "session_id": session_id}) | |
| elif data["type"] == "end_session": | |
| if session_id: | |
| await event_buffer.stop() | |
| summary = await end_session(session_id) | |
| if summary: | |
| await websocket.send_json({"type": "session_ended", "summary": summary}) | |
| session_id = None | |
| # ---- Calibration commands ---- | |
| elif data["type"] == "calibration_start": | |
| loop = asyncio.get_event_loop() | |
| await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| _cal["cal"] = GazeCalibration() | |
| _cal["collecting"] = True | |
| _cal["fusion"] = None | |
| # Tell L2CS pipeline to run every frame during calibration | |
| l2cs_pipe = pipelines.get("l2cs") | |
| if l2cs_pipe is not None and hasattr(l2cs_pipe, '_calibrating'): | |
| l2cs_pipe._calibrating = True | |
| cal = _cal["cal"] | |
| await websocket.send_json({ | |
| "type": "calibration_started", | |
| "num_points": cal.num_points, | |
| "target": list(cal.current_target), | |
| "index": cal.current_index, | |
| }) | |
| elif data["type"] == "calibration_next": | |
| cal = _cal.get("cal") | |
| if _cal.get("verifying"): | |
| # Verification phase complete — user clicked next | |
| _cal["verifying"] = False | |
| _cal["collecting"] = False | |
| # Re-enable frame skipping | |
| l2cs_pipe = pipelines.get("l2cs") | |
| if l2cs_pipe is not None and hasattr(l2cs_pipe, '_calibrating'): | |
| l2cs_pipe._calibrating = False | |
| # Check verification samples | |
| v_samples = _cal.get("verify_samples", []) | |
| vt = _cal.get("verify_target", [0.5, 0.5]) | |
| if len(v_samples) >= 3: | |
| med_yaw = float(np.median([s[0] for s in v_samples])) | |
| med_pitch = float(np.median([s[1] for s in v_samples])) | |
| px, py, err, passed = cal.verify(med_yaw, med_pitch, vt[0], vt[1]) | |
| print(f"[CAL] Verification: target=({vt[0]:.2f},{vt[1]:.2f}) " | |
| f"predicted=({px:.3f},{py:.3f}) error={err:.3f} passed={passed}") | |
| else: | |
| passed = True # not enough samples, trust the fit | |
| _cal["fusion"] = GazeEyeFusion(cal) | |
| await websocket.send_json({ | |
| "type": "calibration_done", | |
| "success": True, | |
| "verified": passed, | |
| }) | |
| elif cal is not None: | |
| more = cal.advance() | |
| if more: | |
| await websocket.send_json({ | |
| "type": "calibration_point", | |
| "target": list(cal.current_target), | |
| "index": cal.current_index, | |
| }) | |
| else: | |
| # All 9 points collected — try to fit | |
| _cal["collecting"] = False | |
| ok = cal.fit() | |
| if ok: | |
| # Enter verification phase: show center target | |
| _cal["verifying"] = True | |
| _cal["verify_target"] = [0.5, 0.5] | |
| _cal["verify_samples"] = [] | |
| await websocket.send_json({ | |
| "type": "calibration_verify", | |
| "target": [0.5, 0.5], | |
| "message": "Look at the dot to verify calibration", | |
| }) | |
| else: | |
| # Re-enable frame skipping | |
| l2cs_pipe = pipelines.get("l2cs") | |
| if l2cs_pipe is not None and hasattr(l2cs_pipe, '_calibrating'): | |
| l2cs_pipe._calibrating = False | |
| await websocket.send_json( | |
| { | |
| "type": "calibration_done", | |
| "success": False, | |
| "error": "Not enough samples", | |
| } | |
| ) | |
| elif data["type"] == "calibration_cancel": | |
| _cal["cal"] = None | |
| _cal["collecting"] = False | |
| _cal["fusion"] = None | |
| l2cs_pipe = pipelines.get("l2cs") | |
| if l2cs_pipe is not None and hasattr(l2cs_pipe, '_calibrating'): | |
| l2cs_pipe._calibrating = False | |
| await websocket.send_json({"type": "calibration_cancelled"}) | |
| except WebSocketDisconnect: | |
| running = False | |
| _frame_ready.set() | |
| except Exception as e: | |
| print(f"[WS] receive error: {e}") | |
| running = False | |
| _frame_ready.set() | |
| async def _process_loop(): | |
| """Process only the latest frame, dropping stale ones.""" | |
| nonlocal frame_count, running | |
| loop = asyncio.get_event_loop() | |
| while running: | |
| await _frame_ready.wait() | |
| _frame_ready.clear() | |
| if not running: | |
| return | |
| raw = _slot["frame"] | |
| _slot["frame"] = None | |
| if raw is None: | |
| continue | |
| try: | |
| nparr = np.frombuffer(raw, np.uint8) | |
| frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if frame is None: | |
| continue | |
| frame = cv2.resize(frame, (640, 480)) | |
| # During calibration collection, always use L2CS | |
| collecting = _cal.get("collecting", False) | |
| if collecting: | |
| if pipelines.get("l2cs") is None: | |
| await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| use_model = "l2cs" if pipelines.get("l2cs") is not None else _cached_model_name | |
| else: | |
| use_model = _cached_model_name | |
| model_name = use_model | |
| if model_name == "l2cs" and pipelines.get("l2cs") is None: | |
| await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| if model_name not in pipelines or pipelines.get(model_name) is None: | |
| model_name = "mlp" | |
| active_pipeline = pipelines.get(model_name) | |
| # L2CS boost: run L2CS alongside base model | |
| use_boost = ( | |
| _l2cs_boost_enabled | |
| and model_name != "l2cs" | |
| and pipelines.get("l2cs") is not None | |
| and not collecting | |
| ) | |
| landmarks_list = None | |
| out = None | |
| if active_pipeline is not None: | |
| if use_boost: | |
| out = await loop.run_in_executor( | |
| _inference_executor, | |
| _process_frame_with_l2cs_boost, | |
| active_pipeline, | |
| frame, | |
| model_name, | |
| ) | |
| else: | |
| out = await loop.run_in_executor( | |
| _inference_executor, | |
| _process_frame_safe, | |
| active_pipeline, | |
| frame, | |
| model_name, | |
| ) | |
| is_focused = out["is_focused"] | |
| confidence = out.get("mlp_prob", out.get("raw_score", 0.0)) | |
| lm = out.get("landmarks") | |
| if lm is not None: | |
| landmarks_list = [ | |
| [round(float(lm[i, 0]), 3), round(float(lm[i, 1]), 3)] | |
| for i in range(lm.shape[0]) | |
| ] | |
| # Calibration sample collection (L2CS gaze angles) | |
| if collecting and _cal.get("cal") is not None: | |
| pipe_yaw = out.get("gaze_yaw") | |
| pipe_pitch = out.get("gaze_pitch") | |
| if pipe_yaw is not None and pipe_pitch is not None: | |
| _cal["cal"].collect_sample(pipe_yaw, pipe_pitch) | |
| # Verification sample collection | |
| if _cal.get("verifying") and out.get("gaze_yaw") is not None: | |
| _cal["verify_samples"].append( | |
| (out["gaze_yaw"], out["gaze_pitch"]) | |
| ) | |
| # Gaze fusion (single call — applied before event logging | |
| # and response to avoid double-EMA smoothing) | |
| fusion = _cal.get("fusion") | |
| has_gaze = out.get("gaze_yaw") is not None | |
| fuse = None | |
| if fusion is not None and has_gaze and (model_name == "l2cs" or use_boost): | |
| fuse = fusion.update(out["gaze_yaw"], out["gaze_pitch"], lm) | |
| if model_name == "l2cs": | |
| # L2CS standalone: fusion fully controls focus decision | |
| is_focused = fuse["focused"] | |
| confidence = fuse["focus_score"] | |
| elif use_boost and not fuse["on_screen"]: | |
| # Boost mode: if gaze is clearly off-screen, override to unfocused | |
| is_focused = False | |
| confidence = min(confidence, 0.1) | |
| if session_id: | |
| metadata = { | |
| "s_face": out.get("s_face", 0.0), | |
| "s_eye": out.get("s_eye", 0.0), | |
| "mar": out.get("mar", 0.0), | |
| "model": model_name, | |
| } | |
| event_buffer.add(session_id, is_focused, confidence, metadata) | |
| else: | |
| is_focused = False | |
| confidence = 0.0 | |
| resp = { | |
| "type": "detection", | |
| "focused": is_focused, | |
| "confidence": round(confidence, 3), | |
| "detections": [], | |
| "model": model_name, | |
| "fc": frame_count, | |
| "frame_count": frame_count, | |
| "eye_gaze_enabled": _l2cs_boost_enabled or model_name == "l2cs", | |
| } | |
| if out is not None: | |
| if out.get("yaw") is not None: | |
| resp["yaw"] = round(out["yaw"], 1) | |
| resp["pitch"] = round(out["pitch"], 1) | |
| resp["roll"] = round(out["roll"], 1) | |
| if out.get("mar") is not None: | |
| resp["mar"] = round(out["mar"], 3) | |
| resp["sf"] = round(out.get("s_face", 0), 3) | |
| resp["se"] = round(out.get("s_eye", 0), 3) | |
| # Attach gaze fusion fields + raw gaze angles for visualization | |
| if fuse is not None: | |
| resp["gaze_x"] = fuse["gaze_x"] | |
| resp["gaze_y"] = fuse["gaze_y"] | |
| resp["on_screen"] = fuse["on_screen"] | |
| if model_name == "l2cs": | |
| resp["focused"] = fuse["focused"] | |
| resp["confidence"] = round(fuse["focus_score"], 3) | |
| elif use_boost and not fuse["on_screen"]: | |
| resp["focused"] = False | |
| resp["confidence"] = min(resp["confidence"], 0.1) | |
| if has_gaze: | |
| resp["gaze_yaw"] = round(out["gaze_yaw"], 4) | |
| resp["gaze_pitch"] = round(out["gaze_pitch"], 4) | |
| if out.get("boost_active"): | |
| resp["boost"] = True | |
| resp["base_score"] = out.get("base_score", 0) | |
| resp["l2cs_score"] = out.get("l2cs_score", 0) | |
| if landmarks_list is not None: | |
| resp["lm"] = landmarks_list | |
| await websocket.send_json(resp) | |
| frame_count += 1 | |
| except Exception as e: | |
| print(f"[WS] process error: {e}") | |
| try: | |
| await asyncio.gather(_receive_loop(), _process_loop()) | |
| except Exception: | |
| pass | |
| finally: | |
| running = False | |
| if session_id: | |
| await event_buffer.stop() | |
| await end_session(session_id) | |
| # ================ API ENDPOINTS ================ | |
| async def api_start_session(): | |
| session_id = await create_session() | |
| return {"session_id": session_id} | |
| async def api_end_session(data: SessionEnd): | |
| summary = await end_session(data.session_id) | |
| if not summary: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return summary | |
| async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| # limit=-1 returns all rows (export); otherwise paginate | |
| limit_clause = "LIMIT ? OFFSET ?" | |
| params = [] | |
| base_query = "SELECT * FROM focus_sessions" | |
| where_clause = "" | |
| if filter == "today": | |
| date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "week": | |
| date_filter = datetime.now() - timedelta(days=7) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "month": | |
| date_filter = datetime.now() - timedelta(days=30) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "all": | |
| where_clause = " WHERE end_time IS NOT NULL" | |
| query = f"{base_query}{where_clause} ORDER BY start_time DESC" | |
| if limit == -1: | |
| pass | |
| else: | |
| query += f" {limit_clause}" | |
| params.extend([limit, offset]) | |
| cursor = await db.execute(query, tuple(params)) | |
| rows = await cursor.fetchall() | |
| return [dict(row) for row in rows] | |
| async def import_sessions(sessions: List[dict]): | |
| count = 0 | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| for session in sessions: | |
| # Use .get() to handle potential missing fields from older versions or edits | |
| await db.execute(""" | |
| INSERT INTO focus_sessions ( | |
| start_time, end_time, duration_seconds, focus_score, | |
| total_frames, focused_frames, created_at | |
| ) | |
| VALUES (?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| session.get('start_time'), | |
| session.get('end_time'), | |
| session.get('duration_seconds', 0), | |
| session.get('focus_score', 0.0), | |
| session.get('total_frames', 0), | |
| session.get('focused_frames', 0), | |
| session.get('created_at', session.get('start_time')) | |
| )) | |
| count += 1 | |
| await db.commit() | |
| return {"status": "success", "count": count} | |
| except Exception as e: | |
| print(f"Import Error: {e}") | |
| return {"status": "error", "message": str(e)} | |
| async def clear_history(): | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| # Delete events first (foreign key good practice) | |
| await db.execute("DELETE FROM focus_events") | |
| await db.execute("DELETE FROM focus_sessions") | |
| await db.commit() | |
| return {"status": "success", "message": "History cleared"} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_session(session_id: int): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,)) | |
| row = await cursor.fetchone() | |
| if not row: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| session = dict(row) | |
| cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,)) | |
| events = [dict(r) for r in await cursor.fetchall()] | |
| session['events'] = events | |
| return session | |
| async def get_settings(): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| result = dict(row) if row else { | |
| "model_name": "mlp", | |
| } | |
| result['l2cs_boost'] = _l2cs_boost_enabled | |
| return result | |
| async def update_settings(settings: SettingsUpdate): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1") | |
| exists = await cursor.fetchone() | |
| if not exists: | |
| await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)") | |
| await db.commit() | |
| updates = [] | |
| params = [] | |
| if settings.model_name is not None and settings.model_name in pipelines: | |
| if settings.model_name == "l2cs": | |
| loop = asyncio.get_event_loop() | |
| loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| if not loaded: | |
| raise HTTPException(status_code=400, detail=f"L2CS model unavailable: {_l2cs_error}") | |
| elif pipelines[settings.model_name] is None: | |
| raise HTTPException(status_code=400, detail=f"Model '{settings.model_name}' not loaded") | |
| updates.append("model_name = ?") | |
| params.append(settings.model_name) | |
| global _cached_model_name | |
| _cached_model_name = settings.model_name | |
| if settings.l2cs_boost is not None: | |
| global _l2cs_boost_enabled | |
| if settings.l2cs_boost: | |
| loop = asyncio.get_event_loop() | |
| loaded = await loop.run_in_executor(_inference_executor, _ensure_l2cs) | |
| if not loaded: | |
| raise HTTPException(status_code=400, detail=f"L2CS boost unavailable: {_l2cs_error}") | |
| _l2cs_boost_enabled = settings.l2cs_boost | |
| if updates: | |
| query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1" | |
| await db.execute(query, tuple(params)) | |
| await db.commit() | |
| return {"status": "success", "updated": len(updates) > 0} | |
| async def get_system_stats(): | |
| """Return server CPU and memory usage for UI display.""" | |
| try: | |
| import psutil | |
| cpu = psutil.cpu_percent(interval=0.1) | |
| mem = psutil.virtual_memory() | |
| return { | |
| "cpu_percent": round(cpu, 1), | |
| "memory_percent": round(mem.percent, 1), | |
| "memory_used_mb": round(mem.used / (1024 * 1024), 0), | |
| "memory_total_mb": round(mem.total / (1024 * 1024), 0), | |
| } | |
| except ImportError: | |
| return { | |
| "cpu_percent": None, | |
| "memory_percent": None, | |
| "memory_used_mb": None, | |
| "memory_total_mb": None, | |
| } | |
| async def get_stats_summary(): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| total_sessions = (await cursor.fetchone())[0] | |
| cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| total_focus_time = (await cursor.fetchone())[0] or 0 | |
| cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| avg_focus_score = (await cursor.fetchone())[0] or 0.0 | |
| cursor = await db.execute( | |
| """ | |
| SELECT DISTINCT DATE(start_time) as session_date | |
| FROM focus_sessions | |
| WHERE end_time IS NOT NULL | |
| ORDER BY session_date DESC | |
| """ | |
| ) | |
| dates = [row[0] for row in await cursor.fetchall()] | |
| streak_days = 0 | |
| if dates: | |
| current_date = datetime.now().date() | |
| for i, date_str in enumerate(dates): | |
| session_date = datetime.fromisoformat(date_str).date() | |
| expected_date = current_date - timedelta(days=i) | |
| if session_date == expected_date: | |
| streak_days += 1 | |
| else: | |
| break | |
| return { | |
| 'total_sessions': total_sessions, | |
| 'total_focus_time': int(total_focus_time), | |
| 'avg_focus_score': round(avg_focus_score, 3), | |
| 'streak_days': streak_days | |
| } | |
| async def get_available_models(): | |
| """Return model names, statuses, and which is currently active.""" | |
| statuses = {} | |
| errors = {} | |
| available = [] | |
| for name, p in pipelines.items(): | |
| if name == "l2cs": | |
| if p is not None: | |
| statuses[name] = "ready" | |
| available.append(name) | |
| elif is_l2cs_weights_available(): | |
| statuses[name] = "lazy" | |
| available.append(name) | |
| elif _l2cs_error: | |
| statuses[name] = "error" | |
| errors[name] = _l2cs_error | |
| else: | |
| statuses[name] = "unavailable" | |
| elif p is not None: | |
| statuses[name] = "ready" | |
| available.append(name) | |
| else: | |
| statuses[name] = "unavailable" | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT model_name FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| current = row[0] if row else "mlp" | |
| if current not in available and available: | |
| current = available[0] | |
| l2cs_boost_available = ( | |
| statuses.get("l2cs") in ("ready", "lazy") and current != "l2cs" | |
| ) | |
| return { | |
| "available": available, | |
| "current": current, | |
| "statuses": statuses, | |
| "errors": errors, | |
| "l2cs_boost": _l2cs_boost_enabled, | |
| "l2cs_boost_available": l2cs_boost_available, | |
| } | |
| async def l2cs_status(): | |
| """L2CS-specific status: weights available, loaded, and calibration info.""" | |
| loaded = pipelines.get("l2cs") is not None | |
| return { | |
| "weights_available": is_l2cs_weights_available(), | |
| "loaded": loaded, | |
| "error": _l2cs_error, | |
| } | |
| async def get_mesh_topology(): | |
| """Return tessellation edge pairs for client-side face mesh drawing (cached by client).""" | |
| return {"tessellation": _TESSELATION_CONNS} | |
| async def health_check(): | |
| available = [name for name, p in pipelines.items() if p is not None] | |
| return {"status": "healthy", "models_loaded": available, "database": os.path.exists(db_path)} | |
| # ================ STATIC FILES (SPA SUPPORT) ================ | |
| # Resolve frontend dir from this file so it works regardless of cwd. | |
| # Prefer a built `dist/` app when present, otherwise fall back to `static/`. | |
| _BASE_DIR = Path(__file__).resolve().parent | |
| _DIST_DIR = _BASE_DIR / "dist" | |
| _STATIC_DIR = _BASE_DIR / "static" | |
| _FRONTEND_DIR = _DIST_DIR if (_DIST_DIR / "index.html").is_file() else _STATIC_DIR | |
| _ASSETS_DIR = _FRONTEND_DIR / "assets" | |
| # 1. Mount the assets folder (JS/CSS) first so /assets/* is never caught by catch-all | |
| if _ASSETS_DIR.is_dir(): | |
| app.mount("/assets", StaticFiles(directory=str(_ASSETS_DIR)), name="assets") | |
| # 2. Catch-all for SPA: serve index.html for app routes, never for /assets (would break JS MIME type) | |
| async def serve_react_app(full_path: str, request: Request): | |
| if full_path.startswith("api") or full_path.startswith("ws"): | |
| raise HTTPException(status_code=404, detail="Not Found") | |
| # Don't serve HTML for asset paths; let them 404 so we don't break module script loading | |
| if full_path.startswith("assets") or full_path.startswith("assets/"): | |
| raise HTTPException(status_code=404, detail="Not Found") | |
| file_path = _FRONTEND_DIR / full_path | |
| if full_path and file_path.is_file(): | |
| return FileResponse(str(file_path)) | |
| index_path = _FRONTEND_DIR / "index.html" | |
| if index_path.is_file(): | |
| return FileResponse(str(index_path)) | |
| return {"message": "React app not found. Please run 'npm run build' and copy dist to static if needed."} | |