| import sys |
|
|
| import cv2 |
| import numpy as np |
| import torch |
|
|
| from imagenet_class_indices import CLS2IDX |
|
|
| sys.path.append("Transformer-Explainability") |
|
|
|
|
| from baselines.ViT.ViT_explanation_generator import LRP, Baselines |
| from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP |
| from baselines.ViT.ViT_new import vit_base_patch16_224 as vit |
|
|
|
|
| |
| def show_cam_on_image(img, mask): |
| heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) |
| heatmap = np.float32(heatmap) / 255 |
| cam = heatmap + np.float32(img) |
| cam = cam / np.max(cam) |
| return cam |
|
|
|
|
| |
| model = vit_LRP(pretrained=True) |
| model.eval() |
| attribution_generator = LRP(model) |
| model_baseline = vit(pretrained=True) |
| model_baseline.eval() |
| baselines_generator = Baselines(model_baseline) |
|
|
|
|
| def generate_visualization( |
| original_image, class_index=None, method="transformer_attribution", LRP=True |
| ): |
| if LRP: |
| transformer_attribution = attribution_generator.generate_LRP( |
| original_image.unsqueeze(0), method=method, index=class_index |
| ).detach() |
| else: |
| if method == "gradcam": |
| transformer_attribution = baselines_generator.generate_cam_attn( |
| original_image.unsqueeze(0), index=class_index |
| ).detach() |
| else: |
| transformer_attribution = baselines_generator.generate_rollout( |
| original_image.unsqueeze(0) |
| ).detach() |
| if method != "full": |
| transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14) |
| transformer_attribution = torch.nn.functional.interpolate( |
| transformer_attribution, scale_factor=16, mode="bilinear" |
| ) |
| else: |
| transformer_attribution = transformer_attribution.reshape(1, 1, 224, 224) |
| transformer_attribution = ( |
| transformer_attribution.reshape(224, 224).data.cpu().numpy() |
| ) |
| transformer_attribution = ( |
| transformer_attribution - transformer_attribution.min() |
| ) / (transformer_attribution.max() - transformer_attribution.min()) |
|
|
| image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy() |
| image_transformer_attribution = ( |
| image_transformer_attribution - image_transformer_attribution.min() |
| ) / (image_transformer_attribution.max() - image_transformer_attribution.min()) |
| vis = show_cam_on_image(image_transformer_attribution, transformer_attribution) |
| vis = np.uint8(255 * vis) |
| vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) |
| return vis |
|
|
|
|
| def print_top_classes(predictions, **kwargs): |
| |
| prob = torch.softmax(predictions, dim=1) |
| class_indices = predictions.data.topk(5, dim=1)[1][0].tolist() |
| max_str_len = 0 |
| class_names = [] |
| for cls_idx in class_indices: |
| class_names.append(CLS2IDX[cls_idx]) |
| if len(CLS2IDX[cls_idx]) > max_str_len: |
| max_str_len = len(CLS2IDX[cls_idx]) |
|
|
| print("Top 5 classes:") |
| for cls_idx in class_indices: |
| output_string = "\t{} : {}".format(cls_idx, CLS2IDX[cls_idx]) |
| output_string += " " * (max_str_len - len(CLS2IDX[cls_idx])) + "\t\t" |
| output_string += "value = {:.3f}\t prob = {:.1f}%".format( |
| predictions[0, cls_idx], 100 * prob[0, cls_idx] |
| ) |
| print(output_string) |
|
|