| import argparse |
| import os |
|
|
| import h5py |
| |
| from misc_functions import * |
| from torchvision.datasets import ImageNet |
| from tqdm import tqdm |
| from ViT_explanation_generator import LRP, Baselines |
| from ViT_LRP import vit_base_patch16_224 as vit_LRP |
| from ViT_new import vit_base_patch16_224 |
| from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP |
|
|
|
|
| def normalize(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): |
| dtype = tensor.dtype |
| mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) |
| std = torch.as_tensor(std, dtype=dtype, device=tensor.device) |
| tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) |
| return tensor |
|
|
|
|
| def compute_saliency_and_save(args): |
| first = True |
| with h5py.File(os.path.join(args.method_dir, "results.hdf5"), "a") as f: |
| data_cam = f.create_dataset( |
| "vis", |
| (1, 1, 224, 224), |
| maxshape=(None, 1, 224, 224), |
| dtype=np.float32, |
| compression="gzip", |
| ) |
| data_image = f.create_dataset( |
| "image", |
| (1, 3, 224, 224), |
| maxshape=(None, 3, 224, 224), |
| dtype=np.float32, |
| compression="gzip", |
| ) |
| data_target = f.create_dataset( |
| "target", (1,), maxshape=(None,), dtype=np.int32, compression="gzip" |
| ) |
| for batch_idx, (data, target) in enumerate(tqdm(sample_loader)): |
| if first: |
| first = False |
| data_cam.resize(data_cam.shape[0] + data.shape[0] - 1, axis=0) |
| data_image.resize(data_image.shape[0] + data.shape[0] - 1, axis=0) |
| data_target.resize(data_target.shape[0] + data.shape[0] - 1, axis=0) |
| else: |
| data_cam.resize(data_cam.shape[0] + data.shape[0], axis=0) |
| data_image.resize(data_image.shape[0] + data.shape[0], axis=0) |
| data_target.resize(data_target.shape[0] + data.shape[0], axis=0) |
|
|
| |
| data_image[-data.shape[0] :] = data.data.cpu().numpy() |
| data_target[-data.shape[0] :] = target.data.cpu().numpy() |
|
|
| target = target.to(device) |
|
|
| data = normalize(data) |
| data = data.to(device) |
| data.requires_grad_() |
|
|
| index = None |
| if args.vis_class == "target": |
| index = target |
|
|
| if args.method == "rollout": |
| Res = baselines.generate_rollout(data, start_layer=1).reshape( |
| data.shape[0], 1, 14, 14 |
| ) |
| |
|
|
| elif args.method == "lrp": |
| Res = lrp.generate_LRP(data, start_layer=1, index=index).reshape( |
| data.shape[0], 1, 14, 14 |
| ) |
| |
|
|
| elif args.method == "transformer_attribution": |
| Res = lrp.generate_LRP( |
| data, start_layer=1, method="grad", index=index |
| ).reshape(data.shape[0], 1, 14, 14) |
| |
|
|
| elif args.method == "full_lrp": |
| Res = orig_lrp.generate_LRP(data, method="full", index=index).reshape( |
| data.shape[0], 1, 224, 224 |
| ) |
| |
|
|
| elif args.method == "lrp_last_layer": |
| Res = orig_lrp.generate_LRP( |
| data, method="last_layer", is_ablation=args.is_ablation, index=index |
| ).reshape(data.shape[0], 1, 14, 14) |
| |
|
|
| elif args.method == "attn_last_layer": |
| Res = lrp.generate_LRP( |
| data, method="last_layer_attn", is_ablation=args.is_ablation |
| ).reshape(data.shape[0], 1, 14, 14) |
|
|
| elif args.method == "attn_gradcam": |
| Res = baselines.generate_cam_attn(data, index=index).reshape( |
| data.shape[0], 1, 14, 14 |
| ) |
|
|
| if args.method != "full_lrp" and args.method != "input_grads": |
| Res = torch.nn.functional.interpolate( |
| Res, scale_factor=16, mode="bilinear" |
| ).cuda() |
| Res = (Res - Res.min()) / (Res.max() - Res.min()) |
|
|
| data_cam[-data.shape[0] :] = Res.data.cpu().numpy() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Train a segmentation") |
| parser.add_argument("--batch-size", type=int, default=1, help="") |
| parser.add_argument( |
| "--method", |
| type=str, |
| default="grad_rollout", |
| choices=[ |
| "rollout", |
| "lrp", |
| "transformer_attribution", |
| "full_lrp", |
| "lrp_last_layer", |
| "attn_last_layer", |
| "attn_gradcam", |
| ], |
| help="", |
| ) |
| parser.add_argument("--lmd", type=float, default=10, help="") |
| parser.add_argument( |
| "--vis-class", |
| type=str, |
| default="top", |
| choices=["top", "target", "index"], |
| help="", |
| ) |
| parser.add_argument("--class-id", type=int, default=0, help="") |
| parser.add_argument("--cls-agn", action="store_true", default=False, help="") |
| parser.add_argument("--no-ia", action="store_true", default=False, help="") |
| parser.add_argument("--no-fx", action="store_true", default=False, help="") |
| parser.add_argument("--no-fgx", action="store_true", default=False, help="") |
| parser.add_argument("--no-m", action="store_true", default=False, help="") |
| parser.add_argument("--no-reg", action="store_true", default=False, help="") |
| parser.add_argument("--is-ablation", type=bool, default=False, help="") |
| parser.add_argument("--imagenet-validation-path", type=str, required=True, help="") |
| args = parser.parse_args() |
|
|
| |
| PATH = os.path.dirname(os.path.abspath(__file__)) + "/" |
| os.makedirs(os.path.join(PATH, "visualizations"), exist_ok=True) |
|
|
| try: |
| os.remove( |
| os.path.join( |
| PATH, |
| "visualizations/{}/{}/results.hdf5".format(args.method, args.vis_class), |
| ) |
| ) |
| except OSError: |
| pass |
|
|
| os.makedirs( |
| os.path.join(PATH, "visualizations/{}".format(args.method)), exist_ok=True |
| ) |
| if args.vis_class == "index": |
| os.makedirs( |
| os.path.join( |
| PATH, |
| "visualizations/{}/{}_{}".format( |
| args.method, args.vis_class, args.class_id |
| ), |
| ), |
| exist_ok=True, |
| ) |
| args.method_dir = os.path.join( |
| PATH, |
| "visualizations/{}/{}_{}".format( |
| args.method, args.vis_class, args.class_id |
| ), |
| ) |
| else: |
| ablation_fold = "ablation" if args.is_ablation else "not_ablation" |
| os.makedirs( |
| os.path.join( |
| PATH, |
| "visualizations/{}/{}/{}".format( |
| args.method, args.vis_class, ablation_fold |
| ), |
| ), |
| exist_ok=True, |
| ) |
| args.method_dir = os.path.join( |
| PATH, |
| "visualizations/{}/{}/{}".format( |
| args.method, args.vis_class, ablation_fold |
| ), |
| ) |
|
|
| cuda = torch.cuda.is_available() |
| device = torch.device("cuda" if cuda else "cpu") |
|
|
| |
| model = vit_base_patch16_224(pretrained=True).cuda() |
| baselines = Baselines(model) |
|
|
| |
| model_LRP = vit_LRP(pretrained=True).cuda() |
| model_LRP.eval() |
| lrp = LRP(model_LRP) |
|
|
| |
| model_orig_LRP = vit_orig_LRP(pretrained=True).cuda() |
| model_orig_LRP.eval() |
| orig_lrp = LRP(model_orig_LRP) |
|
|
| |
| transform = transforms.Compose( |
| [ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| ] |
| ) |
|
|
| imagenet_ds = ImageNet( |
| args.imagenet_validation_path, split="val", download=False, transform=transform |
| ) |
| sample_loader = torch.utils.data.DataLoader( |
| imagenet_ds, batch_size=args.batch_size, shuffle=False, num_workers=4 |
| ) |
|
|
| compute_saliency_and_save(args) |
|
|