Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List | |
| import base64 | |
| import cv2 | |
| import numpy as np | |
| import aiosqlite | |
| import json | |
| from datetime import datetime, timedelta | |
| import math | |
| import os | |
| from pathlib import Path | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Focus Guard API") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables | |
| model = None | |
| db_path = "focus_guard.db" | |
| # ================ DATABASE MODELS ================ | |
| async def init_database(): | |
| """Initialize SQLite database with required tables""" | |
| async with aiosqlite.connect(db_path) as db: | |
| # FocusSessions table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS focus_sessions ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| start_time TIMESTAMP NOT NULL, | |
| end_time TIMESTAMP, | |
| duration_seconds INTEGER DEFAULT 0, | |
| focus_score REAL DEFAULT 0.0, | |
| total_frames INTEGER DEFAULT 0, | |
| focused_frames INTEGER DEFAULT 0, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| # FocusEvents table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS focus_events ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id INTEGER NOT NULL, | |
| timestamp TIMESTAMP NOT NULL, | |
| is_focused BOOLEAN NOT NULL, | |
| confidence REAL NOT NULL, | |
| detection_data TEXT, | |
| FOREIGN KEY (session_id) REFERENCES focus_sessions (id) | |
| ) | |
| """) | |
| # UserSettings table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS user_settings ( | |
| id INTEGER PRIMARY KEY CHECK (id = 1), | |
| sensitivity INTEGER DEFAULT 6, | |
| notification_enabled BOOLEAN DEFAULT 1, | |
| notification_threshold INTEGER DEFAULT 30, | |
| frame_rate INTEGER DEFAULT 30, | |
| model_name TEXT DEFAULT 'yolov8n.pt' | |
| ) | |
| """) | |
| # Insert default settings if not exists | |
| await db.execute(""" | |
| INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name) | |
| VALUES (1, 6, 1, 30, 30, 'yolov8n.pt') | |
| """) | |
| await db.commit() | |
| # ================ PYDANTIC MODELS ================ | |
| class SessionCreate(BaseModel): | |
| pass | |
| class SessionEnd(BaseModel): | |
| session_id: int | |
| class SettingsUpdate(BaseModel): | |
| sensitivity: Optional[int] = None | |
| notification_enabled: Optional[bool] = None | |
| notification_threshold: Optional[int] = None | |
| frame_rate: Optional[int] = None | |
| # ================ YOLO MODEL LOADING ================ | |
| def load_yolo_model(): | |
| """Load YOLOv8 model with optimizations for CPU""" | |
| global model | |
| try: | |
| # Fix PyTorch 2.6+ weights_only issue | |
| # Set environment variable to allow loading YOLO weights | |
| os.environ['TORCH_LOAD_WEIGHTS_ONLY'] = '0' | |
| import torch | |
| if hasattr(torch.serialization, 'add_safe_globals'): | |
| # PyTorch 2.6+ compatibility - add required classes | |
| try: | |
| from ultralytics.nn.tasks import DetectionModel | |
| import torch.nn as nn | |
| torch.serialization.add_safe_globals([ | |
| DetectionModel, | |
| nn.modules.container.Sequential, | |
| ]) | |
| except Exception as e: | |
| print(f" Safe globals setup: {e}") | |
| from ultralytics import YOLO | |
| model_path = "models/yolov8n.pt" | |
| # Check if model file exists, if not use yolov8n (will download) | |
| if not os.path.exists(model_path): | |
| print(f"Model file {model_path} not found, downloading yolov8n.pt...") | |
| model_path = "yolov8n.pt" # This will trigger auto-download | |
| # Load model (ultralytics handles weights_only internally in newer versions) | |
| model = YOLO(model_path) | |
| # Optimize for CPU | |
| try: | |
| model.fuse() # Fuse Conv2d + BatchNorm layers | |
| print("[OK] Model layers fused for optimization") | |
| except Exception as e: | |
| print(f" Model fusion skipped: {e}") | |
| # Warm up model with dummy inference | |
| print("Warming up model...") | |
| dummy_img = np.zeros((416, 416, 3), dtype=np.uint8) | |
| model(dummy_img, imgsz=416, conf=0.4, iou=0.45, max_det=5, classes=[0], verbose=False) | |
| print("[OK] YOLOv8 model loaded and warmed up successfully") | |
| return True | |
| except Exception as e: | |
| print(f"[ERROR] Failed to load YOLOv8 model: {e}") | |
| print(" The app will run without detection features") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| # ================ FOCUS DETECTION ALGORITHM ================ | |
| def is_user_focused(detections, frame_shape, sensitivity=6): | |
| """ | |
| Determine if user is focused based on YOLOv8 detections | |
| Simple logic: Detects person with confidence >= 80% (0.8) | |
| Args: | |
| detections: List of detection dictionaries | |
| frame_shape: Tuple of (height, width, channels) | |
| sensitivity: Integer 1-10, higher = stricter criteria (adjusts confidence threshold) | |
| Returns: | |
| Tuple of (is_focused: bool, confidence: float, metadata: dict) | |
| """ | |
| # Filter person detections (class 0 in COCO dataset) | |
| persons = [d for d in detections if d.get('class') == 0] | |
| if not persons: | |
| return False, 0.0, {'reason': 'no_person', 'count': 0} | |
| # Find person with highest confidence | |
| best_person = max(persons, key=lambda x: x.get('confidence', 0)) | |
| bbox = best_person['bbox'] # [x1, y1, x2, y2] | |
| conf = best_person['confidence'] | |
| # Calculate confidence threshold based on sensitivity | |
| # sensitivity 6 (default) = 0.8 threshold | |
| # sensitivity 1 (lowest) = 0.5 threshold | |
| # sensitivity 10 (highest) = 0.9 threshold | |
| base_threshold = 0.8 | |
| sensitivity_adjustment = (sensitivity - 6) * 0.02 # ±0.08 range | |
| confidence_threshold = base_threshold + sensitivity_adjustment | |
| confidence_threshold = max(0.5, min(0.95, confidence_threshold)) # Clamp to 0.5-0.95 | |
| # Simple focus determination: confidence >= threshold | |
| is_focused = conf >= confidence_threshold | |
| # Optional: Check if person is somewhat centered (loose requirement) | |
| h, w = frame_shape[0], frame_shape[1] | |
| bbox_center_x = (bbox[0] + bbox[2]) / 2 | |
| bbox_center_y = (bbox[1] + bbox[3]) / 2 | |
| # Normalize to 0-1 range | |
| center_x_norm = bbox_center_x / w if w > 0 else 0.5 | |
| center_y_norm = bbox_center_y / h if h > 0 else 0.5 | |
| # Check if person is in frame (not at extreme edges) | |
| # Allow very loose centering: 20%-80% horizontal, 15%-85% vertical | |
| in_frame = (0.2 <= center_x_norm <= 0.8) and (0.15 <= center_y_norm <= 0.85) | |
| # Reduce focus score if person is at extreme edge | |
| position_factor = 1.0 if in_frame else 0.7 | |
| final_score = conf * position_factor | |
| # Also reduce if multiple persons detected | |
| if len(persons) > 1: | |
| final_score *= 0.9 | |
| reason = f"person_detected_multi_{len(persons)}" | |
| else: | |
| reason = "person_detected" if is_focused else "low_confidence" | |
| metadata = { | |
| 'bbox': bbox, | |
| 'detection_confidence': round(conf, 3), | |
| 'confidence_threshold': round(confidence_threshold, 3), | |
| 'center_position': [round(center_x_norm, 3), round(center_y_norm, 3)], | |
| 'in_frame': in_frame, | |
| 'person_count': len(persons), | |
| 'reason': reason | |
| } | |
| return is_focused and in_frame, final_score, metadata | |
| def parse_yolo_results(results): | |
| """Parse YOLOv8 results into a list of detections""" | |
| detections = [] | |
| if results and len(results) > 0: | |
| result = results[0] | |
| boxes = result.boxes | |
| if boxes is not None and len(boxes) > 0: | |
| for box in boxes: | |
| # Get box coordinates | |
| xyxy = box.xyxy[0].cpu().numpy() | |
| conf = float(box.conf[0].cpu().numpy()) | |
| cls = int(box.cls[0].cpu().numpy()) | |
| detection = { | |
| 'bbox': [float(x) for x in xyxy], | |
| 'confidence': conf, | |
| 'class': cls, | |
| 'class_name': result.names[cls] if hasattr(result, 'names') else str(cls) | |
| } | |
| detections.append(detection) | |
| return detections | |
| # ================ DATABASE OPERATIONS ================ | |
| async def create_session(): | |
| """Create a new focus session""" | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute( | |
| "INSERT INTO focus_sessions (start_time) VALUES (?)", | |
| (datetime.now().isoformat(),) | |
| ) | |
| await db.commit() | |
| return cursor.lastrowid | |
| async def end_session(session_id: int): | |
| """End a focus session and calculate statistics""" | |
| async with aiosqlite.connect(db_path) as db: | |
| # Get session data | |
| cursor = await db.execute( | |
| "SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?", | |
| (session_id,) | |
| ) | |
| row = await cursor.fetchone() | |
| if not row: | |
| return None | |
| start_time_str, total_frames, focused_frames = row | |
| start_time = datetime.fromisoformat(start_time_str) | |
| end_time = datetime.now() | |
| duration = (end_time - start_time).total_seconds() | |
| # Calculate focus score | |
| focus_score = focused_frames / total_frames if total_frames > 0 else 0.0 | |
| # Update session | |
| await db.execute(""" | |
| UPDATE focus_sessions | |
| SET end_time = ?, duration_seconds = ?, focus_score = ? | |
| WHERE id = ? | |
| """, (end_time.isoformat(), int(duration), focus_score, session_id)) | |
| await db.commit() | |
| return { | |
| 'session_id': session_id, | |
| 'start_time': start_time_str, | |
| 'end_time': end_time.isoformat(), | |
| 'duration_seconds': int(duration), | |
| 'focus_score': round(focus_score, 3), | |
| 'total_frames': total_frames, | |
| 'focused_frames': focused_frames | |
| } | |
| async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict): | |
| """Store a focus detection event""" | |
| async with aiosqlite.connect(db_path) as db: | |
| await db.execute(""" | |
| INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata))) | |
| # Update session frame counts | |
| await db.execute(f""" | |
| UPDATE focus_sessions | |
| SET total_frames = total_frames + 1, | |
| focused_frames = focused_frames + {1 if is_focused else 0} | |
| WHERE id = ? | |
| """, (session_id,)) | |
| await db.commit() | |
| # ================ STARTUP/SHUTDOWN EVENTS ================ | |
| async def startup_event(): | |
| """Initialize database and load model on startup""" | |
| print(" Starting Focus Guard API...") | |
| await init_database() | |
| print("[OK] Database initialized") | |
| load_yolo_model() | |
| async def shutdown_event(): | |
| """Cleanup on shutdown""" | |
| print(" Shutting down Focus Guard API...") | |
| # ================ STATIC FILES ================ | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| async def read_index(): | |
| return FileResponse("static/index.html") | |
| # ================ WEBSOCKET ENDPOINT ================ | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| session_id = None | |
| frame_count = 0 | |
| last_inference_time = 0 | |
| min_inference_interval = 0.1 # Max 10 FPS server-side | |
| try: | |
| # Get user settings | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT sensitivity FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| sensitivity = row[0] if row else 6 | |
| while True: | |
| # Receive data from client | |
| data = await websocket.receive_json() | |
| if data['type'] == 'frame': | |
| from time import time | |
| current_time = time() | |
| # Rate limiting | |
| if current_time - last_inference_time < min_inference_interval: | |
| # Skip inference, just acknowledge | |
| await websocket.send_json({ | |
| 'type': 'ack', | |
| 'frame_count': frame_count | |
| }) | |
| continue | |
| last_inference_time = current_time | |
| try: | |
| # Decode base64 image | |
| img_data = base64.b64decode(data['image']) | |
| nparr = np.frombuffer(img_data, np.uint8) | |
| frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if frame is None: | |
| continue | |
| # Resize for faster inference | |
| frame = cv2.resize(frame, (640, 480)) | |
| # YOLOv8 inference | |
| if model is not None: | |
| results = model( | |
| frame, | |
| imgsz=416, | |
| conf=0.4, | |
| iou=0.45, | |
| max_det=5, | |
| classes=[0], # Only person class | |
| verbose=False | |
| ) | |
| detections = parse_yolo_results(results) | |
| else: | |
| # Fallback if model not loaded | |
| detections = [] | |
| # Determine focus status | |
| is_focused, confidence, metadata = is_user_focused( | |
| detections, frame.shape, sensitivity | |
| ) | |
| # Store event in database if session active | |
| if session_id: | |
| await store_focus_event(session_id, is_focused, confidence, metadata) | |
| # Send results back to client | |
| response = { | |
| 'type': 'detection', | |
| 'focused': is_focused, | |
| 'confidence': round(confidence, 3), | |
| 'detections': detections, | |
| 'frame_count': frame_count | |
| } | |
| await websocket.send_json(response) | |
| frame_count += 1 | |
| except Exception as e: | |
| print(f"Error processing frame: {e}") | |
| await websocket.send_json({ | |
| 'type': 'error', | |
| 'message': str(e) | |
| }) | |
| elif data['type'] == 'start_session': | |
| session_id = await create_session() | |
| await websocket.send_json({ | |
| 'type': 'session_started', | |
| 'session_id': session_id | |
| }) | |
| elif data['type'] == 'end_session': | |
| if session_id: | |
| summary = await end_session(session_id) | |
| await websocket.send_json({ | |
| 'type': 'session_ended', | |
| 'summary': summary | |
| }) | |
| session_id = None | |
| except WebSocketDisconnect: | |
| if session_id: | |
| await end_session(session_id) | |
| print(f"WebSocket disconnected (session: {session_id})") | |
| except Exception as e: | |
| print(f"WebSocket error: {e}") | |
| if websocket.client_state.value == 1: # CONNECTED | |
| await websocket.close() | |
| # ================ REST API ENDPOINTS ================ | |
| async def api_start_session(): | |
| """Start a new focus session""" | |
| session_id = await create_session() | |
| return {"session_id": session_id} | |
| async def api_end_session(data: SessionEnd): | |
| """End a focus session""" | |
| summary = await end_session(data.session_id) | |
| if not summary: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return summary | |
| async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0): | |
| """Get focus sessions with optional filtering""" | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| # Build query based on filter | |
| if filter == "today": | |
| date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) | |
| query = "SELECT * FROM focus_sessions WHERE start_time >= ? ORDER BY start_time DESC LIMIT ? OFFSET ?" | |
| params = (date_filter.isoformat(), limit, offset) | |
| elif filter == "week": | |
| date_filter = datetime.now() - timedelta(days=7) | |
| query = "SELECT * FROM focus_sessions WHERE start_time >= ? ORDER BY start_time DESC LIMIT ? OFFSET ?" | |
| params = (date_filter.isoformat(), limit, offset) | |
| elif filter == "month": | |
| date_filter = datetime.now() - timedelta(days=30) | |
| query = "SELECT * FROM focus_sessions WHERE start_time >= ? ORDER BY start_time DESC LIMIT ? OFFSET ?" | |
| params = (date_filter.isoformat(), limit, offset) | |
| else: | |
| query = "SELECT * FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY start_time DESC LIMIT ? OFFSET ?" | |
| params = (limit, offset) | |
| cursor = await db.execute(query, params) | |
| rows = await cursor.fetchall() | |
| sessions = [dict(row) for row in rows] | |
| return sessions | |
| async def get_session(session_id: int): | |
| """Get detailed session information""" | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,)) | |
| row = await cursor.fetchone() | |
| if not row: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| session = dict(row) | |
| # Get events | |
| cursor = await db.execute( | |
| "SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", | |
| (session_id,) | |
| ) | |
| events = [dict(r) for r in await cursor.fetchall()] | |
| session['events'] = events | |
| return session | |
| async def get_settings(): | |
| """Get user settings""" | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| if row: | |
| return dict(row) | |
| else: | |
| return { | |
| 'sensitivity': 6, | |
| 'notification_enabled': True, | |
| 'notification_threshold': 30, | |
| 'frame_rate': 30, | |
| 'model_name': 'yolov8n.pt' | |
| } | |
| async def update_settings(settings: SettingsUpdate): | |
| """Update user settings""" | |
| async with aiosqlite.connect(db_path) as db: | |
| # First ensure the record exists | |
| cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1") | |
| exists = await cursor.fetchone() | |
| if not exists: | |
| # Insert default record if it doesn't exist | |
| await db.execute(""" | |
| INSERT INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name) | |
| VALUES (1, 6, 1, 30, 30, 'yolov8n.pt') | |
| """) | |
| await db.commit() | |
| print("[OK] Created default user_settings record") | |
| # Now update with provided values | |
| updates = [] | |
| params = [] | |
| if settings.sensitivity is not None: | |
| updates.append("sensitivity = ?") | |
| params.append(max(1, min(10, settings.sensitivity))) | |
| if settings.notification_enabled is not None: | |
| updates.append("notification_enabled = ?") | |
| params.append(settings.notification_enabled) | |
| if settings.notification_threshold is not None: | |
| updates.append("notification_threshold = ?") | |
| params.append(max(5, min(300, settings.notification_threshold))) | |
| if settings.frame_rate is not None: | |
| updates.append("frame_rate = ?") | |
| params.append(max(5, min(60, settings.frame_rate))) | |
| if updates: | |
| query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1" | |
| await db.execute(query, params) | |
| await db.commit() | |
| print(f"[OK] Settings updated: {settings.model_dump(exclude_none=True)}") | |
| return {"status": "success", "updated": len(updates) > 0} | |
| async def get_stats_summary(): | |
| """Get overall statistics summary""" | |
| async with aiosqlite.connect(db_path) as db: | |
| # Total sessions | |
| cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| total_sessions = (await cursor.fetchone())[0] | |
| # Total focus time | |
| cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| total_focus_time = (await cursor.fetchone())[0] or 0 | |
| # Average focus score | |
| cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| avg_focus_score = (await cursor.fetchone())[0] or 0.0 | |
| # Streak calculation (consecutive days with sessions) | |
| cursor = await db.execute(""" | |
| SELECT DISTINCT DATE(start_time) as session_date | |
| FROM focus_sessions | |
| WHERE end_time IS NOT NULL | |
| ORDER BY session_date DESC | |
| """) | |
| dates = [row[0] for row in await cursor.fetchall()] | |
| streak_days = 0 | |
| if dates: | |
| current_date = datetime.now().date() | |
| for i, date_str in enumerate(dates): | |
| session_date = datetime.fromisoformat(date_str).date() | |
| expected_date = current_date - timedelta(days=i) | |
| if session_date == expected_date: | |
| streak_days += 1 | |
| else: | |
| break | |
| return { | |
| 'total_sessions': total_sessions, | |
| 'total_focus_time': int(total_focus_time), | |
| 'avg_focus_score': round(avg_focus_score, 3), | |
| 'streak_days': streak_days | |
| } | |
| # ================ HEALTH CHECK ================ | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "database": os.path.exists(db_path) | |
| } | |