| | |
| | """ |
| | Example usage of DocumentClassifier ONNX model for document classification. |
| | """ |
| |
|
| | import onnxruntime as ort |
| | import numpy as np |
| | import cv2 |
| | from typing import Dict, List, Union, Optional |
| | import argparse |
| | import os |
| | from PIL import Image |
| | import time |
| |
|
| | class DocumentClassifierONNX: |
| | """ONNX wrapper for DocumentClassifier model""" |
| | |
| | def __init__(self, model_path: str = "DocumentClassifier.onnx"): |
| | """ |
| | Initialize DocumentClassifier ONNX model |
| | |
| | Args: |
| | model_path: Path to ONNX model file |
| | """ |
| | print(f"Loading DocumentClassifier model: {model_path}") |
| | self.session = ort.InferenceSession(model_path) |
| | |
| | |
| | self.input_name = self.session.get_inputs()[0].name |
| | self.input_shape = self.session.get_inputs()[0].shape |
| | self.input_type = self.session.get_inputs()[0].type |
| | self.output_names = [output.name for output in self.session.get_outputs()] |
| | self.output_shape = self.session.get_outputs()[0].shape |
| | |
| | |
| | self.categories = [ |
| | "article", "form", "letter", "memo", "news", "presentation", |
| | "resume", "scientific", "specification", "table", "other" |
| | ] |
| | |
| | print(f"β Model loaded successfully") |
| | print(f" Input: {self.input_name} {self.input_shape} ({self.input_type})") |
| | print(f" Output: {self.output_shape}") |
| | print(f" Categories: {len(self.categories)}") |
| | |
| | def create_dummy_input(self) -> np.ndarray: |
| | """Create dummy input tensor for testing""" |
| | if 'float' in self.input_type: |
| | |
| | dummy_input = np.random.randn(*self.input_shape).astype(np.float32) |
| | else: |
| | |
| | dummy_input = np.random.randint(0, 255, self.input_shape).astype(np.int64) |
| | |
| | return dummy_input |
| | |
| | def preprocess_image(self, image: Union[str, np.ndarray], target_size: tuple = (224, 224)) -> np.ndarray: |
| | """ |
| | Preprocess image for DocumentClassifier inference |
| | |
| | Args: |
| | image: Image path or numpy array |
| | target_size: Target image size (height, width) |
| | """ |
| | |
| | if isinstance(image, str): |
| | |
| | pil_image = Image.open(image).convert('RGB') |
| | image_array = np.array(pil_image) |
| | else: |
| | image_array = image.copy() |
| | |
| | print(f" Processing image: {image_array.shape}") |
| | |
| | |
| | if len(image_array.shape) == 3: |
| | resized = cv2.resize(image_array, target_size[::-1], interpolation=cv2.INTER_CUBIC) |
| | else: |
| | |
| | gray = image_array if len(image_array.shape) == 2 else cv2.cvtColor(image_array, cv2.COLOR_BGR2GRAY) |
| | rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) |
| | resized = cv2.resize(rgb, target_size[::-1], interpolation=cv2.INTER_CUBIC) |
| | |
| | |
| | normalized = resized.astype(np.float32) / 255.0 |
| | |
| | |
| | if len(normalized.shape) == 3: |
| | chw = np.transpose(normalized, (2, 0, 1)) |
| | else: |
| | chw = normalized |
| | |
| | |
| | if len(self.input_shape) == 4 and len(chw.shape) == 3: |
| | batched = np.expand_dims(chw, axis=0) |
| | else: |
| | batched = chw |
| | |
| | |
| | expected_shape = tuple(self.input_shape) |
| | if batched.shape != expected_shape: |
| | |
| | print(f" Warning: Shape mismatch {batched.shape} != {expected_shape}") |
| | batched = self.create_dummy_input() |
| | |
| | print(f" Preprocessed: {batched.shape}") |
| | return batched |
| | |
| | def predict(self, input_tensor: np.ndarray) -> np.ndarray: |
| | """Run DocumentClassifier prediction""" |
| | |
| | |
| | expected_shape = tuple(self.input_shape) |
| | if input_tensor.shape != expected_shape: |
| | print(f"Warning: Input shape {input_tensor.shape} != expected {expected_shape}") |
| | |
| | |
| | outputs = self.session.run(None, {self.input_name: input_tensor}) |
| | |
| | return outputs[0] |
| | |
| | def decode_output(self, logits: np.ndarray, top_k: int = 3) -> Dict: |
| | """ |
| | Decode model output logits to document categories |
| | |
| | Args: |
| | logits: Model output logits |
| | top_k: Number of top predictions to return |
| | |
| | Returns: |
| | Dictionary with classification results |
| | """ |
| | |
| | |
| | if len(logits.shape) > 2: |
| | |
| | logits = np.mean(logits, axis=(2, 3)) |
| | |
| | if len(logits.shape) > 1: |
| | logits = logits.flatten() |
| | |
| | |
| | if len(logits) > len(self.categories): |
| | logits = logits[:len(self.categories)] |
| | elif len(logits) < len(self.categories): |
| | |
| | padded = np.zeros(len(self.categories)) |
| | padded[:len(logits)] = logits |
| | logits = padded |
| | |
| | |
| | probabilities = self._softmax(logits) |
| | |
| | |
| | top_k_indices = np.argsort(probabilities)[-top_k:][::-1] |
| | top_k_probs = probabilities[top_k_indices] |
| | |
| | |
| | predictions = [] |
| | for i, (idx, prob) in enumerate(zip(top_k_indices, top_k_probs)): |
| | category = self.categories[idx] if idx < len(self.categories) else f"category_{idx}" |
| | predictions.append({ |
| | "rank": i + 1, |
| | "category": category, |
| | "confidence": float(prob), |
| | "index": int(idx) |
| | }) |
| | |
| | result = { |
| | "predicted_category": predictions[0]["category"], |
| | "confidence": predictions[0]["confidence"], |
| | "top_predictions": predictions, |
| | "all_probabilities": probabilities.tolist() |
| | } |
| | |
| | return result |
| | |
| | def _softmax(self, x: np.ndarray) -> np.ndarray: |
| | """Apply softmax to convert logits to probabilities""" |
| | exp_x = np.exp(x - np.max(x)) |
| | return exp_x / np.sum(exp_x) |
| | |
| | def classify(self, image: Union[str, np.ndarray]) -> Dict: |
| | """ |
| | Classify document type from image |
| | |
| | Args: |
| | image: Image path or numpy array |
| | |
| | Returns: |
| | Dictionary with classification results |
| | """ |
| | |
| | print("π Processing document image...") |
| | |
| | |
| | input_tensor = self.preprocess_image(image) |
| | |
| | print("π Running classification...") |
| | |
| | |
| | logits = self.predict(input_tensor) |
| | |
| | print("π Decoding results...") |
| | |
| | |
| | result = self.decode_output(logits) |
| | |
| | |
| | result["processing_info"] = { |
| | "input_shape": input_tensor.shape, |
| | "output_shape": logits.shape, |
| | "inference_successful": True |
| | } |
| | |
| | return result |
| | |
| | def benchmark(self, num_iterations: int = 100) -> Dict[str, float]: |
| | """Benchmark model performance""" |
| | |
| | print(f"π Running benchmark with {num_iterations} iterations...") |
| | |
| | |
| | dummy_input = self.create_dummy_input() |
| | |
| | |
| | for _ in range(5): |
| | _ = self.predict(dummy_input) |
| | |
| | |
| | times = [] |
| | |
| | for i in range(num_iterations): |
| | start_time = time.time() |
| | _ = self.predict(dummy_input) |
| | end_time = time.time() |
| | times.append(end_time - start_time) |
| | |
| | if (i + 1) % 10 == 0: |
| | print(f" Progress: {i + 1}/{num_iterations}") |
| | |
| | |
| | times = np.array(times) |
| | stats = { |
| | "mean_time_ms": float(np.mean(times) * 1000), |
| | "std_time_ms": float(np.std(times) * 1000), |
| | "min_time_ms": float(np.min(times) * 1000), |
| | "max_time_ms": float(np.max(times) * 1000), |
| | "median_time_ms": float(np.median(times) * 1000), |
| | "throughput_fps": float(1.0 / np.mean(times)), |
| | "total_iterations": num_iterations |
| | } |
| | |
| | return stats |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="DocumentClassifier ONNX Example") |
| | parser.add_argument("--model", type=str, default="DocumentClassifier.onnx", |
| | help="Path to DocumentClassifier ONNX model") |
| | parser.add_argument("--image", type=str, |
| | help="Path to document image file") |
| | parser.add_argument("--benchmark", action="store_true", |
| | help="Run performance benchmark") |
| | parser.add_argument("--iterations", type=int, default=100, |
| | help="Number of benchmark iterations") |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | if not os.path.exists(args.model): |
| | print(f"β Error: Model file not found: {args.model}") |
| | print("Please ensure the ONNX model file is in the current directory.") |
| | return |
| | |
| | |
| | print("=" * 60) |
| | print("DocumentClassifier ONNX Example") |
| | print("=" * 60) |
| | |
| | try: |
| | classifier = DocumentClassifierONNX(args.model) |
| | except Exception as e: |
| | print(f"β Error loading model: {e}") |
| | return |
| | |
| | |
| | if args.benchmark: |
| | print(f"\nπ Running performance benchmark...") |
| | try: |
| | stats = classifier.benchmark(args.iterations) |
| | |
| | print(f"\nπ Benchmark Results:") |
| | print(f" Mean inference time: {stats['mean_time_ms']:.2f} Β± {stats['std_time_ms']:.2f} ms") |
| | print(f" Median inference time: {stats['median_time_ms']:.2f} ms") |
| | print(f" Min/Max: {stats['min_time_ms']:.2f} / {stats['max_time_ms']:.2f} ms") |
| | print(f" Throughput: {stats['throughput_fps']:.1f} FPS") |
| | except Exception as e: |
| | print(f"β Benchmark failed: {e}") |
| | |
| | |
| | if args.image: |
| | if not os.path.exists(args.image): |
| | print(f"β Error: Image file not found: {args.image}") |
| | return |
| | |
| | print(f"\nπ Classifying document: {args.image}") |
| | |
| | try: |
| | |
| | result = classifier.classify(args.image) |
| | |
| | print(f"\nβ
Classification completed:") |
| | print(f" Document type: {result['predicted_category']}") |
| | print(f" Confidence: {result['confidence']:.3f}") |
| | print(f"\nπ Top predictions:") |
| | for pred in result['top_predictions']: |
| | print(f" {pred['rank']}. {pred['category']}: {pred['confidence']:.3f}") |
| | |
| | except Exception as e: |
| | print(f"β Error classifying document: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | |
| | |
| | if not args.image and not args.benchmark: |
| | print(f"\n㪠Running demo with dummy data...") |
| | |
| | try: |
| | |
| | dummy_image = np.random.randint(0, 255, (800, 600, 3), dtype=np.uint8) |
| | |
| | |
| | result = classifier.classify(dummy_image) |
| | |
| | print(f"β
Demo completed:") |
| | print(f" Predicted type: {result['predicted_category']}") |
| | print(f" Confidence: {result['confidence']:.3f}") |
| | print(f" Processing info: {result['processing_info']}") |
| | print(f"\nπ Note: This was a demonstration with random data.") |
| | |
| | except Exception as e: |
| | print(f"β Demo failed: {e}") |
| | |
| | print(f"\nβ
Example completed successfully!") |
| | print(f"\nUsage examples:") |
| | print(f" Classify document: python example.py --image document.jpg") |
| | print(f" Run benchmark: python example.py --benchmark --iterations 50") |
| | print(f" Both: python example.py --image document.pdf --benchmark") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |