FocusGuardBaseModel / models /collect_features.py
Kexin-251202's picture
Deploy base model
c86c45b verified
import argparse
import collections
import math
import os
import sys
import time
import cv2
import numpy as np
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from models.face_mesh import FaceMeshDetector
from models.head_pose import HeadPoseEstimator
from models.eye_scorer import EyeBehaviourScorer, compute_gaze_ratio, compute_mar
FONT = cv2.FONT_HERSHEY_SIMPLEX
GREEN = (0, 255, 0)
RED = (0, 0, 255)
WHITE = (255, 255, 255)
YELLOW = (0, 255, 255)
ORANGE = (0, 165, 255)
GRAY = (120, 120, 120)
FEATURE_NAMES = [
"ear_left", "ear_right", "ear_avg", "h_gaze", "v_gaze", "mar",
"yaw", "pitch", "roll", "s_face", "s_eye", "gaze_offset", "head_deviation",
"perclos", "blink_rate", "closure_duration", "yawn_duration",
]
NUM_FEATURES = len(FEATURE_NAMES)
assert NUM_FEATURES == 17
class TemporalTracker:
EAR_BLINK_THRESH = 0.21
MAR_YAWN_THRESH = 0.55
PERCLOS_WINDOW = 60
BLINK_WINDOW_SEC = 30.0
def __init__(self):
self.ear_history = collections.deque(maxlen=self.PERCLOS_WINDOW)
self.blink_timestamps = collections.deque()
self._eyes_closed = False
self._closure_start = None
self._yawn_start = None
def update(self, ear_avg, mar, now=None):
if now is None:
now = time.time()
closed = ear_avg < self.EAR_BLINK_THRESH
self.ear_history.append(1.0 if closed else 0.0)
perclos = sum(self.ear_history) / len(self.ear_history) if self.ear_history else 0.0
if self._eyes_closed and not closed:
self.blink_timestamps.append(now)
self._eyes_closed = closed
cutoff = now - self.BLINK_WINDOW_SEC
while self.blink_timestamps and self.blink_timestamps[0] < cutoff:
self.blink_timestamps.popleft()
blink_rate = len(self.blink_timestamps) * (60.0 / self.BLINK_WINDOW_SEC)
if closed:
if self._closure_start is None:
self._closure_start = now
closure_dur = now - self._closure_start
else:
self._closure_start = None
closure_dur = 0.0
yawning = mar > self.MAR_YAWN_THRESH
if yawning:
if self._yawn_start is None:
self._yawn_start = now
yawn_dur = now - self._yawn_start
else:
self._yawn_start = None
yawn_dur = 0.0
return perclos, blink_rate, closure_dur, yawn_dur
def extract_features(landmarks, w, h, head_pose, eye_scorer, temporal,
*, _pre=None):
from models.eye_scorer import _LEFT_EYE_EAR, _RIGHT_EYE_EAR, compute_ear
p = _pre or {}
ear_left = p.get("ear_left", compute_ear(landmarks, _LEFT_EYE_EAR))
ear_right = p.get("ear_right", compute_ear(landmarks, _RIGHT_EYE_EAR))
ear_avg = (ear_left + ear_right) / 2.0
if "h_gaze" in p and "v_gaze" in p:
h_gaze, v_gaze = p["h_gaze"], p["v_gaze"]
else:
h_gaze, v_gaze = compute_gaze_ratio(landmarks)
mar = p.get("mar", compute_mar(landmarks))
angles = p.get("angles")
if angles is None:
angles = head_pose.estimate(landmarks, w, h)
yaw = angles[0] if angles else 0.0
pitch = angles[1] if angles else 0.0
roll = angles[2] if angles else 0.0
s_face = p.get("s_face", head_pose.score(landmarks, w, h))
s_eye = p.get("s_eye", eye_scorer.score(landmarks))
gaze_offset = math.sqrt((h_gaze - 0.5) ** 2 + (v_gaze - 0.5) ** 2)
head_deviation = math.sqrt(yaw ** 2 + pitch ** 2) # cleaned downstream
perclos, blink_rate, closure_dur, yawn_dur = temporal.update(ear_avg, mar)
return np.array([
ear_left, ear_right, ear_avg,
h_gaze, v_gaze,
mar,
yaw, pitch, roll,
s_face, s_eye,
gaze_offset,
head_deviation,
perclos, blink_rate, closure_dur, yawn_dur,
], dtype=np.float32)
def quality_report(labels):
n = len(labels)
n1 = int((labels == 1).sum())
n0 = n - n1
transitions = int(np.sum(np.diff(labels) != 0))
duration_sec = n / 30.0 # approximate at 30fps
warnings = []
print(f"\n{'='*50}")
print(f" DATA QUALITY REPORT")
print(f"{'='*50}")
print(f" Total samples : {n}")
print(f" Focused : {n1} ({n1/max(n,1)*100:.1f}%)")
print(f" Unfocused : {n0} ({n0/max(n,1)*100:.1f}%)")
print(f" Duration : {duration_sec:.0f}s ({duration_sec/60:.1f} min)")
print(f" Transitions : {transitions}")
if transitions > 0:
print(f" Avg segment : {n/transitions:.0f} frames ({n/transitions/30:.1f}s)")
# checks
if duration_sec < 120:
warnings.append(f"TOO SHORT: {duration_sec:.0f}s — aim for 5-10 minutes (300-600s)")
if n < 3000:
warnings.append(f"LOW SAMPLE COUNT: {n} frames — aim for 9000+ (5 min at 30fps)")
balance = n1 / max(n, 1)
if balance < 0.3 or balance > 0.7:
warnings.append(f"IMBALANCED: {balance:.0%} focused — aim for 35-65% focused")
if transitions < 10:
warnings.append(f"TOO FEW TRANSITIONS: {transitions} — switch every 10-30s, aim for 20+")
if transitions == 1:
warnings.append("SINGLE BLOCK: you recorded one unfocused + one focused block — "
"model will learn temporal position, not focus patterns")
if warnings:
print(f"\n ⚠️ WARNINGS ({len(warnings)}):")
for w in warnings:
print(f" • {w}")
print(f"\n Consider re-recording this session.")
else:
print(f"\n ✅ All checks passed!")
print(f"{'='*50}\n")
return len(warnings) == 0
# ---------------------------------------------------------------------------
# Main
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="session",
help="Your name or session ID")
parser.add_argument("--camera", type=int, default=0,
help="Camera index")
parser.add_argument("--duration", type=int, default=600,
help="Max recording time (seconds, default 10 min)")
parser.add_argument("--output-dir", type=str,
default=os.path.join(_PROJECT_ROOT, "collected_data"),
help="Where to save .npz files")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
detector = FaceMeshDetector()
head_pose = HeadPoseEstimator()
eye_scorer = EyeBehaviourScorer()
temporal = TemporalTracker()
cap = cv2.VideoCapture(args.camera)
if not cap.isOpened():
print("[COLLECT] ERROR: can't open camera")
return
print("[COLLECT] Data Collection Tool")
print(f"[COLLECT] Session: {args.name}, max {args.duration}s")
print(f"[COLLECT] Features per frame: {NUM_FEATURES}")
print("[COLLECT] Controls:")
print(" 1 = FOCUSED (looking at screen normally)")
print(" 0 = NOT FOCUSED (phone, away, eyes closed, yawning)")
print(" p = pause")
print(" q = save & quit")
print()
print("[COLLECT] TIPS for good data:")
print(" • Switch between 1 and 0 every 10-30 seconds")
print(" • Aim for 20+ transitions total")
print(" • Act out varied scenarios: reading, phone, talking, drowsy")
print(" • Record at least 5 minutes")
print()
features_list = []
labels_list = []
label = None # None = paused
transitions = 0 # count label switches
prev_label = None
status = "PAUSED -- press 1 (focused) or 0 (not focused)"
t_start = time.time()
prev_time = time.time()
fps = 0.0
try:
while True:
elapsed = time.time() - t_start
if elapsed > args.duration:
print(f"[COLLECT] Time limit ({args.duration}s)")
break
ret, frame = cap.read()
if not ret:
break
h, w = frame.shape[:2]
landmarks = detector.process(frame)
face_ok = landmarks is not None
if face_ok and label is not None:
vec = extract_features(landmarks, w, h, head_pose, eye_scorer, temporal)
features_list.append(vec)
labels_list.append(label)
if prev_label is not None and label != prev_label:
transitions += 1
prev_label = label
now = time.time()
fps = 0.9 * fps + 0.1 * (1.0 / max(now - prev_time, 1e-6))
prev_time = now
# --- draw UI ---
n = len(labels_list)
n1 = sum(1 for x in labels_list if x == 1)
n0 = n - n1
remaining = max(0, args.duration - elapsed)
bar_color = GREEN if label == 1 else (RED if label == 0 else (80, 80, 80))
cv2.rectangle(frame, (0, 0), (w, 70), (0, 0, 0), -1)
cv2.putText(frame, status, (10, 22), FONT, 0.55, bar_color, 2, cv2.LINE_AA)
cv2.putText(frame, f"Samples: {n} (F:{n1} U:{n0}) Switches: {transitions}",
(10, 48), FONT, 0.42, WHITE, 1, cv2.LINE_AA)
cv2.putText(frame, f"FPS:{fps:.0f}", (w - 80, 22), FONT, 0.45, WHITE, 1, cv2.LINE_AA)
cv2.putText(frame, f"{int(remaining)}s left", (w - 80, 48), FONT, 0.42, YELLOW, 1, cv2.LINE_AA)
if n > 0:
bar_w = min(w - 20, 300)
bar_x = w - bar_w - 10
bar_y = 58
frac = n1 / n
cv2.rectangle(frame, (bar_x, bar_y), (bar_x + bar_w, bar_y + 8), (40, 40, 40), -1)
cv2.rectangle(frame, (bar_x, bar_y), (bar_x + int(bar_w * frac), bar_y + 8), GREEN, -1)
cv2.putText(frame, f"{frac:.0%}F", (bar_x + bar_w + 4, bar_y + 8),
FONT, 0.3, GRAY, 1, cv2.LINE_AA)
if not face_ok:
cv2.putText(frame, "NO FACE", (w // 2 - 60, h // 2), FONT, 0.7, RED, 2, cv2.LINE_AA)
# red dot = recording
if label is not None and face_ok:
cv2.circle(frame, (w - 20, 80), 8, RED, -1)
# live warnings
warn_y = h - 35
if n > 100 and transitions < 3:
cv2.putText(frame, "! Switch more often (aim for 20+ transitions)",
(10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA)
warn_y -= 18
if elapsed > 30 and n > 0:
bal = n1 / n
if bal < 0.25 or bal > 0.75:
cv2.putText(frame, f"! Imbalanced ({bal:.0%} focused) - record more of the other",
(10, warn_y), FONT, 0.38, ORANGE, 1, cv2.LINE_AA)
warn_y -= 18
cv2.putText(frame, "1:focused 0:unfocused p:pause q:save+quit",
(10, h - 10), FONT, 0.38, GRAY, 1, cv2.LINE_AA)
cv2.imshow("FocusGuard -- Data Collection", frame)
key = cv2.waitKey(1) & 0xFF
if key == ord("1"):
label = 1
status = "Recording: FOCUSED"
print(f"[COLLECT] -> FOCUSED (n={n}, transitions={transitions})")
elif key == ord("0"):
label = 0
status = "Recording: NOT FOCUSED"
print(f"[COLLECT] -> NOT FOCUSED (n={n}, transitions={transitions})")
elif key == ord("p"):
label = None
status = "PAUSED"
print(f"[COLLECT] paused (n={n})")
elif key == ord("q"):
break
finally:
cap.release()
cv2.destroyAllWindows()
detector.close()
if len(features_list) > 0:
feats = np.stack(features_list)
labs = np.array(labels_list, dtype=np.int64)
ts = time.strftime("%Y%m%d_%H%M%S")
fname = f"{args.name}_{ts}.npz"
fpath = os.path.join(args.output_dir, fname)
np.savez(fpath,
features=feats,
labels=labs,
feature_names=np.array(FEATURE_NAMES))
print(f"\n[COLLECT] Saved {len(labs)} samples -> {fpath}")
print(f" Shape: {feats.shape} ({NUM_FEATURES} features)")
quality_report(labs)
else:
print("\n[COLLECT] No data collected")
print("[COLLECT] Done")
if __name__ == "__main__":
main()