latex-ocr / image_processing_latex_ocr.py
harryrobert's picture
Upload folder using huggingface_hub
3372a56 verified
import torch
import numpy as np
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.utils import logging
logger = logging.get_logger(__name__)
class LaTeXOCRImageProcessor(BaseImageProcessor):
model_type = "latex_ocr"
def __init__(
self,
image_height=64,
max_image_width=1024,
patch_size=16,
**kwargs
):
super().__init__(**kwargs)
self.image_height = image_height
self.max_image_width = max_image_width
self.patch_size = patch_size
def preprocess(self, images, **kwargs) -> BatchFeature:
if not isinstance(images, list):
images = [images]
processed_images = []
for img in images:
if img.mode != "RGB":
img = img.convert("RGB")
w, h = img.size
new_w = int(round(w * self.image_height / max(h, 1)))
new_w = min(new_w, self.max_image_width)
new_w = max((new_w // self.patch_size) * self.patch_size, self.patch_size)
if (w, h) != (new_w, self.image_height):
img = img.resize((new_w, self.image_height), Image.BILINEAR)
img_array = np.array(img).astype(np.float32) / 255.0
img_array = np.transpose(img_array, (2, 0, 1))
processed_images.append(img_array)
return BatchFeature(data={"pixel_values": processed_images}, tensor_type="pt")