Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Any | |
| import base64 | |
| import cv2 | |
| import numpy as np | |
| import aiosqlite | |
| import json | |
| from datetime import datetime, timedelta | |
| import math | |
| import os | |
| from pathlib import Path | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Focus Guard API") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables | |
| model = None | |
| db_path = "focus_guard.db" | |
| # ================ DATABASE MODELS ================ | |
| async def init_database(): | |
| """Initialize SQLite database with required tables""" | |
| async with aiosqlite.connect(db_path) as db: | |
| # FocusSessions table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS focus_sessions ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| start_time TIMESTAMP NOT NULL, | |
| end_time TIMESTAMP, | |
| duration_seconds INTEGER DEFAULT 0, | |
| focus_score REAL DEFAULT 0.0, | |
| total_frames INTEGER DEFAULT 0, | |
| focused_frames INTEGER DEFAULT 0, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| # FocusEvents table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS focus_events ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id INTEGER NOT NULL, | |
| timestamp TIMESTAMP NOT NULL, | |
| is_focused BOOLEAN NOT NULL, | |
| confidence REAL NOT NULL, | |
| detection_data TEXT, | |
| FOREIGN KEY (session_id) REFERENCES focus_sessions (id) | |
| ) | |
| """) | |
| # UserSettings table | |
| await db.execute(""" | |
| CREATE TABLE IF NOT EXISTS user_settings ( | |
| id INTEGER PRIMARY KEY CHECK (id = 1), | |
| sensitivity INTEGER DEFAULT 6, | |
| notification_enabled BOOLEAN DEFAULT 1, | |
| notification_threshold INTEGER DEFAULT 30, | |
| frame_rate INTEGER DEFAULT 30, | |
| model_name TEXT DEFAULT 'yolov8n.pt' | |
| ) | |
| """) | |
| # Insert default settings if not exists | |
| await db.execute(""" | |
| INSERT OR IGNORE INTO user_settings (id, sensitivity, notification_enabled, notification_threshold, frame_rate, model_name) | |
| VALUES (1, 6, 1, 30, 30, 'yolov8n.pt') | |
| """) | |
| await db.commit() | |
| # ================ PYDANTIC MODELS ================ | |
| class SessionCreate(BaseModel): | |
| pass | |
| class SessionEnd(BaseModel): | |
| session_id: int | |
| class SettingsUpdate(BaseModel): | |
| sensitivity: Optional[int] = None | |
| notification_enabled: Optional[bool] = None | |
| notification_threshold: Optional[int] = None | |
| frame_rate: Optional[int] = None | |
| # ================ YOLO MODEL LOADING ================ | |
| def load_yolo_model(): | |
| """Load YOLOv8 model with optimizations for CPU""" | |
| global model | |
| try: | |
| # Fix PyTorch 2.6+ weights_only issue | |
| os.environ['TORCH_LOAD_WEIGHTS_ONLY'] = '0' | |
| import torch | |
| if hasattr(torch.serialization, 'add_safe_globals'): | |
| try: | |
| from ultralytics.nn.tasks import DetectionModel | |
| import torch.nn as nn | |
| torch.serialization.add_safe_globals([ | |
| DetectionModel, | |
| nn.modules.container.Sequential, | |
| ]) | |
| except Exception as e: | |
| print(f" Safe globals setup: {e}") | |
| from ultralytics import YOLO | |
| model_path = "models/yolov8n.pt" | |
| if not os.path.exists(model_path): | |
| print(f"Model file {model_path} not found, downloading yolov8n.pt...") | |
| model_path = "yolov8n.pt" | |
| model = YOLO(model_path) | |
| try: | |
| model.fuse() | |
| print("[OK] Model layers fused for optimization") | |
| except Exception as e: | |
| print(f" Model fusion skipped: {e}") | |
| # Warm up | |
| print("Warming up model...") | |
| dummy_img = np.zeros((416, 416, 3), dtype=np.uint8) | |
| model(dummy_img, imgsz=416, conf=0.4, iou=0.45, max_det=5, classes=[0], verbose=False) | |
| print("[OK] YOLOv8 model loaded and warmed up successfully") | |
| return True | |
| except Exception as e: | |
| print(f"[ERROR] Failed to load YOLOv8 model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| # ================ FOCUS DETECTION ALGORITHM ================ | |
| def is_user_focused(detections, frame_shape, sensitivity=6): | |
| persons = [d for d in detections if d.get('class') == 0] | |
| if not persons: | |
| return False, 0.0, {'reason': 'no_person', 'count': 0} | |
| best_person = max(persons, key=lambda x: x.get('confidence', 0)) | |
| bbox = best_person['bbox'] | |
| conf = best_person['confidence'] | |
| base_threshold = 0.8 | |
| sensitivity_adjustment = (sensitivity - 6) * 0.02 | |
| confidence_threshold = base_threshold + sensitivity_adjustment | |
| confidence_threshold = max(0.5, min(0.95, confidence_threshold)) | |
| is_focused = conf >= confidence_threshold | |
| h, w = frame_shape[0], frame_shape[1] | |
| bbox_center_x = (bbox[0] + bbox[2]) / 2 | |
| bbox_center_y = (bbox[1] + bbox[3]) / 2 | |
| center_x_norm = bbox_center_x / w if w > 0 else 0.5 | |
| center_y_norm = bbox_center_y / h if h > 0 else 0.5 | |
| in_frame = (0.2 <= center_x_norm <= 0.8) and (0.15 <= center_y_norm <= 0.85) | |
| position_factor = 1.0 if in_frame else 0.7 | |
| final_score = conf * position_factor | |
| if len(persons) > 1: | |
| final_score *= 0.9 | |
| reason = f"person_detected_multi_{len(persons)}" | |
| else: | |
| reason = "person_detected" if is_focused else "low_confidence" | |
| metadata = { | |
| 'bbox': bbox, | |
| 'detection_confidence': round(conf, 3), | |
| 'confidence_threshold': round(confidence_threshold, 3), | |
| 'center_position': [round(center_x_norm, 3), round(center_y_norm, 3)], | |
| 'in_frame': in_frame, | |
| 'person_count': len(persons), | |
| 'reason': reason | |
| } | |
| return is_focused and in_frame, final_score, metadata | |
| def parse_yolo_results(results): | |
| detections = [] | |
| if results and len(results) > 0: | |
| result = results[0] | |
| boxes = result.boxes | |
| if boxes is not None and len(boxes) > 0: | |
| for box in boxes: | |
| xyxy = box.xyxy[0].cpu().numpy() | |
| conf = float(box.conf[0].cpu().numpy()) | |
| cls = int(box.cls[0].cpu().numpy()) | |
| detection = { | |
| 'bbox': [float(x) for x in xyxy], | |
| 'confidence': conf, | |
| 'class': cls, | |
| 'class_name': result.names[cls] if hasattr(result, 'names') else str(cls) | |
| } | |
| detections.append(detection) | |
| return detections | |
| # ================ DATABASE OPERATIONS ================ | |
| async def create_session(): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute( | |
| "INSERT INTO focus_sessions (start_time) VALUES (?)", | |
| (datetime.now().isoformat(),) | |
| ) | |
| await db.commit() | |
| return cursor.lastrowid | |
| async def end_session(session_id: int): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute( | |
| "SELECT start_time, total_frames, focused_frames FROM focus_sessions WHERE id = ?", | |
| (session_id,) | |
| ) | |
| row = await cursor.fetchone() | |
| if not row: | |
| return None | |
| start_time_str, total_frames, focused_frames = row | |
| start_time = datetime.fromisoformat(start_time_str) | |
| end_time = datetime.now() | |
| duration = (end_time - start_time).total_seconds() | |
| focus_score = focused_frames / total_frames if total_frames > 0 else 0.0 | |
| await db.execute(""" | |
| UPDATE focus_sessions | |
| SET end_time = ?, duration_seconds = ?, focus_score = ? | |
| WHERE id = ? | |
| """, (end_time.isoformat(), int(duration), focus_score, session_id)) | |
| await db.commit() | |
| return { | |
| 'session_id': session_id, | |
| 'start_time': start_time_str, | |
| 'end_time': end_time.isoformat(), | |
| 'duration_seconds': int(duration), | |
| 'focus_score': round(focus_score, 3), | |
| 'total_frames': total_frames, | |
| 'focused_frames': focused_frames | |
| } | |
| async def store_focus_event(session_id: int, is_focused: bool, confidence: float, metadata: dict): | |
| async with aiosqlite.connect(db_path) as db: | |
| await db.execute(""" | |
| INSERT INTO focus_events (session_id, timestamp, is_focused, confidence, detection_data) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, (session_id, datetime.now().isoformat(), is_focused, confidence, json.dumps(metadata))) | |
| await db.execute(f""" | |
| UPDATE focus_sessions | |
| SET total_frames = total_frames + 1, | |
| focused_frames = focused_frames + {1 if is_focused else 0} | |
| WHERE id = ? | |
| """, (session_id,)) | |
| await db.commit() | |
| # ================ STARTUP/SHUTDOWN ================ | |
| async def startup_event(): | |
| print(" Starting Focus Guard API...") | |
| await init_database() | |
| print("[OK] Database initialized") | |
| load_yolo_model() | |
| async def shutdown_event(): | |
| print(" Shutting down Focus Guard API...") | |
| # ================ WEBSOCKET ================ | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| session_id = None | |
| frame_count = 0 | |
| last_inference_time = 0 | |
| min_inference_interval = 0.1 | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT sensitivity FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| sensitivity = row[0] if row else 6 | |
| while True: | |
| data = await websocket.receive_json() | |
| if data['type'] == 'frame': | |
| from time import time | |
| current_time = time() | |
| if current_time - last_inference_time < min_inference_interval: | |
| await websocket.send_json({'type': 'ack', 'frame_count': frame_count}) | |
| continue | |
| last_inference_time = current_time | |
| try: | |
| img_data = base64.b64decode(data['image']) | |
| nparr = np.frombuffer(img_data, np.uint8) | |
| frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if frame is None: continue | |
| frame = cv2.resize(frame, (640, 480)) | |
| if model is not None: | |
| results = model(frame, imgsz=416, conf=0.4, iou=0.45, max_det=5, classes=[0], verbose=False) | |
| detections = parse_yolo_results(results) | |
| else: | |
| detections = [] | |
| is_focused, confidence, metadata = is_user_focused(detections, frame.shape, sensitivity) | |
| if session_id: | |
| await store_focus_event(session_id, is_focused, confidence, metadata) | |
| await websocket.send_json({ | |
| 'type': 'detection', | |
| 'focused': is_focused, | |
| 'confidence': round(confidence, 3), | |
| 'detections': detections, | |
| 'frame_count': frame_count | |
| }) | |
| frame_count += 1 | |
| except Exception as e: | |
| print(f"Error processing frame: {e}") | |
| await websocket.send_json({'type': 'error', 'message': str(e)}) | |
| elif data['type'] == 'start_session': | |
| session_id = await create_session() | |
| await websocket.send_json({'type': 'session_started', 'session_id': session_id}) | |
| elif data['type'] == 'end_session': | |
| if session_id: | |
| summary = await end_session(session_id) | |
| await websocket.send_json({'type': 'session_ended', 'summary': summary}) | |
| session_id = None | |
| except WebSocketDisconnect: | |
| if session_id: await end_session(session_id) | |
| except Exception as e: | |
| if websocket.client_state.value == 1: await websocket.close() | |
| # ================ API ENDPOINTS ================ | |
| async def api_start_session(): | |
| session_id = await create_session() | |
| return {"session_id": session_id} | |
| async def api_end_session(data: SessionEnd): | |
| summary = await end_session(data.session_id) | |
| if not summary: raise HTTPException(status_code=404, detail="Session not found") | |
| return summary | |
| async def get_sessions(filter: str = "all", limit: int = 50, offset: int = 0): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| # NEW: If importing/exporting all, remove limit if special flag or high limit | |
| # For simplicity: if limit is -1, return all | |
| limit_clause = "LIMIT ? OFFSET ?" | |
| params = [] | |
| base_query = "SELECT * FROM focus_sessions" | |
| where_clause = "" | |
| if filter == "today": | |
| date_filter = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "week": | |
| date_filter = datetime.now() - timedelta(days=7) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "month": | |
| date_filter = datetime.now() - timedelta(days=30) | |
| where_clause = " WHERE start_time >= ?" | |
| params.append(date_filter.isoformat()) | |
| elif filter == "all": | |
| # Just ensure we only get completed sessions or all sessions | |
| where_clause = " WHERE end_time IS NOT NULL" | |
| query = f"{base_query}{where_clause} ORDER BY start_time DESC" | |
| # Handle Limit for Exports | |
| if limit == -1: | |
| # No limit clause for export | |
| pass | |
| else: | |
| query += f" {limit_clause}" | |
| params.extend([limit, offset]) | |
| cursor = await db.execute(query, tuple(params)) | |
| rows = await cursor.fetchall() | |
| return [dict(row) for row in rows] | |
| # --- NEW: Import Endpoint --- | |
| async def import_sessions(sessions: List[dict]): | |
| count = 0 | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| for session in sessions: | |
| # Use .get() to handle potential missing fields from older versions or edits | |
| await db.execute(""" | |
| INSERT INTO focus_sessions (start_time, end_time, duration_seconds, focus_score, total_frames, focused_frames, created_at) | |
| VALUES (?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| session.get('start_time'), | |
| session.get('end_time'), | |
| session.get('duration_seconds', 0), | |
| session.get('focus_score', 0.0), | |
| session.get('total_frames', 0), | |
| session.get('focused_frames', 0), | |
| session.get('created_at', session.get('start_time')) | |
| )) | |
| count += 1 | |
| await db.commit() | |
| return {"status": "success", "count": count} | |
| except Exception as e: | |
| print(f"Import Error: {e}") | |
| return {"status": "error", "message": str(e)} | |
| # --- NEW: Clear History Endpoint --- | |
| async def clear_history(): | |
| try: | |
| async with aiosqlite.connect(db_path) as db: | |
| # Delete events first (foreign key good practice) | |
| await db.execute("DELETE FROM focus_events") | |
| await db.execute("DELETE FROM focus_sessions") | |
| await db.commit() | |
| return {"status": "success", "message": "History cleared"} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_session(session_id: int): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| cursor = await db.execute("SELECT * FROM focus_sessions WHERE id = ?", (session_id,)) | |
| row = await cursor.fetchone() | |
| if not row: raise HTTPException(status_code=404, detail="Session not found") | |
| session = dict(row) | |
| cursor = await db.execute("SELECT * FROM focus_events WHERE session_id = ? ORDER BY timestamp", (session_id,)) | |
| events = [dict(r) for r in await cursor.fetchall()] | |
| session['events'] = events | |
| return session | |
| async def get_settings(): | |
| async with aiosqlite.connect(db_path) as db: | |
| db.row_factory = aiosqlite.Row | |
| cursor = await db.execute("SELECT * FROM user_settings WHERE id = 1") | |
| row = await cursor.fetchone() | |
| if row: return dict(row) | |
| else: return {'sensitivity': 6, 'notification_enabled': True, 'notification_threshold': 30, 'frame_rate': 30, 'model_name': 'yolov8n.pt'} | |
| async def update_settings(settings: SettingsUpdate): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT id FROM user_settings WHERE id = 1") | |
| exists = await cursor.fetchone() | |
| if not exists: | |
| await db.execute("INSERT INTO user_settings (id, sensitivity) VALUES (1, 6)") | |
| await db.commit() | |
| updates = [] | |
| params = [] | |
| if settings.sensitivity is not None: | |
| updates.append("sensitivity = ?") | |
| params.append(max(1, min(10, settings.sensitivity))) | |
| if settings.notification_enabled is not None: | |
| updates.append("notification_enabled = ?") | |
| params.append(settings.notification_enabled) | |
| if settings.notification_threshold is not None: | |
| updates.append("notification_threshold = ?") | |
| params.append(max(5, min(300, settings.notification_threshold))) | |
| if settings.frame_rate is not None: | |
| updates.append("frame_rate = ?") | |
| params.append(max(5, min(60, settings.frame_rate))) | |
| if updates: | |
| query = f"UPDATE user_settings SET {', '.join(updates)} WHERE id = 1" | |
| await db.execute(query, params) | |
| await db.commit() | |
| return {"status": "success", "updated": len(updates) > 0} | |
| async def get_stats_summary(): | |
| async with aiosqlite.connect(db_path) as db: | |
| cursor = await db.execute("SELECT COUNT(*) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| total_sessions = (await cursor.fetchone())[0] | |
| cursor = await db.execute("SELECT SUM(duration_seconds) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| total_focus_time = (await cursor.fetchone())[0] or 0 | |
| cursor = await db.execute("SELECT AVG(focus_score) FROM focus_sessions WHERE end_time IS NOT NULL") | |
| avg_focus_score = (await cursor.fetchone())[0] or 0.0 | |
| cursor = await db.execute("SELECT DISTINCT DATE(start_time) as session_date FROM focus_sessions WHERE end_time IS NOT NULL ORDER BY session_date DESC") | |
| dates = [row[0] for row in await cursor.fetchall()] | |
| streak_days = 0 | |
| if dates: | |
| current_date = datetime.now().date() | |
| for i, date_str in enumerate(dates): | |
| session_date = datetime.fromisoformat(date_str).date() | |
| expected_date = current_date - timedelta(days=i) | |
| if session_date == expected_date: streak_days += 1 | |
| else: break | |
| return { | |
| 'total_sessions': total_sessions, | |
| 'total_focus_time': int(total_focus_time), | |
| 'avg_focus_score': round(avg_focus_score, 3), | |
| 'streak_days': streak_days | |
| } | |
| async def health_check(): | |
| return {"status": "healthy", "model_loaded": model is not None, "database": os.path.exists(db_path)} | |
| # ================ STATIC FILES (SPA SUPPORT) ================ | |
| # 1. Mount the assets folder (JS/CSS built by Vite/React) | |
| if os.path.exists("static/assets"): | |
| app.mount("/assets", StaticFiles(directory="static/assets"), name="assets") | |
| # 2. Catch-all route for SPA (React Router) | |
| # This ensures that if you refresh /customise, it serves index.html instead of 404 | |
| async def serve_react_app(full_path: str, request: Request): | |
| # Skip API and WS routes | |
| if full_path.startswith("api") or full_path.startswith("ws"): | |
| raise HTTPException(status_code=404, detail="Not Found") | |
| # Serve index.html for any other route | |
| if os.path.exists("static/index.html"): | |
| return FileResponse("static/index.html") | |
| else: | |
| return {"message": "React app not found. Please run 'npm run build' and copy dist to static."} |