| from typing import List |
| import warnings |
|
|
| import torch |
| from torch import nn, Tensor |
| from torchvision import transforms |
|
|
| from torchtune.models.llama3 import lora_llama3_8b, llama3_8b |
| from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear |
| from torchtune.modules import TransformerDecoder |
|
|
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore", UserWarning) |
| from imagebind.models import imagebind_model |
| from models.imagebind_wrapper import get_imagebind_v2, V2_PATH |
| from models.imagebind_wrapper import ImageBind |
|
|
| IMAGEBIND_DIM = 1024 |
| CLIP_DIM = 768 |
|
|
|
|
| class MMEmbedding(nn.Embedding): |
| def __init__(self, e, perception_tokens=1, use_clip=False): |
| super().__init__( |
| num_embeddings=e.num_embeddings, |
| embedding_dim=e.embedding_dim, |
| padding_idx=e.padding_idx, |
| max_norm=e.max_norm, |
| norm_type=e.norm_type, |
| scale_grad_by_freq=e.scale_grad_by_freq, |
| sparse=e.sparse, |
| ) |
| self._perception_tokens = perception_tokens |
| self._context = [] |
| self._use_clip = use_clip |
|
|
| dim_in = IMAGEBIND_DIM + (CLIP_DIM if use_clip else 0) |
| dim_out = e.embedding_dim * perception_tokens |
|
|
| self.proj_to_llama = nn.Sequential( |
| nn.Linear(dim_in, dim_out), |
| nn.GELU(), |
| nn.LayerNorm(dim_out), |
| nn.Linear(dim_out, dim_out), |
| ) |
|
|
| def set_context(self, context): |
| self._context = context |
|
|
| def forward(self, input: Tensor) -> Tensor: |
| r = super().forward(input) |
| |
| for b, context_dict in enumerate(self._context): |
| |
| for s, embed in context_dict.items(): |
| |
| if self._use_clip: |
| llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"], embed["clip_embed"]])) |
| else: |
| llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"]])) |
| r[b, s:s+self._perception_tokens] = llama_embed.view(self._perception_tokens, -1) |
| return r |
|
|
|
|
| class MMLinear(nn.Linear): |
| def __init__(self, o): |
| super().__init__( |
| in_features=o.in_features, |
| out_features=o.out_features, |
| bias=(o.bias != None) |
| ) |
| self._context = [] |
|
|
| dim_out = CLIP_DIM |
| dim_in = o.in_features |
| self.proj_from_llama = nn.Sequential( |
| nn.Linear(dim_in, dim_out), |
| nn.GELU(), |
| nn.LayerNorm(dim_out), |
| nn.Linear(dim_out, dim_out), |
| ) |
|
|
| def set_context(self, context): |
| self._context = context |
|
|
| def forward(self, input_bsd: Tensor) -> Tensor: |
| |
| self._clip_projections = [] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| r = super().forward(input_bsd) |
| return r |
|
|
|
|
|
|
| def lora_mmllama3_8b( |
| lora_attn_modules: List[LORA_ATTN_MODULES], |
| apply_lora_to_mlp: bool = False, |
| apply_lora_to_output: bool = False, |
| lora_rank: int = 8, |
| lora_alpha: float = 16, |
| quantize_base: bool = False, |
| perception_tokens: int = 2, |
| use_clip: bool = False |
| ) -> TransformerDecoder: |
| llama3 = lora_llama3_8b( |
| lora_attn_modules, |
| apply_lora_to_mlp, |
| apply_lora_to_output, |
| lora_rank, |
| lora_alpha, |
| quantize_base, |
| ) |
| llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip) |
| llama3.output = MMLinear(llama3.output) |
| return llama3 |
|
|
|
|
| def mmllama3_8b( |
| perception_tokens: int = 2, |
| use_clip: bool = False |
| ) -> TransformerDecoder: |
| llama3 = llama3_8b() |
| llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip) |
| llama3.output = MMLinear(llama3.output) |
| return llama3 |
|
|
|
|
| def imagebind_huge(use_v2: bool=True): |
| if use_v2: |
| imagebind = ImageBind(v2=True) |
| else: |
| imagebind = imagebind_model.imagebind_huge(pretrained=True) |
| imagebind.transform_from_pil = transforms.Compose([ |
| transforms.Resize( |
| 224, interpolation=transforms.InterpolationMode.BICUBIC |
| ), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=(0.48145466, 0.4578275, 0.40821073), |
| std=(0.26862954, 0.26130258, 0.27577711), |
| ), |
| ]) |
| return imagebind |
|
|
|
|