Spaces:
Sleeping
Sleeping
| import argparse | |
| import collections | |
| import math | |
| import os | |
| import sys | |
| import time | |
| import cv2 | |
| import numpy as np | |
| _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| if _PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, _PROJECT_ROOT) | |
| from models.face_mesh import FaceMeshDetector | |
| from models.head_pose import HeadPoseEstimator | |
| from models.eye_scorer import EyeBehaviourScorer, compute_gaze_ratio, compute_mar | |
| FONT = cv2.FONT_HERSHEY_SIMPLEX | |
| GREEN = (0, 255, 0) | |
| RED = (0, 0, 255) | |
| WHITE = (255, 255, 255) | |
| YELLOW = (0, 255, 255) | |
| ORANGE = (0, 165, 255) | |
| GRAY = (120, 120, 120) | |
| FEATURE_NAMES = [ | |
| "ear_left", "ear_right", "ear_avg", "h_gaze", "v_gaze", "mar", | |
| "yaw", "pitch", "roll", "s_face", "s_eye", "gaze_offset", "head_deviation", | |
| "perclos", "blink_rate", "closure_duration", "yawn_duration", | |
| ] | |
| NUM_FEATURES = len(FEATURE_NAMES) | |
| assert NUM_FEATURES == 17 | |
| class TemporalTracker: | |
| EAR_BLINK_THRESH = 0.21 | |
| MAR_YAWN_THRESH = 0.55 | |
| PERCLOS_WINDOW = 60 | |
| BLINK_WINDOW_SEC = 30.0 | |
| def __init__(self): | |
| self.ear_history = collections.deque(maxlen=self.PERCLOS_WINDOW) | |
| self.blink_timestamps = collections.deque() | |
| self._eyes_closed = False | |
| self._closure_start = None | |
| self._yawn_start = None | |
| def update(self, ear_avg, mar, now=None): | |
| if now is None: | |
| now = time.time() | |
| closed = ear_avg < self.EAR_BLINK_THRESH | |
| self.ear_history.append(1.0 if closed else 0.0) | |
| perclos = sum(self.ear_history) / len(self.ear_history) if self.ear_history else 0.0 | |
| if self._eyes_closed and not closed: | |
| self.blink_timestamps.append(now) | |
| self._eyes_closed = closed | |
| cutoff = now - self.BLINK_WINDOW_SEC | |
| while self.blink_timestamps and self.blink_timestamps[0] < cutoff: | |
| self.blink_timestamps.popleft() | |
| blink_rate = len(self.blink_timestamps) * (60.0 / self.BLINK_WINDOW_SEC) | |
| if closed: | |
| if self._closure_start is None: | |
| self._closure_start = now | |
| closure_dur = now - self._closure_start | |
| else: | |
| self._closure_start = None | |
| closure_dur = 0.0 | |
| yawning = mar > self.MAR_YAWN_THRESH | |
| if yawning: | |
| if self._yawn_start is None: | |
| self._yawn_start = now | |
| yawn_dur = now - self._yawn_start | |
| else: | |
| self._yawn_start = None | |
| yawn_dur = 0.0 | |
| return perclos, blink_rate, closure_dur, yawn_dur | |
| def extract_features(landmarks, w, h, head_pose, eye_scorer, temporal, | |
| *, _pre=None): | |
| from models.eye_scorer import _LEFT_EYE_EAR, _RIGHT_EYE_EAR, compute_ear | |
| p = _pre or {} | |
| ear_left = p.get("ear_left", compute_ear(landmarks, _LEFT_EYE_EAR)) | |
| ear_right = p.get("ear_right", compute_ear(landmarks, _RIGHT_EYE_EAR)) | |
| ear_avg = (ear_left + ear_right) / 2.0 | |
| if "h_gaze" in p and "v_gaze" in p: | |
| h_gaze, v_gaze = p["h_gaze"], p["v_gaze"] | |
| else: | |
| h_gaze, v_gaze = compute_gaze_ratio(landmarks) | |
| mar = p.get("mar", compute_mar(landmarks)) | |
| angles = p.get("angles") | |
| if angles is None: | |
| angles = head_pose.estimate(landmarks, w, h) | |
| yaw = angles[0] if angles else 0.0 | |
| pitch = angles[1] if angles else 0.0 | |
| roll = angles[2] if angles else 0.0 | |
| s_face = p.get("s_face", head_pose.score(landmarks, w, h)) | |
| s_eye = p.get("s_eye", eye_scorer.score(landmarks)) | |
| gaze_offset = math.sqrt((h_gaze - 0.5) ** 2 + (v_gaze - 0.5) ** 2) | |
| head_deviation = math.sqrt(yaw ** 2 + pitch ** 2) # cleaned downstream | |
| perclos, blink_rate, closure_dur, yawn_dur = temporal.update(ear_avg, mar) | |
| return np.array([ | |
| ear_left, ear_right, ear_avg, | |
| h_gaze, v_gaze, | |
| mar, | |
| yaw, pitch, roll, | |
| s_face, s_eye, | |
| gaze_offset, | |
| head_deviation, | |
| perclos, blink_rate, closure_dur, yawn_dur, | |
| ], dtype=np.float32) | |
| def quality_report(labels): | |
| n = len(labels) | |
| n1 = int((labels == 1).sum()) | |
| n0 = n - n1 | |
| transitions = int(np.sum(np.diff(labels) != 0)) | |
| duration_sec = n / 30.0 # approximate at 30fps | |
| warnings = [] | |
| print(f"\n{'='*50}") | |
| print(f" DATA QUALITY REPORT") | |
| print(f"{'='*50}") | |
| print(f" Total samples : {n}") | |
| print(f" Focused : {n1} ({n1/max(n,1)*100:.1f}%)") | |
| print(f" Unfocused : {n0} ({n0/max(n,1)*100:.1f}%)") | |
| print(f" Duration : {duration_sec:.0f}s ({duration_sec/60:.1f} min)") | |
| print(f" Transitions : {transitions}") | |
| if transitions > 0: | |
| print(f" Avg segment : {n/transitions:.0f} frames ({n/transitions/30:.1f}s)") | |
| # checks | |
| if duration_sec < 120: | |
| warnings.append(f"TOO SHORT: {duration_sec:.0f}s — aim for 5-10 minutes (300-600s)") | |
| if n < 3000: | |
| warnings.append(f"LOW SAMPLE COUNT: {n} frames — aim for 9000+ (5 min at 30fps)") | |
| balance = n1 / max(n, 1) | |
| if balance < 0.3 or balance > 0.7: | |
| warnings.append(f"IMBALANCED: {balance:.0%} focused — aim for 35-65% focused") | |
| if transitions < 10: | |
| warnings.append(f"TOO FEW TRANSITIONS: {transitions} — switch every 10-30s, aim for 20+") | |
| if transitions == 1: | |
| warnings.append("SINGLE BLOCK: you recorded one unfocused + one focused block — " | |
| "model will learn temporal position, not focus patterns") | |
| if warnings: | |
| print(f"\n ⚠️ WARNINGS ({len(warnings)}):") | |
| for w in warnings: | |
| print(f" • {w}") | |
| print(f"\n Consider re-recording this session.") | |
| else: | |
| print(f"\n ✅ All checks passed!") | |
| print(f"{'='*50}\n") | |
| return len(warnings) == 0 | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--name", type=str, default="session", | |
| help="Your name or session ID") | |
| parser.add_argument("--camera", type=int, default=0, | |
| help="Camera index") | |
| parser.add_argument("--duration", type=int, default=600, | |
| help="Max recording time (seconds, default 10 min)") | |
| parser.add_argument("--output-dir", type=str, | |
| default=os.path.join(_PROJECT_ROOT, "data", "collected_data"), | |
| help="Where to save .npz files") | |
| args = parser.parse_args() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| detector = FaceMeshDetector() | |
| head_pose = HeadPoseEstimator() | |
| eye_scorer = EyeBehaviourScorer() | |
| temporal = TemporalTracker() | |
| cap = cv2.VideoCapture(args.camera) | |
| if not cap.isOpened(): | |
| print("[COLLECT] ERROR: can't open camera") | |
| return | |
| print("[COLLECT] Data Collection Tool") | |
| print(f"[COLLECT] Session: {args.name}, max {args.duration}s") | |
| print(f"[COLLECT] Features per frame: {NUM_FEATURES}") | |
| print("[COLLECT] Controls:") | |
| print(" 1 = FOCUSED (looking at screen normally)") | |
| print(" 0 = NOT FOCUSED (phone, away, eyes closed, yawning)") | |
| print(" p = pause") | |
| print(" q = save & quit") | |
| print() | |
| print("[COLLECT] TIPS for good data:") | |
| print(" • Switch between 1 and 0 every 10-30 seconds") | |
| print(" • Aim for 20+ transitions total") | |
| print(" • Act out varied scenarios: reading, phone, talking, drowsy") | |
| print(" • Record at least 5 minutes") | |
| print() | |
| features_list = [] | |
| labels_list = [] | |
| label = None # None = paused | |
| transitions = 0 # count label switches | |
| prev_label = None | |
| status = "PAUSED -- press 1 (focused) or 0 (not focused)" | |
| t_start = time.time() | |
| prev_time = time.time() | |
| fps = 0.0 | |
| try: | |
| while True: | |
| elapsed = time.time() - t_start | |
| if elapsed > args.duration: | |
| print(f"[COLLECT] Time limit ({args.duration}s)") | |
| break | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| h, w = frame.shape[:2] | |
| landmarks = detector.process(frame) | |
| face_ok = landmarks is not None | |
| if face_ok and label is not None: | |
| vec = extract_features(landmarks, w, h, head_pose, eye_scorer, temporal) | |
| features_list.append(vec) | |
| labels_list.append(label) | |
| if prev_label is not None and label != prev_label: | |
| transitions += 1 | |
| prev_label = label | |
| now = time.time() | |
| fps = 0.9 * fps + 0.1 * (1.0 / max(now - prev_time, 1e-6)) | |
| prev_time = now | |
| # --- draw UI --- | |
| n = len(labels_list) | |
| n1 = sum(1 for x in labels_list if x == 1) | |
| n0 = n - n1 | |
| remaining = max(0, args.duration - elapsed) | |
| bar_color = GREEN if label == 1 else (RED if label == 0 else (80, 80, 80)) | |
| cv2.rectangle(frame, (0, 0), (w, 70), (0, 0, 0), -1) | |
| cv2.putText(frame, status, (10, 22), FONT, 0.55, bar_color, 2, cv2.LINE_AA) | |
| cv2.putText(frame, f"Samples: {n} (F:{n1} U:{n0}) Switches: {transitions}", | |
| (10, 48), FONT, 0.42, WHITE, 1, cv2.LINE_AA) | |
| cv2.putText(frame, f"FPS:{fps:.0f}", (w - 80, 22), FONT, 0.45, WHITE, 1, cv2.LINE_AA) | |
| cv2.putText(frame, f"{int(remaining)}s left", (w - 80, 48), FONT, 0.42, YELLOW, 1, cv2.LINE_AA) | |
| if n > 0: | |
| bar_w = min(w - 20, 300) | |
| bar_x = w - bar_w - 10 | |
| bar_y = 58 | |
| frac = n1 / n | |
| cv2.rectangle(frame, (bar_x, bar_y), (bar_x + bar_w, bar_y + 8), (40, 40, 40), -1) | |
| cv2.rectangle(frame, (bar_x, bar_y), (bar_x + int(bar_w * frac), bar_y + 8), GREEN, -1) | |
| cv2.putText(frame, f"{frac:.0%}F", (bar_x + bar_w + 4, bar_y + 8), | |
| FONT, 0.3, GRAY, 1, cv2.LINE_AA) | |
| if not face_ok: | |
| cv2.putText(frame, "NO FACE", (w // 2 - 60, h // 2), FONT, 0.7, RED, 2, cv2.LINE_AA) | |
| # red dot = recording | |
| if label is not None and face_ok: | |
| cv2.circle(frame, (w - 20, 80), 8, RED, -1) | |
| # live warnings | |
| warn_y = h - 35 | |
| if n > 100 and transitions < 3: | |
| cv2.putText(frame, "! Switch more often (aim for 20+ transitions)", | |
| (10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA) | |
| warn_y -= 18 | |
| if elapsed > 30 and n > 0: | |
| bal = n1 / n | |
| if bal < 0.25 or bal > 0.75: | |
| cv2.putText(frame, f"! Imbalanced ({bal:.0%} focused) - record more of the other", | |
| (10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA) | |
| warn_y -= 18 | |
| cv2.putText(frame, "1:focused 0:unfocused p:pause q:save+quit", | |
| (10, h - 10), FONT, 0.38, GRAY, 1, cv2.LINE_AA) | |
| cv2.imshow("FocusGuard -- Data Collection", frame) | |
| key = cv2.waitKey(1) & 0xFF | |
| if key == ord("1"): | |
| label = 1 | |
| status = "Recording: FOCUSED" | |
| print(f"[COLLECT] -> FOCUSED (n={n}, transitions={transitions})") | |
| elif key == ord("0"): | |
| label = 0 | |
| status = "Recording: NOT FOCUSED" | |
| print(f"[COLLECT] -> NOT FOCUSED (n={n}, transitions={transitions})") | |
| elif key == ord("p"): | |
| label = None | |
| status = "PAUSED" | |
| print(f"[COLLECT] paused (n={n})") | |
| elif key == ord("q"): | |
| break | |
| finally: | |
| cap.release() | |
| cv2.destroyAllWindows() | |
| detector.close() | |
| if len(features_list) > 0: | |
| feats = np.stack(features_list) | |
| labs = np.array(labels_list, dtype=np.int64) | |
| ts = time.strftime("%Y%m%d_%H%M%S") | |
| fname = f"{args.name}_{ts}.npz" | |
| fpath = os.path.join(args.output_dir, fname) | |
| np.savez(fpath, | |
| features=feats, | |
| labels=labs, | |
| feature_names=np.array(FEATURE_NAMES)) | |
| print(f"\n[COLLECT] Saved {len(labs)} samples -> {fpath}") | |
| print(f" Shape: {feats.shape} ({NUM_FEATURES} features)") | |
| quality_report(labs) | |
| else: | |
| print("\n[COLLECT] No data collected") | |
| print("[COLLECT] Done") | |
| if __name__ == "__main__": | |
| main() |