| | import argparse |
| | import os |
| |
|
| | import h5py |
| | |
| | from misc_functions import * |
| | from torchvision.datasets import ImageNet |
| | from tqdm import tqdm |
| | from ViT_explanation_generator import LRP, Baselines |
| | from ViT_LRP import vit_base_patch16_224 as vit_LRP |
| | from ViT_new import vit_base_patch16_224 |
| | from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP |
| |
|
| |
|
| | def normalize(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): |
| | dtype = tensor.dtype |
| | mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) |
| | std = torch.as_tensor(std, dtype=dtype, device=tensor.device) |
| | tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) |
| | return tensor |
| |
|
| |
|
| | def compute_saliency_and_save(args): |
| | first = True |
| | with h5py.File(os.path.join(args.method_dir, "results.hdf5"), "a") as f: |
| | data_cam = f.create_dataset( |
| | "vis", |
| | (1, 1, 224, 224), |
| | maxshape=(None, 1, 224, 224), |
| | dtype=np.float32, |
| | compression="gzip", |
| | ) |
| | data_image = f.create_dataset( |
| | "image", |
| | (1, 3, 224, 224), |
| | maxshape=(None, 3, 224, 224), |
| | dtype=np.float32, |
| | compression="gzip", |
| | ) |
| | data_target = f.create_dataset( |
| | "target", (1,), maxshape=(None,), dtype=np.int32, compression="gzip" |
| | ) |
| | for batch_idx, (data, target) in enumerate(tqdm(sample_loader)): |
| | if first: |
| | first = False |
| | data_cam.resize(data_cam.shape[0] + data.shape[0] - 1, axis=0) |
| | data_image.resize(data_image.shape[0] + data.shape[0] - 1, axis=0) |
| | data_target.resize(data_target.shape[0] + data.shape[0] - 1, axis=0) |
| | else: |
| | data_cam.resize(data_cam.shape[0] + data.shape[0], axis=0) |
| | data_image.resize(data_image.shape[0] + data.shape[0], axis=0) |
| | data_target.resize(data_target.shape[0] + data.shape[0], axis=0) |
| |
|
| | |
| | data_image[-data.shape[0] :] = data.data.cpu().numpy() |
| | data_target[-data.shape[0] :] = target.data.cpu().numpy() |
| |
|
| | target = target.to(device) |
| |
|
| | data = normalize(data) |
| | data = data.to(device) |
| | data.requires_grad_() |
| |
|
| | index = None |
| | if args.vis_class == "target": |
| | index = target |
| |
|
| | if args.method == "rollout": |
| | Res = baselines.generate_rollout(data, start_layer=1).reshape( |
| | data.shape[0], 1, 14, 14 |
| | ) |
| | |
| |
|
| | elif args.method == "lrp": |
| | Res = lrp.generate_LRP(data, start_layer=1, index=index).reshape( |
| | data.shape[0], 1, 14, 14 |
| | ) |
| | |
| |
|
| | elif args.method == "transformer_attribution": |
| | Res = lrp.generate_LRP( |
| | data, start_layer=1, method="grad", index=index |
| | ).reshape(data.shape[0], 1, 14, 14) |
| | |
| |
|
| | elif args.method == "full_lrp": |
| | Res = orig_lrp.generate_LRP(data, method="full", index=index).reshape( |
| | data.shape[0], 1, 224, 224 |
| | ) |
| | |
| |
|
| | elif args.method == "lrp_last_layer": |
| | Res = orig_lrp.generate_LRP( |
| | data, method="last_layer", is_ablation=args.is_ablation, index=index |
| | ).reshape(data.shape[0], 1, 14, 14) |
| | |
| |
|
| | elif args.method == "attn_last_layer": |
| | Res = lrp.generate_LRP( |
| | data, method="last_layer_attn", is_ablation=args.is_ablation |
| | ).reshape(data.shape[0], 1, 14, 14) |
| |
|
| | elif args.method == "attn_gradcam": |
| | Res = baselines.generate_cam_attn(data, index=index).reshape( |
| | data.shape[0], 1, 14, 14 |
| | ) |
| |
|
| | if args.method != "full_lrp" and args.method != "input_grads": |
| | Res = torch.nn.functional.interpolate( |
| | Res, scale_factor=16, mode="bilinear" |
| | ).cuda() |
| | Res = (Res - Res.min()) / (Res.max() - Res.min()) |
| |
|
| | data_cam[-data.shape[0] :] = Res.data.cpu().numpy() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Train a segmentation") |
| | parser.add_argument("--batch-size", type=int, default=1, help="") |
| | parser.add_argument( |
| | "--method", |
| | type=str, |
| | default="grad_rollout", |
| | choices=[ |
| | "rollout", |
| | "lrp", |
| | "transformer_attribution", |
| | "full_lrp", |
| | "lrp_last_layer", |
| | "attn_last_layer", |
| | "attn_gradcam", |
| | ], |
| | help="", |
| | ) |
| | parser.add_argument("--lmd", type=float, default=10, help="") |
| | parser.add_argument( |
| | "--vis-class", |
| | type=str, |
| | default="top", |
| | choices=["top", "target", "index"], |
| | help="", |
| | ) |
| | parser.add_argument("--class-id", type=int, default=0, help="") |
| | parser.add_argument("--cls-agn", action="store_true", default=False, help="") |
| | parser.add_argument("--no-ia", action="store_true", default=False, help="") |
| | parser.add_argument("--no-fx", action="store_true", default=False, help="") |
| | parser.add_argument("--no-fgx", action="store_true", default=False, help="") |
| | parser.add_argument("--no-m", action="store_true", default=False, help="") |
| | parser.add_argument("--no-reg", action="store_true", default=False, help="") |
| | parser.add_argument("--is-ablation", type=bool, default=False, help="") |
| | parser.add_argument("--imagenet-validation-path", type=str, required=True, help="") |
| | args = parser.parse_args() |
| |
|
| | |
| | PATH = os.path.dirname(os.path.abspath(__file__)) + "/" |
| | os.makedirs(os.path.join(PATH, "visualizations"), exist_ok=True) |
| |
|
| | try: |
| | os.remove( |
| | os.path.join( |
| | PATH, |
| | "visualizations/{}/{}/results.hdf5".format(args.method, args.vis_class), |
| | ) |
| | ) |
| | except OSError: |
| | pass |
| |
|
| | os.makedirs( |
| | os.path.join(PATH, "visualizations/{}".format(args.method)), exist_ok=True |
| | ) |
| | if args.vis_class == "index": |
| | os.makedirs( |
| | os.path.join( |
| | PATH, |
| | "visualizations/{}/{}_{}".format( |
| | args.method, args.vis_class, args.class_id |
| | ), |
| | ), |
| | exist_ok=True, |
| | ) |
| | args.method_dir = os.path.join( |
| | PATH, |
| | "visualizations/{}/{}_{}".format( |
| | args.method, args.vis_class, args.class_id |
| | ), |
| | ) |
| | else: |
| | ablation_fold = "ablation" if args.is_ablation else "not_ablation" |
| | os.makedirs( |
| | os.path.join( |
| | PATH, |
| | "visualizations/{}/{}/{}".format( |
| | args.method, args.vis_class, ablation_fold |
| | ), |
| | ), |
| | exist_ok=True, |
| | ) |
| | args.method_dir = os.path.join( |
| | PATH, |
| | "visualizations/{}/{}/{}".format( |
| | args.method, args.vis_class, ablation_fold |
| | ), |
| | ) |
| |
|
| | cuda = torch.cuda.is_available() |
| | device = torch.device("cuda" if cuda else "cpu") |
| |
|
| | |
| | model = vit_base_patch16_224(pretrained=True).cuda() |
| | baselines = Baselines(model) |
| |
|
| | |
| | model_LRP = vit_LRP(pretrained=True).cuda() |
| | model_LRP.eval() |
| | lrp = LRP(model_LRP) |
| |
|
| | |
| | model_orig_LRP = vit_orig_LRP(pretrained=True).cuda() |
| | model_orig_LRP.eval() |
| | orig_lrp = LRP(model_orig_LRP) |
| |
|
| | |
| | transform = transforms.Compose( |
| | [ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | ] |
| | ) |
| |
|
| | imagenet_ds = ImageNet( |
| | args.imagenet_validation_path, split="val", download=False, transform=transform |
| | ) |
| | sample_loader = torch.utils.data.DataLoader( |
| | imagenet_ds, batch_size=args.batch_size, shuffle=False, num_workers=4 |
| | ) |
| |
|
| | compute_saliency_and_save(args) |
| |
|