File size: 3,012 Bytes
bd27421 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | """PDF page classifier for production inference."""
import json
from pathlib import Path
from typing import Any
import numpy as np
import numpy.typing as npt
try:
from .base_classifier import _BasePDFPageClassifier
except ImportError:
from base_classifier import _BasePDFPageClassifier # standalone / HF usage
try:
import onnxruntime as ort
except ImportError as _e:
raise ImportError(
"onnxruntime is required for inference.\n"
"Install with: pip install onnxruntime"
) from _e
class PDFPageClassifierONNX(_BasePDFPageClassifier):
"""Classify PDF pages using a deployed ONNX model.
Loads a self-contained deployment directory produced by
``export_onnx.save_for_deployment`` and exposes a simple ``predict``
interface. All preprocessing (center-crop, resize, normalization) is
performed in pure PIL + numpy, matching the pipeline used during training.
Example::
clf = PDFPageClassifier.from_pretrained("outputs/run-42/deployment")
result = clf.predict("page_001.png")
print(result["needs_image_embedding"], result["predicted_classes"])
"""
def __init__(self, model_path: str, config: dict[str, Any]) -> None:
"""Initialise the classifier.
Args:
model_path: Path to the ONNX model file.
config: Deployment config dict (same schema as config.json written
by save_for_deployment).
"""
super().__init__(config)
self._session = ort.InferenceSession(model_path)
self._input_name: str = self._session.get_inputs()[0].name
@classmethod
def from_pretrained(cls, model_dir: str) -> "PDFPageClassifier":
"""Load a classifier from a deployment directory.
The directory must contain:
- ``model.onnx`` — exported by save_for_deployment
- ``config.json`` — written by save_for_deployment
Args:
model_dir: Path to the deployment directory.
Returns:
Initialised PDFPageClassifier.
"""
path = Path(model_dir)
config_path = path / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"config.json not found in {model_dir}")
# Prefer INT8 (QAT export) over FP32 when both are present
candidates = ["model_int8.onnx", "model.onnx"]
for candidate in candidates:
if (path / candidate).exists():
model_path = path / candidate
break
else:
raise FileNotFoundError(
f"No ONNX model found in {model_dir}. "
f"Expected one of: {', '.join(candidates)}."
)
with open(config_path, encoding="utf-8") as f:
config = json.load(f)
return cls(str(model_path), config)
def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
return self._session.run(None, {self._input_name: batch_input})[0]
|