Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Any | |
| import base64 | |
| import cv2 | |
| import numpy as np | |
| import aiosqlite | |
| import json | |
| from datetime import datetime, timedelta | |
| import math | |
| import os | |
| from pathlib import Path | |
| from typing import Callable | |
| import asyncio | |
| from aiortc import RTCPeerConnection, RTCSessionDescription, VideoStreamTrack | |
| from av import VideoFrame | |
| from ui.pipeline import MLPPipeline | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Focus Guard API") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables | |
| db_path = "focus_guard.db" | |
| pcs = set() | |
| async def _wait_for_ice_gathering(pc: RTCPeerConnection): | |
| if pc.iceGatheringState == "complete": | |
| return | |
| done = asyncio.Event() | |
| def _on_state_change(): | |
| if pc.iceGatheringState == "complete": | |
| done.set() | |
| await done.wait() | |
| # ================ DATABASE MODELS ================ | |
| async def init_database(): | |
| """Initialize SQLite database with required tables""" | |
| async with aiosqlite.connect(db_path) as db: | |
| # FocusSessions table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS focus_sessions ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| start_time TIMESTAMP NOT NULL, | |
| end_time TIMESTAMP, | |
| duration_seconds INTEGER DEFAULT 0, | |
| focus_score REAL DEFAULT 0.0, | |
| total_frames INTEGER DEFAULT 0, | |
| focused_frames INTEGER DEFAULT 0, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| # FocusEvents table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS focus_events ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id INTEGER NOT NULL, | |
| timestamp TIMESTAMP NOT NULL, | |
| is_focused BOOLEAN NOT NULL, | |
| confidence REAL NOT NULL, | |
| detection_data TEXT, | |
| FOREIGN KEY (session_id) REFERENCES focus_sessions (id) | |
| ) | |
| """) | |
| # UserSettings table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS user_settings ( | |
| id INTEGER PRIMARY KEY CHECK (id = 1), | |
| sensitivity INTEGER DEFAULT 6, | |
| notification_enabled BOOLEAN DEFAULT 1, | |
| notification_threshold INTEGER DEFAULT 30, | |
| frame_rate INTEGER DEFAULT 30, | |
| model_name TEXT DEFAULT 'yolov8n.pt' | |
| ) | |
| """) | |
| # Insert default settings if not exists | |
| await db.execute(""" | |
| INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name) | |
| VALUES (1, 6, 1, 30, 30, 'yolov8n.pt') | |
| """) | |
| await db.commit() | |
| # ================ PYDANTIC MODELS ================ | |
| class SessionCreate(BaseModel): | |
| pass | |
| class SessionEnd(BaseModel): | |
| session_id: int | |
| class SettingsUpdate(BaseModel): | |
| sensitivity: Optional[int] = None | |
| notification_enabled: Optional[bool] = None | |
| notification_threshold: Optional[int] = None | |
| frame_rate: Optional[int] = None | |
| class VideoTransformTrack(VideoStreamTrack): | |
| def __init__(self, track, session_id: int, get_channel: Callable[[], Any]): | |
| super().__init__() | |
| self.track = track | |
| self.session_id = session_id | |
| self.get_channel = get_channel | |
| self.last_inference_time = 0 | |
| self.min_inference_interval = 1 / 60 | |
| self.last_frame = None | |
| async def recv(self): | |
| frame = await self.track.recv() | |
| img = frame.to_ndarray(format="bgr24") | |
| if img is None: | |
| return frame | |
| # Normalize size for inference/drawing | |
| img = cv2.resize(img, (640, 480)) | |
| now = datetime.now().timestamp() | |
| do_infer = (now - self.last_inference_time) >= self.min_inference_interval | |
| if do_infer and mlp_pipeline is not None: | |
| self.last_inference_time = now | |
| out = mlp_pipeline.process_frame(img) | |
| is_focused = out["is_focused"] | |
| confidence = out["mlp_prob"] | |
| metadata = {"s_face": out["s_face"], "s_eye": out["s_eye"], "mar": out["mar"]} | |
| detections = [] | |
| status_text = "FOCUSED" if is_focused else "NOT FOCUSED" | |
| color = (0, 255, 0) if is_focused else (0, 0, 255) | |
| cv2.putText(img, status_text, (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) | |
| cv2.putText(img, f"Confidence: {confidence * 100:.1f}%", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) | |
| if self.session_id: | |
| await store_focus_event(self.session_id, is_focused, confidence, metadata) | |
| channel = self.get_channel() | |
| if channel and channel.readyState == "open": | |
| try: | |
| channel.send(json.dumps({"type": "detection", "focused": is_focused, "confidence": round(confidence, 3), "detections": detections})) | |
| except Exception: | |
| pass | |
| self.last_frame = img | |
| elif self.last_frame is not None: | |
| img = self.last_frame | |
| new_frame = VideoFrame.from_ndarray(img, format="bgr24") | |
| new_frame.pts = frame.pts | |
| new_frame.time_base = frame.time_base | |
| return new_frame | |
| # ================ DATABASE OPERATIONS ================ | |
| async def create_session(): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute( | |
| "INSERT INTO focus_sessions (start_time) VALUES (?)", | |
| (datetime.now().isoformat(),) | |
| ) | |
| await db.commit() | |
| return cursor.lastrowid | |
| async def end_session(session_id: int): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute( | |
| "SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?", | |
| (session_id,) | |
| ) | |
| row = await cursor.fetchone() | |
| if not row: | |
| return None | |
| start_time_str, total_frames, focused_frames = row | |
| start_time = datetime.fromisoformat(start_time_str) | |
| end_time = datetime.now() | |
| duration = (end_time - start_time).total_seconds() | |
| focus_score = focused_frames / total_frames if total_frames > 0 else 0.0 | |
| await db.execute(""" | |
| UPDATE focus_sessions | |
| SET end_time = ?, duration_seconds = ?, focus_score = ? | |
| WHERE id = ? | |
| """, (end_time.isoformat(), int(duration), focus_score, session_id)) | |
| await db.commit() | |
| return { | |
| 'session_id': session_id, | |
| 'start_time': start_time_str, | |
| 'end_time': end_time.isoformat(), | |
| 'duration_seconds': int(duration), | |
| 'focus_score': round(focus_score, 3), | |
| 'total_frames': total_frames, | |
| 'focused_frames': focused_frames | |
| } | |
| async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict): | |
| async with aiosqlite.connect(db_path) as db: | |
| await db.execute(""" | |
| INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata))) | |
| await db.execute(f""" | |
| UPDATE focus_sessions | |
| SET total_frames = total_frames + 1, | |
| focused_frames = focused_frames + {1 if is_focused else 0} | |
| WHERE id = ? | |
| """, (session_id,)) | |
| await db.commit() | |
| # ================ STARTUP/SHUTDOWN ================ | |
| mlp_pipeline = None | |
| async def startup_event(): | |
| global mlp_pipeline | |
| print(" Starting Focus Guard API...") | |
| await init_database() | |
| print("[OK] Database initialized") | |
| mlp_pipeline = MLPPipeline() | |
| print("[OK] MLPPipeline loaded") | |
| async def shutdown_event(): | |
| print(" Shutting down Focus Guard API...") | |
| # ================ WEBRTC SIGNALING ================ | |
| async def webrtc_offer(offer: dict): | |
| try: | |
| print(f"Received WebRTC offer") | |
| pc = RTCPeerConnection() | |
| pcs.add(pc) | |
| session_id = await create_session() | |
| print(f"Created session: {session_id}") | |
| channel_ref = {"channel": None} | |
| def on_datachannel(channel): | |
| print(f"Data channel opened") | |
| channel_ref["channel"] = channel | |
| def on_track(track): | |
| print(f"Received track: {track.kind}") | |
| if track.kind == "video": | |
| local_track = VideoTransformTrack(track, session_id, lambda: channel_ref["channel"]) | |
| pc.addTrack(local_track) | |
| print(f"Video track added") | |
| async def on_ended(): | |
| print(f"Track ended") | |
| async def on_connectionstatechange(): | |
| print(f"Connection state changed: {pc.connectionState}") | |
| if pc.connectionState in ("failed", "closed", "disconnected"): | |
| try: | |
| await end_session(session_id) | |
| except Exception as e: | |
| print(f"⚠Error ending session: {e}") | |
| pcs.discard(pc) | |
| await pc.close() | |
| await pc.setRemoteDescription(RTCSessionDescription(sdp=offer["sdp"], type=offer["type"])) | |
| print(f"Remote description set") | |
| answer = await pc.createAnswer() | |
| await pc.setLocalDescription(answer) | |
| print(f"Answer created") | |
| await _wait_for_ice_gathering(pc) | |
| print(f"ICE gathering complete") | |
| return {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type, "session_id": session_id} | |
| except Exception as e: | |
| print(f"WebRTC offer error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"WebRTC error: {str(e)}") | |
| # ================ WEBSOCKET ================ | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| session_id = None | |
| frame_count = 0 | |
| last_inference_time = 0 | |
| min_inference_interval = 1 / 60 | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT sensitivity FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| sensitivity = row[0] if row else 6 | |
| while True: | |
| data = await websocket.receive_json() | |
| if data['type'] == 'frame': | |
| from time import time | |
| current_time = time() | |
| if current_time - last_inference_time < min_inference_interval: | |
| await websocket.send_json({'type': 'ack', 'frame_count': frame_count}) | |
| continue | |
| last_inference_time = current_time | |
| try: | |
| img_data = base64.b64decode(data['image']) | |
| nparr = np.frombuffer(img_data, np.uint8) | |
| frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if frame is None: continue | |
| frame = cv2.resize(frame, (640, 480)) | |
| if mlp_pipeline is not None: | |
| out = mlp_pipeline.process_frame(frame) | |
| is_focused = out["is_focused"] | |
| confidence = out["mlp_prob"] | |
| metadata = { | |
| "s_face": out["s_face"], | |
| "s_eye": out["s_eye"], | |
| "mar": out["mar"] | |
| } | |
| else: | |
| is_focused = False | |
| confidence = 0.0 | |
| metadata = {} | |
| detections = [] | |
| if session_id: | |
| await store_focus_event(session_id, is_focused, confidence, metadata) | |
| await websocket.send_json({ | |
| 'type': 'detection', | |
| 'focused': is_focused, | |
| 'confidence': round(confidence, 3), | |
| 'detections': detections, | |
| 'frame_count': frame_count | |
| }) | |
| frame_count += 1 | |
| except Exception as e: | |
| print(f"Error processing frame: {e}") | |
| await websocket.send_json({'type': 'error', 'message': str(e)}) | |
| elif data['type'] == 'start_session': | |
| session_id = await create_session() | |
| await websocket.send_json({'type': 'session_started', 'session_id': session_id}) | |
| elif data['type'] == 'end_session': | |
| if session_id: | |
| print(f"Ending session {session_id}...") | |
| summary = await end_session(session_id) | |
| print(f"Session summary: {summary}") | |
| if summary: | |
| await websocket.send_json({'type': 'session_ended', 'summary': summary}) | |
| print("Session ended message sent") | |
| else: | |
| print("Warning: No summary returned") | |
| session_id = None | |
| else: | |
| print("Warning: end_session called but no active session_id") | |
| except WebSocketDisconnect: | |
| if session_id: await end_session(session_id) | |
| except Exception as e: | |
| if websocket.client_state.value == 1: await websocket.close() | |
| # ================ API ENDPOINTS ================ | |
| async def api_start_session(): | |
| session_id = await create_session() | |
| return {"session_id": session_id} | |
| async def api_end_session(data: SessionEnd): | |
| summary = await end_session(data.session_id) | |
| if not summary: raise HTTPException(status_code=404, detail="Session not found") | |
| return summary | |
| async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| # NEW: If importing/exporting all, remove limit if special flag or high limit | |
| # For simplicity: if limit is -1, return all | |
| limit_clause = "LIMIT ? OFFSET ?" | |
| params = [] | |
| base_query = "SELECT * FROM focus_sessions" | |
| where_clause = "" | |
| if filter == "today": | |
| date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "week": | |
| date_filter = datetime.now() - timedelta(days=7) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "month": | |
| date_filter = datetime.now() - timedelta(days=30) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "all": | |
| # Just ensure we only get completed sessions or all sessions | |
| where_clause = " WHERE end_time IS NOT NULL" | |
| query = f"{base_query}{where_clause} ORDER BY start_time DESC" | |
| # Handle Limit for Exports | |
| if limit == -1: | |
| # No limit clause for export | |
| pass | |
| else: | |
| query += f" {limit_clause}" | |
| params.extend([limit, offset]) | |
| cursor = await db.execute(query, tuple(params)) | |
| rows = await cursor.fetchall() | |
| return [dict(row) for row in rows] | |
| # --- NEW: Import Endpoint --- | |
| async def import_sessions(sessions: List[dict]): | |
| count = 0 | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| for session in sessions: | |
| # Use .get() to handle potential missing fields from older versions or edits | |
| await db.execute(""" | |
| INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at) | |
| VALUES (?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| session.get('start_time'), | |
| session.get('end_time'), | |
| session.get('duration_seconds', 0), | |
| session.get('focus_score', 0.0), | |
| session.get('total_frames', 0), | |
| session.get('focused_frames', 0), | |
| session.get('created_at', session.get('start_time')) | |
| )) | |
| count += 1 | |
| await db.commit() | |
| return {"status": "success", "count": count} | |
| except Exception as e: | |
| print(f"Import Error: {e}") | |
| return {"status": "error", "message": str(e)} | |
| # --- NEW: Clear History Endpoint --- | |
| async def clear_history(): | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| # Delete events first (foreign key good practice) | |
| await db.execute("DELETE FROM focus_events") | |
| await db.execute("DELETE FROM focus_sessions") | |
| await db.commit() | |
| return {"status": "success", "message": "History cleared"} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_session(session_id: int): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,)) | |
| row = await cursor.fetchone() | |
| if not row: raise HTTPException(status_code=404, detail="Session not found") | |
| session = dict(row) | |
| cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,)) | |
| events = [dict(r) for r in await cursor.fetchall()] | |
| session['events'] = events | |
| return session | |
| async def get_settings(): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| if row: return dict(row) | |
| else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'yolov8n.pt'} | |
| async def update_settings(settings: SettingsUpdate): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1") | |
| exists = await cursor.fetchone() | |
| if not exists: | |
| await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)") | |
| await db.commit() | |
| updates = [] | |
| params = [] | |
| if settings.sensitivity is not None: | |
| updates.append("sensitivity = ?") | |
| params.append(max(1, min(10, settings.sensitivity))) | |
| if settings.notification_enabled is not None: | |
| updates.append("notification_enabled = ?") | |
| params.append(settings.notification_enabled) | |
| if settings.notification_threshold is not None: | |
| updates.append("notification_threshold = ?") | |
| params.append(max(5, min(300, settings.notification_threshold))) | |
| if settings.frame_rate is not None: | |
| updates.append("frame_rate = ?") | |
| params.append(max(5, min(60, settings.frame_rate))) | |
| if updates: | |
| query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1" | |
| await db.execute(query, params) | |
| await db.commit() | |
| return {"status": "success", "updated": len(updates) > 0} | |
| async def get_stats_summary(): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| total_sessions = (await cursor.fetchone())[0] | |
| cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| total_focus_time = (await cursor.fetchone())[0] or 0 | |
| cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| avg_focus_score = (await cursor.fetchone())[0] or 0.0 | |
| cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC") | |
| dates = [row[0] for row in await cursor.fetchall()] | |
| streak_days = 0 | |
| if dates: | |
| current_date = datetime.now().date() | |
| for i, date_str in enumerate(dates): | |
| session_date = datetime.fromisoformat(date_str).date() | |
| expected_date = current_date - timedelta(days=i) | |
| if session_date == expected_date: streak_days += 1 | |
| else: break | |
| return { | |
| 'total_sessions': total_sessions, | |
| 'total_focus_time': int(total_focus_time), | |
| 'avg_focus_score': round(avg_focus_score, 3), | |
| 'streak_days': streak_days | |
| } | |
| async def health_check(): | |
| return {"status": "healthy", "model_loaded": mlp_pipeline is not None, "database": os.path.exists(db_path)} | |
| # ================ STATIC FILES (SPA SUPPORT) ================ | |
| FRONTEND_DIR = "dist" if os.path.exists("dist/index.html") else "static" | |
| assets_path = os.path.join(FRONTEND_DIR, "assets") | |
| if os.path.exists(assets_path): | |
| app.mount("/assets", StaticFiles(directory=assets_path), name="assets") | |
| async def serve_react_app(full_path: str, request: Request): | |
| if full_path.startswith("api") or full_path.startswith("ws"): | |
| raise HTTPException(status_code=404, detail="Not Found") | |
| file_path = os.path.join(FRONTEND_DIR, full_path) | |
| if os.path.isfile(file_path): | |
| return FileResponse(file_path) | |
| index_path = os.path.join(FRONTEND_DIR, "index.html") | |
| if os.path.exists(index_path): | |
| return FileResponse(index_path) | |
| else: | |
| return {"message": "React app not found. Please run npm run build."} |