MiniMax-M3 / image_processor.py
xuebi
Initial commit
3a41b31
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
"""
MiniMax VL family HuggingFace-compatible Processor, ImageProcessor, VideoProcessor.
"""
import math
from typing import List, Tuple
import torch
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
group_images_by_shape,
reorder_images,
)
from transformers.image_utils import PILImageResampling, SizeDict
from transformers.processing_utils import (
ImagesKwargs,
Unpack,
)
from transformers.utils import TensorType
MAX_RATIO = 200
def round_by_factor(number: int, factor: int) -> int:
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
return math.floor(number / factor) * factor
def smart_resize(
height: int,
width: int,
factor: int = 28,
min_pixels: int = 4 * 28 * 28,
max_pixels: int = 451584,
) -> tuple[int, int]:
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, "
f"got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
# ==============================================================================
# MiniMax M3 VL Image Processor Fast (Fast Mode - Torch based)
# ==============================================================================
class MiniMaxM3VLImageProcessorKwargs(ImagesKwargs, total=False):
patch_size: int
temporal_patch_size: int
merge_size: int
max_pixels: int
class MiniMaxM3VLImageProcessor(BaseImageProcessorFast):
do_resize = True
resample = PILImageResampling.BICUBIC
size = {"height": 672, "width": 672} # required by base class validation, not used as resize bound
default_to_square = False
do_rescale = True
rescale_factor = 1 / 255
do_normalize = True
image_mean = [0.48145466, 0.4578275, 0.40821073]
image_std = [0.26862954, 0.26130258, 0.27577711]
do_convert_rgb = True
patch_size = 14
temporal_patch_size = 2
merge_size = 2
max_pixels = 451584 # 672*672
valid_kwargs = MiniMaxM3VLImageProcessorKwargs
model_input_names = ["pixel_values", "image_grid_thw"]
def __init__(self, **kwargs: Unpack[MiniMaxM3VLImageProcessorKwargs]):
super().__init__(**kwargs)
def preprocess(
self, images, **kwargs: Unpack[MiniMaxM3VLImageProcessorKwargs]
) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: List[torch.Tensor],
do_resize: bool,
size: SizeDict,
resample: PILImageResampling | InterpolationMode | int | None,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: float | List[float] | None,
image_std: float | List[float] | None,
patch_size: int,
temporal_patch_size: int,
merge_size: int,
max_pixels: int,
disable_grouping: bool | None,
return_tensors: str | TensorType | None,
**kwargs,
) -> BatchFeature:
grouped_images, grouped_images_index = group_images_by_shape(
images, disable_grouping=disable_grouping
)
resized_images_grouped = {}
factor = patch_size * merge_size
for shape, stacked_images in grouped_images.items():
height, width = stacked_images.shape[-2:]
if do_resize:
resized_height, resized_width = smart_resize(
height, width, factor=factor,
max_pixels=max_pixels,
)
stacked_images = self.resize(
stacked_images,
size=SizeDict(height=resized_height, width=resized_width),
resample=resample,
)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
grouped_images, grouped_images_index = group_images_by_shape(
resized_images, disable_grouping=disable_grouping
)
processed_images_grouped = {}
processed_grids = {}
for shape, stacked_images in grouped_images.items():
resized_height, resized_width = stacked_images.shape[-2:]
patches = self.rescale_and_normalize(
stacked_images,
do_rescale,
rescale_factor,
do_normalize,
image_mean,
image_std,
)
if patches.ndim == 4:
patches = patches.unsqueeze(1)
if patches.shape[1] % temporal_patch_size != 0:
repeats = patches[:, -1:].repeat(
1,
temporal_patch_size - (patches.shape[1] % temporal_patch_size),
1,
1,
1,
)
patches = torch.cat([patches, repeats], dim=1)
batch_size, grid_t, channel = patches.shape[:3]
grid_t = grid_t // temporal_patch_size
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
patches = patches.view(
batch_size,
grid_t,
temporal_patch_size,
channel,
grid_h // merge_size,
merge_size,
patch_size,
grid_w // merge_size,
merge_size,
patch_size,
)
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
flatten_patches = patches.reshape(
batch_size,
grid_t * grid_h * grid_w,
channel * temporal_patch_size * patch_size * patch_size,
)
processed_images_grouped[shape] = flatten_patches
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
processed_images = reorder_images(
processed_images_grouped, grouped_images_index
)
processed_grids = reorder_images(processed_grids, grouped_images_index)
pixel_values = torch.cat(processed_images, dim=0)
image_grid_thw = torch.tensor(processed_grids, dtype=torch.long)
return BatchFeature(
data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw},
tensor_type=return_tensors,
)
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
images_kwargs = images_kwargs or {}
patch_size = images_kwargs.get("patch_size", self.patch_size)
merge_size = images_kwargs.get("merge_size", self.merge_size)
max_pixels = images_kwargs.get("max_pixels", self.max_pixels)
resized_height, resized_width = smart_resize(
height, width, factor=patch_size * merge_size,
max_pixels=max_pixels,
)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
return grid_h * grid_w