| | import ast |
| | import contextlib |
| | import gc |
| | import json |
| | import os |
| | from dataclasses import dataclass |
| | from functools import partial |
| | from itertools import chain |
| | from typing import Any, Dict, List, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | import torch.nn as nn |
| | from einops import rearrange |
| | from timm.layers import LayerNorm, LayerNorm2d |
| | from timm.models.regnet import RegStage |
| | from torch.nn import CrossEntropyLoss |
| | from transformers import ( |
| | AutoConfig, |
| | AutoModel, |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | PreTrainedModel, |
| | ) |
| | from transformers.generation.utils import GenerationMixin |
| | from transformers.modeling_utils import ( |
| | is_fsdp_enabled, |
| | is_local_dist_rank_0, |
| | no_init_weights, |
| | ) |
| | from transformers.models.auto import CONFIG_MAPPING |
| | from transformers.utils import ModelOutput |
| |
|
| | from .configuration_hyperclovax import HCXVisionConfig |
| | from .image_processing_hyperclovax import select_best_resolution |
| |
|
| | EOT = "<|endofturn|>" |
| | IMAGE_LOC = "<|dummy3|>" |
| | VIDEO_LOC = "<|_unuse_missing_100270|>" |
| |
|
| |
|
| | def get_rank(): |
| | if dist.is_initialized(): |
| | return dist.get_rank() |
| | return 0 |
| |
|
| |
|
| | def get_world_size(): |
| | if torch.distributed.is_initialized(): |
| | world_size = torch.distributed.get_world_size() |
| | else: |
| | world_size = 1 |
| | return world_size |
| |
|
| |
|
| | def unpad_image(tensor: torch.Tensor, original_size: Tuple[int, int]) -> torch.Tensor: |
| | """Unpads a PyTorch tensor of a padded and resized image. |
| | |
| | This function removes padding from a tensor image that was previously padded and resized. |
| | The padding is removed based on the aspect ratio difference between the original and current image dimensions. |
| | |
| | Args: |
| | tensor: The image tensor, assumed to be in CxHxW format. |
| | original_size: The original size of the image as (width, height). |
| | |
| | Returns: |
| | The unpadded image tensor. |
| | |
| | Examples: |
| | >>> import torch |
| | >>> # Example 1: Unpadding with height padding |
| | >>> padded_tensor = torch.randn(1, 64, 48) # Padded tensor (C=1, H=64, W=48) |
| | >>> original_size = (32, 32) # Original size (width=32, height=32) |
| | >>> unpadded_tensor = unpad_image(padded_tensor, original_size) |
| | >>> unpadded_tensor.shape |
| | torch.Size([1, 48, 48]) |
| | >>> # Example 2: Unpadding with width padding |
| | >>> padded_tensor = torch.randn(1, 48, 64) # Padded tensor (C=1, H=48, W=64) |
| | >>> original_size = (32, 32) # Original size (width=32, height=32) |
| | >>> unpadded_tensor = unpad_image(padded_tensor, original_size) |
| | >>> unpadded_tensor.shape |
| | torch.Size([1, 48, 48]) |
| | """ |
| | original_width, original_height = original_size |
| | current_height, current_width = tensor.shape[1:] |
| |
|
| | original_aspect_ratio = original_width / original_height |
| | current_aspect_ratio = current_width / current_height |
| |
|
| | if original_aspect_ratio > current_aspect_ratio: |
| | scale_factor = current_width / original_width |
| | new_height = int(original_height * scale_factor) |
| | padding = (current_height - new_height) // 2 |
| | unpadded_tensor = tensor[:, padding : current_height - padding, :] |
| | else: |
| | scale_factor = current_height / original_height |
| | new_width = int(original_width * scale_factor) |
| | padding = (current_width - new_width) // 2 |
| | unpadded_tensor = tensor[:, :, padding : current_width - padding] |
| |
|
| | return unpadded_tensor |
| |
|
| |
|
| | def get_anyres_image_grid_shape( |
| | image_size: Tuple[int, int], |
| | grid_pinpoints: Union[str, List[Tuple[int, int]]], |
| | patch_size: int, |
| | ) -> Tuple[int, int]: |
| | """Calculates the image patch grid shape after any-resolution preprocessing. |
| | |
| | Selects the optimal resolution from predefined grid pinpoints based on input image |
| | dimensions using `select_best_resolution`, then computes the grid layout by |
| | dividing the selected resolution by the patch size using integer division. |
| | |
| | Args: |
| | image_size (Tuple[int, int]): Original image dimensions in (width, height) format. |
| | grid_pinpoints (Union[str, List[Tuple[int, int]]]): Accepts either: |
| | - List of (height, width) resolution tuples |
| | - String representation of list (e.g., "[(224, 224), (336, 336)]") |
| | patch_size (int): Spatial dimension of square patches for grid division. |
| | |
| | Returns: |
| | Tuple[int, int]: Grid dimensions as (num_patches_width, num_patches_height). |
| | |
| | Examples: |
| | >>> # Basic case with list input |
| | >>> get_anyres_image_grid_shape((1000, 800), [(224, 224), (448, 448)], 112) |
| | (4, 4) |
| | |
| | >>> # Basic case with string input |
| | >>> get_anyres_image_grid_shape((600, 400), "[(336, 336), (672, 672)]", 112) |
| | (6, 6) |
| | |
| | >>> # Case where resolution is not perfectly divisible by patch_size |
| | >>> # select_best_resolution picks (224, 224). 224 // 100 = 2 |
| | >>> get_anyres_image_grid_shape((500, 500), [(224, 224)], 100) |
| | (2, 2) |
| | |
| | >>> # Different patch size |
| | >>> # select_best_resolution picks (448, 448). 448 // 224 = 2 |
| | >>> get_anyres_image_grid_shape((1200, 900), [(448, 448), (224, 224)], 224) |
| | (2, 2) |
| | |
| | Note: |
| | String-formatted grid_pinpoints are converted via ast.literal_eval. Invalid formats |
| | may raise syntax exceptions. The actual resolution selection depends on the |
| | implementation of `select_best_resolution`. The doctests assume |
| | `select_best_resolution` picks the *first* resolution provided in `grid_pinpoints`. |
| | """ |
| | possible_resolutions = grid_pinpoints if isinstance(grid_pinpoints, list) else ast.literal_eval(grid_pinpoints) |
| |
|
| | original_width, original_height = image_size |
| | height, width = select_best_resolution((original_height, original_width), possible_resolutions) |
| | return width // patch_size, height // patch_size |
| |
|
| |
|
| | def reshape_and_unpad_image_features( |
| | image_feature: torch.Tensor, |
| | height: int, |
| | width: int, |
| | image_size: Tuple[int, int], |
| | possible_resolutions: List[Tuple[int, int]], |
| | grid_size: int, |
| | unpad: bool, |
| | image_newline: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """Reshapes and processes image features with optional unpadding operation. |
| | |
| | Processes input image features by: |
| | 1. Separating base features from spatial features |
| | 2. Reshaping spatial features into a 5D tensor (num_patch_height, num_patch_width, height, width, channels) |
| | 3. Performing either unpadding operation or simple reshaping based on 'unpad' flag |
| | 4. Concatenating processed features with base features |
| | |
| | Args: |
| | image_feature: Input tensor containing image features with shape |
| | [1 + num_patches, feature_dim] where the first element is the base feature |
| | height: Original image height in pixels |
| | width: Original image width in pixels |
| | image_size: Target image size as (width, height) tuple |
| | possible_resolutions: List of possible [height, width] resolutions for multi-scale processing |
| | grid_size: Grid dimension for patch arrangement |
| | unpad: Flag to enable unpadding operation |
| | image_newline: Special token tensor used as separator when unpadding |
| | |
| | Returns: |
| | torch.Tensor: Processed image features tensor with shape [1 + num_processed_patches, feature_dim] |
| | |
| | Raises: |
| | AssertionError: If base feature dimension doesn't match height*width |
| | """ |
| | base_image_feature = image_feature[0] |
| | image_feature = image_feature[1:] |
| |
|
| | assert ( |
| | height * width == base_image_feature.shape[0] |
| | ), f"height: {height}, width: {width}, base_image_feature.shape[0]: {base_image_feature.shape[0]}" |
| |
|
| | num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_size, possible_resolutions, grid_size) |
| | image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) |
| |
|
| | if unpad: |
| | image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() |
| | image_feature = image_feature.flatten(1, 2).flatten(2, 3) |
| | image_feature = unpad_image(image_feature, image_size) |
| | image_feature = torch.cat( |
| | ( |
| | image_feature, |
| | image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device), |
| | ), |
| | dim=-1, |
| | ) |
| | image_feature = image_feature.flatten(1, 2).transpose(0, 1) |
| | else: |
| | image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() |
| | image_feature = image_feature.flatten(0, 3) |
| | image_feature = torch.cat((base_image_feature, image_feature), dim=0) |
| |
|
| | return image_feature |
| |
|
| |
|
| | def anyres_postprocessing( |
| | image_forward_outs: List[torch.FloatTensor], |
| | image_sizes: List[List[int]], |
| | possible_resolutions: List[Tuple[int, int]], |
| | patch_size: int, |
| | grid_size: int, |
| | image_newline: torch.FloatTensor, |
| | num_queries_vis_abstractor: int = -1, |
| | unpad: bool = False, |
| | ) -> List[torch.FloatTensor]: |
| | """Processes 2D visual features into 1D sequences with post-processing steps. |
| | |
| | Performs AnyRes postprocessing by flattening 2D visual features from grid partitions into 1D sequences, adding |
| | newline embeddings at row boundaries for images, and optionally removing padding regions based on original image |
| | sizes. For video data, processes each frame's features separately into a single sequence per video and disables |
| | unpadding and newline insertion. |
| | |
| | Args: |
| | image_forward_outs (List[torch.FloatTensor]): List of input tensors with shape |
| | (number_of_images_in_grid, total_patches, feature_dim) containing visual features. |
| | split_sizes (List[int]): A list containing the number of patches for each sample in the batch. The sum of |
| | `split_sizes` should equal `image_forward_outs.shape[0]`. |
| | image_sizes (List[List[int]]): A list where each element is a list `[width, height]` representing the original |
| | dimensions of the corresponding image sample. Used for unpadding. |
| | possible_resolutions (List[Tuple[int, int]]): A list of supported resolution tuples `(height, width)` used by |
| | `reshape_and_unpad_image_features` for spatial reconstruction, especially during unpadding. |
| | patch_size (int): The spatial dimension (height and width) of the square patches the image was divided into. |
| | grid_size (int): The spatial dimension (height and width) of the square grid onto which patches are mapped. |
| | `grid_size` should be divisible by `patch_size`. |
| | image_newline (torch.FloatTensor): A learnable tensor representing the newline embedding, typically with shape |
| | (1, feature_dim). Added after each row of image patches when not unpadding. |
| | num_queries_vis_abstractor (int, optional): If a visual abstractor with a fixed number of output queries is used |
| | instead of grid patching, this specifies the number of queries. Must be a perfect square if > 0. |
| | Defaults to -1 (indicating standard grid patching is used). |
| | unpad (bool, optional): If `True`, removes padding tokens from image features based on `image_sizes` and |
| | `possible_resolutions`. Does not apply to video features. Defaults to False. |
| | |
| | Returns: |
| | List[torch.FloatTensor]: A list of tensors, where each tensor represents the processed 1D sequence of visual |
| | features for a single sample from the input batch. The length of the sequence varies depending on processing |
| | (unpadding, newlines, video flattening). |
| | |
| | Raises: |
| | AssertionError: If `num_queries_vis_abstractor` is greater than 0 but not a perfect square. |
| | """ |
| | height = width = grid_size // patch_size |
| |
|
| | if num_queries_vis_abstractor > 0: |
| | assert (num_queries_vis_abstractor**0.5).is_integer(), "n_queries must be square number" |
| | height = width = int(num_queries_vis_abstractor**0.5) |
| |
|
| | |
| | new_image_features = [] |
| | for image_idx, image_feature in enumerate(image_forward_outs): |
| | if image_feature.shape[0] > 1: |
| | image_feature = reshape_and_unpad_image_features( |
| | image_feature=image_feature, |
| | height=height, |
| | width=width, |
| | image_size=image_sizes[image_idx], |
| | possible_resolutions=possible_resolutions, |
| | grid_size=grid_size, |
| | unpad=unpad, |
| | image_newline=image_newline, |
| | ) |
| | else: |
| | image_feature = image_feature[0] |
| | image_feature = torch.cat((image_feature, image_newline[None].to(image_feature.device)), dim=0) |
| | new_image_features.append(image_feature) |
| | image_features = new_image_features |
| | return image_features |
| |
|
| |
|
| | @dataclass |
| | class HCXVisionOutput(ModelOutput): |
| | """Output class for vision models, containing various computation results. |
| | |
| | Args: |
| | loss (Optional[torch.FloatTensor], optional): Total cross-entropy loss calculated from logits and labels. |
| | loss_per_sample (Optional[torch.FloatTensor], optional): Per-sample loss values for advanced loss processing. |
| | logits (torch.FloatTensor): Classification scores (before SoftMax) of shape (batch_size, num_classes). |
| | past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]], optional): Contains precomputed hidden-states |
| | that can be used (see `past_key_values` input) to speed up sequential decoding. |
| | hidden_states (Optional[Tuple[torch.FloatTensor]], optional): |
| | Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of |
| | shape (batch_size, sequence_length, hidden_size). |
| | Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| | attentions (Optional[Tuple[torch.FloatTensor]], optional): Tuple of torch.FloatTensor (one for each layer) |
| | of shape (batch_size, num_heads, sequence_length, sequence_length). Attentions weights after the attention |
| | softmax, used to compute the weighted average in the self-attention heads. |
| | """ |
| |
|
| | loss: Optional[torch.FloatTensor] = None |
| | loss_per_sample: Optional[torch.FloatTensor] = None |
| | logits: torch.FloatTensor = None |
| | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| | hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| | attentions: Optional[Tuple[torch.FloatTensor]] = None |
| |
|
| |
|
| | class HCXVisionForCausalLM(PreTrainedModel, GenerationMixin): |
| | """HCX Vision model for causal language modeling with vision-language capabilities. |
| | |
| | This class combines a vision model with a language model to create a multimodal model |
| | capable of processing images or videos and generating text based on the visual inputs. |
| | |
| | Attributes: |
| | config_class: Configuration class for the model. |
| | vision_model_name: Name of the vision model component. |
| | _no_split_modules: List of modules that should not be split during parallel processing. |
| | supports_gradient_checkpointing: Whether the model supports gradient checkpointing. |
| | _skip_keys_device_placement: Keys to skip during device placement. |
| | """ |
| |
|
| | config_class = HCXVisionConfig |
| | vision_model_name = "vision_model" |
| | _no_split_modules = ["SiglipEncoderLayer", "LlamaDecoderLayer", "HyperCLOVAXDecoderLayer"] |
| | supports_gradient_checkpointing = True |
| | _skip_keys_device_placement = "past_key_values" |
| | _supports_flash_attn_2 = True |
| | _supports_sdpa = True |
| |
|
| | def __init__( |
| | self, |
| | config: HCXVisionConfig, |
| | **kwargs: Optional[Any], |
| | ) -> None: |
| | """Initialize the HCXVisionForCausalLM model. |
| | |
| | Args: |
| | config: Configuration object for the model containing parameters for both |
| | vision and language components. |
| | **kwargs: Additional keyword arguments: |
| | - use_liger: Whether to use liger kernel for hyperclovax models. |
| | - use_fused_ce: Whether to use fused cross-entropy loss. |
| | - use_sum_loss: Whether to use sum reduction for loss instead of mean. |
| | - is_safetensor_save: Whether to save model using safetensors format. |
| | |
| | Raises: |
| | ValueError: If vision_config is not defined or if text_config is not defined. |
| | """ |
| | super().__init__(config) |
| |
|
| | |
| | text_config = self._init_text_config(config) |
| | vision_config = self._init_vision_config(config) |
| | |
| | |
| | config.possible_resolutions = self._init_possible_resolutions(config, vision_config) |
| |
|
| | |
| | with no_init_weights(): |
| | self.vision_model = AutoModel.from_config(vision_config, trust_remote_code=True) |
| |
|
| | self.mm_projector = self._init_mm_projector(config, text_config, vision_config) |
| |
|
| | self.language_model = AutoModelForCausalLM.from_config(text_config) |
| | self.lm_head_vocab_size = getattr(text_config, "padded_vocab_size", text_config.vocab_size) |
| | self.language_model.lm_head = nn.Linear(text_config.hidden_size, self.lm_head_vocab_size, bias=False) |
| |
|
| | if config.anyres: |
| | self.image_newline = nn.Parameter(torch.empty(text_config.hidden_size, dtype=self.dtype)) |
| |
|
| | |
| | if text_config.model_type in ["llama", "hyperclovax", "gpt2"]: |
| | self.language_model.gradient_checkpointing_enable() |
| | if text_config.model_type == "hyperclovax" and self.use_liger: |
| | self.language_model._get_apply_liger_kernel_converter()(model=self.language_model) |
| |
|
| | |
| | self.vision_config = vision_config = self.vision_model.config |
| | self.text_config = text_config = self.language_model.config |
| | config.update({"vision_config": vision_config}) |
| | config.update({"text_config": text_config}) |
| |
|
| | |
| | self.use_liger = kwargs.pop("use_liger", False) |
| | self.use_fused_ce = kwargs.pop("use_fused_ce", False) |
| | self.use_meansum_loss = kwargs.pop("use_meansum_loss", False) |
| | self.freeze_before_sampler = kwargs.pop("freeze_before_sampler", False) |
| | self.use_turnmeansum_loss = kwargs.pop("use_turnmeansum_loss", False) |
| | self.vision_input_chunk_size = kwargs.pop("vision_input_chunk_size", None) |
| | self.is_safetensor_save = kwargs.get("is_safetensor_save", True) |
| |
|
| | use_sum_loss = True if kwargs.pop("use_sum_loss", False) else False |
| | self.reduction = self._init_reduction_type(use_sum_loss) |
| |
|
| | self.vision_model_use_no_grad = None |
| |
|
| | self._backward_compatibility_gradient_checkpointing() |
| |
|
| | def _init_weights(self, module): |
| | |
| | if ( |
| | isinstance(module, nn.Conv2d) |
| | or isinstance(module, nn.Embedding) |
| | or isinstance(module, nn.Linear) |
| | ): |
| | module.weight.data.normal_(mean=0.0, std=0.02) |
| | if hasattr(module, "bias") and module.bias is not None: |
| | module.bias.data.zero_() |
| |
|
| | elif isinstance(module, nn.LayerNorm): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| | elif isinstance(module, nn.Parameter): |
| | embed_std = 1 / torch.sqrt(torch.tensor(module.size(0), dtype=torch.float)).to(module.dtype) |
| | module.data.normal_(mean=0.0, std=embed_std) |
| |
|
| | def _init_reduction_type(self, use_sum_loss): |
| | assert not ( |
| | self.use_meansum_loss and self.use_turnmeansum_loss |
| | ), "use_meansum_loss and use_turnmeansum_loss cannot both be True; only one or neither may be True." |
| | if self.use_meansum_loss or self.use_turnmeansum_loss: |
| | reduction = "none" |
| | elif use_sum_loss: |
| | reduction = "sum" |
| | else: |
| | reduction = "mean" |
| | return reduction |
| |
|
| | def _init_vision_config(self, config): |
| | vision_model_type = config.vision_config.model_type |
| | if vision_model_type in CONFIG_MAPPING: |
| | vision_config = CONFIG_MAPPING[vision_model_type](**config.vision_config.to_dict()) |
| | vision_config.auto_map = {} |
| | else: |
| | if config.vision_model_name_or_path is not None: |
| | vision_config = AutoConfig.from_pretrained(config.vision_model_name_or_path, trust_remote_code=True) |
| | elif config.vision_config._name_or_path is not None: |
| | vision_config = AutoConfig.from_pretrained(config.vision_config._name_or_path, trust_remote_code=True) |
| | else: |
| | raise ValueError("vision_config is not defined") |
| |
|
| | vision_config.anyres = config.anyres |
| | vision_config.max_num_grids = config.max_num_grids |
| | return vision_config |
| |
|
| | def _init_text_config(self, config): |
| | if hasattr(config, "text_config") and config.text_config is not None: |
| | model_type = config.text_config.model_type |
| | text_config = CONFIG_MAPPING[model_type](**config.text_config.to_dict()) |
| | else: |
| | raise ValueError("text_config is not defined") |
| | text_config._attn_implementation = config._attn_implementation |
| | if text_config.model_type != "hyperclovax": |
| | text_config.logits_scaling = 1.0 |
| | return text_config |
| |
|
| | def _init_possible_resolutions(self, config, vision_config): |
| | """possible_resolution should be matched with preprocessor_config.json""" |
| | if not getattr(config, "possible_resolutions", []): |
| | possible_resolutions = [] |
| | if config.anyres: |
| | assert config.max_num_grids > 0 |
| | for i in range(1, config.max_num_grids + 1): |
| | for j in range(1, config.max_num_grids + 1): |
| | if i == 1 and j == 1 and not config.use_1x1_grid: |
| | continue |
| | if i * j <= config.max_num_grids: |
| | possible_resolutions.append([i, j]) |
| |
|
| | possible_resolutions = [ |
| | [ys * vision_config.image_size, xs * vision_config.image_size] for ys, xs in possible_resolutions |
| | ] |
| | return possible_resolutions |
| | else: |
| | return config.possible_resolutions |
| |
|
| | def _init_mm_projector(self, config, text_config, vision_config): |
| | input_hidden_size = vision_config.hidden_size |
| | if config.mm_projector_type == "linear": |
| | mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size) |
| | mm_projector.dtype = next(mm_projector.parameters()).dtype |
| | elif config.mm_projector_type == "cabstractor": |
| | mm_projector = HCXVisionCAbstractor( |
| | num_queries=config.num_queries_vis_abstractor_image, |
| | num_input_tokens=(vision_config.image_size // vision_config.patch_size) ** 2, |
| | encoder_hidden_size=input_hidden_size, |
| | hidden_size=input_hidden_size, |
| | output_hidden_size=text_config.hidden_size, |
| | pos_emb=config.proj_pos_emb, |
| | prenorm=config.proj_prenorm, |
| | ) |
| | else: |
| | mm_projector = HCXVisionMlp( |
| | config.mm_projector_type, |
| | input_hidden_size, |
| | hidden_features=input_hidden_size, |
| | out_features=self.text_config.hidden_size, |
| | ) |
| | return mm_projector |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | pixel_values_images: Optional[List[List[torch.FloatTensor]]] = None, |
| | image_sizes_images: Optional[List[List[Tuple[int, int]]]] = None, |
| | pixel_values_videos: Optional[List[List[torch.FloatTensor]]] = None, |
| | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | **kwargs, |
| | ) -> Union[Tuple, HCXVisionOutput]: |
| | """Forward pass of the model. |
| | |
| | This method processes the input tokens and images, combines them into a unified |
| | representation, and generates text output based on the inputs. |
| | |
| | Args: |
| | input_ids: Input token IDs. In positions where images are inputted, the value is replaced by "<|dummy3|>" |
| | pixel_values: List of lists of 4D tensors for images. Each outer list corresponds to a batch and contains |
| | inner lists of image tensors. |
| | past_key_values: Pre-computed key and value states of the attention layers for faster inference. |
| | attention_mask: Mask to avoid performing attention on padding token indices. |
| | inputs_embeds: Input embeddings. If provided, input_ids will not be used. |
| | labels: Labels for computing the language modeling loss. |
| | use_cache: Whether to use past key/values for faster inference. |
| | output_attentions: Whether to return attention weights of each layer. |
| | output_hidden_states: Whether to return hidden states of each layer. |
| | return_dict: Whether to return a ModelOutput instead of a tuple. |
| | image_sizes: List of lists representing image dimensions (width, height). |
| | vision_query_lengths: List of lists containing lengths when each image is converted into visual tokens. |
| | non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample. |
| | img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample. |
| | num_queries_vis_abstractors: List of lists containing number of visual tokens for each image grid.\ |
| | For video frames, this is the number of visual tokens for the fast part. |
| | num_queries_vis_abstractors_slow: List of lists containing number of visual tokens for |
| | the slow part when applying the slowfast algorithm to video frames. |
| | first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is |
| | applied to the first or last frames of the video. |
| | is_video_list: List of booleans indicating which inputs are videos. |
| | **kwargs: Additional keyword arguments. |
| | |
| | Returns: |
| | If return_dict=True, returns an HCXVisionOutput object containing: |
| | - loss: Language modeling loss if labels are provided, otherwise None. |
| | - loss_per_sample: Per-sample loss if labels are provided, otherwise None. |
| | - logits: Prediction scores of the language modeling head. |
| | - past_key_values: Past key/values for faster inference if use_cache=True. |
| | - hidden_states: Hidden states of all layers if output_hidden_states=True. |
| | - attentions: Attention weights of all layers if output_attentions=True. |
| | If return_dict=False, returns a tuple containing the above items except loss_per_sample. |
| | """ |
| | output_attentions = ( |
| | output_attentions if output_attentions is not None else self.config.vision_config.output_attentions |
| | ) |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.vision_config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | if inputs_embeds is None and past_key_values is None: |
| | if pixel_values_images is not None or pixel_values_videos is not None: |
| | inputs_embeds = self.extract_inputs_embeds( |
| | input_ids=input_ids, |
| | pixel_values_images=pixel_values_images, |
| | image_sizes_images=image_sizes_images, |
| | pixel_values_videos=pixel_values_videos, |
| | ) |
| | else: |
| | inputs_embeds = self.get_input_embeddings()(input_ids) |
| |
|
| | if inputs_embeds is not None: |
| | input_ids = None |
| |
|
| | |
| | |
| | outputs = self.language_model.base_model( |
| | input_ids=input_ids, |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | hidden_states = outputs[0] |
| | hidden_states = hidden_states * self.text_config.logits_scaling |
| |
|
| | loss = None |
| | loss_per_sample = None |
| | logits = self.language_model.lm_head(hidden_states) |
| | if labels is not None: |
| | |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| |
|
| | |
| | loss_fct = CrossEntropyLoss(reduction="none") |
| | shift_logits = shift_logits.view(-1, self.lm_head_vocab_size) |
| | shift_labels = shift_labels.view(-1) |
| |
|
| | |
| | shift_labels = shift_labels.to(shift_logits.device) |
| | loss = loss_fct(shift_logits, shift_labels) |
| | if get_rank() == 0: |
| | loss_per_sample = loss.view(logits.shape[0], -1).sum(axis=1) / ( |
| | shift_labels.view(logits.shape[0], -1) != self.config.ignore_index |
| | ).sum(axis=1) |
| | loss = loss[shift_labels != self.config.ignore_index].mean() |
| | if not return_dict: |
| | output = (logits,) + outputs[1:] |
| | return (loss,) + output if loss is not None else output |
| |
|
| | return HCXVisionOutput( |
| | loss=loss, |
| | loss_per_sample=loss_per_sample, |
| | logits=logits, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| | |
| | def get_input_embeddings(self): |
| | return self.language_model.get_input_embeddings() |
| |
|
| | |
| | def set_input_embeddings(self, value): |
| | self.language_model.set_input_embeddings(value) |
| |
|
| | |
| | def get_output_embeddings(self): |
| | return self.language_model.get_output_embeddings() |
| |
|
| | |
| | def set_output_embeddings(self, new_embeddings): |
| | self.language_model.set_output_embeddings(new_embeddings) |
| |
|
| | |
| | def set_decoder(self, decoder): |
| | self.language_model.set_decoder(decoder) |
| |
|
| | |
| | def get_decoder(self): |
| | return self.language_model.get_decoder() |
| |
|
| | |
| | def tie_weights(self): |
| | return self.language_model.tie_weights() |
| |
|
| | |
| | def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: |
| | model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) |
| | self.config.text_config.vocab_size = model_embeds.num_embeddings |
| | self.vocab_size = model_embeds.num_embeddings |
| | return model_embeds |
| |
|
| | def extract_inputs_embeds( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | pixel_values_images: Optional[List[List[torch.FloatTensor]]] = None, |
| | image_sizes_images: Optional[List[List[Tuple[int, int]]]] = None, |
| | pixel_values_videos: Optional[List[List[torch.FloatTensor]]] = None, |
| | ): |
| | """Extract input embeddings by processing text tokens and visual features. |
| | |
| | This method processes the input tokens and image features, extracts the visual features |
| | using the vision model, and combines them with the text token embeddings to create |
| | a unified input representation for the language model. |
| | |
| | Args: |
| | input_ids: Input token IDs with img_start_id markers for image positions. |
| | pixel_values: List of lists of image tensors. |
| | past_key_values: Pre-computed key and value states for faster inference. |
| | image_sizes: List of lists of image dimensions (width, height). |
| | vision_query_lengths: List of lists of lengths when each image is converted to visual tokens. |
| | non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample. |
| | img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample. |
| | first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is |
| | applied to the first or last frames of the video. |
| | is_videos: List of booleans indicating which inputs are videos. |
| | |
| | Returns: |
| | Combined embeddings of text tokens and visual features. |
| | """ |
| | |
| | len_pixel_values_images = [len(pixel_value) for pixel_value in pixel_values_images] if pixel_values_images else [] |
| | len_pixel_values_videos = [len(pixel_value) for pixel_value in pixel_values_videos] if pixel_values_videos else [] |
| |
|
| | if sum(len_pixel_values_images) + sum(len_pixel_values_videos) == 0: |
| | return None |
| |
|
| | inputs_embeds = self.get_input_embeddings()(input_ids) |
| |
|
| | if sum(len_pixel_values_images) > 0: |
| | image_features_batch = self.forward_images( |
| | pixel_values_images, image_sizes_images, len_pixel_values_images |
| | ) |
| | for i, image_features in enumerate(image_features_batch): |
| | if len(image_features) > 0: |
| | image_token_indices = (input_ids[i] == self.config.image_token_id).nonzero().squeeze() |
| | inputs_embeds[i][image_token_indices] = torch.cat(image_features).to(inputs_embeds.dtype) |
| |
|
| | if sum(len_pixel_values_videos) > 0: |
| | video_features_batch = self.forward_videos(pixel_values_videos, len_pixel_values_videos) |
| | for i, video_features in enumerate(video_features_batch): |
| | if len(video_features) > 0: |
| | video_token_indices = (input_ids[i] == self.config.video_token_id).nonzero().squeeze() |
| | inputs_embeds[i][video_token_indices] = torch.cat(video_features).to(inputs_embeds.dtype) |
| |
|
| | return inputs_embeds |
| |
|
| | def forward_images( |
| | self, |
| | pixel_values_images: List[List[torch.FloatTensor]], |
| | image_sizes_images: List[List[Tuple[int, int]]], |
| | len_pixel_values_images: List[int], |
| | ) -> List[List[torch.Tensor]]: |
| | if sum(len_pixel_values_images) == 0: |
| | return None |
| |
|
| | concat_pixel_values_images = torch.cat(list(chain(*pixel_values_images)), dim=0) |
| |
|
| | visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 |
| | context_vision_model = torch.no_grad() if self.vision_model_use_no_grad else contextlib.nullcontext() |
| | with context_vision_model: |
| | if self.config.use_nth_layer == -1: |
| | |
| | self.vision_model.vision_model.post_layernorm = nn.Identity() |
| | image_forward_outs = self.vision_model(concat_pixel_values_images) |
| | image_forward_outs = image_forward_outs.last_hidden_state[:, visual_token_idx:] |
| | else: |
| | image_forward_outs = self.vision_model(concat_pixel_values_images, output_hidden_states=True) |
| | image_forward_outs = image_forward_outs.hidden_states[self.config.use_nth_layer][:, visual_token_idx:] |
| |
|
| | image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype) |
| | image_forward_outs = self.mm_projector(image_forward_outs) |
| |
|
| | |
| | split_sizes = [pixel_value.shape[0] for pixel_value in chain(*pixel_values_images)] |
| | image_forward_outs = torch.split(image_forward_outs, split_sizes, dim=0) |
| |
|
| | |
| | image_features = anyres_postprocessing( |
| | image_forward_outs=image_forward_outs, |
| | image_sizes=[image_size for image_sizes in image_sizes_images for image_size in image_sizes], |
| | num_queries_vis_abstractor=self.config.num_queries_vis_abstractor_image, |
| | unpad=self.config.unpad, |
| | patch_size=self.vision_config.patch_size, |
| | grid_size=self.vision_config.image_size, |
| | image_newline=self.image_newline, |
| | possible_resolutions=self.config.possible_resolutions, |
| | ) |
| |
|
| | |
| | image_features = [ |
| | image_features[sum(len_pixel_values_images[:i]) : sum(len_pixel_values_images[: i + 1])] |
| | for i in range(len(len_pixel_values_images)) |
| | ] |
| |
|
| | return image_features |
| |
|
| | def forward_videos( |
| | self, |
| | pixel_values_videos: List[List[torch.FloatTensor]], |
| | len_pixel_values_videos: List[int], |
| | ) -> List[torch.Tensor]: |
| |
|
| | len_video_grids = sum(len_pixel_values_videos) |
| | if len_video_grids == 0: |
| | return None |
| |
|
| | |
| | concat_pixel_values_videos = torch.cat(list(chain(*pixel_values_videos)), dim=0) |
| |
|
| | visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 |
| | context_vision_model = torch.no_grad() if self.vision_model_use_no_grad else contextlib.nullcontext() |
| | with context_vision_model: |
| | if self.config.use_nth_layer == -1: |
| | |
| | self.vision_model.vision_model.post_layernorm = nn.Identity() |
| | video_forward_outs = self.vision_model(concat_pixel_values_videos) |
| | video_forward_outs = video_forward_outs.last_hidden_state[:, visual_token_idx:] |
| | else: |
| | video_forward_outs = self.vision_model(concat_pixel_values_videos, output_hidden_states=True) |
| | video_forward_outs = video_forward_outs.hidden_states[self.config.use_nth_layer][:, visual_token_idx:] |
| |
|
| | video_forward_outs = video_forward_outs.to(dtype=self.mm_projector.dtype) |
| |
|
| | |
| | |
| | grid_idx = 0 |
| | num_grids = [grid_idx] |
| | num_queries_vis_abstractors = [] |
| | len_total_frames = video_forward_outs.shape[0] |
| |
|
| | if self.config.first_last_frames_slow: |
| | |
| | |
| | assert len_total_frames != 0 |
| | if len_total_frames <= 2: |
| | num_queries_vis_abstractors.append(self.config.num_queries_vis_abstractor_video_slow) |
| | grid_idx += len_total_frames |
| | num_grids.append(grid_idx) |
| | else: |
| | num_queries_vis_abstractors.append(self.config.num_queries_vis_abstractor_video_slow) |
| | grid_idx += 1 |
| | num_grids.append(grid_idx) |
| |
|
| | num_queries_vis_abstractors.append(self.config.num_queries_vis_abstractor_video_fast) |
| | grid_idx += len_total_frames - 2 |
| | num_grids.append(grid_idx) |
| |
|
| | num_queries_vis_abstractors.append(self.config.num_queries_vis_abstractor_video_slow) |
| | grid_idx += 1 |
| | num_grids.append(grid_idx) |
| | else: |
| | |
| | for pixel_values_frames in pixel_values_videos: |
| | for pixel_values_frame in pixel_values_frames: |
| | if len(pixel_values_frame) > 0: |
| | num_queries_vis_abstractors.append(self.config.num_queries_vis_abstractor_video_slow) |
| | grid_idx += 1 |
| | num_grids.append(grid_idx) |
| | num_queries_vis_abstractors.append(self.config.num_queries_vis_abstractor_video_fast) |
| | grid_idx = grid_idx + len(pixel_values_frame) - 1 |
| | num_grids.append(grid_idx) |
| |
|
| | video_forward_outs = self.mm_projector(video_forward_outs, num_queries_vis_abstractors, num_grids) |
| |
|
| | |
| | |
| | video_features = [] |
| | target_features = [] |
| | target_group_size = 0 |
| | group_counter = 0 |
| | video_groups = [ |
| | len(frame) for frames in pixel_values_videos for frame in frames |
| | ] |
| |
|
| | for forward_out in video_forward_outs: |
| | target_group_size += len(forward_out) |
| | target_features.append(forward_out.flatten(0, 1)) |
| |
|
| | video_group_size = video_groups[group_counter] |
| | if video_group_size == target_group_size: |
| | video_features.append(torch.cat(target_features, dim=0)) |
| | target_features = [] |
| | group_counter += 1 |
| | target_group_size = 0 |
| |
|
| | elif video_group_size < target_group_size: |
| | raise RuntimeError(f"video_group_size < target_group_size!! [{video_group_size} < {target_group_size}]") |
| |
|
| | assert len(target_features) == 0, f"target_features is not empty!! {target_features}" |
| | assert len(video_groups) == len(video_features) |
| |
|
| | |
| | video_features = [ |
| | video_features[sum(len_pixel_values_videos[:i]) : sum(len_pixel_values_videos[: i + 1])] |
| | for i in range(len(len_pixel_values_videos)) |
| | ] |
| |
|
| | return video_features |
| |
|
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | pixel_values_images: Optional[List[List[torch.FloatTensor]]] = None, |
| | image_sizes_images: Optional[List[List[Tuple[int, int]]]] = None, |
| | pixel_values_videos: Optional[List[List[torch.FloatTensor]]] = None, |
| | pad_token_id: Optional[int] = None, |
| | eos_token_id: Optional[int] = None, |
| | bad_words_ids: Optional[List[List[int]]] = None, |
| | max_length: int = 196, |
| | min_length: int = 2, |
| | do_sample: bool = True, |
| | num_beams: int = 1, |
| | top_p: float = 0.6, |
| | top_k: int = 0, |
| | temperature: float = 0.5, |
| | repetition_penalty: float = 1.0, |
| | length_penalty: int = 1, |
| | use_cache: bool = True, |
| | verbose: bool = False, |
| | **kwargs, |
| | ) -> torch.LongTensor: |
| | """Generate text based on input tokens and images. |
| | |
| | This method generates text based on the provided input tokens and images using |
| | beam search and/or sampling strategies. |
| | |
| | Args: |
| | input_ids: Input token IDs with img_start_id markers for image positions. |
| | pixel_values: List of lists of image tensors. |
| | image_sizes: List of lists of image dimensions (width, height). |
| | vision_query_lengths: List of lists of lengths when each image is converted to visual tokens. |
| | non_vision_query_lengths: List of lengths of text tokens (excluding visual tokens) for each sample. |
| | num_queries_vis_abstractors: List of lists containing number of visual tokens for each image grid. |
| | num_queries_vis_abstractors_slow: List of lists containing number of visual tokens for the slow part when |
| | applying the slowfast algorithm to video frames. |
| | first_last_frames_slows: List of booleans indicating whether the slowfast algorithm is applied to the first |
| | or last frames of the video. |
| | is_videos: List of booleans indicating which inputs are videos. |
| | img_start_ids_list: List of lists containing indices of img_start_id tokens for each sample. |
| | pad_token_id: Token ID used for padding. |
| | eos_token_id: Token ID used to signal the end of a sequence. |
| | bad_words_ids: List of token ID sequences that should not be generated. |
| | max_length: Maximum length of the sequence to be generated (input length + max_new_tokens). |
| | min_length: Minimum length of the sequence to be generated (input length + min_new_tokens). |
| | do_sample: Whether to use sampling for generation (otherwise uses greedy decoding). |
| | num_beams: Number of beams for beam search. 1 means no beam search. |
| | top_p: Nucleus sampling parameter. Tokens with cumulative probability > top_p are kept. |
| | top_k: Number of highest probability tokens to keep for top-k-filtering. |
| | temperature: Value used to modulate the next token probabilities. |
| | repetition_penalty: Penalty applied to tokens that have already appeared in the sequence. |
| | length_penalty: Exponential penalty applied to sequence length. |
| | use_cache: Whether to use past key/values for faster inference. |
| | **kwargs: Additional keyword arguments. |
| | |
| | Returns: |
| | Generated token IDs. |
| | """ |
| | |
| | if pad_token_id is None: |
| | pad_token_id = self.tokenizer.pad_token_id |
| | if eos_token_id is None: |
| | eos_token_id = self.tokenizer.encode("<|endofturn|>")[0] |
| | if bad_words_ids is None: |
| | bad_words_ids = [ |
| | [ |
| | self.config.text_config.bos_token_id, |
| | ], |
| | [ |
| | self.config.text_config.eos_token_id, |
| | ], |
| | ] |
| |
|
| | if (pixel_values_images is None or all(len(pixel_values) == 0 for pixel_values in pixel_values_images)) and ( |
| | pixel_values_videos is None or all(len(pixel_values) == 0 for pixel_values in pixel_values_videos) |
| | ): |
| | return self.language_model.generate( |
| | input_ids, pad_token_id=pad_token_id, eos_token_id=eos_token_id, bad_words_ids=bad_words_ids, **kwargs |
| | ) |
| |
|
| | inputs_embeds = self.extract_inputs_embeds( |
| | input_ids=input_ids, |
| | pixel_values_images=pixel_values_images, |
| | image_sizes_images=image_sizes_images, |
| | pixel_values_videos=pixel_values_videos, |
| | ) |
| |
|
| | inputs_embeds = inputs_embeds.to(device=self.language_model.device, dtype=self.language_model.dtype) |
| |
|
| | |
| | pred = self.language_model.generate( |
| | inputs_embeds=inputs_embeds, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | bad_words_ids=bad_words_ids, |
| | max_new_tokens=max_length, |
| | min_length=min_length, |
| | num_beams=num_beams, |
| | do_sample=(False if temperature == 0.0 else do_sample), |
| | top_k=top_k, |
| | top_p=top_p, |
| | temperature=temperature, |
| | repetition_penalty=repetition_penalty, |
| | length_penalty=length_penalty, |
| | early_stopping=(False if num_beams <= 1 else True), |
| | use_cache=use_cache, |
| | ) |
| |
|
| | if verbose: |
| | llm_query = self.tokenizer.batch_decode( |
| | [ |
| | [token_id for token_id in input_ids_row if token_id != self.tokenizer.pad_token_id] |
| | for input_ids_row in input_ids.detach().cpu().tolist() |
| | ], |
| | skip_special_tokens=False, |
| | )[0] |
| | llm_pred = self.tokenizer.batch_decode( |
| | [ |
| | [token_id for token_id in pred_row if token_id != self.tokenizer.pad_token_id] |
| | for pred_row in pred.detach().cpu().tolist() |
| | ], |
| | skip_special_tokens=False, |
| | )[0] |
| | print(f"# [info] llm_query: {llm_query}") |
| | print(f"# [info] llm_pred: {llm_pred}") |
| |
|
| | return pred |
| |
|
| | def to_vision_model_device(self, input_tensor: Union[torch.Tensor, List]) -> Union[torch.Tensor, List]: |
| | """Move input tensors to the vision model's device. |
| | This method recursively moves input tensors or lists of tensors to the vision model's device. |
| | |
| | Args: |
| | input_tensor: Input tensor or list of tensors to be moved to the vision model's device. |
| | |
| | Returns: |
| | The input tensor or list of tensors moved to the vision model's device. |
| | |
| | Raises: |
| | TypeError: If the input is neither a tensor nor a list. |
| | """ |
| | if isinstance(input_tensor, list): |
| | return [self.to_vision_model_device(item) for item in input_tensor] |
| | elif isinstance(input_tensor, torch.Tensor): |
| | return input_tensor.to(self.vision_model.device) |
| | else: |
| | raise TypeError("Unsupported data type. Only tensors and lists are allowed.") |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids: torch.LongTensor, |
| | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | **kwargs, |
| | ) -> Dict[str, Any]: |
| | """Prepare inputs for the generation algorithm. |
| | |
| | This method prepares the input for each generation step based on the model's needs. |
| | |
| | Args: |
| | input_ids: Input token IDs. |
| | past_key_values: Pre-computed key and value states for faster inference. |
| | attention_mask: Mask to avoid performing attention on padding token indices. |
| | inputs_embeds: Input embeddings. If provided, input_ids will not be used. |
| | **kwargs: Additional keyword arguments. |
| | |
| | Returns: |
| | Dictionary containing the prepared inputs for the model. |
| | """ |
| | input_ids = kwargs.get("decoder_input_ids", input_ids) |
| |
|
| | if past_key_values: |
| | input_ids = input_ids[:, -1:] |
| |
|
| | |
| | if inputs_embeds is not None and past_key_values is None: |
| | model_inputs = {"inputs_embeds": inputs_embeds} |
| | else: |
| | model_inputs = {"input_ids": input_ids} |
| |
|
| | model_inputs.update( |
| | { |
| | "past_key_values": past_key_values, |
| | "use_cache": kwargs.get("use_cache"), |
| | "attention_mask": attention_mask, |
| | "pixel_values": kwargs.get("pixel_values", None), |
| | } |
| | ) |
| | return model_inputs |
| |
|
| | @classmethod |
| | def from_config(cls, config, vision_model_name_or_path): |
| | return cls(config, vision_model_name_or_path) |
| |
|
| | @classmethod |
| | def from_pretrained( |
| | cls, |
| | pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, |
| | *model_args, |
| | **kwargs, |
| | ) -> "HCXVisionForCausalLM": |
| | assert pretrained_model_name_or_path is not None |
| |
|
| | save_only_vision = kwargs.pop("save_only_vision") if "save_only_vision" in kwargs else False |
| | save_only_qformer = kwargs.pop("save_only_qformer") if "save_only_qformer" in kwargs else False |
| | save_shard_size = kwargs.pop("save_shard_size") if "save_shard_size" in kwargs else "5GB" |
| |
|
| | if pretrained_model_name_or_path is not None: |
| | model: HCXVisionForCausalLM = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
| | model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) |
| |
|
| | image_token_id = model.tokenizer.encode(IMAGE_LOC, add_special_tokens=False) |
| | assert ( |
| | len(image_token_id) == 1 |
| | ), f'"<|dummy3|>" was not encoded into a single special token. Encoding result: {image_token_id}' |
| | model.config.image_token_id = image_token_id[0] |
| | |
| | video_token_id = model.tokenizer.encode(VIDEO_LOC, add_special_tokens=False) |
| | assert ( |
| | len(video_token_id) == 1 |
| | ), f'"<|_unuse_missing_100270|>" was not encoded into a single special token. Encoding result: {video_token_id}' |
| | model.config.video_token_id = video_token_id[0] |
| |
|
| | model.save_only_vision = save_only_vision |
| | model.save_only_qformer = save_only_qformer |
| | model.save_shard_size = save_shard_size |
| |
|
| | return model |
| |
|
| | def get_language_model(self): |
| | return self.language_model.base_model |
| |
|
| | def get_vision_model(self): |
| | return self.vision_model |
| |
|
| | def save_pretrained( |
| | self, |
| | save_directory: Union[str, os.PathLike], |
| | *args, |
| | **kwargs, |
| | ): |
| | state_dict = kwargs["state_dict"] if "state_dict" in kwargs else self.state_dict() |
| | partial_state_dict = self.get_pretrained_state_dict( |
| | state_dict, |
| | save_directory, |
| | ) |
| | kwargs["state_dict"] = partial_state_dict |
| | kwargs["safe_serialization"] = self.is_safetensor_save |
| | kwargs.setdefault("max_shard_size", self.save_shard_size) |
| | super().save_pretrained(save_directory, *args, **kwargs) |
| |
|
| | def get_pretrained_state_dict(self, state_dict, save_dir): |
| | vision_key = "vision_model." |
| | llm_keys = ["language_model."] |
| | head_key = "lm_head." |
| |
|
| | for key in list(state_dict.keys()): |
| | if self.save_only_vision: |
| | for llm_key in llm_keys: |
| | if llm_key in key: |
| | state_dict.pop(key) |
| | if key.startswith(head_key): |
| | state_dict.pop(key) |
| |
|
| | elif self.save_only_qformer: |
| | if f"{vision_key}" in key: |
| | state_dict.pop(key) |
| |
|
| | return state_dict |
| |
|
| |
|
| |
|
| | class HCXVisionMlp(nn.Module): |
| | def __init__( |
| | self, |
| | mm_projector_type, |
| | in_features, |
| | hidden_features=None, |
| | out_features=None, |
| | act_layer=nn.GELU, |
| | ): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | self.mm_projector_type = mm_projector_type |
| | if self.mm_projector_type == "mlp": |
| | self.fc1 = nn.Linear(in_features, hidden_features) |
| | self.act = act_layer() |
| | self.fc2 = nn.Linear(hidden_features, out_features) |
| | elif self.mm_projector_type == "inverted_mlp": |
| | self.fc1 = nn.Linear(in_features, 2 * hidden_features) |
| | self.act = act_layer() |
| | self.fc2 = nn.Linear(2 * hidden_features, out_features) |
| | else: |
| | raise NotImplementedError("{} is not implemented".format(self.mm_projector_type)) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x = self.act(x) |
| | x = self.fc2(x) |
| | return x |
| |
|
| |
|
| | class HCXVisionCAbstractor(nn.Module): |
| | """ |
| | This module is based on C-Abstractor, whose license is under apache-2.0. |
| | You can check the original code at https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py |
| | and we made necessary modifications. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | num_queries: int, |
| | num_input_tokens: int, |
| | encoder_hidden_size: int, |
| | hidden_size: int, |
| | output_hidden_size: int, |
| | pos_emb: bool = True, |
| | prenorm: bool = False, |
| | ): |
| | super().__init__() |
| | self.num_input_tokens = num_input_tokens |
| | self.output_hidden_size = output_hidden_size |
| |
|
| | |
| | if pos_emb: |
| | self.pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, encoder_hidden_size)) |
| | self.pos_emb.data.normal_(mean=0.0, std=0.02) |
| | else: |
| | self.pos_emb = None |
| |
|
| | |
| | if prenorm: |
| | self.prenorm = LayerNorm(encoder_hidden_size) |
| | else: |
| | self.prenorm = None |
| |
|
| | self.build_net(num_queries, encoder_hidden_size, hidden_size, output_hidden_size) |
| | self.dtype = next(self.parameters()).dtype |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
| | num_grids: Optional[List[int]] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: (B, L, encoder_hidden_size) tensor from the visual backbone (e.g. CLIP visual encoder), including cls token. |
| | """ |
| | if self.prenorm is not None: |
| | x = self.prenorm(x) |
| |
|
| | if self.pos_emb is not None: |
| | x = x + self.pos_emb |
| |
|
| | x = self._forward( |
| | x, |
| | num_queries_vis_abstractors=num_queries_vis_abstractors, |
| | num_grids=num_grids, |
| | ) |
| |
|
| | return x |
| |
|
| | def _forward( |
| | self, |
| | x: torch.Tensor, |
| | num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
| | num_grids: Optional[List[int]] = None, |
| | ) -> torch.Tensor: |
| | |
| | B, L, dim = x.shape |
| | hw = int(L**0.5) |
| | x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw) |
| |
|
| | if num_queries_vis_abstractors is not None: |
| | assert num_grids is not None |
| | return self._forward_adaptive_num_query(x, num_queries_vis_abstractors, num_grids) |
| |
|
| | x = self.net(x) |
| | x = rearrange(x, "b d h w -> b (h w) d") |
| | x = self.readout(x) |
| | return x |
| |
|
| | def _forward_adaptive_num_query( |
| | self, |
| | x: torch.Tensor, |
| | num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
| | num_grids: Optional[List[int]] = None, |
| | ) -> List[torch.Tensor]: |
| | |
| | assert len(self.net) == 3 |
| |
|
| | x = self.net[0](x) |
| | new_x = [] |
| | for i, num_queries in enumerate(num_queries_vis_abstractors): |
| | hw = int(num_queries**0.5) |
| | sampler = nn.AdaptiveAvgPool2d((hw, hw)) |
| | out = sampler(x[num_grids[i] : num_grids[i + 1], :]) |
| | out = self.net[2](out) |
| |
|
| | out = rearrange(out, "b d h w -> b (h w) d") |
| | out = self.readout(out) |
| |
|
| | new_x.append(out) |
| | return new_x |
| |
|
| | def build_net( |
| | self, |
| | n_queries: int, |
| | encoder_hidden_size: int, |
| | hidden_size: int, |
| | output_hidden_size: int, |
| | depth: int = 3, |
| | mlp_depth: int = 2, |
| | ): |
| | assert (n_queries**0.5).is_integer(), f"n_queries must be square number. n_queries: {n_queries}" |
| | hw = int(n_queries**0.5) |
| |
|
| | |
| | RegBlock = partial( |
| | RegStage, |
| | stride=1, |
| | dilation=1, |
| | act_layer=nn.SiLU, |
| | norm_layer=LayerNorm2d, |
| | ) |
| |
|
| | s1 = RegBlock( |
| | depth, |
| | encoder_hidden_size, |
| | hidden_size, |
| | ) |
| | sampler = nn.AdaptiveAvgPool2d((hw, hw)) |
| | s2 = RegBlock( |
| | depth, |
| | hidden_size, |
| | hidden_size, |
| | ) |
| |
|
| | self.net = nn.Sequential(s1, sampler, s2) |
| | self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size) |
| |
|
| | def build_mlp( |
| | self, |
| | depth: int, |
| | hidden_size: int, |
| | output_hidden_size: int, |
| | ): |
| | layers = [nn.Linear(hidden_size, output_hidden_size)] |
| | for _ in range(1, depth): |
| | layers.append(nn.SiLU()) |
| | layers.append(nn.Linear(output_hidden_size, output_hidden_size)) |
| | return nn.Sequential(*layers) |
| |
|
| |
|