Spaces:
Running
Running
| from __future__ import annotations | |
| from abc import ABC, abstractmethod | |
| import numpy as np | |
| class EyeClassifier(ABC): | |
| def name(self) -> str: | |
| pass | |
| def predict_score(self, crops_bgr: list[np.ndarray]) -> float: | |
| pass | |
| class GeometricOnlyClassifier(EyeClassifier): | |
| def name(self) -> str: | |
| return "geometric" | |
| def predict_score(self, crops_bgr: list[np.ndarray]) -> float: | |
| return 1.0 | |
| class YOLOv11Classifier(EyeClassifier): | |
| def __init__(self, checkpoint_path: str, device: str = "cpu"): | |
| from ultralytics import YOLO | |
| self._model = YOLO(checkpoint_path) | |
| self._device = device | |
| names = self._model.names | |
| self._attentive_idx = None | |
| for idx, cls_name in names.items(): | |
| if cls_name in ("open", "attentive"): | |
| self._attentive_idx = idx | |
| break | |
| if self._attentive_idx is None: | |
| self._attentive_idx = max(names.keys()) | |
| print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}") | |
| def name(self) -> str: | |
| return "yolo" | |
| def predict_score(self, crops_bgr: list[np.ndarray]) -> float: | |
| if not crops_bgr: | |
| return 1.0 | |
| results = self._model.predict(crops_bgr, device=self._device, verbose=False) | |
| scores = [float(r.probs.data[self._attentive_idx]) for r in results] | |
| return sum(scores) / len(scores) if scores else 1.0 | |
| def load_eye_classifier( | |
| path: str | None = None, | |
| backend: str = "yolo", | |
| device: str = "cpu", | |
| ) -> EyeClassifier: | |
| if path is None or backend == "geometric": | |
| return GeometricOnlyClassifier() | |
| try: | |
| return YOLOv11Classifier(path, device=device) | |
| except ImportError: | |
| print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics") | |
| raise | |