File size: 1,924 Bytes
e0507e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import annotations

from abc import ABC, abstractmethod

import numpy as np


class EyeClassifier(ABC):
    @property
    @abstractmethod
    def name(self) -> str:
        pass

    @abstractmethod
    def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
        pass


class GeometricOnlyClassifier(EyeClassifier):
    @property
    def name(self) -> str:
        return "geometric"

    def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
        return 1.0


class YOLOv11Classifier(EyeClassifier):
    def __init__(self, checkpoint_path: str, device: str = "cpu"):
        from ultralytics import YOLO

        self._model = YOLO(checkpoint_path)
        self._device = device

        names = self._model.names
        self._attentive_idx = None
        for idx, cls_name in names.items():
            if cls_name in ("open", "attentive"):
                self._attentive_idx = idx
                break
        if self._attentive_idx is None:
            self._attentive_idx = max(names.keys())
        print(f"[YOLO] Classes: {names}, attentive_idx={self._attentive_idx}")

    @property
    def name(self) -> str:
        return "yolo"

    def predict_score(self, crops_bgr: list[np.ndarray]) -> float:
        if not crops_bgr:
            return 1.0
        results = self._model.predict(crops_bgr, device=self._device, verbose=False)
        scores = [float(r.probs.data[self._attentive_idx]) for r in results]
        return sum(scores) / len(scores) if scores else 1.0


def load_eye_classifier(
    path: str | None = None,
    backend: str = "yolo",
    device: str = "cpu",
) -> EyeClassifier:
    if path is None or backend == "geometric":
        return GeometricOnlyClassifier()

    try:
        return YOLOv11Classifier(path, device=device)
    except ImportError:
        print("[CLASSIFIER] ultralytics required for YOLO. pip install ultralytics")
        raise