Yingtao-Zheng's picture
Upload partially updated files
8bbb872
from __future__ import annotations
from pathlib import Path
import os
import cv2
import numpy as np
from ultralytics import YOLO
def list_images(folder: Path):
exts = {".png", ".jpg", ".jpeg", ".bmp", ".webp"}
return sorted([p for p in folder.iterdir() if p.suffix.lower() in exts])
def find_weights(project_root: Path) -> Path | None:
candidates = [
project_root / "weights" / "best.pt",
project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
project_root / "runs" / "classify" / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "best.pt",
project_root / "runs_cls" / "eye_open_closed_cpu" / "weights" / "last.pt",
]
return next((p for p in candidates if p.is_file()), None)
def detect_eyelid_boundary(gray: np.ndarray) -> np.ndarray | None:
"""
Returns an ellipse fit to the largest contour near the eye boundary.
Output format: (center(x,y), (axis1, axis2), angle) or None.
"""
blur = cv2.GaussianBlur(gray, (5, 5), 0)
edges = cv2.Canny(blur, 40, 120)
edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), iterations=1)
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
contours = sorted(contours, key=cv2.contourArea, reverse=True)
for c in contours:
if len(c) >= 5 and cv2.contourArea(c) > 50:
return cv2.fitEllipse(c)
return None
def detect_pupil_center(gray: np.ndarray) -> tuple[int, int] | None:
"""
More robust pupil detection:
- enhance contrast (CLAHE)
- find dark blobs
- score by circularity and proximity to center
"""
h, w = gray.shape
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
eq = clahe.apply(gray)
blur = cv2.GaussianBlur(eq, (7, 7), 0)
# Focus on the central region to avoid eyelashes/edges
cx, cy = w // 2, h // 2
rx, ry = int(w * 0.3), int(h * 0.3)
x0, x1 = max(cx - rx, 0), min(cx + rx, w)
y0, y1 = max(cy - ry, 0), min(cy + ry, h)
roi = blur[y0:y1, x0:x1]
# Inverted threshold to capture dark pupil
_, thresh = cv2.threshold(roi, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=2)
thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, np.ones((5, 5), np.uint8), iterations=1)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
best = None
best_score = -1.0
for c in contours:
area = cv2.contourArea(c)
if area < 15:
continue
perimeter = cv2.arcLength(c, True)
if perimeter <= 0:
continue
circularity = 4 * np.pi * (area / (perimeter * perimeter))
if circularity < 0.3:
continue
m = cv2.moments(c)
if m["m00"] == 0:
continue
px = int(m["m10"] / m["m00"]) + x0
py = int(m["m01"] / m["m00"]) + y0
# Score by circularity and distance to center
dist = np.hypot(px - cx, py - cy) / max(w, h)
score = circularity - dist
if score > best_score:
best_score = score
best = (px, py)
return best
def is_focused(pupil_center: tuple[int, int], img_shape: tuple[int, int]) -> bool:
"""
Decide focus based on pupil offset from image center.
"""
h, w = img_shape
cx, cy = w // 2, h // 2
px, py = pupil_center
dx = abs(px - cx) / max(w, 1)
dy = abs(py - cy) / max(h, 1)
return (dx < 0.12) and (dy < 0.12)
def annotate(img_bgr: np.ndarray, ellipse, pupil_center, focused: bool, cls_label: str, conf: float):
out = img_bgr.copy()
if ellipse is not None:
cv2.ellipse(out, ellipse, (0, 255, 255), 2)
if pupil_center is not None:
cv2.circle(out, pupil_center, 4, (0, 0, 255), -1)
label = f"{cls_label} ({conf:.2f}) | focused={int(focused)}"
cv2.putText(out, label, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
return out
def main():
project_root = Path(__file__).resolve().parent.parent
data_dir = project_root / "Dataset"
alt_data_dir = project_root / "DATA"
out_dir = project_root / "runs_focus"
out_dir.mkdir(parents=True, exist_ok=True)
weights = find_weights(project_root)
if weights is None:
print("Weights not found. Train first.")
return
# Support both Dataset/test/{open,closed} and Dataset/{open,closed}
def resolve_test_dirs(root: Path):
test_open = root / "test" / "open"
test_closed = root / "test" / "closed"
if test_open.exists() and test_closed.exists():
return test_open, test_closed
test_open = root / "open"
test_closed = root / "closed"
if test_open.exists() and test_closed.exists():
return test_open, test_closed
alt_closed = root / "close"
if test_open.exists() and alt_closed.exists():
return test_open, alt_closed
return None, None
test_open, test_closed = resolve_test_dirs(data_dir)
if (test_open is None or test_closed is None) and alt_data_dir.exists():
test_open, test_closed = resolve_test_dirs(alt_data_dir)
if not test_open.exists() or not test_closed.exists():
print("Test folders missing. Expected:")
print(test_open)
print(test_closed)
return
test_files = list_images(test_open) + list_images(test_closed)
print("Total test images:", len(test_files))
max_images = int(os.getenv("MAX_IMAGES", "0"))
if max_images > 0:
test_files = test_files[:max_images]
print("Limiting to MAX_IMAGES:", max_images)
model = YOLO(str(weights))
results = model.predict(test_files, imgsz=224, device="cpu", verbose=False)
names = model.names
for r in results:
probs = r.probs
top_idx = int(probs.top1)
top_conf = float(probs.top1conf)
pred_label = names[top_idx]
img = cv2.imread(r.path)
if img is None:
continue
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
ellipse = detect_eyelid_boundary(gray)
pupil_center = detect_pupil_center(gray)
focused = False
if pred_label.lower() == "open" and pupil_center is not None:
focused = is_focused(pupil_center, gray.shape)
annotated = annotate(img, ellipse, pupil_center, focused, pred_label, top_conf)
out_path = out_dir / (Path(r.path).stem + "_annotated.jpg")
cv2.imwrite(str(out_path), annotated)
print(f"{Path(r.path).name}: pred={pred_label} conf={top_conf:.3f} focused={focused}")
print(f"\nAnnotated outputs saved to: {out_dir}")
if __name__ == "__main__":
main()