Yingtao-Zheng's picture
Upload partially updated files
8bbb872
from __future__ import annotations
import os
from pathlib import Path
import cv2
import numpy as np
from ultralytics import YOLO
try:
import mediapipe as mp
except Exception: # pragma: no cover
mp = None
def find_weights(project_root: Path) -> Path | None:
candidates = [
project_root / "weights" / "best.pt",
project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
]
return next((p for p in candidates if p.is_file()), None)
def detect_pupil_center(gray: np.ndarray) -> tuple[int, int] | None:
h, w = gray.shape
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
eq = clahe.apply(gray)
blur = cv2.GaussianBlur(eq, (7, 7), 0)
cx, cy = w // 2, h // 2
rx, ry = int(w * 0.3), int(h * 0.3)
x0, x1 = max(cx - rx, 0), min(cx + rx, w)
y0, y1 = max(cy - ry, 0), min(cy + ry, h)
roi = blur[y0:y1, x0:x1]
_, thresh = cv2.threshold(roi, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=2)
thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
best = None
best_score = -1.0
for c in contours:
area = cv2.contourArea(c)
if area < 15:
continue
perimeter = cv2.arcLength(c, True)
if perimeter <= 0:
continue
circularity = 4 * np.pi * (area / (perimeter * perimeter))
if circularity < 0.3:
continue
m = cv2.moments(c)
if m["m00"] == 0:
continue
px = int(m["m10"] / m["m00"]) + x0
py = int(m["m01"] / m["m00"]) + y0
dist = np.hypot(px - cx, py - cy) / max(w, h)
score = circularity - dist
if score > best_score:
best_score = score
best = (px, py)
return best
def is_focused(pupil_center: tuple[int, int], img_shape: tuple[int, int]) -> bool:
h, w = img_shape
cx = w // 2
px, _ = pupil_center
dx = abs(px - cx) / max(w, 1)
return dx < 0.12
def classify_frame(model: YOLO, frame: np.ndarray) -> tuple[str, float]:
# Use classifier directly on frame (assumes frame is eye crop)
results = model.predict(frame, imgsz=224, device="cpu", verbose=False)
r = results[0]
probs = r.probs
top_idx = int(probs.top1)
top_conf = float(probs.top1conf)
pred_label = model.names[top_idx]
return pred_label, top_conf
def annotate_frame(frame: np.ndarray, label: str, focused: bool, conf: float, time_sec: float):
out = frame.copy()
text = f"{label} | focused={int(focused)} | conf={conf:.2f} | t={time_sec:.2f}s"
cv2.putText(out, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
return out
def write_segments(path: Path, segments: list[tuple[float, float, str]]):
with path.open("w") as f:
for start, end, label in segments:
f.write(f"{start:.2f},{end:.2f},{label}\n")
def process_video(video_path: Path, model: YOLO | None):
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
print(f"Failed to open {video_path}")
return
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out_path = video_path.with_name(video_path.stem + "_pred.mp4")
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(str(out_path), fourcc, fps, (width, height))
csv_path = video_path.with_name(video_path.stem + "_predictions.csv")
seg_path = video_path.with_name(video_path.stem + "_segments.txt")
frame_idx = 0
last_label = None
seg_start = 0.0
segments: list[tuple[float, float, str]] = []
with csv_path.open("w") as fcsv:
fcsv.write("time_sec,label,focused,conf\n")
if mp is None:
print("mediapipe is not installed. Falling back to classifier-only mode.")
use_mp = mp is not None
if use_mp:
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(
static_image_mode=False,
max_num_faces=1,
refine_landmarks=True,
min_detection_confidence=0.5,
min_tracking_confidence=0.5,
)
while True:
ret, frame = cap.read()
if not ret:
break
time_sec = frame_idx / fps
conf = 0.0
pred_label = "open"
focused = False
if use_mp:
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
res = face_mesh.process(rgb)
if res.multi_face_landmarks:
lm = res.multi_face_landmarks[0].landmark
h, w = frame.shape[:2]
# Eye landmarks (MediaPipe FaceMesh)
left_eye = [33, 160, 158, 133, 153, 144]
right_eye = [362, 385, 387, 263, 373, 380]
left_iris = [468, 469, 470, 471]
right_iris = [473, 474, 475, 476]
def pts(idxs):
return np.array([(int(lm[i].x * w), int(lm[i].y * h)) for i in idxs])
def ear(eye_pts):
# EAR using 6 points
p1, p2, p3, p4, p5, p6 = eye_pts
v1 = np.linalg.norm(p2 - p6)
v2 = np.linalg.norm(p3 - p5)
h1 = np.linalg.norm(p1 - p4)
return (v1 + v2) / (2.0 * h1 + 1e-6)
le = pts(left_eye)
re = pts(right_eye)
le_ear = ear(le)
re_ear = ear(re)
ear_avg = (le_ear + re_ear) / 2.0
# openness threshold
pred_label = "open" if ear_avg > 0.22 else "closed"
# iris centers
li = pts(left_iris)
ri = pts(right_iris)
li_c = li.mean(axis=0).astype(int)
ri_c = ri.mean(axis=0).astype(int)
# eye centers (midpoint of corners)
le_c = ((le[0] + le[3]) / 2).astype(int)
re_c = ((re[0] + re[3]) / 2).astype(int)
# focus = iris close to eye center horizontally for both eyes
le_dx = abs(li_c[0] - le_c[0]) / max(np.linalg.norm(le[0] - le[3]), 1)
re_dx = abs(ri_c[0] - re_c[0]) / max(np.linalg.norm(re[0] - re[3]), 1)
focused = (pred_label == "open") and (le_dx < 0.18) and (re_dx < 0.18)
# draw eye boundaries
cv2.polylines(frame, [le], True, (0, 255, 255), 1)
cv2.polylines(frame, [re], True, (0, 255, 255), 1)
# draw iris centers
cv2.circle(frame, tuple(li_c), 3, (0, 0, 255), -1)
cv2.circle(frame, tuple(ri_c), 3, (0, 0, 255), -1)
else:
pred_label = "closed"
focused = False
else:
if model is not None:
pred_label, conf = classify_frame(model, frame)
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
pupil_center = detect_pupil_center(gray) if pred_label.lower() == "open" else None
focused = False
if pred_label.lower() == "open" and pupil_center is not None:
focused = is_focused(pupil_center, gray.shape)
if pred_label.lower() != "open":
focused = False
label = "open_focused" if (pred_label.lower() == "open" and focused) else "open_not_focused"
if pred_label.lower() != "open":
label = "closed_not_focused"
fcsv.write(f"{time_sec:.2f},{label},{int(focused)},{conf:.4f}\n")
if last_label is None:
last_label = label
seg_start = time_sec
elif label != last_label:
segments.append((seg_start, time_sec, last_label))
seg_start = time_sec
last_label = label
annotated = annotate_frame(frame, label, focused, conf, time_sec)
writer.write(annotated)
frame_idx += 1
if last_label is not None:
end_time = frame_idx / fps
segments.append((seg_start, end_time, last_label))
write_segments(seg_path, segments)
cap.release()
writer.release()
print(f"Saved: {out_path}")
print(f"CSV: {csv_path}")
print(f"Segments: {seg_path}")
def main():
project_root = Path(__file__).resolve().parent.parent
weights = find_weights(project_root)
model = YOLO(str(weights)) if weights is not None else None
# Default to 1.mp4 and 2.mp4 in project root
videos = []
for name in ["1.mp4", "2.mp4"]:
p = project_root / name
if p.exists():
videos.append(p)
# Also allow passing paths via env var
extra = os.getenv("VIDEOS", "")
for v in [x.strip() for x in extra.split(",") if x.strip()]:
vp = Path(v)
if not vp.is_absolute():
vp = project_root / vp
if vp.exists():
videos.append(vp)
if not videos:
print("No videos found. Expected 1.mp4 / 2.mp4 in project root.")
return
for v in videos:
process_video(v, model)
if __name__ == "__main__":
main()