integration_test2 / models /eye_classifier.py
Abdelrahman Almatrooshi
Add missing eye_crop and eye_classifier modules
e0507e7
from __future__ import annotations
from abc import ABC, abstractmethod
import numpy as np
class EyeClassifier(ABC):
@property
@abstractmethod
def name(self) -> str:
pass
@abstractmethod
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
pass
class GeometricOnlyClassifier(EyeClassifier):
@property
def name(self) -> str:
return "geometric"
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
return 1.0
class YOLOv11Classifier(EyeClassifier):
def __init__(self, checkpoint_path: str, device: str = "cpu"):
from ultralytics import YOLO
self._model = YOLO(checkpoint_path)
self._device = device
names = self._model.names
self._attentive_idx = None
for idx, cls_name in names.items():
if cls_name in ("open", "attentive"):
self._attentive_idx = idx
break
if self._attentive_idx is None:
self._attentive_idx = max(names.keys())
print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}")
@property
def name(self) -> str:
return "yolo"
def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
if not crops_bgr:
return 1.0
results = self._model.predict(crops_bgr, device=self._device, verbose=False)
scores = [float(r.probs.data[self._attentive_idx]) for r in results]
return sum(scores) / len(scores) if scores else 1.0
def load_eye_classifier(
path: str | None = None,
backend: str = "yolo",
device: str = "cpu",
) -> EyeClassifier:
if path is None or backend == "geometric":
return GeometricOnlyClassifier()
try:
return YOLOv11Classifier(path, device=device)
except ImportError:
print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics")
raise