| --- |
| license: mit |
| base_model: |
| - google/efficientnet-b0 |
| tags: |
| - image-classification |
| - document-analysis |
| - figure-classification |
| --- |
| |
|
|
| # EfficientNet-B0 Document Figure Classifier v2.5 |
|
|
| This is an image classification model based on **Google EfficientNet-B0**, fine-tuned on a subset of the subset of HuggingFace/finepdfs to classify document figures into one of the following 26 categories: |
|
|
| 1. **logo** |
| 2. **photograph** |
| 3. **icon** |
| 4. **engineering_drawing** |
| 5. **line_chart** |
| 6. **bar_chart** |
| 7. **other** |
| 8. **table** |
| 9. **flow_chart** |
| 10. **screenshot_from_computer** |
| 11. **signature** |
| 12. **screenshot_from_manual** |
| 13. **geographical_map** |
| 14. **pie_chart** |
| 15. **page_thumbnail** |
| 16. **stamp** |
| 17. **music** |
| 18. **calendar** |
| 19. **qr_code** |
| 20. **bar_code** |
| 21. **full_page_image** |
| 22. **scatter_plot** |
| 23. **chemistry_structure** |
| 24. **topographical_map** |
| 25. **crossword_puzzle** |
| 26. **box_plot** |
|
|
|
|
| ## Model Performance |
|
|
| **Note:** This model uses the same architecture and implementation as v2.0. The improved performance is achieved by training on a dataset that is 10 times larger than the one used for v2.0. |
|
|
| The model was evaluated on a held-out test set from the finepdfs dataset with the following metrics: |
|
|
| | Metric | v2.5 | v2.0 | Improvement | |
| |--------|------|------|-------------| |
| | **Accuracy** | 0.90703 | 0.87053 | +3.65% | |
| | **Balanced Accuracy** | 0.68836 | 0.60231 | +8.61% | |
| | **Macro F1** | 0.68942 | 0.60144 | +8.80% | |
| | **Weighted F1** | 0.90716 | 0.87270 | +3.45% | |
| | **Cohen's Kappa** | 0.87449 | 0.82563 | +4.89% | |
|
|
| ### Per-Label Performance |
|
|
| | Label | Precision (v2.5) | Recall (v2.5) | Precision (v2.0) | Recall (v2.0) | |
| |-------|------------------|---------------|------------------|---------------| |
| | **logo** | 0.92807 | 0.91816 | 0.88317 | 0.88728 | |
| | **photograph** | 0.90966 | 0.96029 | 0.88169 | 0.93359 | |
| | **icon** | 0.83605 | 0.82678 | 0.79281 | 0.72133 | |
| | **engineering_drawing** | 0.71689 | 0.81172 | 0.58795 | 0.71555 | |
| | **line_chart** | 0.73055 | 0.92117 | 0.75865 | 0.84576 | |
| | **bar_chart** | 0.88599 | 0.92720 | 0.72624 | 0.93883 | |
| | **other** | 0.41893 | 0.38213 | 0.28239 | 0.37312 | |
| | **table** | 0.98636 | 0.96765 | 0.97950 | 0.95250 | |
| | **flow_chart** | 0.75926 | 0.82425 | 0.61527 | 0.81518 | |
| | **screenshot_from_computer** | 0.85952 | 0.71980 | 0.80510 | 0.65844 | |
| | **signature** | 0.89020 | 0.85971 | 0.91852 | 0.80914 | |
| | **screenshot_from_manual** | 0.48559 | 0.34543 | 0.34748 | 0.20662 | |
| | **geographical_map** | 0.86780 | 0.85219 | 0.82959 | 0.80720 | |
| | **pie_chart** | 0.96880 | 0.94220 | 0.89903 | 0.93931 | |
| | **page_thumbnail** | 0.52008 | 0.35188 | 0.40194 | 0.21475 | |
| | **stamp** | 0.71269 | 0.41794 | 0.63492 | 0.26258 | |
| | **music** | 0.48037 | 0.57778 | 0.76955 | 0.51944 | |
| | **calendar** | 0.52880 | 0.28775 | 0.51176 | 0.24786 | |
| | **qr_code** | 0.95694 | 0.93240 | 0.97500 | 0.90909 | |
| | **bar_code** | 0.34244 | 0.84305 | 0.12087 | 0.82063 | |
| | **full_page_image** | 0.40323 | 0.65789 | 0.43750 | 0.28116 | |
| | **scatter_plot** | 0.66848 | 0.67213 | 0.60386 | 0.68306 | |
| | **chemistry_structure** | 0.72781 | 0.65426 | 0.77444 | 0.54787 | |
| | **topographical_map** | 0.83333 | 0.38462 | 0.68750 | 0.28205 | |
| | **crossword_puzzle** | 0.57143 | 0.21622 | 0.80000 | 0.21622 | |
| | **box_plot** | 0.85714 | 0.64286 | 1.00000 | 0.07143 | |
|
|
|
|
| ## How to use - Transformers |
|
|
| Example of how to classify an image into one of the 26 classes using transformers: |
|
|
| ```python |
| import torch |
| import torchvision.transforms as transforms |
| |
| from transformers import EfficientNetForImageClassification |
| from PIL import Image |
| import requests |
| |
| |
| urls = [ |
| 'http://images.cocodataset.org/val2017/000000039769.jpg', |
| 'http://images.cocodataset.org/test-stuff2017/000000001750.jpg', |
| 'http://images.cocodataset.org/test-stuff2017/000000000001.jpg' |
| ] |
| |
| image_processor = transforms.Compose( |
| [ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.47853944, 0.4732864, 0.47434163], |
| ), |
| ] |
| ) |
| |
| images = [] |
| for url in urls: |
| image = Image.open(requests.get(url, stream=True).raw).convert("RGB") |
| image = image_processor(image) |
| images.append(image) |
| |
| |
| model_id = 'docling-project/DocumentFigureClassifier-v2.5' |
| |
| model = EfficientNetForImageClassification.from_pretrained(model_id) |
| |
| labels = model.config.id2label |
| |
| device = torch.device("cpu") |
| |
| torch_images = torch.stack(images).to(device) |
| |
| with torch.no_grad(): |
| logits = model(torch_images).logits # (batch_size, num_classes) |
| probs_batch = logits.softmax(dim=1) # (batch_size, num_classes) |
| probs_batch = probs_batch.cpu().numpy().tolist() |
| |
| for idx, probs_image in enumerate(probs_batch): |
| preds = [(labels[i], prob) for i, prob in enumerate(probs_image)] |
| preds.sort(key=lambda t: t[1], reverse=True) |
| print(f"{idx}: {preds}") |
| ``` |
|
|
|
|
| ## How to use - ONNX |
|
|
| Example of how to classify an image into one of the 26 classes using onnx runtime: |
|
|
| ```python |
| import onnxruntime |
| |
| import numpy as np |
| import torchvision.transforms as transforms |
| |
| from PIL import Image |
| import requests |
| |
| LABELS = [ |
| "logo", |
| "photograph", |
| "icon", |
| "engineering_drawing", |
| "line_chart", |
| "bar_chart", |
| "other", |
| "table", |
| "flow_chart", |
| "screenshot_from_computer", |
| "signature", |
| "screenshot_from_manual", |
| "geographical_map", |
| "pie_chart", |
| "page_thumbnail", |
| "stamp", |
| "music", |
| "calendar", |
| "qr_code", |
| "bar_code", |
| "full_page_image", |
| "scatter_plot", |
| "chemistry_structure", |
| "topographical_map", |
| "crossword_puzzle", |
| "box_plot" |
| ] |
| |
| |
| urls = [ |
| 'http://images.cocodataset.org/val2017/000000039769.jpg', |
| 'http://images.cocodataset.org/test-stuff2017/000000001750.jpg', |
| 'http://images.cocodataset.org/test-stuff2017/000000000001.jpg' |
| ] |
| |
| images = [] |
| for url in urls: |
| image = Image.open(requests.get(url, stream=True).raw).convert("RGB") |
| images.append(image) |
| |
| |
| image_processor = transforms.Compose( |
| [ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.47853944, 0.4732864, 0.47434163], |
| ), |
| ] |
| ) |
| |
| |
| processed_images_onnx = [image_processor(image).unsqueeze(0) for image in images] |
| |
| # onnx needs numpy as input |
| onnx_inputs = [item.numpy(force=True) for item in processed_images_onnx] |
| |
| # pack into a batch |
| onnx_inputs = np.concatenate(onnx_inputs, axis=0) |
| |
| ort_session = onnxruntime.InferenceSession( |
| "./DocumentFigureClassifier-v2_5-onnx/model.onnx", |
| providers=["CUDAExecutionProvider", "CPUExecutionProvider"] |
| ) |
| |
| |
| for item in ort_session.run(None, {'input': onnx_inputs}): |
| for x in iter(item): |
| pred = x.argmax() |
| print(LABELS[pred]) |
| ``` |
|
|
|
|
| ## Training Data |
|
|
| This model was trained on a subset of the subset of HuggingFace/finepdfs, a large-scale dataset for document understanding tasks. |
|
|
|
|
| ## Citation |
|
|
| If you use this model in your work, please cite the following papers: |
|
|
| ``` |
| @article{Tan2019EfficientNetRM, |
| title={EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks}, |
| author={Mingxing Tan and Quoc V. Le}, |
| journal={ArXiv}, |
| year={2019}, |
| volume={abs/1905.11946} |
| } |
| |
| @techreport{Docling, |
| author = {Deep Search Team}, |
| month = {8}, |
| title = {{Docling Technical Report}}, |
| url={https://arxiv.org/abs/2408.09869}, |
| eprint={2408.09869}, |
| doi = "10.48550/arXiv.2408.09869", |
| version = {1.0.0}, |
| year = {2024} |
| } |