| | import math |
| | from copy import deepcopy |
| | from io import BytesIO |
| | from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union |
| |
|
| | import numpy as np |
| | from transformers.image_utils import get_image_size, to_numpy_array |
| | from typing_extensions import override |
| |
|
| | from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER |
| | from ..extras.packages import is_pillow_available, is_pyav_available |
| |
|
| |
|
| | if is_pillow_available(): |
| | from PIL import Image |
| | from PIL.Image import Image as ImageObject |
| |
|
| |
|
| | if is_pyav_available(): |
| | import av |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | import torch |
| | from av.stream import Stream |
| | from transformers import PreTrainedTokenizer, ProcessorMixin |
| | from transformers.image_processing_utils import BaseImageProcessor |
| |
|
| | class EncodedImage(TypedDict): |
| | path: Optional[str] |
| | bytes: Optional[bytes] |
| |
|
| | ImageInput = Union[str, EncodedImage, ImageObject] |
| | VideoInput = str |
| |
|
| |
|
| | def _get_paligemma_token_type_ids( |
| | imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin" |
| | ) -> List[List[int]]: |
| | r""" |
| | Gets paligemma token type ids for computing loss. |
| | |
| | Returns: |
| | batch_token_type_ids: shape (batch_size, sequence_length) |
| | """ |
| | batch_token_type_ids = [] |
| | for imglen, seqlen in zip(imglens, seqlens): |
| | image_seqlen = imglen * getattr(processor, "image_seqlen") |
| | batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen)) |
| |
|
| | return batch_token_type_ids |
| |
|
| |
|
| | class BasePlugin: |
| | def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None: |
| | self.image_token = image_token |
| | self.video_token = video_token |
| |
|
| | def _validate_input( |
| | self, |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | ) -> None: |
| | r""" |
| | Validates if this model accepts the input modalities. |
| | """ |
| | if len(images) != 0 and self.image_token is None: |
| | raise ValueError("This model does not support image input.") |
| |
|
| | if len(videos) != 0 and self.video_token is None: |
| | raise ValueError("This model does not support video input.") |
| |
|
| | def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": |
| | r""" |
| | Pre-processes a single image. |
| | """ |
| | image_resolution: int = kwargs.get("image_resolution") |
| | if max(image.width, image.height) > image_resolution: |
| | resize_factor = image_resolution / max(image.width, image.height) |
| | width, height = int(image.width * resize_factor), int(image.height * resize_factor) |
| | image = image.resize((width, height), resample=Image.NEAREST) |
| |
|
| | if image.mode != "RGB": |
| | image = image.convert("RGB") |
| |
|
| | return image |
| |
|
| | def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int: |
| | r""" |
| | Computes video sample frames according to fps. |
| | """ |
| | video_fps: float = kwargs.get("video_fps") |
| | video_maxlen: int = kwargs.get("video_maxlen") |
| | total_frames = video_stream.frames |
| | sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps |
| | sample_frames = min(total_frames, video_maxlen, sample_frames) |
| | return math.floor(sample_frames) |
| |
|
| | def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]: |
| | r""" |
| | Regularizes images to avoid error. Including reading and pre-processing. |
| | """ |
| | results = [] |
| | for image in images: |
| | if isinstance(image, str): |
| | image = Image.open(image) |
| | elif isinstance(image, dict): |
| | if image["bytes"] is not None: |
| | image = Image.open(BytesIO(image["bytes"])) |
| | else: |
| | image = Image.open(image["path"]) |
| |
|
| | if not isinstance(image, ImageObject): |
| | raise ValueError("Expect input is a list of Images, but got {}.".format(type(image))) |
| |
|
| | results.append(self._preprocess_image(image, **kwargs)) |
| |
|
| | return results |
| |
|
| | def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]: |
| | r""" |
| | Regularizes videos to avoid error. Including reading, resizing and converting. |
| | """ |
| | results = [] |
| | for video in videos: |
| | container = av.open(video, "r") |
| | video_stream = next(stream for stream in container.streams if stream.type == "video") |
| | total_frames = video_stream.frames |
| | sample_frames = self._get_video_sample_frames(video_stream, **kwargs) |
| | sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) |
| | frames: List["ImageObject"] = [] |
| | container.seek(0) |
| | for frame_idx, frame in enumerate(container.decode(video_stream)): |
| | if frame_idx in sample_indices: |
| | frames.append(frame.to_image()) |
| |
|
| | frames = self._regularize_images(frames, **kwargs) |
| | results.append(frames) |
| |
|
| | return results |
| |
|
| | def _get_mm_inputs( |
| | self, |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | processor: "ProcessorMixin", |
| | ) -> Dict[str, "torch.Tensor"]: |
| | r""" |
| | Processes visual inputs. |
| | |
| | Returns: (llava and paligemma) |
| | pixel_values: tensor with shape (B, C, H, W) |
| | |
| | Returns: (qwen2-vl) |
| | pixel_values: tensor with shape (num_patches, patch_dim) |
| | image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height |
| | |
| | It holds num_patches == torch.prod(image_grid_thw) |
| | """ |
| | image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") |
| | video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor) |
| | input_dict = {"images": None} |
| | if len(images) != 0: |
| | images = self._regularize_images( |
| | images, |
| | image_resolution=getattr(processor, "image_resolution", 512), |
| | ) |
| | input_dict["images"] = images |
| |
|
| | if len(videos) != 0: |
| | videos = self._regularize_videos( |
| | videos, |
| | image_resolution=getattr(processor, "video_resolution", 128), |
| | video_fps=getattr(processor, "video_fps", 1.0), |
| | video_maxlen=getattr(processor, "video_maxlen", 64), |
| | ) |
| | input_dict["videos"] = videos |
| |
|
| | mm_inputs = {} |
| | if image_processor != video_processor: |
| | if input_dict.get("images") is not None: |
| | mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt")) |
| | if input_dict.get("videos") is not None: |
| | mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt")) |
| | elif input_dict.get("images") is not None or input_dict.get("videos") is not None: |
| | mm_inputs.update(image_processor(**input_dict, return_tensors="pt")) |
| |
|
| | return mm_inputs |
| |
|
| | def process_messages( |
| | self, |
| | messages: Sequence[Dict[str, str]], |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> List[Dict[str, str]]: |
| | r""" |
| | Pre-processes input messages before tokenization for VLMs. |
| | """ |
| | self._validate_input(images, videos) |
| | return messages |
| |
|
| | def process_token_ids( |
| | self, |
| | input_ids: List[int], |
| | labels: Optional[List[int]], |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | tokenizer: "PreTrainedTokenizer", |
| | processor: Optional["ProcessorMixin"], |
| | ) -> Tuple[List[int], Optional[List[int]]]: |
| | r""" |
| | Pre-processes token ids after tokenization for VLMs. |
| | """ |
| | self._validate_input(images, videos) |
| | return input_ids, labels |
| |
|
| | def get_mm_inputs( |
| | self, |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | imglens: Sequence[int], |
| | vidlens: Sequence[int], |
| | seqlens: Sequence[int], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> Dict[str, Union[List[int], "torch.Tensor"]]: |
| | r""" |
| | Builds batched multimodal inputs for VLMs. |
| | """ |
| | self._validate_input(images, videos) |
| | return {} |
| |
|
| |
|
| | class LlavaPlugin(BasePlugin): |
| | @override |
| | def process_messages( |
| | self, |
| | messages: Sequence[Dict[str, str]], |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> List[Dict[str, str]]: |
| | self._validate_input(images, videos) |
| | num_image_tokens = 0 |
| | image_seqlen = getattr(processor, "image_seqlen") |
| | messages = deepcopy(messages) |
| | for message in messages: |
| | content = message["content"] |
| | while IMAGE_PLACEHOLDER in content: |
| | num_image_tokens += 1 |
| | content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) |
| |
|
| | message["content"] = content.replace("{{image}}", self.image_token * image_seqlen) |
| |
|
| | if len(images) != num_image_tokens: |
| | raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) |
| |
|
| | return messages |
| |
|
| | @override |
| | def get_mm_inputs( |
| | self, |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | imglens: Sequence[int], |
| | vidlens: Sequence[int], |
| | seqlens: Sequence[int], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> Dict[str, Union[List[int], "torch.Tensor"]]: |
| | self._validate_input(images, videos) |
| | return self._get_mm_inputs(images, videos, processor) |
| |
|
| |
|
| | class LlavaNextPlugin(BasePlugin): |
| | @override |
| | def process_messages( |
| | self, |
| | messages: Sequence[Dict[str, str]], |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> List[Dict[str, str]]: |
| | self._validate_input(images, videos) |
| | num_image_tokens = 0 |
| | messages = deepcopy(messages) |
| | mm_inputs = self._get_mm_inputs(images, videos, processor) |
| | if "image_sizes" in mm_inputs: |
| | image_sizes = iter(mm_inputs["image_sizes"]) |
| | if "pixel_values" in mm_inputs: |
| | height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) |
| | for message in messages: |
| | content = message["content"] |
| | while self.image_token in content: |
| | image_size = next(image_sizes) |
| | orig_height, orig_width = image_size |
| | image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) |
| | if processor.vision_feature_select_strategy == "default": |
| | image_seqlen -= 1 |
| | num_image_tokens += 1 |
| | content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1) |
| |
|
| | message["content"] = content.replace("{{image}}", self.image_token) |
| |
|
| | if len(images) != num_image_tokens: |
| | raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) |
| | return messages |
| |
|
| | @override |
| | def get_mm_inputs( |
| | self, |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | imglens: Sequence[int], |
| | vidlens: Sequence[int], |
| | seqlens: Sequence[int], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> Dict[str, Union[List[int], "torch.Tensor"]]: |
| | self._validate_input(images, videos) |
| | res = self._get_mm_inputs(images, videos, processor) |
| | return res |
| |
|
| |
|
| | class LlavaNextVideoPlugin(BasePlugin): |
| | @override |
| | def process_messages( |
| | self, |
| | messages: Sequence[Dict[str, str]], |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> List[Dict[str, str]]: |
| | self._validate_input(images, videos) |
| | num_image_tokens = 0 |
| | num_video_tokens = 0 |
| | messages = deepcopy(messages) |
| | mm_inputs = self._get_mm_inputs(images, videos, processor) |
| | if "pixel_values" in mm_inputs: |
| | image_sizes = iter(mm_inputs["image_sizes"]) |
| | height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) |
| | for message in messages: |
| | content = message["content"] |
| |
|
| | while self.image_token in content: |
| | image_size = next(image_sizes) |
| | orig_height, orig_width = image_size |
| | image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) |
| | if processor.vision_feature_select_strategy == "default": |
| | image_seqlen -= 1 |
| | num_image_tokens += 1 |
| | content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1) |
| |
|
| | message["content"] = content.replace("{{image}}", self.image_token) |
| |
|
| | if "pixel_values_videos" in mm_inputs: |
| | pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) |
| | height, width = get_image_size(pixel_values_video[0]) |
| | num_frames = pixel_values_video.shape[0] |
| | image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) |
| | video_seqlen = image_seqlen // 4 * num_frames |
| |
|
| | for message in messages: |
| | content = message["content"] |
| | while self.video_token in content: |
| | num_video_tokens += 1 |
| | content = content.replace(self.video_token, "{{video}}", 1) |
| | message["content"] = content.replace("{{video}}", self.video_token * video_seqlen) |
| |
|
| | if len(images) != num_image_tokens: |
| | raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) |
| |
|
| | if len(videos) != num_video_tokens: |
| | raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) |
| |
|
| | return messages |
| |
|
| | @override |
| | def get_mm_inputs( |
| | self, |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | imglens: Sequence[int], |
| | vidlens: Sequence[int], |
| | seqlens: Sequence[int], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> Dict[str, Union[List[int], "torch.Tensor"]]: |
| | self._validate_input(images, videos) |
| | return self._get_mm_inputs(images, videos, processor) |
| |
|
| |
|
| | class PaliGemmaPlugin(BasePlugin): |
| | @override |
| | def process_messages( |
| | self, |
| | messages: Sequence[Dict[str, str]], |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> List[Dict[str, str]]: |
| | self._validate_input(images, videos) |
| | num_image_tokens = 0 |
| | messages = deepcopy(messages) |
| | for message in messages: |
| | content = message["content"] |
| | while IMAGE_PLACEHOLDER in content: |
| | num_image_tokens += 1 |
| | content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) |
| |
|
| | message["content"] = content.replace("{{image}}", "") |
| |
|
| | if len(images) != num_image_tokens: |
| | raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) |
| |
|
| | return messages |
| |
|
| | @override |
| | def process_token_ids( |
| | self, |
| | input_ids: List[int], |
| | labels: Optional[List[int]], |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | tokenizer: "PreTrainedTokenizer", |
| | processor: Optional["ProcessorMixin"], |
| | ) -> Tuple[List[int], Optional[List[int]]]: |
| | self._validate_input(images, videos) |
| | num_images = len(images) |
| | image_seqlen = num_images * getattr(processor, "image_seqlen") |
| | image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) |
| | input_ids = [image_token_id] * image_seqlen + input_ids |
| | if labels is not None: |
| | labels = [IGNORE_INDEX] * image_seqlen + labels |
| |
|
| | return input_ids, labels |
| |
|
| | @override |
| | def get_mm_inputs( |
| | self, |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | imglens: Sequence[int], |
| | vidlens: Sequence[int], |
| | seqlens: Sequence[int], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> Dict[str, Union[List[int], "torch.Tensor"]]: |
| | self._validate_input(images, videos) |
| | mm_inputs = self._get_mm_inputs(images, videos, processor) |
| | mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) |
| | return mm_inputs |
| |
|
| |
|
| | class Qwen2vlPlugin(BasePlugin): |
| | @override |
| | def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": |
| | image = super()._preprocess_image(image, **kwargs) |
| | if min(image.width, image.height) < 28: |
| | width, height = max(image.width, 28), max(image.height, 28) |
| | image = image.resize((width, height), resample=Image.NEAREST) |
| |
|
| | if image.width / image.height > 200: |
| | width, height = image.height * 180, image.height |
| | image = image.resize((width, height), resample=Image.NEAREST) |
| |
|
| | if image.height / image.width > 200: |
| | width, height = image.width, image.width * 180 |
| | image = image.resize((width, height), resample=Image.NEAREST) |
| |
|
| | return image |
| |
|
| | @override |
| | def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int: |
| | sample_frames = super()._get_video_sample_frames(video_stream, **kwargs) |
| | sample_frames = sample_frames // 2 * 2 |
| | return sample_frames |
| |
|
| | @override |
| | def process_messages( |
| | self, |
| | messages: Sequence[Dict[str, str]], |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> List[Dict[str, str]]: |
| | self._validate_input(images, videos) |
| | image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") |
| | merge_length: int = getattr(image_processor, "merge_size") ** 2 |
| | mm_inputs = self._get_mm_inputs(images, videos, processor) |
| | image_grid_thw = mm_inputs.get("image_grid_thw", []) |
| | video_grid_thw = mm_inputs.get("video_grid_thw", []) |
| |
|
| | num_image_tokens, num_video_tokens = 0, 0 |
| | messages = deepcopy(messages) |
| | for message in messages: |
| | content = message["content"] |
| | while IMAGE_PLACEHOLDER in content: |
| | if num_image_tokens >= len(image_grid_thw): |
| | raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER)) |
| |
|
| | content = content.replace( |
| | IMAGE_PLACEHOLDER, |
| | "<|vision_start|>{}<|vision_end|>".format( |
| | self.image_token * (image_grid_thw[num_image_tokens].prod() // merge_length) |
| | ), |
| | 1, |
| | ) |
| | num_image_tokens += 1 |
| |
|
| | while VIDEO_PLACEHOLDER in content: |
| | if num_video_tokens >= len(video_grid_thw): |
| | raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER)) |
| |
|
| | content = content.replace( |
| | VIDEO_PLACEHOLDER, |
| | "<|vision_start|>{}<|vision_end|>".format( |
| | self.video_token * (video_grid_thw[num_video_tokens].prod() // merge_length) |
| | ), |
| | 1, |
| | ) |
| | num_video_tokens += 1 |
| |
|
| | message["content"] = content |
| |
|
| | if len(images) != num_image_tokens: |
| | raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) |
| |
|
| | if len(videos) != num_video_tokens: |
| | raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER)) |
| |
|
| | return messages |
| |
|
| | @override |
| | def get_mm_inputs( |
| | self, |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | imglens: Sequence[int], |
| | vidlens: Sequence[int], |
| | seqlens: Sequence[int], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> Dict[str, Union[List[int], "torch.Tensor"]]: |
| | self._validate_input(images, videos) |
| | return self._get_mm_inputs(images, videos, processor) |
| |
|
| |
|
| | class VideoLlavaPlugin(BasePlugin): |
| | @override |
| | def process_messages( |
| | self, |
| | messages: Sequence[Dict[str, str]], |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> List[Dict[str, str]]: |
| | self._validate_input(images, videos) |
| | num_image_tokens = 0 |
| | num_video_tokens = 0 |
| | messages = deepcopy(messages) |
| | mm_inputs = self._get_mm_inputs(images, videos, processor) |
| | num_frames = 0 |
| | exist_images = "pixel_values_images" in mm_inputs |
| | exist_videos = "pixel_values_videos" in mm_inputs |
| | if exist_videos or exist_images: |
| | if exist_images: |
| | height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0])) |
| | num_frames = 1 |
| | if exist_videos: |
| | pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) |
| | height, width = get_image_size(pixel_values_video[0]) |
| | num_frames = pixel_values_video.shape[0] |
| | image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1 |
| | video_seqlen = image_seqlen * num_frames |
| | if processor.vision_feature_select_strategy == "default": |
| | image_seqlen -= 1 |
| | for message in messages: |
| | content = message["content"] |
| | while self.image_token in content: |
| | num_image_tokens += 1 |
| | content = content.replace(self.image_token, "{{image}}", 1) |
| | while self.video_token in content: |
| | num_video_tokens += 1 |
| | content = content.replace(self.video_token, "{{video}}", 1) |
| |
|
| | content = content.replace("{{image}}", self.image_token * image_seqlen) |
| | message["content"] = content.replace("{{video}}", self.video_token * video_seqlen) |
| |
|
| | if len(images) != num_image_tokens: |
| | raise ValueError("The number of images does not match the number of {} tokens".format(self.image_token)) |
| |
|
| | if len(videos) != num_video_tokens: |
| | raise ValueError("The number of videos does not match the number of {} tokens".format(self.video_token)) |
| |
|
| | return messages |
| |
|
| | @override |
| | def get_mm_inputs( |
| | self, |
| | images: Sequence["ImageInput"], |
| | videos: Sequence["VideoInput"], |
| | imglens: Sequence[int], |
| | vidlens: Sequence[int], |
| | seqlens: Sequence[int], |
| | processor: Optional["ProcessorMixin"], |
| | ) -> Dict[str, Union[List[int], "torch.Tensor"]]: |
| | self._validate_input(images, videos) |
| | return self._get_mm_inputs(images, videos, processor) |
| |
|
| |
|
| | PLUGINS = { |
| | "base": BasePlugin, |
| | "llava": LlavaPlugin, |
| | "llava_next": LlavaNextPlugin, |
| | "llava_next_video": LlavaNextVideoPlugin, |
| | "paligemma": PaliGemmaPlugin, |
| | "qwen2_vl": Qwen2vlPlugin, |
| | "video_llava": VideoLlavaPlugin, |
| | } |
| |
|
| |
|
| | def get_mm_plugin( |
| | name: str, |
| | image_token: Optional[str] = None, |
| | video_token: Optional[str] = None, |
| | ) -> "BasePlugin": |
| | plugin_class = PLUGINS.get(name, None) |
| | if plugin_class is None: |
| | raise ValueError("Multimodal plugin `{}` not found.".format(name)) |
| |
|
| | return plugin_class(image_token, video_token) |
| |
|