| | import math |
| | import types |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torchvision.transforms import Compose, Resize, InterpolationMode |
| | import open_clip |
| | from open_clip.transformer import VisionTransformer |
| | from open_clip.timm_model import TimmModel |
| | from einops import rearrange |
| |
|
| | from .utils import ( |
| | hooked_attention_timm_forward, |
| | hooked_resblock_forward, |
| | hooked_attention_forward, |
| | hooked_resblock_timm_forward, |
| | hooked_attentional_pooler_timm_forward, |
| | vit_dynamic_size_forward, |
| | min_max, |
| | hooked_torch_multi_head_attention_forward, |
| | ) |
| |
|
| |
|
| | class LeWrapper(nn.Module): |
| | """ |
| | Wrapper around OpenCLIP to add LeGrad to OpenCLIP's model while keep all the functionalities of the original model. |
| | """ |
| |
|
| | def __init__(self, model, layer_index=-2): |
| | super(LeWrapper, self).__init__() |
| | |
| | for attr in dir(model): |
| | if not attr.startswith("__"): |
| | setattr(self, attr, getattr(model, attr)) |
| |
|
| | |
| | self._activate_hooks(layer_index=layer_index) |
| |
|
| | def _activate_hooks(self, layer_index): |
| | |
| | print("Activating necessary hooks and gradients ....") |
| | if isinstance(self.visual, VisionTransformer): |
| | |
| | self.visual.forward = types.MethodType( |
| | vit_dynamic_size_forward, self.visual |
| | ) |
| | |
| | self.patch_size = self.visual.patch_size[0] |
| | |
| | self.starting_depth = ( |
| | layer_index |
| | if layer_index >= 0 |
| | else len(self.visual.transformer.resblocks) + layer_index |
| | ) |
| |
|
| | if self.visual.attn_pool is None: |
| | self.model_type = "clip" |
| | self._activate_self_attention_hooks() |
| | else: |
| | self.model_type = "coca" |
| | self._activate_att_pool_hooks(layer_index=layer_index) |
| |
|
| | elif isinstance(self.visual, TimmModel): |
| | |
| | self.visual.trunk.dynamic_img_size = True |
| | self.visual.trunk.patch_embed.dynamic_img_size = True |
| | self.visual.trunk.patch_embed.strict_img_size = False |
| | self.visual.trunk.patch_embed.flatten = False |
| | self.visual.trunk.patch_embed.output_fmt = "NHWC" |
| | self.model_type = "timm_siglip" |
| | |
| | self.patch_size = self.visual.trunk.patch_embed.patch_size[0] |
| | |
| | self.starting_depth = ( |
| | layer_index |
| | if layer_index >= 0 |
| | else len(self.visual.trunk.blocks) + layer_index |
| | ) |
| | if ( |
| | hasattr(self.visual.trunk, "attn_pool") |
| | and self.visual.trunk.attn_pool is not None |
| | ): |
| | self._activate_timm_attn_pool_hooks(layer_index=layer_index) |
| | else: |
| | self._activate_timm_self_attention_hooks() |
| | else: |
| | raise ValueError( |
| | "Model currently not supported, see legrad.list_pretrained() for a list of available models" |
| | ) |
| | print("Hooks and gradients activated!") |
| |
|
| | def _activate_self_attention_hooks(self): |
| | |
| | if isinstance(self.visual, VisionTransformer): |
| | blocks = self.visual.transformer.resblocks |
| | elif isinstance(self.visual, TimmModel): |
| | blocks = self.visual.trunk.blocks |
| | else: |
| | raise ValueError("Unsupported model type for self-attention hooks") |
| |
|
| | |
| | |
| | for name, param in self.named_parameters(): |
| | param.requires_grad = False |
| | if name.startswith("visual.trunk.blocks"): |
| | depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0]) |
| | if depth >= self.starting_depth: |
| | param.requires_grad = True |
| |
|
| | |
| | for layer in range(self.starting_depth, len(blocks)): |
| | blocks[layer].attn.forward = types.MethodType( |
| | hooked_attention_forward, blocks[layer].attn |
| | ) |
| | blocks[layer].forward = types.MethodType( |
| | hooked_resblock_forward, blocks[layer] |
| | ) |
| |
|
| | def _activate_timm_self_attention_hooks(self): |
| | |
| | blocks = self.visual.trunk.blocks |
| |
|
| | |
| | |
| | for name, param in self.named_parameters(): |
| | param.requires_grad = False |
| | if name.startswith("visual.trunk.blocks"): |
| | depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0]) |
| | if depth >= self.starting_depth: |
| | param.requires_grad = True |
| |
|
| | |
| | for layer in range(self.starting_depth, len(blocks)): |
| | blocks[layer].attn.forward = types.MethodType( |
| | hooked_attention_timm_forward, blocks[layer].attn |
| | ) |
| | blocks[layer].forward = types.MethodType( |
| | hooked_resblock_timm_forward, blocks[layer] |
| | ) |
| |
|
| | def _activate_att_pool_hooks(self, layer_index): |
| | |
| | |
| | for name, param in self.named_parameters(): |
| | param.requires_grad = False |
| | if name.startswith("visual.transformer.resblocks"): |
| | |
| | depth = int( |
| | name.split("visual.transformer.resblocks.")[-1].split(".")[0] |
| | ) |
| | if depth >= self.starting_depth: |
| | param.requires_grad = True |
| |
|
| | |
| | for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)): |
| | self.visual.transformer.resblocks[layer].forward = types.MethodType( |
| | hooked_resblock_forward, self.visual.transformer.resblocks[layer] |
| | ) |
| | |
| | self.visual.attn_pool.attn.forward = types.MethodType( |
| | hooked_torch_multi_head_attention_forward, self.visual.attn_pool.attn |
| | ) |
| |
|
| | def _activate_timm_attn_pool_hooks(self, layer_index): |
| | |
| | if ( |
| | not hasattr(self.visual.trunk, "attn_pool") |
| | or self.visual.trunk.attn_pool is None |
| | ): |
| | raise ValueError("Attentional pooling not found in TimmModel") |
| |
|
| | self.visual.trunk.attn_pool.forward = types.MethodType( |
| | hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool |
| | ) |
| | for block in self.visual.trunk.blocks: |
| | if hasattr(block, "attn"): |
| | block.attn.forward = types.MethodType( |
| | hooked_attention_forward, block.attn |
| | ) |
| |
|
| | |
| | for name, param in self.named_parameters(): |
| | param.requires_grad = False |
| | if name.startswith("visual.trunk.attn_pool"): |
| | param.requires_grad = True |
| | if name.startswith("visual.trunk.blocks"): |
| | |
| | depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0]) |
| | if depth >= self.starting_depth: |
| | param.requires_grad = True |
| |
|
| | |
| | for layer in range(self.starting_depth, len(self.visual.trunk.blocks)): |
| | self.visual.trunk.blocks[layer].forward = types.MethodType( |
| | hooked_resblock_timm_forward, self.visual.trunk.blocks[layer] |
| | ) |
| |
|
| | self.visual.trunk.attn_pool.forward = types.MethodType( |
| | hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool |
| | ) |
| |
|
| | def compute_legrad(self, text_embedding, image=None, apply_correction=True): |
| | if "clip" in self.model_type: |
| | return self.compute_legrad_clip(text_embedding, image) |
| | elif "siglip" in self.model_type: |
| | return self.compute_legrad_siglip( |
| | text_embedding, image, apply_correction=apply_correction |
| | ) |
| | elif "coca" in self.model_type: |
| | return self.compute_legrad_coca(text_embedding, image) |
| |
|
| | def compute_legrad_clip(self, text_embedding, image=None): |
| | num_prompts = text_embedding.shape[0] |
| | if image is not None: |
| | |
| | _ = self.encode_image(image) |
| |
|
| | blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values()) |
| |
|
| | image_features_list = [] |
| |
|
| | for layer in range(self.starting_depth, len(self.visual.trunk.blocks)): |
| | |
| | intermediate_feat = blocks_list[layer].feat_post_mlp |
| | |
| | intermediate_feat = intermediate_feat.mean(dim=1) |
| | intermediate_feat = self.visual.head( |
| | self.visual.trunk.norm(intermediate_feat) |
| | ) |
| | intermediate_feat = F.normalize(intermediate_feat, dim=-1) |
| | image_features_list.append(intermediate_feat) |
| |
|
| | num_tokens = blocks_list[-1].feat_post_mlp.shape[1] - 1 |
| | w = h = int(math.sqrt(num_tokens)) |
| |
|
| | |
| | accum_expl_map = 0 |
| | for layer, (blk, img_feat) in enumerate( |
| | zip(blocks_list[self.starting_depth :], image_features_list) |
| | ): |
| | self.visual.zero_grad() |
| | sim = text_embedding @ img_feat.transpose(-1, -2) |
| | one_hot = ( |
| | F.one_hot(torch.arange(0, num_prompts)) |
| | .float() |
| | .requires_grad_(True) |
| | .to(text_embedding.device) |
| | ) |
| | one_hot = torch.sum(one_hot * sim) |
| |
|
| | |
| | attn_map = blocks_list[self.starting_depth + layer].attn.attention_map |
| |
|
| | |
| | |
| | grad = torch.autograd.grad( |
| | one_hot, [attn_map], retain_graph=True, create_graph=True |
| | )[0] |
| | |
| | grad = torch.clamp(grad, min=0.0) |
| |
|
| | |
| | image_relevance = grad.mean(dim=1).mean(dim=1)[:, 1:] |
| | expl_map = rearrange(image_relevance, "b (w h) -> 1 b w h", w=w, h=h) |
| | |
| | expl_map = F.interpolate( |
| | expl_map, scale_factor=self.patch_size, mode="bilinear" |
| | ) |
| | accum_expl_map += expl_map |
| |
|
| | |
| | accum_expl_map = min_max(accum_expl_map) |
| | return accum_expl_map |
| |
|
| | def compute_legrad_coca(self, text_embedding, image=None): |
| | if image is not None: |
| | _ = self.encode_image(image) |
| |
|
| | blocks_list = list( |
| | dict(self.visual.transformer.resblocks.named_children()).values() |
| | ) |
| |
|
| | image_features_list = [] |
| |
|
| | for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)): |
| | intermediate_feat = self.visual.transformer.resblocks[ |
| | layer |
| | ].feat_post_mlp |
| | intermediate_feat = intermediate_feat.permute( |
| | 1, 0, 2 |
| | ) |
| | image_features_list.append(intermediate_feat) |
| |
|
| | num_tokens = blocks_list[-1].feat_post_mlp.shape[0] - 1 |
| | w = h = int(math.sqrt(num_tokens)) |
| |
|
| | |
| | accum_expl_map = 0 |
| | for layer, (blk, img_feat) in enumerate( |
| | zip(blocks_list[self.starting_depth :], image_features_list) |
| | ): |
| | self.visual.zero_grad() |
| | |
| | image_embedding = self.visual.attn_pool(img_feat)[ |
| | :, 0 |
| | ] |
| | image_embedding = image_embedding @ self.visual.proj |
| |
|
| | sim = text_embedding @ image_embedding.transpose(-1, -2) |
| | one_hot = torch.sum(sim) |
| |
|
| | attn_map = ( |
| | self.visual.attn_pool.attn.attention_maps |
| | ) |
| |
|
| | |
| | grad = torch.autograd.grad( |
| | one_hot, [attn_map], retain_graph=True, create_graph=True |
| | )[ |
| | 0 |
| | ] |
| | grad = torch.clamp(grad, min=0.0) |
| |
|
| | image_relevance = grad.mean(dim=0)[ |
| | 0, 1: |
| | ] |
| | expl_map = rearrange(image_relevance, "(w h) -> 1 1 w h", w=w, h=h) |
| | expl_map = F.interpolate( |
| | expl_map, scale_factor=self.patch_size, mode="bilinear" |
| | ) |
| | accum_expl_map += expl_map |
| |
|
| | |
| | accum_expl_map = (accum_expl_map - accum_expl_map.min()) / ( |
| | accum_expl_map.max() - accum_expl_map.min() |
| | ) |
| | return accum_expl_map |
| |
|
| | def _init_empty_embedding(self): |
| | if not hasattr(self, "empty_embedding"): |
| | |
| | _tok = open_clip.get_tokenizer(model_name="ViT-B-16-SigLIP") |
| | empty_text = _tok(["a photo of a"]).to(self.logit_scale.data.device) |
| | empty_embedding = self.encode_text(empty_text) |
| | empty_embedding = F.normalize(empty_embedding, dim=-1) |
| | self.empty_embedding = empty_embedding.t() |
| |
|
| | def compute_legrad_siglip( |
| | self, |
| | text_embedding, |
| | image=None, |
| | apply_correction=True, |
| | correction_threshold=0.8, |
| | ): |
| | |
| | blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values()) |
| | if image is not None: |
| | _ = self.encode_image(image) |
| |
|
| | image_features_list = [] |
| | for blk in blocks_list[self.starting_depth :]: |
| | intermediate_feat = blk.feat_post_mlp |
| | image_features_list.append(intermediate_feat) |
| |
|
| | num_tokens = blocks_list[-1].feat_post_mlp.shape[1] |
| | w = h = int(math.sqrt(num_tokens)) |
| |
|
| | if apply_correction: |
| | self._init_empty_embedding() |
| | accum_expl_map_empty = 0 |
| |
|
| | accum_expl_map = 0 |
| | for layer, (blk, img_feat) in enumerate( |
| | zip(blocks_list[self.starting_depth :], image_features_list) |
| | ): |
| | self.zero_grad() |
| | pooled_feat = self.visual.trunk.attn_pool(img_feat) |
| | pooled_feat = F.normalize(pooled_feat, dim=-1) |
| | |
| | sim = text_embedding @ pooled_feat.transpose(-1, -2) |
| | one_hot = torch.sum(sim) |
| | grad = torch.autograd.grad( |
| | one_hot, |
| | [self.visual.trunk.attn_pool.attn_probs], |
| | retain_graph=True, |
| | create_graph=True, |
| | )[0] |
| | grad = torch.clamp(grad, min=0.0) |
| |
|
| | image_relevance = grad.mean(dim=1)[ |
| | :, 0 |
| | ] |
| | expl_map = rearrange(image_relevance, "b (w h) -> b 1 w h", w=w, h=h) |
| | accum_expl_map += expl_map |
| |
|
| | if apply_correction: |
| | |
| | sim_empty = pooled_feat @ self.empty_embedding |
| | one_hot_empty = torch.sum(sim_empty) |
| | grad_empty = torch.autograd.grad( |
| | one_hot_empty, |
| | [self.visual.trunk.attn_pool.attn_probs], |
| | retain_graph=True, |
| | create_graph=True, |
| | )[0] |
| | grad_empty = torch.clamp(grad_empty, min=0.0) |
| |
|
| | image_relevance_empty = grad_empty.mean(dim=1)[ |
| | :, 0 |
| | ] |
| | expl_map_empty = rearrange( |
| | image_relevance_empty, "b (w h) -> b 1 w h", w=w, h=h |
| | ) |
| | accum_expl_map_empty += expl_map_empty |
| |
|
| | if apply_correction: |
| | heatmap_empty = min_max(accum_expl_map_empty) |
| | accum_expl_map[heatmap_empty > correction_threshold] = 0 |
| |
|
| | Res = min_max(accum_expl_map) |
| | Res = F.interpolate( |
| | Res, scale_factor=self.patch_size, mode="bilinear" |
| | ) |
| |
|
| | return Res |
| |
|
| |
|
| | class LePreprocess(nn.Module): |
| | """ |
| | Modify OpenCLIP preprocessing to accept arbitrary image size. |
| | """ |
| |
|
| | def __init__(self, preprocess, image_size): |
| | super(LePreprocess, self).__init__() |
| | self.transform = Compose( |
| | [ |
| | Resize( |
| | (image_size, image_size), interpolation=InterpolationMode.BICUBIC |
| | ), |
| | preprocess.transforms[-3], |
| | preprocess.transforms[-2], |
| | preprocess.transforms[-1], |
| | ] |
| | ) |
| |
|
| | def forward(self, image): |
| | return self.transform(image) |
| |
|