import argparse import json from pathlib import Path import torch import torch.nn as nn import torchvision.transforms.functional as TF from PIL import Image from torchcam.methods import GradCAM from torchcam.utils import overlay_mask from torchvision import models as tvm from torchvision import transforms from src.train import SmallCNN, get_device def build_argparser(): p = argparse.ArgumentParser(description="Grad-CAM explanations") p.add_argument("--ckpt", type=str, required=True, help="Path to best.ckpt") p.add_argument("--image", type=str, required=True, help="Path to an input image") p.add_argument( "--dataset", choices=["fashion-mnist", "mnist", "cifar10"], default="fashion-mnist", help="Used to apply the right normalization and class names", ) p.add_argument( "--target-layer", type=str, default="conv2", help="Layer to attach CAMs (e.g., 'conv2' for SmallCNN, 'layer4' for ResNet)", ) p.add_argument( "--outdir", type=str, default=None, help="Where to store results; defaults near the checkpoint", ) p.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto") p.add_argument("--topk", type=int, default=3, help="How many top classes to render") return p def get_transforms_from_meta(meta): img_size = int(meta.get("img_size", 28)) mean = meta.get("mean", [0.2860]) # fallback FMNIST std = meta.get("std", [0.3530]) # channels: grayscale if mean/std length==1, else RGB if len(mean) == 1: tf = transforms.Compose( [ transforms.Grayscale(num_output_channels=1), transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean, std), ] ) else: tf = transforms.Compose( [ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean, std), ] ) return tf def denorm_to_pil(x: torch.Tensor, mean, std) -> Image.Image: """ x: normalized tensor CxHxW mean/std: list(s) from meta returns: PIL RGB image for overlay """ x = x.detach().cpu().clone() if len(mean) == 1: # grayscale m, s = float(mean[0]), float(std[0]) x = x * s + m x = x.clamp(0, 1) pil = transforms.ToPILImage()(x) # grayscale PIL return pil.convert("RGB") else: # RGB mean_t = torch.tensor(mean)[:, None, None] std_t = torch.tensor(std)[:, None, None] x = x * std_t + mean_t x = x.clamp(0, 1) return transforms.ToPILImage()(x) def load_model(ckpt_path, device): ckpt = torch.load(ckpt_path, map_location=device) classes = ckpt.get("classes", None) meta = ckpt.get("meta", {}) num_classes = len(classes) if classes else 10 model_name = meta.get("model_name", "smallcnn") if model_name == "smallcnn": model = SmallCNN(num_classes=num_classes).to(device) elif model_name == "resnet18_cifar": m = tvm.resnet18(weights=None) m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) m.maxpool = nn.Identity() m.fc = nn.Linear(m.fc.in_features, num_classes) model = m.to(device) elif model_name == "resnet18_imagenet": try: w = tvm.ResNet18_Weights.IMAGENET1K_V1 except Exception: w = None m = tvm.resnet18(weights=w) m.fc = nn.Linear(m.fc.in_features, num_classes) model = m.to(device) else: raise ValueError(f"Unknown model in ckpt: {model_name}") model.load_state_dict(ckpt["model_state"]) model.eval() return model, classes, meta def run_gradcam( model, target_layer, img_tensor, device, classes, outdir: Path, topk=3, base_pil_rgb: Image.Image = None, ): """ img_tensor: CxHxW normalized (not batched) base_pil_rgb: PIL image already denormalized & RGB for overlay (optional). If None, will min-max scale from img_tensor (last-resort). """ model.eval() x = img_tensor.to(device).unsqueeze(0) # [1,C,H,W] H, W = img_tensor.shape[-2:] cam_extractor = GradCAM(model, target_layer=target_layer) # forward once to get top-k logits = model(x) probs = torch.softmax(logits, dim=1)[0].detach().cpu() top_vals, top_idxs = probs.topk(topk) if base_pil_rgb is None: # Fallback: simple min-max scaling (works but less faithful than denorm) xx = img_tensor.detach().cpu() xx = (xx - xx.min()) / (xx.max() - xx.min() + 1e-8) base_pil_rgb = transforms.ToPILImage()(xx) if xx.shape[0] == 1: base_pil_rgb = base_pil_rgb.convert("RGB") results = [] for rank, (score, cls_idx) in enumerate(zip(top_vals.tolist(), top_idxs.tolist())): retain = rank < topk - 1 cams = cam_extractor(int(cls_idx), logits, retain_graph=retain) cam = cams[0].detach().cpu() # [h,w] cam_up = TF.resize(cam.unsqueeze(0), size=[H, W])[0] # upsample to input size heat = transforms.ToPILImage()(cam_up) overlay = overlay_mask(base_pil_rgb, heat, alpha=0.6) out_png = ( outdir / f"gradcam_top{rank+1}_class{cls_idx}_" + f"{classes[cls_idx] if classes else cls_idx}.png" ) overlay.save(out_png) results.append( { "rank": rank + 1, "class_index": int(cls_idx), "class_name": classes[cls_idx] if classes else str(cls_idx), "prob": float(score), "file": str(out_png), } ) with open(outdir / "summary.json", "w") as f: json.dump({"topk": results}, f, indent=2) print("Saved:", outdir) return results def main(): args = build_argparser().parse_args() device = get_device(args.device) ckpt_path = Path(args.ckpt) # outdir default if args.outdir is None: run_id = ckpt_path.parent.name outdir = ckpt_path.parent.parent.parent / "reports" / run_id / "explain" else: outdir = Path(args.outdir) outdir.mkdir(parents=True, exist_ok=True) # 1) load model+meta first model, classes, meta = load_model(str(ckpt_path), device) # 2) build tf from meta tf = get_transforms_from_meta(meta) # 3) load and transform image pil = Image.open(args.image).convert("RGB") x = tf(pil) # CxHxW normalized # 4) make a denormalized RGB base image for overlay base_pil = denorm_to_pil(x, meta.get("mean", [0.2860]), meta.get("std", [0.3530])) # 5) target layer (CLI overrides meta default) target_layer = args.target_layer or meta.get("default_target_layer", "conv2") # 6) run Grad-CAM results = run_gradcam( model, target_layer, x, device, classes, outdir, topk=args.topk, base_pil_rgb=base_pil, ) # 7) print summary for r in results: print(f"Top{r['rank']}: {r['class_name']} ({r['prob']:.3f}) -> {r['file']}") if __name__ == "__main__": main()