File size: 3,012 Bytes
bd27421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""PDF page classifier for production inference."""

import json
from pathlib import Path
from typing import Any

import numpy as np
import numpy.typing as npt

try:
    from .base_classifier import _BasePDFPageClassifier
except ImportError:
    from base_classifier import _BasePDFPageClassifier  # standalone / HF usage

try:
    import onnxruntime as ort
except ImportError as _e:
    raise ImportError(
        "onnxruntime is required for inference.\n"
        "Install with: pip install onnxruntime"
    ) from _e



class PDFPageClassifierONNX(_BasePDFPageClassifier):
    """Classify PDF pages using a deployed ONNX model.

    Loads a self-contained deployment directory produced by
    ``export_onnx.save_for_deployment`` and exposes a simple ``predict``
    interface.  All preprocessing (center-crop, resize, normalization) is
    performed in pure PIL + numpy, matching the pipeline used during training.

    Example::

        clf = PDFPageClassifier.from_pretrained("outputs/run-42/deployment")
        result = clf.predict("page_001.png")
        print(result["needs_image_embedding"], result["predicted_classes"])
    """

    def __init__(self, model_path: str, config: dict[str, Any]) -> None:
        """Initialise the classifier.

        Args:
            model_path: Path to the ONNX model file.
            config: Deployment config dict (same schema as config.json written
                by save_for_deployment).
        """
        super().__init__(config)
        self._session = ort.InferenceSession(model_path)
        self._input_name: str = self._session.get_inputs()[0].name

    @classmethod
    def from_pretrained(cls, model_dir: str) -> "PDFPageClassifier":
        """Load a classifier from a deployment directory.

        The directory must contain:
          - ``model.onnx``  — exported by save_for_deployment
          - ``config.json`` — written by save_for_deployment

        Args:
            model_dir: Path to the deployment directory.

        Returns:
            Initialised PDFPageClassifier.
        """
        path = Path(model_dir)
        config_path = path / "config.json"

        if not config_path.exists():
            raise FileNotFoundError(f"config.json not found in {model_dir}")

        # Prefer INT8 (QAT export) over FP32 when both are present
        candidates = ["model_int8.onnx", "model.onnx"]
        for candidate in candidates:
            if (path / candidate).exists():
                model_path = path / candidate
                break
        else:
            raise FileNotFoundError(
                f"No ONNX model found in {model_dir}. "
                f"Expected one of: {', '.join(candidates)}."
            )

        with open(config_path, encoding="utf-8") as f:
            config = json.load(f)

        return cls(str(model_path), config)

    def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
        return self._session.run(None, {self._input_name: batch_input})[0]