Johnyquest7 commited on
Commit
29ab4a8
·
verified ·
1 Parent(s): 5dd37b3

Upload evaluate_and_gradcam.py

Browse files
Files changed (1) hide show
  1. evaluate_and_gradcam.py +198 -0
evaluate_and_gradcam.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Thyroid Ultrasound Evaluation + Grad-CAM Visualization
3
+ Evaluates model on test set and generates attention visualizations.
4
+ """
5
+ import os, sys, io, math, json, random, warnings, base64, traceback
6
+ warnings.filterwarnings("ignore")
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+ import matplotlib
11
+ matplotlib.use("Agg")
12
+ import matplotlib.pyplot as plt
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from datasets import load_dataset
17
+ from transformers import (
18
+ AutoImageProcessor, AutoModelForImageClassification,
19
+ Trainer, TrainingArguments, DefaultDataCollator
20
+ )
21
+ from sklearn.metrics import (
22
+ accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
23
+ )
24
+
25
+ os.environ["TRACKIO_SPACE_ID"] = ""
26
+ os.environ["TRACKIO_PROJECT"] = ""
27
+
28
+ HF_USERNAME = "Johnyquest7"
29
+ DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset"
30
+ MODEL_NAME = f"{HF_USERNAME}/ML-Inter_thyroid"
31
+ OUTPUT_DIR = "./eval_outputs"
32
+ SEED = 42
33
+ MAX_SAMPLES_GRADCAM = 20
34
+
35
+ random.seed(SEED)
36
+ np.random.seed(SEED)
37
+ torch.manual_seed(SEED)
38
+
39
+ def main():
40
+ print("=" * 60)
41
+ print("Thyroid Ultrasound Model Evaluation + Grad-CAM")
42
+ print("=" * 60)
43
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
44
+
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+ print(f"\nDevice: {device}")
47
+ print(f"Loading model: {MODEL_NAME}")
48
+
49
+ try:
50
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
51
+ model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device).eval()
52
+ except Exception as e:
53
+ print(f"Model loading failed: {e}")
54
+ sys.exit(1)
55
+
56
+ print(f"\nLoading dataset: {DATASET_NAME}")
57
+ ds = load_dataset(DATASET_NAME, split="train")
58
+ ds = ds.shuffle(seed=SEED)
59
+ train_test = ds.train_test_split(test_size=0.2, stratify_by_column="label", seed=SEED)
60
+ test_ds = train_test["test"]
61
+ print(f"Test samples: {len(test_ds)}")
62
+
63
+ id2label = model.config.id2label
64
+ label2id = model.config.label2id
65
+
66
+ def transform(examples):
67
+ images = [img.convert("RGB") if img.mode != "RGB" else img for img in examples["image"]]
68
+ return processor(images, return_tensors="pt")
69
+
70
+ test_ds.set_transform(transform)
71
+
72
+ # Evaluate
73
+ print("\nRunning evaluation...")
74
+ args = TrainingArguments(
75
+ output_dir="/tmp/eval", per_device_eval_batch_size=16,
76
+ remove_unused_columns=False, disable_tqdm=True,
77
+ logging_strategy="steps", logging_first_step=True,
78
+ report_to=[]
79
+ )
80
+ trainer = Trainer(model=model, args=args, data_collator=DefaultDataCollator(),
81
+ eval_dataset=test_ds)
82
+ metrics = trainer.evaluate()
83
+ print(f"\nRaw metrics: {metrics}")
84
+
85
+ # Collect predictions
86
+ all_logits, all_labels = [], []
87
+ for i in range(0, len(test_ds), 16):
88
+ batch = test_ds[i:i+16]
89
+ inputs = {k: torch.stack([v for v in batch[k]]).to(device) if isinstance(batch[k][0], torch.Tensor) else None
90
+ for k in batch if k in processor.model_input_names or k == "pixel_values"}
91
+ if "pixel_values" in inputs and inputs["pixel_values"] is not None:
92
+ with torch.no_grad():
93
+ outputs = model(pixel_values=inputs["pixel_values"])
94
+ all_logits.extend(outputs.logits.cpu().numpy())
95
+ all_labels.extend(batch["label"])
96
+
97
+ y_true = np.array(all_labels)
98
+ y_logits = np.array(all_logits)
99
+ y_pred = np.argmax(y_logits, axis=1)
100
+ probs = F.softmax(torch.from_numpy(y_logits), dim=1).numpy()
101
+
102
+ acc = accuracy_score(y_true, y_pred)
103
+ prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted")
104
+ try:
105
+ auc = roc_auc_score(y_true, probs[:, 1])
106
+ except:
107
+ auc = roc_auc_score(y_true, probs[:, 0])
108
+ cm = confusion_matrix(y_true, y_pred)
109
+
110
+ final = {
111
+ "test_accuracy": float(acc),
112
+ "test_weighted_f1": float(f1),
113
+ "test_weighted_precision": float(prec),
114
+ "test_weighted_recall": float(rec),
115
+ "test_roc_auc": float(auc),
116
+ "test_confusion_matrix": cm.tolist(),
117
+ "eval_loss": float(metrics.get("eval_loss", 0)),
118
+ }
119
+ print(f"\n{'='*60}")
120
+ print("FINAL TEST METRICS")
121
+ print(f"{'='*60}")
122
+ for k, v in final.items():
123
+ print(f" {k}: {v}")
124
+ json.dump(final, open(f"{OUTPUT_DIR}/test_metrics.json", "w"), indent=2)
125
+ print(f"\nSaved to {OUTPUT_DIR}/test_metrics.json")
126
+
127
+ # Grad-CAM: collect misclassified and correct
128
+ correct_idx = [i for i in range(len(y_true)) if y_true[i] == y_pred[i]]
129
+ incorrect_idx = [i for i in range(len(y_true)) if y_true[i] != y_pred[i]]
130
+ random.shuffle(correct_idx)
131
+ random.shuffle(incorrect_idx)
132
+ selected = correct_idx[:min(5, len(correct_idx))] + incorrect_idx[:min(5, len(incorrect_idx))]
133
+ print(f"\nGenerating Grad-CAM for {len(selected)} samples ({len(correct_idx[:5])} correct, {len(incorrect_idx[:5])} incorrect)...")
134
+
135
+ # Hook into last stage norm of Swin
136
+ gradcam_data = {}
137
+ def fwd_hook(module, input, output):
138
+ gradcam_data["feat"] = output.detach()
139
+ def bwd_hook(module, grad_input, grad_output):
140
+ gradcam_data["grad"] = grad_output[0].detach()
141
+
142
+ target_layer = model.swinv2.encoder.layers[-1].blocks[-1].layernorm_after
143
+ fwd_handle = target_layer.register_forward_hook(fwd_hook)
144
+ bwd_handle = target_layer.register_full_backward_hook(bwd_hook)
145
+
146
+ for idx in selected[:MAX_SAMPLES_GRADCAM]:
147
+ try:
148
+ sample = test_ds[idx]
149
+ label = sample["label"]
150
+ img_tensor = sample["pixel_values"].unsqueeze(0).to(device).requires_grad_(True)
151
+ model.zero_grad()
152
+ outputs = model(pixel_values=img_tensor)
153
+ target_class = int(y_pred[idx])
154
+ score = outputs.logits[0, target_class]
155
+ score.backward()
156
+
157
+ feat = gradcam_data["feat"][0]
158
+ grads = gradcam_data["grad"][0]
159
+ if feat.dim() == 3: # Swin output (H*W, C)
160
+ weights = grads.mean(dim=0, keepdim=True)
161
+ cam = torch.matmul(feat, weights.t()).squeeze()
162
+ H = W = int(math.sqrt(cam.shape[0]))
163
+ cam = cam.reshape(H, W)
164
+ else:
165
+ weights = grads.mean(dim=(0,1), keepdim=True)
166
+ cam = (feat * weights).sum(dim=-1).squeeze()
167
+
168
+ cam = F.relu(cam)
169
+ cam = cam - cam.min()
170
+ cam = cam / (cam.max() + 1e-8)
171
+ cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(256,256), mode="bilinear", align_corners=False)
172
+ cam = cam.squeeze().cpu().numpy()
173
+
174
+ # Overlay
175
+ img_np = img_tensor.squeeze().detach().cpu().permute(1,2,0).numpy()
176
+ img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
177
+ plt.figure(figsize=(6,6))
178
+ plt.imshow(img_np)
179
+ plt.imshow(cam, cmap="jet", alpha=0.5)
180
+ plt.title(f"Pred: {id2label[target_class]} | True: {id2label[label]}")
181
+ plt.axis("off")
182
+ fname = f"{OUTPUT_DIR}/gradcam_sample_{idx}_pred{id2label[target_class]}_true{id2label[label]}.png"
183
+ plt.savefig(fname, bbox_inches="tight", dpi=150)
184
+ plt.close()
185
+ print(f" Saved {fname}")
186
+ except Exception as e:
187
+ print(f" Skipped sample {idx}: {e}")
188
+ traceback.print_exc()
189
+
190
+ fwd_handle.remove()
191
+ bwd_handle.remove()
192
+
193
+ # Push outputs to Hub as a dataset or files
194
+ print("\nEvaluation complete.")
195
+ print(f"Results saved to {OUTPUT_DIR}/")
196
+
197
+ if __name__ == "__main__":
198
+ main()