| """ |
| LeGrad for ImageBind — same spirit as ``legrad.wrapper.LeWrapper`` / ``legrad_api.ipynb``: |
| hook residual blocks + PyTorch ``nn.MultiheadAttention``, then |
| ``grad(sum(text · vision))`` w.r.t. attention probabilities. |
| |
| Requires the ``legrad`` package (``pip install -e /path/to/LeGrad`` or PYTHONPATH). |
| |
| Vision: heatmap over image patches (CLS query → patch keys). |
| Text: relevance vector over context positions (EOS query row → all keys). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| import types |
| from typing import List, Optional, Sequence |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from imagebind.models.imagebind_model import ImageBindModel, ModalityType |
|
|
|
|
| def _import_legrad_utils(): |
| try: |
| from legrad.utils import hooked_torch_multi_head_attention_forward, min_max |
|
|
| return hooked_torch_multi_head_attention_forward, min_max |
| except ImportError as e: |
| raise ImportError( |
| "ImageBind LeGrad needs the `legrad` package. Install with " |
| "`pip install -e <path-to-LeGrad>` or add LeGrad to PYTHONPATH." |
| ) from e |
|
|
|
|
| def hooked_imagebind_block_forward(self, x: torch.Tensor, attn_mask: torch.Tensor): |
| """Store features after attention and after MLP (ImageBind ``BlockWithMasking``).""" |
| if self.layer_scale_type is None: |
| x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) |
| self.feat_post_attn = x |
| x = x + self.drop_path(self.mlp(self.norm_2(x))) |
| self.feat_post_mlp = x |
| else: |
| x = ( |
| x |
| + self.drop_path(self.attn(self.norm_1(x), attn_mask)) |
| * self.layer_scale_gamma1 |
| ) |
| self.feat_post_attn = x |
| x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 |
| self.feat_post_mlp = x |
| return x |
|
|
|
|
| def _make_hooked_imagebind_mha_forward(hooked_torch_mha_forward): |
| def hooked_imagebind_mha_forward(self, x: torch.Tensor, attn_mask: torch.Tensor): |
| """Adapter: LeGrad hooked MHA expects ``(q,k,v,...)``; ImageBind calls ``(x, attn_mask)``.""" |
| out, _ = hooked_torch_mha_forward( |
| self, |
| x, |
| x, |
| x, |
| key_padding_mask=None, |
| need_weights=True, |
| attn_mask=attn_mask, |
| ) |
| return out |
|
|
| return hooked_imagebind_mha_forward |
|
|
|
|
| class ImageBindLeWrapper(nn.Module): |
| """ |
| Thin wrapper around ``ImageBindModel`` for LeGrad (vision and/or text branches). |
| |
| Mirrors ``LeWrapper`` from ``legrad/wrapper.py``: copies public attributes/methods from |
| the base model, patches transformer blocks and attention with hooks, and provides |
| ``compute_legrad_*`` helpers similar to ``compute_legrad_coca`` / ``compute_legrad_clip``. |
| """ |
|
|
| def __init__( |
| self, |
| model: ImageBindModel, |
| layer_index: int = -2, |
| trunk_key: str = ModalityType.VISION, |
| ): |
| super().__init__() |
| for attr in dir(model): |
| if not attr.startswith("__"): |
| setattr(self, attr, getattr(model, attr)) |
|
|
| self._legrad_trunk_key = trunk_key |
| hooked_torch_mha_forward, self._min_max = _import_legrad_utils() |
| self._hooked_mha_fn = _make_hooked_imagebind_mha_forward(hooked_torch_mha_forward) |
| self._activate_hooks(layer_index=layer_index, trunk_key=trunk_key) |
|
|
| def _trunk(self, key: Optional[str] = None): |
| key = key or self._legrad_trunk_key |
| return self.modality_trunks[key] |
|
|
| def _activate_hooks(self, layer_index: int, trunk_key: str) -> None: |
| trunk = self._trunk(trunk_key) |
| n_blocks = len(trunk.blocks) |
| self.starting_depth = ( |
| layer_index if layer_index >= 0 else n_blocks + layer_index |
| ) |
| self.starting_depth = max(0, min(self.starting_depth, n_blocks - 1)) |
|
|
| prefix = f"modality_trunks.{trunk_key}.blocks" |
| for name, param in self.named_parameters(): |
| param.requires_grad = False |
| if name.startswith(prefix): |
| depth = int(name.split(f"{prefix}.")[-1].split(".")[0]) |
| if depth >= self.starting_depth: |
| param.requires_grad = True |
|
|
| for layer in range(self.starting_depth, n_blocks): |
| blk = trunk.blocks[layer] |
| blk.forward = types.MethodType(hooked_imagebind_block_forward, blk) |
| blk.attn.forward = types.MethodType(self._hooked_mha_fn, blk.attn) |
|
|
| print( |
| f"LeGrad (ImageBind): hooks on `{trunk_key}` blocks " |
| f"[{self.starting_depth}, {n_blocks - 1}] — gradients enabled from block " |
| f"{self.starting_depth} onward." |
| ) |
|
|
| def _encode_vision_trunk(self, vision: torch.Tensor) -> torch.Tensor: |
| p = self.modality_preprocessors[ModalityType.VISION](vision=vision) |
| return self.modality_trunks[ModalityType.VISION](**p["trunk"]) |
|
|
| def _encode_text_trunk(self, text: torch.Tensor) -> torch.Tensor: |
| p = self.modality_preprocessors[ModalityType.TEXT](text=text) |
| self._text_head_kwargs = dict(p.get("head", {})) |
| return self.modality_trunks[ModalityType.TEXT](**p["trunk"]) |
|
|
| def _vision_embed_from_layer(self, layer_idx: int) -> torch.Tensor: |
| x_bld = ( |
| self._trunk(ModalityType.VISION).blocks[layer_idx].feat_post_mlp.permute( |
| 1, 0, 2 |
| ) |
| ) |
| h = self.modality_heads[ModalityType.VISION](x_bld) |
| return self.modality_postprocessors[ModalityType.VISION](h) |
|
|
| def _text_embed_from_layer(self, layer_idx: int) -> torch.Tensor: |
| x_bld = ( |
| self._trunk(ModalityType.TEXT).blocks[layer_idx].feat_post_mlp.permute( |
| 1, 0, 2 |
| ) |
| ) |
| h = self.modality_heads[ModalityType.TEXT](x_bld, **self._text_head_kwargs) |
| return self.modality_postprocessors[ModalityType.TEXT](h) |
|
|
| @staticmethod |
| def _cls_to_patch_relevance( |
| attn_grad: torch.Tensor, batch_size: int, num_heads: int |
| ) -> torch.Tensor: |
| """attn_grad: (B*H, L, L) -> (B, num_patches) CLS row, patch columns.""" |
| L = attn_grad.shape[-1] |
| g = attn_grad.view(batch_size, num_heads, L, L).clamp(min=0.0) |
| g = g.mean(dim=1) |
| return g[:, 0, 1:] |
|
|
| @staticmethod |
| def _relevance_to_spatial_map( |
| relevance: torch.Tensor, patches_layout: Sequence[int], out_hw: tuple = (224, 224) |
| ) -> torch.Tensor: |
| """relevance: (num_patches,) → (1,1,H,W) upsampled.""" |
| pl = tuple(patches_layout) |
| if len(pl) == 3: |
| t, h, w = pl |
| g = relevance.reshape(t, h, w).float() |
| g = g.mean(dim=0) if t > 1 else g[0] |
| elif len(pl) == 2: |
| g = relevance.reshape(pl[0], pl[1]).float() |
| else: |
| side = int(math.sqrt(relevance.numel())) |
| g = relevance.reshape(side, side).float() |
| m = g.unsqueeze(0).unsqueeze(0) |
| return F.interpolate(m, size=out_hw, mode="bilinear", align_corners=False) |
|
|
| def compute_legrad_imagebind( |
| self, |
| text_embedding: torch.Tensor, |
| vision: Optional[torch.Tensor] = None, |
| normalize: bool = True, |
| ) -> torch.Tensor: |
| """ |
| Accumulate LeGrad maps over vision blocks ``[starting_depth, n_blocks)`` (CLIP-style). |
| |
| ``text_embedding``: (B, D) same ordering as ``vision`` batch, L2-normalized like |
| ``model({TEXT: ...})`` outputs. |
| """ |
| if vision is not None: |
| _ = self._encode_vision_trunk(vision) |
|
|
| trunk = self._trunk(ModalityType.VISION) |
| blocks: List = list(trunk.blocks) |
| layout = self.modality_preprocessors[ModalityType.VISION].patches_layout |
| num_heads = blocks[0].attn.num_heads |
| bsz = text_embedding.shape[0] |
|
|
| accum = 0.0 |
| for layer in range(self.starting_depth, len(blocks)): |
| self.zero_grad(set_to_none=True) |
| vision_emb = self._vision_embed_from_layer(layer) |
| one_hot = (text_embedding * vision_emb).sum() |
| attn_map = blocks[layer].attn.attention_maps |
| grad = torch.autograd.grad( |
| one_hot, [attn_map], retain_graph=True, create_graph=True |
| )[0] |
| rel = self._cls_to_patch_relevance(grad, bsz, num_heads) |
| expl = self._relevance_to_spatial_map(rel[0], layout) |
| accum = accum + expl |
|
|
| if normalize: |
| accum = self._min_max(accum) |
| return accum |
|
|
| def compute_legrad_imagebind_one_layer( |
| self, |
| text_embedding: torch.Tensor, |
| vision: Optional[torch.Tensor] = None, |
| layer_idx: Optional[int] = None, |
| normalize: bool = True, |
| ) -> torch.Tensor: |
| """Single vision block (``legrad_api.compute_legrad_coca_one_layer`` style).""" |
| if vision is not None: |
| _ = self._encode_vision_trunk(vision) |
|
|
| trunk = self._trunk(ModalityType.VISION) |
| blocks = trunk.blocks |
| n_blocks = len(blocks) |
| if layer_idx is None: |
| layer_idx = n_blocks - 1 |
| if layer_idx < self.starting_depth or layer_idx >= n_blocks: |
| raise ValueError( |
| f"layer_idx must be in [{self.starting_depth}, {n_blocks - 1}], got {layer_idx}" |
| ) |
|
|
| layout = self.modality_preprocessors[ModalityType.VISION].patches_layout |
| num_heads = blocks[layer_idx].attn.num_heads |
| bsz = text_embedding.shape[0] |
|
|
| self.zero_grad(set_to_none=True) |
| vision_emb = self._vision_embed_from_layer(layer_idx) |
| one_hot = (text_embedding * vision_emb).sum() |
| attn_map = blocks[layer_idx].attn.attention_maps |
| grad = torch.autograd.grad( |
| one_hot, [attn_map], retain_graph=True, create_graph=True |
| )[0] |
| rel = self._cls_to_patch_relevance(grad, bsz, num_heads) |
| expl = self._relevance_to_spatial_map(rel[0], layout) |
| if normalize: |
| expl = (expl - expl.min()) / (expl.max() - expl.min() + 1e-8) |
| return expl |
|
|
| def compute_legrad_text_imagebind( |
| self, |
| vision_embedding: torch.Tensor, |
| text: torch.Tensor, |
| layer_idx: Optional[int] = None, |
| normalize: bool = True, |
| ) -> torch.Tensor: |
| """ |
| Text-branch LeGrad: gradient of ``sum(vision · text)`` w.r.t. attention at one layer. |
| |
| ``vision_embedding``: (B, D) detached reference (e.g. from ``model({VISION})``). |
| ``text``: token ids (B, L). Returns (B, L_ctx) relevance over token positions for EOS |
| query row (uses ``seq_len`` from the text preprocessor). |
| """ |
| if self._legrad_trunk_key != ModalityType.TEXT: |
| raise RuntimeError( |
| "compute_legrad_text_imagebind requires wrapping with trunk_key=TEXT. " |
| "Instantiate ImageBindLeWrapper(model, layer_index=..., trunk_key=ModalityType.TEXT)." |
| ) |
|
|
| _ = self._encode_text_trunk(text) |
|
|
| trunk = self._trunk(ModalityType.TEXT) |
| blocks = trunk.blocks |
| n_blocks = len(blocks) |
| if layer_idx is None: |
| layer_idx = n_blocks - 1 |
| seq_len = self._text_head_kwargs["seq_len"] |
| num_heads = blocks[layer_idx].attn.num_heads |
| bsz = vision_embedding.shape[0] |
|
|
| self.zero_grad(set_to_none=True) |
| text_emb = self._text_embed_from_layer(layer_idx) |
| one_hot = (vision_embedding.detach() * text_emb).sum() |
| attn_map = blocks[layer_idx].attn.attention_maps |
| grad = torch.autograd.grad( |
| one_hot, [attn_map], retain_graph=True, create_graph=True |
| )[0] |
| |
| L = grad.shape[-1] |
| g = grad.view(bsz, num_heads, L, L).clamp(min=0.0).mean(dim=1) |
| idx = torch.arange(bsz, device=g.device) |
| eos_rel = g[idx, seq_len, :] |
| if normalize: |
| eos_rel = self._min_max(eos_rel) |
| return eos_rel |
|
|
| def compute_legrad( |
| self, |
| text_embedding: torch.Tensor, |
| vision: Optional[torch.Tensor] = None, |
| trunk: str = "vision", |
| ) -> torch.Tensor: |
| """Dispatch: ``trunk=='vision'`` → ``compute_legrad_imagebind`` (multi-layer sum).""" |
| if trunk in ("vision", ModalityType.VISION): |
| return self.compute_legrad_imagebind(text_embedding, vision=vision) |
| raise ValueError(f"Unknown trunk {trunk!r}; use compute_legrad_* methods directly.") |
|
|