XAI / ImageBind /imagebind /models /imagebind_model.py
haiphamcse's picture
Upload folder using huggingface_hub
6a00010 verified
#!/usr/bin/env python3
# Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from functools import partial
from types import SimpleNamespace
from typing import Optional, Sequence
import torch
import torch.nn as nn
from imagebind.models.helpers import (EinOpsRearrange, LearnableLogitScaling, Normalize,
SelectElement, SelectEOSAndProject)
from imagebind.models.multimodal_preprocessors import (AudioPreprocessor,
IMUPreprocessor, PadIm2Video,
PatchEmbedGeneric,
RGBDTPreprocessor,
SpatioTemporalPosEmbeddingHelper,
TextPreprocessor,
ThermalPreprocessor)
from imagebind.models.transformer import MultiheadAttention, SimpleTransformer
ModalityType = SimpleNamespace(
VISION="vision",
TEXT="text",
AUDIO="audio",
THERMAL="thermal",
DEPTH="depth",
IMU="imu",
)
class ImageBindModel(nn.Module):
def __init__(
self,
video_frames=2,
kernel_size=(2, 14, 14),
audio_kernel_size=16,
audio_stride=10,
out_embed_dim=768,
vision_embed_dim=1024,
vision_num_blocks=24,
vision_num_heads=16,
audio_embed_dim=768,
audio_num_blocks=12,
audio_num_heads=12,
audio_num_mel_bins=128,
audio_target_len=204,
audio_drop_path=0.1,
text_embed_dim=768,
text_num_blocks=12,
text_num_heads=12,
depth_embed_dim=384,
depth_kernel_size=16,
depth_num_blocks=12,
depth_num_heads=8,
depth_drop_path=0.0,
thermal_embed_dim=768,
thermal_kernel_size=16,
thermal_num_blocks=12,
thermal_num_heads=12,
thermal_drop_path=0.0,
imu_embed_dim=512,
imu_kernel_size=8,
imu_num_blocks=6,
imu_num_heads=8,
imu_drop_path=0.7,
):
super().__init__()
self.modality_preprocessors = self._create_modality_preprocessors(
video_frames,
vision_embed_dim,
kernel_size,
text_embed_dim,
audio_embed_dim,
audio_kernel_size,
audio_stride,
audio_num_mel_bins,
audio_target_len,
depth_embed_dim,
depth_kernel_size,
thermal_embed_dim,
thermal_kernel_size,
imu_embed_dim,
)
self.modality_trunks = self._create_modality_trunks(
vision_embed_dim,
vision_num_blocks,
vision_num_heads,
text_embed_dim,
text_num_blocks,
text_num_heads,
audio_embed_dim,
audio_num_blocks,
audio_num_heads,
audio_drop_path,
depth_embed_dim,
depth_num_blocks,
depth_num_heads,
depth_drop_path,
thermal_embed_dim,
thermal_num_blocks,
thermal_num_heads,
thermal_drop_path,
imu_embed_dim,
imu_num_blocks,
imu_num_heads,
imu_drop_path,
)
self.modality_heads = self._create_modality_heads(
out_embed_dim,
vision_embed_dim,
text_embed_dim,
audio_embed_dim,
depth_embed_dim,
thermal_embed_dim,
imu_embed_dim,
)
self.modality_postprocessors = self._create_modality_postprocessors(
out_embed_dim
)
def _create_modality_preprocessors(
self,
video_frames=2,
vision_embed_dim=1024,
kernel_size=(2, 14, 14),
text_embed_dim=768,
audio_embed_dim=768,
audio_kernel_size=16,
audio_stride=10,
audio_num_mel_bins=128,
audio_target_len=204,
depth_embed_dim=768,
depth_kernel_size=16,
thermal_embed_dim=768,
thermal_kernel_size=16,
imu_embed_dim=512,
):
rgbt_stem = PatchEmbedGeneric(
proj_stem=[
PadIm2Video(pad_type="repeat", ntimes=2),
nn.Conv3d(
in_channels=3,
kernel_size=kernel_size,
out_channels=vision_embed_dim,
stride=kernel_size,
bias=False,
),
]
)
rgbt_preprocessor = RGBDTPreprocessor(
img_size=[3, video_frames, 224, 224],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
rgbt_stem=rgbt_stem,
depth_stem=None,
)
text_preprocessor = TextPreprocessor(
context_length=77,
vocab_size=49408,
embed_dim=text_embed_dim,
causal_masking=True,
)
audio_stem = PatchEmbedGeneric(
proj_stem=[
nn.Conv2d(
in_channels=1,
kernel_size=audio_kernel_size,
stride=audio_stride,
out_channels=audio_embed_dim,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
)
audio_preprocessor = AudioPreprocessor(
img_size=[1, audio_num_mel_bins, audio_target_len],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
audio_stem=audio_stem,
)
depth_stem = PatchEmbedGeneric(
[
nn.Conv2d(
kernel_size=depth_kernel_size,
in_channels=1,
out_channels=depth_embed_dim,
stride=depth_kernel_size,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
)
depth_preprocessor = RGBDTPreprocessor(
img_size=[1, 224, 224],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
rgbt_stem=None,
depth_stem=depth_stem,
)
thermal_stem = PatchEmbedGeneric(
[
nn.Conv2d(
kernel_size=thermal_kernel_size,
in_channels=1,
out_channels=thermal_embed_dim,
stride=thermal_kernel_size,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
)
thermal_preprocessor = ThermalPreprocessor(
img_size=[1, 224, 224],
num_cls_tokens=1,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
thermal_stem=thermal_stem,
)
imu_stem = PatchEmbedGeneric(
[
nn.Linear(
in_features=48,
out_features=imu_embed_dim,
bias=False,
),
],
norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
)
imu_preprocessor = IMUPreprocessor(
img_size=[6, 2000],
num_cls_tokens=1,
kernel_size=8,
embed_dim=imu_embed_dim,
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
imu_stem=imu_stem,
)
modality_preprocessors = {
ModalityType.VISION: rgbt_preprocessor,
ModalityType.TEXT: text_preprocessor,
ModalityType.AUDIO: audio_preprocessor,
ModalityType.DEPTH: depth_preprocessor,
ModalityType.THERMAL: thermal_preprocessor,
ModalityType.IMU: imu_preprocessor,
}
return nn.ModuleDict(modality_preprocessors)
def _create_modality_trunks(
self,
vision_embed_dim=1024,
vision_num_blocks=24,
vision_num_heads=16,
text_embed_dim=768,
text_num_blocks=12,
text_num_heads=12,
audio_embed_dim=768,
audio_num_blocks=12,
audio_num_heads=12,
audio_drop_path=0.0,
depth_embed_dim=768,
depth_num_blocks=12,
depth_num_heads=12,
depth_drop_path=0.0,
thermal_embed_dim=768,
thermal_num_blocks=12,
thermal_num_heads=12,
thermal_drop_path=0.0,
imu_embed_dim=512,
imu_num_blocks=6,
imu_num_heads=8,
imu_drop_path=0.7,
):
def instantiate_trunk(
embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
):
return SimpleTransformer(
embed_dim=embed_dim,
num_blocks=num_blocks,
ffn_dropout_rate=0.0,
drop_path_rate=drop_path,
attn_target=partial(
MultiheadAttention,
embed_dim=embed_dim,
num_heads=num_heads,
bias=True,
add_bias_kv=add_bias_kv,
),
pre_transformer_layer=nn.Sequential(
nn.LayerNorm(embed_dim, eps=1e-6)
if pre_transformer_ln
else nn.Identity(),
EinOpsRearrange("b l d -> l b d"),
),
post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
)
modality_trunks = {}
modality_trunks[ModalityType.VISION] = instantiate_trunk(
vision_embed_dim,
vision_num_blocks,
vision_num_heads,
pre_transformer_ln=True,
add_bias_kv=False,
drop_path=0.0,
)
modality_trunks[ModalityType.TEXT] = instantiate_trunk(
text_embed_dim,
text_num_blocks,
text_num_heads,
pre_transformer_ln=False,
add_bias_kv=False,
drop_path=0.0,
)
modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
audio_embed_dim,
audio_num_blocks,
audio_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=audio_drop_path,
)
modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
depth_embed_dim,
depth_num_blocks,
depth_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=depth_drop_path,
)
modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
thermal_embed_dim,
thermal_num_blocks,
thermal_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=thermal_drop_path,
)
modality_trunks[ModalityType.IMU] = instantiate_trunk(
imu_embed_dim,
imu_num_blocks,
imu_num_heads,
pre_transformer_ln=False,
add_bias_kv=True,
drop_path=imu_drop_path,
)
return nn.ModuleDict(modality_trunks)
def _create_modality_heads(
self,
out_embed_dim,
vision_embed_dim,
text_embed_dim,
audio_embed_dim,
depth_embed_dim,
thermal_embed_dim,
imu_embed_dim,
):
modality_heads = {}
modality_heads[ModalityType.VISION] = nn.Sequential(
nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
)
modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
proj=nn.Sequential(
nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
nn.Linear(text_embed_dim, out_embed_dim, bias=False),
)
)
modality_heads[ModalityType.AUDIO] = nn.Sequential(
nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
)
modality_heads[ModalityType.DEPTH] = nn.Sequential(
nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
)
modality_heads[ModalityType.THERMAL] = nn.Sequential(
nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
)
modality_heads[ModalityType.IMU] = nn.Sequential(
nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
SelectElement(index=0),
nn.Dropout(p=0.5),
nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
)
return nn.ModuleDict(modality_heads)
def _create_modality_postprocessors(self, out_embed_dim):
modality_postprocessors = {}
modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
Normalize(dim=-1), LearnableLogitScaling(learnable=True)
)
modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
)
modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
)
modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
)
modality_postprocessors[ModalityType.IMU] = nn.Sequential(
Normalize(dim=-1),
LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
)
return nn.ModuleDict(modality_postprocessors)
def set_modality_attention_capture(
self,
modality_key: str,
enabled: bool,
block_indices: Optional[Sequence[int]] = None,
) -> None:
"""
Toggle attention-weight capture for one modality trunk (see
``SimpleTransformer.set_attention_capture`` in ``transformer.py``).
After a forward pass with capture enabled on the last block, read weights via
``get_modality_attention_weights(modality_key)``.
"""
self.modality_trunks[modality_key].set_attention_capture(
enabled, block_indices=block_indices
)
def get_modality_attention_weights(
self, modality_key: str, block_index: int = -1
) -> Optional[torch.Tensor]:
"""Return ``last_attn_weights`` from the given block's ``MultiheadAttention``, if any."""
blk = self.modality_trunks[modality_key].blocks[block_index]
attn = blk.attn
if isinstance(attn, MultiheadAttention):
return attn.last_attn_weights
return None
def forward(self, inputs):
outputs = {}
for modality_key, modality_value in inputs.items():
reduce_list = (
modality_value.ndim >= 5
) # Audio and Video inputs consist of multiple clips
if reduce_list:
B, S = modality_value.shape[:2]
modality_value = modality_value.reshape(
B * S, *modality_value.shape[2:]
)
if modality_value is not None:
modality_value = self.modality_preprocessors[modality_key](
**{modality_key: modality_value}
)
trunk_inputs = modality_value["trunk"]
head_inputs = modality_value["head"]
modality_value = self.modality_trunks[modality_key](**trunk_inputs)
modality_value = self.modality_heads[modality_key](
modality_value, **head_inputs
)
modality_value = self.modality_postprocessors[modality_key](
modality_value
)
if reduce_list:
modality_value = modality_value.reshape(B, S, -1)
modality_value = modality_value.mean(dim=1)
outputs[modality_key] = modality_value
return outputs
def imagebind_huge(pretrained=False):
model = ImageBindModel(
vision_embed_dim=1280,
vision_num_blocks=32,
vision_num_heads=16,
text_embed_dim=1024,
text_num_blocks=24,
text_num_heads=16,
out_embed_dim=1024,
audio_drop_path=0.1,
imu_drop_path=0.7,
)
if pretrained:
if not os.path.exists(".checkpoints/imagebind_huge.pth"):
print(
"Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..."
)
os.makedirs(".checkpoints", exist_ok=True)
torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
".checkpoints/imagebind_huge.pth",
progress=True,
)
model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth", weights_only=True))
return model