Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| from ultralytics import YOLO | |
| try: | |
| import mediapipe as mp | |
| except Exception: # pragma: no cover | |
| mp = None | |
| def find_weights(project_root: Path) -> Path | None: | |
| candidates = [ | |
| project_root / "weights" / "best.pt", | |
| project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt", | |
| project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt", | |
| project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt", | |
| project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt", | |
| ] | |
| return next((p for p in candidates if p.is_file()), None) | |
| def detect_pupil_center(gray: np.ndarray) -> tuple[int, int] | None: | |
| h, w = gray.shape | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| eq = clahe.apply(gray) | |
| blur = cv2.GaussianBlur(eq, (7, 7), 0) | |
| cx, cy = w // 2, h // 2 | |
| rx, ry = int(w * 0.3), int(h * 0.3) | |
| x0, x1 = max(cx - rx, 0), min(cx + rx, w) | |
| y0, y1 = max(cy - ry, 0), min(cy + ry, h) | |
| roi = blur[y0:y1, x0:x1] | |
| _, thresh = cv2.threshold(roi, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) | |
| thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=2) | |
| thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1) | |
| contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return None | |
| best = None | |
| best_score = -1.0 | |
| for c in contours: | |
| area = cv2.contourArea(c) | |
| if area < 15: | |
| continue | |
| perimeter = cv2.arcLength(c, True) | |
| if perimeter <= 0: | |
| continue | |
| circularity = 4 * np.pi * (area / (perimeter * perimeter)) | |
| if circularity < 0.3: | |
| continue | |
| m = cv2.moments(c) | |
| if m["m00"] == 0: | |
| continue | |
| px = int(m["m10"] / m["m00"]) + x0 | |
| py = int(m["m01"] / m["m00"]) + y0 | |
| dist = np.hypot(px - cx, py - cy) / max(w, h) | |
| score = circularity - dist | |
| if score > best_score: | |
| best_score = score | |
| best = (px, py) | |
| return best | |
| def is_focused(pupil_center: tuple[int, int], img_shape: tuple[int, int]) -> bool: | |
| h, w = img_shape | |
| cx = w // 2 | |
| px, _ = pupil_center | |
| dx = abs(px - cx) / max(w, 1) | |
| return dx < 0.12 | |
| def classify_frame(model: YOLO, frame: np.ndarray) -> tuple[str, float]: | |
| # Use classifier directly on frame (assumes frame is eye crop) | |
| results = model.predict(frame, imgsz=224, device="cpu", verbose=False) | |
| r = results[0] | |
| probs = r.probs | |
| top_idx = int(probs.top1) | |
| top_conf = float(probs.top1conf) | |
| pred_label = model.names[top_idx] | |
| return pred_label, top_conf | |
| def annotate_frame(frame: np.ndarray, label: str, focused: bool, conf: float, time_sec: float): | |
| out = frame.copy() | |
| text = f"{label} | focused={int(focused)} | conf={conf:.2f} | t={time_sec:.2f}s" | |
| cv2.putText(out, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) | |
| return out | |
| def write_segments(path: Path, segments: list[tuple[float, float, str]]): | |
| with path.open("w") as f: | |
| for start, end, label in segments: | |
| f.write(f"{start:.2f},{end:.2f},{label}\n") | |
| def process_video(video_path: Path, model: YOLO | None): | |
| cap = cv2.VideoCapture(str(video_path)) | |
| if not cap.isOpened(): | |
| print(f"Failed to open {video_path}") | |
| return | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| out_path = video_path.with_name(video_path.stem + "_pred.mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| writer = cv2.VideoWriter(str(out_path), fourcc, fps, (width, height)) | |
| csv_path = video_path.with_name(video_path.stem + "_predictions.csv") | |
| seg_path = video_path.with_name(video_path.stem + "_segments.txt") | |
| frame_idx = 0 | |
| last_label = None | |
| seg_start = 0.0 | |
| segments: list[tuple[float, float, str]] = [] | |
| with csv_path.open("w") as fcsv: | |
| fcsv.write("time_sec,label,focused,conf\n") | |
| if mp is None: | |
| print("mediapipe is not installed. Falling back to classifier-only mode.") | |
| use_mp = mp is not None | |
| if use_mp: | |
| mp_face_mesh = mp.solutions.face_mesh | |
| face_mesh = mp_face_mesh.FaceMesh( | |
| static_image_mode=False, | |
| max_num_faces=1, | |
| refine_landmarks=True, | |
| min_detection_confidence=0.5, | |
| min_tracking_confidence=0.5, | |
| ) | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| time_sec = frame_idx / fps | |
| conf = 0.0 | |
| pred_label = "open" | |
| focused = False | |
| if use_mp: | |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| res = face_mesh.process(rgb) | |
| if res.multi_face_landmarks: | |
| lm = res.multi_face_landmarks[0].landmark | |
| h, w = frame.shape[:2] | |
| # Eye landmarks (MediaPipe FaceMesh) | |
| left_eye = [33, 160, 158, 133, 153, 144] | |
| right_eye = [362, 385, 387, 263, 373, 380] | |
| left_iris = [468, 469, 470, 471] | |
| right_iris = [473, 474, 475, 476] | |
| def pts(idxs): | |
| return np.array([(int(lm[i].x * w), int(lm[i].y * h)) for i in idxs]) | |
| def ear(eye_pts): | |
| # EAR using 6 points | |
| p1, p2, p3, p4, p5, p6 = eye_pts | |
| v1 = np.linalg.norm(p2 - p6) | |
| v2 = np.linalg.norm(p3 - p5) | |
| h1 = np.linalg.norm(p1 - p4) | |
| return (v1 + v2) / (2.0 * h1 + 1e-6) | |
| le = pts(left_eye) | |
| re = pts(right_eye) | |
| le_ear = ear(le) | |
| re_ear = ear(re) | |
| ear_avg = (le_ear + re_ear) / 2.0 | |
| # openness threshold | |
| pred_label = "open" if ear_avg > 0.22 else "closed" | |
| # iris centers | |
| li = pts(left_iris) | |
| ri = pts(right_iris) | |
| li_c = li.mean(axis=0).astype(int) | |
| ri_c = ri.mean(axis=0).astype(int) | |
| # eye centers (midpoint of corners) | |
| le_c = ((le[0] + le[3]) / 2).astype(int) | |
| re_c = ((re[0] + re[3]) / 2).astype(int) | |
| # focus = iris close to eye center horizontally for both eyes | |
| le_dx = abs(li_c[0] - le_c[0]) / max(np.linalg.norm(le[0] - le[3]), 1) | |
| re_dx = abs(ri_c[0] - re_c[0]) / max(np.linalg.norm(re[0] - re[3]), 1) | |
| focused = (pred_label == "open") and (le_dx < 0.18) and (re_dx < 0.18) | |
| # draw eye boundaries | |
| cv2.polylines(frame, [le], True, (0, 255, 255), 1) | |
| cv2.polylines(frame, [re], True, (0, 255, 255), 1) | |
| # draw iris centers | |
| cv2.circle(frame, tuple(li_c), 3, (0, 0, 255), -1) | |
| cv2.circle(frame, tuple(ri_c), 3, (0, 0, 255), -1) | |
| else: | |
| pred_label = "closed" | |
| focused = False | |
| else: | |
| if model is not None: | |
| pred_label, conf = classify_frame(model, frame) | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| pupil_center = detect_pupil_center(gray) if pred_label.lower() == "open" else None | |
| focused = False | |
| if pred_label.lower() == "open" and pupil_center is not None: | |
| focused = is_focused(pupil_center, gray.shape) | |
| if pred_label.lower() != "open": | |
| focused = False | |
| label = "open_focused" if (pred_label.lower() == "open" and focused) else "open_not_focused" | |
| if pred_label.lower() != "open": | |
| label = "closed_not_focused" | |
| fcsv.write(f"{time_sec:.2f},{label},{int(focused)},{conf:.4f}\n") | |
| if last_label is None: | |
| last_label = label | |
| seg_start = time_sec | |
| elif label != last_label: | |
| segments.append((seg_start, time_sec, last_label)) | |
| seg_start = time_sec | |
| last_label = label | |
| annotated = annotate_frame(frame, label, focused, conf, time_sec) | |
| writer.write(annotated) | |
| frame_idx += 1 | |
| if last_label is not None: | |
| end_time = frame_idx / fps | |
| segments.append((seg_start, end_time, last_label)) | |
| write_segments(seg_path, segments) | |
| cap.release() | |
| writer.release() | |
| print(f"Saved: {out_path}") | |
| print(f"CSV: {csv_path}") | |
| print(f"Segments: {seg_path}") | |
| def main(): | |
| project_root = Path(__file__).resolve().parent.parent | |
| weights = find_weights(project_root) | |
| model = YOLO(str(weights)) if weights is not None else None | |
| # Default to 1.mp4 and 2.mp4 in project root | |
| videos = [] | |
| for name in ["1.mp4", "2.mp4"]: | |
| p = project_root / name | |
| if p.exists(): | |
| videos.append(p) | |
| # Also allow passing paths via env var | |
| extra = os.getenv("VIDEOS", "") | |
| for v in [x.strip() for x in extra.split(",") if x.strip()]: | |
| vp = Path(v) | |
| if not vp.is_absolute(): | |
| vp = project_root / vp | |
| if vp.exists(): | |
| videos.append(vp) | |
| if not videos: | |
| print("No videos found. Expected 1.mp4 / 2.mp4 in project root.") | |
| return | |
| for v in videos: | |
| process_video(v, model) | |
| if __name__ == "__main__": | |
| main() | |