import argparse import collections import math import os import sys import time import cv2 import numpy as np _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if _PROJECT_ROOT not in sys.path: sys.path.insert(0, _PROJECT_ROOT) from models.face_mesh import FaceMeshDetector from models.head_pose import HeadPoseEstimator from models.eye_scorer import EyeBehaviourScorer, compute_gaze_ratio, compute_mar FONT = cv2.FONT_HERSHEY_SIMPLEX GREEN = (0, 255, 0) RED = (0, 0, 255) WHITE = (255, 255, 255) YELLOW = (0, 255, 255) ORANGE = (0, 165, 255) GRAY = (120, 120, 120) FEATURE_NAMES = [ "ear_left", "ear_right", "ear_avg", "h_gaze", "v_gaze", "mar", "yaw", "pitch", "roll", "s_face", "s_eye", "gaze_offset", "head_deviation", "perclos", "blink_rate", "closure_duration", "yawn_duration", ] NUM_FEATURES = len(FEATURE_NAMES) assert NUM_FEATURES == 17 class TemporalTracker: EAR_BLINK_THRESH = 0.21 MAR_YAWN_THRESH = 0.55 PERCLOS_WINDOW = 60 BLINK_WINDOW_SEC = 30.0 def __init__(self): self.ear_history = collections.deque(maxlen=self.PERCLOS_WINDOW) self.blink_timestamps = collections.deque() self._eyes_closed = False self._closure_start = None self._yawn_start = None def update(self, ear_avg, mar, now=None): if now is None: now = time.time() closed = ear_avg < self.EAR_BLINK_THRESH self.ear_history.append(1.0 if closed else 0.0) perclos = sum(self.ear_history) / len(self.ear_history) if self.ear_history else 0.0 if self._eyes_closed and not closed: self.blink_timestamps.append(now) self._eyes_closed = closed cutoff = now - self.BLINK_WINDOW_SEC while self.blink_timestamps and self.blink_timestamps[0] < cutoff: self.blink_timestamps.popleft() blink_rate = len(self.blink_timestamps) * (60.0 / self.BLINK_WINDOW_SEC) if closed: if self._closure_start is None: self._closure_start = now closure_dur = now - self._closure_start else: self._closure_start = None closure_dur = 0.0 yawning = mar > self.MAR_YAWN_THRESH if yawning: if self._yawn_start is None: self._yawn_start = now yawn_dur = now - self._yawn_start else: self._yawn_start = None yawn_dur = 0.0 return perclos, blink_rate, closure_dur, yawn_dur def extract_features(landmarks, w, h, head_pose, eye_scorer, temporal, *, _pre=None): from models.eye_scorer import _LEFT_EYE_EAR, _RIGHT_EYE_EAR, compute_ear p = _pre or {} ear_left = p.get("ear_left", compute_ear(landmarks, _LEFT_EYE_EAR)) ear_right = p.get("ear_right", compute_ear(landmarks, _RIGHT_EYE_EAR)) ear_avg = (ear_left + ear_right) / 2.0 if "h_gaze" in p and "v_gaze" in p: h_gaze, v_gaze = p["h_gaze"], p["v_gaze"] else: h_gaze, v_gaze = compute_gaze_ratio(landmarks) mar = p.get("mar", compute_mar(landmarks)) angles = p.get("angles") if angles is None: angles = head_pose.estimate(landmarks, w, h) yaw = angles[0] if angles else 0.0 pitch = angles[1] if angles else 0.0 roll = angles[2] if angles else 0.0 s_face = p.get("s_face", head_pose.score(landmarks, w, h)) s_eye = p.get("s_eye", eye_scorer.score(landmarks)) gaze_offset = math.sqrt((h_gaze - 0.5) ** 2 + (v_gaze - 0.5) ** 2) head_deviation = math.sqrt(yaw ** 2 + pitch ** 2) # cleaned downstream perclos, blink_rate, closure_dur, yawn_dur = temporal.update(ear_avg, mar) return np.array([ ear_left, ear_right, ear_avg, h_gaze, v_gaze, mar, yaw, pitch, roll, s_face, s_eye, gaze_offset, head_deviation, perclos, blink_rate, closure_dur, yawn_dur, ], dtype=np.float32) def quality_report(labels): n = len(labels) n1 = int((labels == 1).sum()) n0 = n - n1 transitions = int(np.sum(np.diff(labels) != 0)) duration_sec = n / 30.0 # approximate at 30fps warnings = [] print(f"\n{'='*50}") print(f" DATA QUALITY REPORT") print(f"{'='*50}") print(f" Total samples : {n}") print(f" Focused : {n1} ({n1/max(n,1)*100:.1f}%)") print(f" Unfocused : {n0} ({n0/max(n,1)*100:.1f}%)") print(f" Duration : {duration_sec:.0f}s ({duration_sec/60:.1f} min)") print(f" Transitions : {transitions}") if transitions > 0: print(f" Avg segment : {n/transitions:.0f} frames ({n/transitions/30:.1f}s)") # checks if duration_sec < 120: warnings.append(f"TOO SHORT: {duration_sec:.0f}s — aim for 5-10 minutes (300-600s)") if n < 3000: warnings.append(f"LOW SAMPLE COUNT: {n} frames — aim for 9000+ (5 min at 30fps)") balance = n1 / max(n, 1) if balance < 0.3 or balance > 0.7: warnings.append(f"IMBALANCED: {balance:.0%} focused — aim for 35-65% focused") if transitions < 10: warnings.append(f"TOO FEW TRANSITIONS: {transitions} — switch every 10-30s, aim for 20+") if transitions == 1: warnings.append("SINGLE BLOCK: you recorded one unfocused + one focused block — " "model will learn temporal position, not focus patterns") if warnings: print(f"\n ⚠️ WARNINGS ({len(warnings)}):") for w in warnings: print(f" • {w}") print(f"\n Consider re-recording this session.") else: print(f"\n ✅ All checks passed!") print(f"{'='*50}\n") return len(warnings) == 0 # --------------------------------------------------------------------------- # Main def main(): parser = argparse.ArgumentParser() parser.add_argument("--name", type=str, default="session", help="Your name or session ID") parser.add_argument("--camera", type=int, default=0, help="Camera index") parser.add_argument("--duration", type=int, default=600, help="Max recording time (seconds, default 10 min)") parser.add_argument("--output-dir", type=str, default=os.path.join(_PROJECT_ROOT, "collected_data"), help="Where to save .npz files") args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) detector = FaceMeshDetector() head_pose = HeadPoseEstimator() eye_scorer = EyeBehaviourScorer() temporal = TemporalTracker() cap = cv2.VideoCapture(args.camera) if not cap.isOpened(): print("[COLLECT] ERROR: can't open camera") return print("[COLLECT] Data Collection Tool") print(f"[COLLECT] Session: {args.name}, max {args.duration}s") print(f"[COLLECT] Features per frame: {NUM_FEATURES}") print("[COLLECT] Controls:") print(" 1 = FOCUSED (looking at screen normally)") print(" 0 = NOT FOCUSED (phone, away, eyes closed, yawning)") print(" p = pause") print(" q = save & quit") print() print("[COLLECT] TIPS for good data:") print(" • Switch between 1 and 0 every 10-30 seconds") print(" • Aim for 20+ transitions total") print(" • Act out varied scenarios: reading, phone, talking, drowsy") print(" • Record at least 5 minutes") print() features_list = [] labels_list = [] label = None # None = paused transitions = 0 # count label switches prev_label = None status = "PAUSED -- press 1 (focused) or 0 (not focused)" t_start = time.time() prev_time = time.time() fps = 0.0 try: while True: elapsed = time.time() - t_start if elapsed > args.duration: print(f"[COLLECT] Time limit ({args.duration}s)") break ret, frame = cap.read() if not ret: break h, w = frame.shape[:2] landmarks = detector.process(frame) face_ok = landmarks is not None if face_ok and label is not None: vec = extract_features(landmarks, w, h, head_pose, eye_scorer, temporal) features_list.append(vec) labels_list.append(label) if prev_label is not None and label != prev_label: transitions += 1 prev_label = label now = time.time() fps = 0.9 * fps + 0.1 * (1.0 / max(now - prev_time, 1e-6)) prev_time = now # --- draw UI --- n = len(labels_list) n1 = sum(1 for x in labels_list if x == 1) n0 = n - n1 remaining = max(0, args.duration - elapsed) bar_color = GREEN if label == 1 else (RED if label == 0 else (80, 80, 80)) cv2.rectangle(frame, (0, 0), (w, 70), (0, 0, 0), -1) cv2.putText(frame, status, (10, 22), FONT, 0.55, bar_color, 2, cv2.LINE_AA) cv2.putText(frame, f"Samples: {n} (F:{n1} U:{n0}) Switches: {transitions}", (10, 48), FONT, 0.42, WHITE, 1, cv2.LINE_AA) cv2.putText(frame, f"FPS:{fps:.0f}", (w - 80, 22), FONT, 0.45, WHITE, 1, cv2.LINE_AA) cv2.putText(frame, f"{int(remaining)}s left", (w - 80, 48), FONT, 0.42, YELLOW, 1, cv2.LINE_AA) if n > 0: bar_w = min(w - 20, 300) bar_x = w - bar_w - 10 bar_y = 58 frac = n1 / n cv2.rectangle(frame, (bar_x, bar_y), (bar_x + bar_w, bar_y + 8), (40, 40, 40), -1) cv2.rectangle(frame, (bar_x, bar_y), (bar_x + int(bar_w * frac), bar_y + 8), GREEN, -1) cv2.putText(frame, f"{frac:.0%}F", (bar_x + bar_w + 4, bar_y + 8), FONT, 0.3, GRAY, 1, cv2.LINE_AA) if not face_ok: cv2.putText(frame, "NO FACE", (w // 2 - 60, h // 2), FONT, 0.7, RED, 2, cv2.LINE_AA) # red dot = recording if label is not None and face_ok: cv2.circle(frame, (w - 20, 80), 8, RED, -1) # live warnings warn_y = h - 35 if n > 100 and transitions < 3: cv2.putText(frame, "! Switch more often (aim for 20+ transitions)", (10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA) warn_y -= 18 if elapsed > 30 and n > 0: bal = n1 / n if bal < 0.25 or bal > 0.75: cv2.putText(frame, f"! Imbalanced ({bal:.0%} focused) - record more of the other", (10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA) warn_y -= 18 cv2.putText(frame, "1:focused 0:unfocused p:pause q:save+quit", (10, h - 10), FONT, 0.38, GRAY, 1, cv2.LINE_AA) cv2.imshow("FocusGuard -- Data Collection", frame) key = cv2.waitKey(1) & 0xFF if key == ord("1"): label = 1 status = "Recording: FOCUSED" print(f"[COLLECT] -> FOCUSED (n={n}, transitions={transitions})") elif key == ord("0"): label = 0 status = "Recording: NOT FOCUSED" print(f"[COLLECT] -> NOT FOCUSED (n={n}, transitions={transitions})") elif key == ord("p"): label = None status = "PAUSED" print(f"[COLLECT] paused (n={n})") elif key == ord("q"): break finally: cap.release() cv2.destroyAllWindows() detector.close() if len(features_list) > 0: feats = np.stack(features_list) labs = np.array(labels_list, dtype=np.int64) ts = time.strftime("%Y%m%d_%H%M%S") fname = f"{args.name}_{ts}.npz" fpath = os.path.join(args.output_dir, fname) np.savez(fpath, features=feats, labels=labs, feature_names=np.array(FEATURE_NAMES)) print(f"\n[COLLECT] Saved {len(labs)} samples -> {fpath}") print(f" Shape: {feats.shape} ({NUM_FEATURES} features)") quality_report(labs) else: print("\n[COLLECT] No data collected") print("[COLLECT] Done") if __name__ == "__main__": main()