""" LeGrad for ImageBind — same spirit as ``legrad.wrapper.LeWrapper`` / ``legrad_api.ipynb``: hook residual blocks + PyTorch ``nn.MultiheadAttention``, then ``grad(sum(text · vision))`` w.r.t. attention probabilities. Requires the ``legrad`` package (``pip install -e /path/to/LeGrad`` or PYTHONPATH). Vision: heatmap over image patches (CLS query → patch keys). Text: relevance vector over context positions (EOS query row → all keys). """ from __future__ import annotations import math import types from typing import List, Optional, Sequence import torch import torch.nn as nn import torch.nn.functional as F from imagebind.models.imagebind_model import ImageBindModel, ModalityType def _import_legrad_utils(): try: from legrad.utils import hooked_torch_multi_head_attention_forward, min_max return hooked_torch_multi_head_attention_forward, min_max except ImportError as e: # pragma: no cover raise ImportError( "ImageBind LeGrad needs the `legrad` package. Install with " "`pip install -e ` or add LeGrad to PYTHONPATH." ) from e def hooked_imagebind_block_forward(self, x: torch.Tensor, attn_mask: torch.Tensor): """Store features after attention and after MLP (ImageBind ``BlockWithMasking``).""" if self.layer_scale_type is None: x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) self.feat_post_attn = x x = x + self.drop_path(self.mlp(self.norm_2(x))) self.feat_post_mlp = x else: x = ( x + self.drop_path(self.attn(self.norm_1(x), attn_mask)) * self.layer_scale_gamma1 ) self.feat_post_attn = x x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2 self.feat_post_mlp = x return x def _make_hooked_imagebind_mha_forward(hooked_torch_mha_forward): def hooked_imagebind_mha_forward(self, x: torch.Tensor, attn_mask: torch.Tensor): """Adapter: LeGrad hooked MHA expects ``(q,k,v,...)``; ImageBind calls ``(x, attn_mask)``.""" out, _ = hooked_torch_mha_forward( self, x, x, x, key_padding_mask=None, need_weights=True, attn_mask=attn_mask, ) return out return hooked_imagebind_mha_forward class ImageBindLeWrapper(nn.Module): """ Thin wrapper around ``ImageBindModel`` for LeGrad (vision and/or text branches). Mirrors ``LeWrapper`` from ``legrad/wrapper.py``: copies public attributes/methods from the base model, patches transformer blocks and attention with hooks, and provides ``compute_legrad_*`` helpers similar to ``compute_legrad_coca`` / ``compute_legrad_clip``. """ def __init__( self, model: ImageBindModel, layer_index: int = -2, trunk_key: str = ModalityType.VISION, ): super().__init__() for attr in dir(model): if not attr.startswith("__"): setattr(self, attr, getattr(model, attr)) self._legrad_trunk_key = trunk_key hooked_torch_mha_forward, self._min_max = _import_legrad_utils() self._hooked_mha_fn = _make_hooked_imagebind_mha_forward(hooked_torch_mha_forward) self._activate_hooks(layer_index=layer_index, trunk_key=trunk_key) def _trunk(self, key: Optional[str] = None): key = key or self._legrad_trunk_key return self.modality_trunks[key] def _activate_hooks(self, layer_index: int, trunk_key: str) -> None: trunk = self._trunk(trunk_key) n_blocks = len(trunk.blocks) self.starting_depth = ( layer_index if layer_index >= 0 else n_blocks + layer_index ) self.starting_depth = max(0, min(self.starting_depth, n_blocks - 1)) prefix = f"modality_trunks.{trunk_key}.blocks" for name, param in self.named_parameters(): param.requires_grad = False if name.startswith(prefix): depth = int(name.split(f"{prefix}.")[-1].split(".")[0]) if depth >= self.starting_depth: param.requires_grad = True for layer in range(self.starting_depth, n_blocks): blk = trunk.blocks[layer] blk.forward = types.MethodType(hooked_imagebind_block_forward, blk) blk.attn.forward = types.MethodType(self._hooked_mha_fn, blk.attn) print( f"LeGrad (ImageBind): hooks on `{trunk_key}` blocks " f"[{self.starting_depth}, {n_blocks - 1}] — gradients enabled from block " f"{self.starting_depth} onward." ) def _encode_vision_trunk(self, vision: torch.Tensor) -> torch.Tensor: p = self.modality_preprocessors[ModalityType.VISION](vision=vision) return self.modality_trunks[ModalityType.VISION](**p["trunk"]) def _encode_text_trunk(self, text: torch.Tensor) -> torch.Tensor: p = self.modality_preprocessors[ModalityType.TEXT](text=text) self._text_head_kwargs = dict(p.get("head", {})) return self.modality_trunks[ModalityType.TEXT](**p["trunk"]) def _vision_embed_from_layer(self, layer_idx: int) -> torch.Tensor: x_bld = ( self._trunk(ModalityType.VISION).blocks[layer_idx].feat_post_mlp.permute( 1, 0, 2 ) ) h = self.modality_heads[ModalityType.VISION](x_bld) return self.modality_postprocessors[ModalityType.VISION](h) def _text_embed_from_layer(self, layer_idx: int) -> torch.Tensor: x_bld = ( self._trunk(ModalityType.TEXT).blocks[layer_idx].feat_post_mlp.permute( 1, 0, 2 ) ) h = self.modality_heads[ModalityType.TEXT](x_bld, **self._text_head_kwargs) return self.modality_postprocessors[ModalityType.TEXT](h) @staticmethod def _cls_to_patch_relevance( attn_grad: torch.Tensor, batch_size: int, num_heads: int ) -> torch.Tensor: """attn_grad: (B*H, L, L) -> (B, num_patches) CLS row, patch columns.""" L = attn_grad.shape[-1] g = attn_grad.view(batch_size, num_heads, L, L).clamp(min=0.0) g = g.mean(dim=1) return g[:, 0, 1:] @staticmethod def _relevance_to_spatial_map( relevance: torch.Tensor, patches_layout: Sequence[int], out_hw: tuple = (224, 224) ) -> torch.Tensor: """relevance: (num_patches,) → (1,1,H,W) upsampled.""" pl = tuple(patches_layout) if len(pl) == 3: t, h, w = pl g = relevance.reshape(t, h, w).float() g = g.mean(dim=0) if t > 1 else g[0] elif len(pl) == 2: g = relevance.reshape(pl[0], pl[1]).float() else: side = int(math.sqrt(relevance.numel())) g = relevance.reshape(side, side).float() m = g.unsqueeze(0).unsqueeze(0) return F.interpolate(m, size=out_hw, mode="bilinear", align_corners=False) def compute_legrad_imagebind( self, text_embedding: torch.Tensor, vision: Optional[torch.Tensor] = None, normalize: bool = True, ) -> torch.Tensor: """ Accumulate LeGrad maps over vision blocks ``[starting_depth, n_blocks)`` (CLIP-style). ``text_embedding``: (B, D) same ordering as ``vision`` batch, L2-normalized like ``model({TEXT: ...})`` outputs. """ if vision is not None: _ = self._encode_vision_trunk(vision) trunk = self._trunk(ModalityType.VISION) blocks: List = list(trunk.blocks) layout = self.modality_preprocessors[ModalityType.VISION].patches_layout num_heads = blocks[0].attn.num_heads bsz = text_embedding.shape[0] accum = 0.0 for layer in range(self.starting_depth, len(blocks)): self.zero_grad(set_to_none=True) vision_emb = self._vision_embed_from_layer(layer) one_hot = (text_embedding * vision_emb).sum() attn_map = blocks[layer].attn.attention_maps grad = torch.autograd.grad( one_hot, [attn_map], retain_graph=True, create_graph=True )[0] rel = self._cls_to_patch_relevance(grad, bsz, num_heads) expl = self._relevance_to_spatial_map(rel[0], layout) accum = accum + expl if normalize: accum = self._min_max(accum) return accum def compute_legrad_imagebind_one_layer( self, text_embedding: torch.Tensor, vision: Optional[torch.Tensor] = None, layer_idx: Optional[int] = None, normalize: bool = True, ) -> torch.Tensor: """Single vision block (``legrad_api.compute_legrad_coca_one_layer`` style).""" if vision is not None: _ = self._encode_vision_trunk(vision) trunk = self._trunk(ModalityType.VISION) blocks = trunk.blocks n_blocks = len(blocks) if layer_idx is None: layer_idx = n_blocks - 1 if layer_idx < self.starting_depth or layer_idx >= n_blocks: raise ValueError( f"layer_idx must be in [{self.starting_depth}, {n_blocks - 1}], got {layer_idx}" ) layout = self.modality_preprocessors[ModalityType.VISION].patches_layout num_heads = blocks[layer_idx].attn.num_heads bsz = text_embedding.shape[0] self.zero_grad(set_to_none=True) vision_emb = self._vision_embed_from_layer(layer_idx) one_hot = (text_embedding * vision_emb).sum() attn_map = blocks[layer_idx].attn.attention_maps grad = torch.autograd.grad( one_hot, [attn_map], retain_graph=True, create_graph=True )[0] rel = self._cls_to_patch_relevance(grad, bsz, num_heads) expl = self._relevance_to_spatial_map(rel[0], layout) if normalize: expl = (expl - expl.min()) / (expl.max() - expl.min() + 1e-8) return expl def compute_legrad_text_imagebind( self, vision_embedding: torch.Tensor, text: torch.Tensor, layer_idx: Optional[int] = None, normalize: bool = True, ) -> torch.Tensor: """ Text-branch LeGrad: gradient of ``sum(vision · text)`` w.r.t. attention at one layer. ``vision_embedding``: (B, D) detached reference (e.g. from ``model({VISION})``). ``text``: token ids (B, L). Returns (B, L_ctx) relevance over token positions for EOS query row (uses ``seq_len`` from the text preprocessor). """ if self._legrad_trunk_key != ModalityType.TEXT: raise RuntimeError( "compute_legrad_text_imagebind requires wrapping with trunk_key=TEXT. " "Instantiate ImageBindLeWrapper(model, layer_index=..., trunk_key=ModalityType.TEXT)." ) _ = self._encode_text_trunk(text) trunk = self._trunk(ModalityType.TEXT) blocks = trunk.blocks n_blocks = len(blocks) if layer_idx is None: layer_idx = n_blocks - 1 seq_len = self._text_head_kwargs["seq_len"] num_heads = blocks[layer_idx].attn.num_heads bsz = vision_embedding.shape[0] self.zero_grad(set_to_none=True) text_emb = self._text_embed_from_layer(layer_idx) one_hot = (vision_embedding.detach() * text_emb).sum() attn_map = blocks[layer_idx].attn.attention_maps grad = torch.autograd.grad( one_hot, [attn_map], retain_graph=True, create_graph=True )[0] # (B*H, L, L) → EOS query → key importances L = grad.shape[-1] g = grad.view(bsz, num_heads, L, L).clamp(min=0.0).mean(dim=1) idx = torch.arange(bsz, device=g.device) eos_rel = g[idx, seq_len, :] if normalize: eos_rel = self._min_max(eos_rel) return eos_rel def compute_legrad( self, text_embedding: torch.Tensor, vision: Optional[torch.Tensor] = None, trunk: str = "vision", ) -> torch.Tensor: """Dispatch: ``trunk=='vision'`` → ``compute_legrad_imagebind`` (multi-layer sum).""" if trunk in ("vision", ModalityType.VISION): return self.compute_legrad_imagebind(text_embedding, vision=vision) raise ValueError(f"Unknown trunk {trunk!r}; use compute_legrad_* methods directly.")