#!/usr/bin/env python3 # Portions Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os from functools import partial from types import SimpleNamespace from typing import Optional, Sequence import torch import torch.nn as nn from imagebind.models.helpers import (EinOpsRearrange, LearnableLogitScaling, Normalize, SelectElement, SelectEOSAndProject) from imagebind.models.multimodal_preprocessors import (AudioPreprocessor, IMUPreprocessor, PadIm2Video, PatchEmbedGeneric, RGBDTPreprocessor, SpatioTemporalPosEmbeddingHelper, TextPreprocessor, ThermalPreprocessor) from imagebind.models.transformer import MultiheadAttention, SimpleTransformer ModalityType = SimpleNamespace( VISION="vision", TEXT="text", AUDIO="audio", THERMAL="thermal", DEPTH="depth", IMU="imu", ) class ImageBindModel(nn.Module): def __init__( self, video_frames=2, kernel_size=(2, 14, 14), audio_kernel_size=16, audio_stride=10, out_embed_dim=768, vision_embed_dim=1024, vision_num_blocks=24, vision_num_heads=16, audio_embed_dim=768, audio_num_blocks=12, audio_num_heads=12, audio_num_mel_bins=128, audio_target_len=204, audio_drop_path=0.1, text_embed_dim=768, text_num_blocks=12, text_num_heads=12, depth_embed_dim=384, depth_kernel_size=16, depth_num_blocks=12, depth_num_heads=8, depth_drop_path=0.0, thermal_embed_dim=768, thermal_kernel_size=16, thermal_num_blocks=12, thermal_num_heads=12, thermal_drop_path=0.0, imu_embed_dim=512, imu_kernel_size=8, imu_num_blocks=6, imu_num_heads=8, imu_drop_path=0.7, ): super().__init__() self.modality_preprocessors = self._create_modality_preprocessors( video_frames, vision_embed_dim, kernel_size, text_embed_dim, audio_embed_dim, audio_kernel_size, audio_stride, audio_num_mel_bins, audio_target_len, depth_embed_dim, depth_kernel_size, thermal_embed_dim, thermal_kernel_size, imu_embed_dim, ) self.modality_trunks = self._create_modality_trunks( vision_embed_dim, vision_num_blocks, vision_num_heads, text_embed_dim, text_num_blocks, text_num_heads, audio_embed_dim, audio_num_blocks, audio_num_heads, audio_drop_path, depth_embed_dim, depth_num_blocks, depth_num_heads, depth_drop_path, thermal_embed_dim, thermal_num_blocks, thermal_num_heads, thermal_drop_path, imu_embed_dim, imu_num_blocks, imu_num_heads, imu_drop_path, ) self.modality_heads = self._create_modality_heads( out_embed_dim, vision_embed_dim, text_embed_dim, audio_embed_dim, depth_embed_dim, thermal_embed_dim, imu_embed_dim, ) self.modality_postprocessors = self._create_modality_postprocessors( out_embed_dim ) def _create_modality_preprocessors( self, video_frames=2, vision_embed_dim=1024, kernel_size=(2, 14, 14), text_embed_dim=768, audio_embed_dim=768, audio_kernel_size=16, audio_stride=10, audio_num_mel_bins=128, audio_target_len=204, depth_embed_dim=768, depth_kernel_size=16, thermal_embed_dim=768, thermal_kernel_size=16, imu_embed_dim=512, ): rgbt_stem = PatchEmbedGeneric( proj_stem=[ PadIm2Video(pad_type="repeat", ntimes=2), nn.Conv3d( in_channels=3, kernel_size=kernel_size, out_channels=vision_embed_dim, stride=kernel_size, bias=False, ), ] ) rgbt_preprocessor = RGBDTPreprocessor( img_size=[3, video_frames, 224, 224], num_cls_tokens=1, pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), rgbt_stem=rgbt_stem, depth_stem=None, ) text_preprocessor = TextPreprocessor( context_length=77, vocab_size=49408, embed_dim=text_embed_dim, causal_masking=True, ) audio_stem = PatchEmbedGeneric( proj_stem=[ nn.Conv2d( in_channels=1, kernel_size=audio_kernel_size, stride=audio_stride, out_channels=audio_embed_dim, bias=False, ), ], norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim), ) audio_preprocessor = AudioPreprocessor( img_size=[1, audio_num_mel_bins, audio_target_len], num_cls_tokens=1, pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), audio_stem=audio_stem, ) depth_stem = PatchEmbedGeneric( [ nn.Conv2d( kernel_size=depth_kernel_size, in_channels=1, out_channels=depth_embed_dim, stride=depth_kernel_size, bias=False, ), ], norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim), ) depth_preprocessor = RGBDTPreprocessor( img_size=[1, 224, 224], num_cls_tokens=1, pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), rgbt_stem=None, depth_stem=depth_stem, ) thermal_stem = PatchEmbedGeneric( [ nn.Conv2d( kernel_size=thermal_kernel_size, in_channels=1, out_channels=thermal_embed_dim, stride=thermal_kernel_size, bias=False, ), ], norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim), ) thermal_preprocessor = ThermalPreprocessor( img_size=[1, 224, 224], num_cls_tokens=1, pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), thermal_stem=thermal_stem, ) imu_stem = PatchEmbedGeneric( [ nn.Linear( in_features=48, out_features=imu_embed_dim, bias=False, ), ], norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim), ) imu_preprocessor = IMUPreprocessor( img_size=[6, 2000], num_cls_tokens=1, kernel_size=8, embed_dim=imu_embed_dim, pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True), imu_stem=imu_stem, ) modality_preprocessors = { ModalityType.VISION: rgbt_preprocessor, ModalityType.TEXT: text_preprocessor, ModalityType.AUDIO: audio_preprocessor, ModalityType.DEPTH: depth_preprocessor, ModalityType.THERMAL: thermal_preprocessor, ModalityType.IMU: imu_preprocessor, } return nn.ModuleDict(modality_preprocessors) def _create_modality_trunks( self, vision_embed_dim=1024, vision_num_blocks=24, vision_num_heads=16, text_embed_dim=768, text_num_blocks=12, text_num_heads=12, audio_embed_dim=768, audio_num_blocks=12, audio_num_heads=12, audio_drop_path=0.0, depth_embed_dim=768, depth_num_blocks=12, depth_num_heads=12, depth_drop_path=0.0, thermal_embed_dim=768, thermal_num_blocks=12, thermal_num_heads=12, thermal_drop_path=0.0, imu_embed_dim=512, imu_num_blocks=6, imu_num_heads=8, imu_drop_path=0.7, ): def instantiate_trunk( embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path ): return SimpleTransformer( embed_dim=embed_dim, num_blocks=num_blocks, ffn_dropout_rate=0.0, drop_path_rate=drop_path, attn_target=partial( MultiheadAttention, embed_dim=embed_dim, num_heads=num_heads, bias=True, add_bias_kv=add_bias_kv, ), pre_transformer_layer=nn.Sequential( nn.LayerNorm(embed_dim, eps=1e-6) if pre_transformer_ln else nn.Identity(), EinOpsRearrange("b l d -> l b d"), ), post_transformer_layer=EinOpsRearrange("l b d -> b l d"), ) modality_trunks = {} modality_trunks[ModalityType.VISION] = instantiate_trunk( vision_embed_dim, vision_num_blocks, vision_num_heads, pre_transformer_ln=True, add_bias_kv=False, drop_path=0.0, ) modality_trunks[ModalityType.TEXT] = instantiate_trunk( text_embed_dim, text_num_blocks, text_num_heads, pre_transformer_ln=False, add_bias_kv=False, drop_path=0.0, ) modality_trunks[ModalityType.AUDIO] = instantiate_trunk( audio_embed_dim, audio_num_blocks, audio_num_heads, pre_transformer_ln=False, add_bias_kv=True, drop_path=audio_drop_path, ) modality_trunks[ModalityType.DEPTH] = instantiate_trunk( depth_embed_dim, depth_num_blocks, depth_num_heads, pre_transformer_ln=False, add_bias_kv=True, drop_path=depth_drop_path, ) modality_trunks[ModalityType.THERMAL] = instantiate_trunk( thermal_embed_dim, thermal_num_blocks, thermal_num_heads, pre_transformer_ln=False, add_bias_kv=True, drop_path=thermal_drop_path, ) modality_trunks[ModalityType.IMU] = instantiate_trunk( imu_embed_dim, imu_num_blocks, imu_num_heads, pre_transformer_ln=False, add_bias_kv=True, drop_path=imu_drop_path, ) return nn.ModuleDict(modality_trunks) def _create_modality_heads( self, out_embed_dim, vision_embed_dim, text_embed_dim, audio_embed_dim, depth_embed_dim, thermal_embed_dim, imu_embed_dim, ): modality_heads = {} modality_heads[ModalityType.VISION] = nn.Sequential( nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6), SelectElement(index=0), nn.Linear(vision_embed_dim, out_embed_dim, bias=False), ) modality_heads[ModalityType.TEXT] = SelectEOSAndProject( proj=nn.Sequential( nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6), nn.Linear(text_embed_dim, out_embed_dim, bias=False), ) ) modality_heads[ModalityType.AUDIO] = nn.Sequential( nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6), SelectElement(index=0), nn.Linear(audio_embed_dim, out_embed_dim, bias=False), ) modality_heads[ModalityType.DEPTH] = nn.Sequential( nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6), SelectElement(index=0), nn.Linear(depth_embed_dim, out_embed_dim, bias=False), ) modality_heads[ModalityType.THERMAL] = nn.Sequential( nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6), SelectElement(index=0), nn.Linear(thermal_embed_dim, out_embed_dim, bias=False), ) modality_heads[ModalityType.IMU] = nn.Sequential( nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6), SelectElement(index=0), nn.Dropout(p=0.5), nn.Linear(imu_embed_dim, out_embed_dim, bias=False), ) return nn.ModuleDict(modality_heads) def _create_modality_postprocessors(self, out_embed_dim): modality_postprocessors = {} modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1) modality_postprocessors[ModalityType.TEXT] = nn.Sequential( Normalize(dim=-1), LearnableLogitScaling(learnable=True) ) modality_postprocessors[ModalityType.AUDIO] = nn.Sequential( Normalize(dim=-1), LearnableLogitScaling(logit_scale_init=20.0, learnable=False), ) modality_postprocessors[ModalityType.DEPTH] = nn.Sequential( Normalize(dim=-1), LearnableLogitScaling(logit_scale_init=5.0, learnable=False), ) modality_postprocessors[ModalityType.THERMAL] = nn.Sequential( Normalize(dim=-1), LearnableLogitScaling(logit_scale_init=10.0, learnable=False), ) modality_postprocessors[ModalityType.IMU] = nn.Sequential( Normalize(dim=-1), LearnableLogitScaling(logit_scale_init=5.0, learnable=False), ) return nn.ModuleDict(modality_postprocessors) def set_modality_attention_capture( self, modality_key: str, enabled: bool, block_indices: Optional[Sequence[int]] = None, ) -> None: """ Toggle attention-weight capture for one modality trunk (see ``SimpleTransformer.set_attention_capture`` in ``transformer.py``). After a forward pass with capture enabled on the last block, read weights via ``get_modality_attention_weights(modality_key)``. """ self.modality_trunks[modality_key].set_attention_capture( enabled, block_indices=block_indices ) def get_modality_attention_weights( self, modality_key: str, block_index: int = -1 ) -> Optional[torch.Tensor]: """Return ``last_attn_weights`` from the given block's ``MultiheadAttention``, if any.""" blk = self.modality_trunks[modality_key].blocks[block_index] attn = blk.attn if isinstance(attn, MultiheadAttention): return attn.last_attn_weights return None def forward(self, inputs): outputs = {} for modality_key, modality_value in inputs.items(): reduce_list = ( modality_value.ndim >= 5 ) # Audio and Video inputs consist of multiple clips if reduce_list: B, S = modality_value.shape[:2] modality_value = modality_value.reshape( B * S, *modality_value.shape[2:] ) if modality_value is not None: modality_value = self.modality_preprocessors[modality_key]( **{modality_key: modality_value} ) trunk_inputs = modality_value["trunk"] head_inputs = modality_value["head"] modality_value = self.modality_trunks[modality_key](**trunk_inputs) modality_value = self.modality_heads[modality_key]( modality_value, **head_inputs ) modality_value = self.modality_postprocessors[modality_key]( modality_value ) if reduce_list: modality_value = modality_value.reshape(B, S, -1) modality_value = modality_value.mean(dim=1) outputs[modality_key] = modality_value return outputs def imagebind_huge(pretrained=False): model = ImageBindModel( vision_embed_dim=1280, vision_num_blocks=32, vision_num_heads=16, text_embed_dim=1024, text_num_blocks=24, text_num_heads=16, out_embed_dim=1024, audio_drop_path=0.1, imu_drop_path=0.7, ) if pretrained: if not os.path.exists(".checkpoints/imagebind_huge.pth"): print( "Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..." ) os.makedirs(".checkpoints", exist_ok=True) torch.hub.download_url_to_file( "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", ".checkpoints/imagebind_huge.pth", progress=True, ) model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth", weights_only=True)) return model