XAI / ImageBind /imagebind /legrad_wrapper.py
haiphamcse's picture
Upload folder using huggingface_hub
6a00010 verified
"""
LeGrad for ImageBind — same spirit as ``legrad.wrapper.LeWrapper`` / ``legrad_api.ipynb``:
hook residual blocks + PyTorch ``nn.MultiheadAttention``, then
``grad(sum(text · vision))`` w.r.t. attention probabilities.
Requires the ``legrad`` package (``pip install -e /path/to/LeGrad`` or PYTHONPATH).
Vision: heatmap over image patches (CLS query → patch keys).
Text: relevance vector over context positions (EOS query row → all keys).
"""
from __future__ import annotations
import math
import types
from typing import List, Optional, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from imagebind.models.imagebind_model import ImageBindModel, ModalityType
def _import_legrad_utils():
try:
from legrad.utils import hooked_torch_multi_head_attention_forward, min_max
return hooked_torch_multi_head_attention_forward, min_max
except ImportError as e: # pragma: no cover
raise ImportError(
"ImageBind LeGrad needs the `legrad` package. Install with "
"`pip install -e <path-to-LeGrad>` or add LeGrad to PYTHONPATH."
) from e
def hooked_imagebind_block_forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
"""Store features after attention and after MLP (ImageBind ``BlockWithMasking``)."""
if self.layer_scale_type is None:
x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask))
self.feat_post_attn = x
x = x + self.drop_path(self.mlp(self.norm_2(x)))
self.feat_post_mlp = x
else:
x = (
x
+ self.drop_path(self.attn(self.norm_1(x), attn_mask))
* self.layer_scale_gamma1
)
self.feat_post_attn = x
x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2
self.feat_post_mlp = x
return x
def _make_hooked_imagebind_mha_forward(hooked_torch_mha_forward):
def hooked_imagebind_mha_forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
"""Adapter: LeGrad hooked MHA expects ``(q,k,v,...)``; ImageBind calls ``(x, attn_mask)``."""
out, _ = hooked_torch_mha_forward(
self,
x,
x,
x,
key_padding_mask=None,
need_weights=True,
attn_mask=attn_mask,
)
return out
return hooked_imagebind_mha_forward
class ImageBindLeWrapper(nn.Module):
"""
Thin wrapper around ``ImageBindModel`` for LeGrad (vision and/or text branches).
Mirrors ``LeWrapper`` from ``legrad/wrapper.py``: copies public attributes/methods from
the base model, patches transformer blocks and attention with hooks, and provides
``compute_legrad_*`` helpers similar to ``compute_legrad_coca`` / ``compute_legrad_clip``.
"""
def __init__(
self,
model: ImageBindModel,
layer_index: int = -2,
trunk_key: str = ModalityType.VISION,
):
super().__init__()
for attr in dir(model):
if not attr.startswith("__"):
setattr(self, attr, getattr(model, attr))
self._legrad_trunk_key = trunk_key
hooked_torch_mha_forward, self._min_max = _import_legrad_utils()
self._hooked_mha_fn = _make_hooked_imagebind_mha_forward(hooked_torch_mha_forward)
self._activate_hooks(layer_index=layer_index, trunk_key=trunk_key)
def _trunk(self, key: Optional[str] = None):
key = key or self._legrad_trunk_key
return self.modality_trunks[key]
def _activate_hooks(self, layer_index: int, trunk_key: str) -> None:
trunk = self._trunk(trunk_key)
n_blocks = len(trunk.blocks)
self.starting_depth = (
layer_index if layer_index >= 0 else n_blocks + layer_index
)
self.starting_depth = max(0, min(self.starting_depth, n_blocks - 1))
prefix = f"modality_trunks.{trunk_key}.blocks"
for name, param in self.named_parameters():
param.requires_grad = False
if name.startswith(prefix):
depth = int(name.split(f"{prefix}.")[-1].split(".")[0])
if depth >= self.starting_depth:
param.requires_grad = True
for layer in range(self.starting_depth, n_blocks):
blk = trunk.blocks[layer]
blk.forward = types.MethodType(hooked_imagebind_block_forward, blk)
blk.attn.forward = types.MethodType(self._hooked_mha_fn, blk.attn)
print(
f"LeGrad (ImageBind): hooks on `{trunk_key}` blocks "
f"[{self.starting_depth}, {n_blocks - 1}] — gradients enabled from block "
f"{self.starting_depth} onward."
)
def _encode_vision_trunk(self, vision: torch.Tensor) -> torch.Tensor:
p = self.modality_preprocessors[ModalityType.VISION](vision=vision)
return self.modality_trunks[ModalityType.VISION](**p["trunk"])
def _encode_text_trunk(self, text: torch.Tensor) -> torch.Tensor:
p = self.modality_preprocessors[ModalityType.TEXT](text=text)
self._text_head_kwargs = dict(p.get("head", {}))
return self.modality_trunks[ModalityType.TEXT](**p["trunk"])
def _vision_embed_from_layer(self, layer_idx: int) -> torch.Tensor:
x_bld = (
self._trunk(ModalityType.VISION).blocks[layer_idx].feat_post_mlp.permute(
1, 0, 2
)
)
h = self.modality_heads[ModalityType.VISION](x_bld)
return self.modality_postprocessors[ModalityType.VISION](h)
def _text_embed_from_layer(self, layer_idx: int) -> torch.Tensor:
x_bld = (
self._trunk(ModalityType.TEXT).blocks[layer_idx].feat_post_mlp.permute(
1, 0, 2
)
)
h = self.modality_heads[ModalityType.TEXT](x_bld, **self._text_head_kwargs)
return self.modality_postprocessors[ModalityType.TEXT](h)
@staticmethod
def _cls_to_patch_relevance(
attn_grad: torch.Tensor, batch_size: int, num_heads: int
) -> torch.Tensor:
"""attn_grad: (B*H, L, L) -> (B, num_patches) CLS row, patch columns."""
L = attn_grad.shape[-1]
g = attn_grad.view(batch_size, num_heads, L, L).clamp(min=0.0)
g = g.mean(dim=1)
return g[:, 0, 1:]
@staticmethod
def _relevance_to_spatial_map(
relevance: torch.Tensor, patches_layout: Sequence[int], out_hw: tuple = (224, 224)
) -> torch.Tensor:
"""relevance: (num_patches,) → (1,1,H,W) upsampled."""
pl = tuple(patches_layout)
if len(pl) == 3:
t, h, w = pl
g = relevance.reshape(t, h, w).float()
g = g.mean(dim=0) if t > 1 else g[0]
elif len(pl) == 2:
g = relevance.reshape(pl[0], pl[1]).float()
else:
side = int(math.sqrt(relevance.numel()))
g = relevance.reshape(side, side).float()
m = g.unsqueeze(0).unsqueeze(0)
return F.interpolate(m, size=out_hw, mode="bilinear", align_corners=False)
def compute_legrad_imagebind(
self,
text_embedding: torch.Tensor,
vision: Optional[torch.Tensor] = None,
normalize: bool = True,
) -> torch.Tensor:
"""
Accumulate LeGrad maps over vision blocks ``[starting_depth, n_blocks)`` (CLIP-style).
``text_embedding``: (B, D) same ordering as ``vision`` batch, L2-normalized like
``model({TEXT: ...})`` outputs.
"""
if vision is not None:
_ = self._encode_vision_trunk(vision)
trunk = self._trunk(ModalityType.VISION)
blocks: List = list(trunk.blocks)
layout = self.modality_preprocessors[ModalityType.VISION].patches_layout
num_heads = blocks[0].attn.num_heads
bsz = text_embedding.shape[0]
accum = 0.0
for layer in range(self.starting_depth, len(blocks)):
self.zero_grad(set_to_none=True)
vision_emb = self._vision_embed_from_layer(layer)
one_hot = (text_embedding * vision_emb).sum()
attn_map = blocks[layer].attn.attention_maps
grad = torch.autograd.grad(
one_hot, [attn_map], retain_graph=True, create_graph=True
)[0]
rel = self._cls_to_patch_relevance(grad, bsz, num_heads)
expl = self._relevance_to_spatial_map(rel[0], layout)
accum = accum + expl
if normalize:
accum = self._min_max(accum)
return accum
def compute_legrad_imagebind_one_layer(
self,
text_embedding: torch.Tensor,
vision: Optional[torch.Tensor] = None,
layer_idx: Optional[int] = None,
normalize: bool = True,
) -> torch.Tensor:
"""Single vision block (``legrad_api.compute_legrad_coca_one_layer`` style)."""
if vision is not None:
_ = self._encode_vision_trunk(vision)
trunk = self._trunk(ModalityType.VISION)
blocks = trunk.blocks
n_blocks = len(blocks)
if layer_idx is None:
layer_idx = n_blocks - 1
if layer_idx < self.starting_depth or layer_idx >= n_blocks:
raise ValueError(
f"layer_idx must be in [{self.starting_depth}, {n_blocks - 1}], got {layer_idx}"
)
layout = self.modality_preprocessors[ModalityType.VISION].patches_layout
num_heads = blocks[layer_idx].attn.num_heads
bsz = text_embedding.shape[0]
self.zero_grad(set_to_none=True)
vision_emb = self._vision_embed_from_layer(layer_idx)
one_hot = (text_embedding * vision_emb).sum()
attn_map = blocks[layer_idx].attn.attention_maps
grad = torch.autograd.grad(
one_hot, [attn_map], retain_graph=True, create_graph=True
)[0]
rel = self._cls_to_patch_relevance(grad, bsz, num_heads)
expl = self._relevance_to_spatial_map(rel[0], layout)
if normalize:
expl = (expl - expl.min()) / (expl.max() - expl.min() + 1e-8)
return expl
def compute_legrad_text_imagebind(
self,
vision_embedding: torch.Tensor,
text: torch.Tensor,
layer_idx: Optional[int] = None,
normalize: bool = True,
) -> torch.Tensor:
"""
Text-branch LeGrad: gradient of ``sum(vision · text)`` w.r.t. attention at one layer.
``vision_embedding``: (B, D) detached reference (e.g. from ``model({VISION})``).
``text``: token ids (B, L). Returns (B, L_ctx) relevance over token positions for EOS
query row (uses ``seq_len`` from the text preprocessor).
"""
if self._legrad_trunk_key != ModalityType.TEXT:
raise RuntimeError(
"compute_legrad_text_imagebind requires wrapping with trunk_key=TEXT. "
"Instantiate ImageBindLeWrapper(model, layer_index=..., trunk_key=ModalityType.TEXT)."
)
_ = self._encode_text_trunk(text)
trunk = self._trunk(ModalityType.TEXT)
blocks = trunk.blocks
n_blocks = len(blocks)
if layer_idx is None:
layer_idx = n_blocks - 1
seq_len = self._text_head_kwargs["seq_len"]
num_heads = blocks[layer_idx].attn.num_heads
bsz = vision_embedding.shape[0]
self.zero_grad(set_to_none=True)
text_emb = self._text_embed_from_layer(layer_idx)
one_hot = (vision_embedding.detach() * text_emb).sum()
attn_map = blocks[layer_idx].attn.attention_maps
grad = torch.autograd.grad(
one_hot, [attn_map], retain_graph=True, create_graph=True
)[0]
# (B*H, L, L) → EOS query → key importances
L = grad.shape[-1]
g = grad.view(bsz, num_heads, L, L).clamp(min=0.0).mean(dim=1)
idx = torch.arange(bsz, device=g.device)
eos_rel = g[idx, seq_len, :]
if normalize:
eos_rel = self._min_max(eos_rel)
return eos_rel
def compute_legrad(
self,
text_embedding: torch.Tensor,
vision: Optional[torch.Tensor] = None,
trunk: str = "vision",
) -> torch.Tensor:
"""Dispatch: ``trunk=='vision'`` → ``compute_legrad_imagebind`` (multi-layer sum)."""
if trunk in ("vision", ModalityType.VISION):
return self.compute_legrad_imagebind(text_embedding, vision=vision)
raise ValueError(f"Unknown trunk {trunk!r}; use compute_legrad_* methods directly.")