Johnyquest7 commited on
Commit
1393c62
·
verified ·
1 Parent(s): 0a9bbd2

Upload cross_dataset_evaluation.py

Browse files
Files changed (1) hide show
  1. cross_dataset_evaluation.py +252 -0
cross_dataset_evaluation.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Comprehensive Cross-Dataset Evaluation for Thyroid Ultrasound Model
4
+ Computes: Accuracy, Sensitivity, Specificity, PPV, NPV, AUC-ROC, F1
5
+ Evaluates on:
6
+ 1. BTX24 test split (same-dataset validation)
7
+ 2. joooy94/thyroid_data (cross-dataset validation)
8
+ Results pushed to Hugging Face Hub.
9
+ """
10
+ import os, sys, json, warnings, traceback
11
+ warnings.filterwarnings("ignore")
12
+
13
+ import numpy as np
14
+ from datasets import load_dataset
15
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
16
+ from sklearn.metrics import (
17
+ accuracy_score, precision_score, recall_score, f1_score,
18
+ roc_auc_score, confusion_matrix, precision_recall_fscore_support
19
+ )
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from huggingface_hub import HfApi
23
+
24
+ HF_USERNAME = "Johnyquest7"
25
+ MODEL_NAME = f"{HF_USERNAME}/ML-Inter_thyroid"
26
+ REPO_ID = f"{HF_USERNAME}/thyroid-training-scripts"
27
+ SEED = 42
28
+ BATCH_SIZE = 8 # Smaller for CPU compatibility
29
+
30
+ np.random.seed(SEED)
31
+ torch.manual_seed(SEED)
32
+
33
+ def evaluate_dataset(dataset_name, split_name, label_column, dataset_is_split=True):
34
+ """Evaluate model on a dataset. Returns metrics dict."""
35
+ print(f"\n{'='*60}")
36
+ print(f"Evaluating on: {dataset_name} | split: {split_name}")
37
+ print(f"{'='*60}")
38
+
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ print(f"Device: {device}")
41
+
42
+ # Load model once
43
+ print(f"Loading model: {MODEL_NAME}")
44
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
45
+ model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device).eval()
46
+ id2label = model.config.id2label
47
+ print(f"Model classes: {id2label}")
48
+
49
+ # Load dataset
50
+ print(f"Loading dataset: {dataset_name}")
51
+ try:
52
+ if dataset_is_split:
53
+ ds = load_dataset(dataset_name, split=split_name)
54
+ else:
55
+ ds = load_dataset(dataset_name)
56
+ if split_name in ds:
57
+ ds = ds[split_name]
58
+ else:
59
+ ds = ds[list(ds.keys())[0]]
60
+ except Exception as e:
61
+ print(f"ERROR loading dataset: {e}")
62
+ return {"error": str(e)}
63
+
64
+ print(f"Total samples: {len(ds)}")
65
+
66
+ # Check if dataset has required columns
67
+ if "image" not in ds.column_names:
68
+ print(f"ERROR: Dataset missing 'image' column. Available: {ds.column_names}")
69
+ return {"error": "Missing image column"}
70
+ if label_column not in ds.column_names:
71
+ print(f"ERROR: Dataset missing '{label_column}' column. Available: {ds.column_names}")
72
+ return {"error": f"Missing {label_column} column"}
73
+
74
+ # Count labels
75
+ labels = [ds[i][label_column] for i in range(min(100, len(ds)))]
76
+ unique_labels = sorted(set(labels))
77
+ print(f"Label values (first 100): {unique_labels}")
78
+
79
+ # Map dataset labels to model labels if needed
80
+ # Assume 0 = benign, 1 = malignant (standard convention)
81
+ # If labels are different, we may need mapping
82
+
83
+ all_logits, all_labels = [], []
84
+ for i in range(0, len(ds), BATCH_SIZE):
85
+ batch_items = [ds[j] for j in range(i, min(i+BATCH_SIZE, len(ds)))]
86
+ try:
87
+ images = []
88
+ valid_labels = []
89
+ for item in batch_items:
90
+ img = item["image"]
91
+ if hasattr(img, 'mode'):
92
+ img = img.convert("RGB") if img.mode != "RGB" else img
93
+ elif hasattr(img, 'convert'):
94
+ img = img.convert("RGB")
95
+ images.append(img)
96
+ valid_labels.append(item[label_column])
97
+
98
+ inputs = processor(images, return_tensors="pt")
99
+ with torch.no_grad():
100
+ outputs = model(pixel_values=inputs["pixel_values"].to(device))
101
+ all_logits.extend(outputs.logits.cpu().numpy())
102
+ all_labels.extend(valid_labels)
103
+ except Exception as e:
104
+ print(f" Error in batch {i//BATCH_SIZE}: {e}")
105
+ continue
106
+
107
+ if (i // BATCH_SIZE) % 10 == 0:
108
+ print(f" Processed {i}/{len(ds)} samples")
109
+
110
+ print(f"\nTotal evaluated: {len(all_labels)}")
111
+ if len(all_labels) == 0:
112
+ return {"error": "No samples evaluated"}
113
+
114
+ y_true = np.array(all_labels)
115
+ y_logits = np.array(all_logits)
116
+ y_pred = np.argmax(y_logits, axis=1)
117
+ probs = F.softmax(torch.from_numpy(y_logits), dim=1).numpy()
118
+
119
+ # Compute all metrics
120
+ acc = accuracy_score(y_true, y_pred)
121
+ prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted", zero_division=0)
122
+
123
+ # Binary metrics
124
+ cm = confusion_matrix(y_true, y_pred)
125
+ print(f"\nConfusion Matrix:\n{cm}")
126
+
127
+ # Handle different label conventions
128
+ # If dataset uses 0=benign, 1=malignant (same as model)
129
+ tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
130
+
131
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
132
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
133
+ ppv = tp / (tp + fp) if (tp + fp) > 0 else 0.0
134
+ npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
135
+
136
+ # AUC-ROC
137
+ try:
138
+ if probs.shape[1] >= 2:
139
+ auc = roc_auc_score(y_true, probs[:, 1])
140
+ else:
141
+ auc = roc_auc_score(y_true, probs[:, 0])
142
+ except Exception as e:
143
+ print(f"AUC calculation failed: {e}")
144
+ auc = 0.0
145
+
146
+ # Per-class metrics
147
+ prec_macro = precision_score(y_true, y_pred, average="macro", zero_division=0)
148
+ rec_macro = recall_score(y_true, y_pred, average="macro", zero_division=0)
149
+ f1_macro = f1_score(y_true, y_pred, average="macro", zero_division=0)
150
+
151
+ metrics = {
152
+ "dataset": dataset_name,
153
+ "split": split_name,
154
+ "n_samples": int(len(y_true)),
155
+ "accuracy": float(acc),
156
+ "weighted_precision": float(prec),
157
+ "weighted_recall": float(rec),
158
+ "weighted_f1": float(f1),
159
+ "macro_precision": float(prec_macro),
160
+ "macro_recall": float(rec_macro),
161
+ "macro_f1": float(f1_macro),
162
+ "sensitivity": float(sensitivity),
163
+ "specificity": float(specificity),
164
+ "ppv": float(ppv),
165
+ "npv": float(npv),
166
+ "auc_roc": float(auc),
167
+ "confusion_matrix": cm.tolist(),
168
+ }
169
+
170
+ print(f"\n{'='*60}")
171
+ print("RESULTS")
172
+ print(f"{'='*60}")
173
+ for k, v in metrics.items():
174
+ if k != "confusion_matrix":
175
+ print(f" {k}: {v}")
176
+
177
+ return metrics
178
+
179
+ def main():
180
+ print("=" * 60)
181
+ print("Cross-Dataset Thyroid Model Evaluation")
182
+ print("=" * 60)
183
+
184
+ all_results = {}
185
+
186
+ # 1. Evaluate on BTX24 test split (our own held-out data)
187
+ try:
188
+ ds_full = load_dataset("BTX24/thyroid-cancer-classification-ultrasound-dataset", split="train")
189
+ ds_full = ds_full.shuffle(seed=SEED)
190
+ train_test = ds_full.train_test_split(test_size=0.2, stratify_by_column="label", seed=SEED)
191
+ test_ds = train_test["test"]
192
+
193
+ # Save test_ds as temporary and evaluate
194
+ print(f"\nBTX24 Test Split: {len(test_ds)} samples")
195
+ metrics_btx24 = evaluate_dataset(
196
+ "BTX24/thyroid-cancer-classification-ultrasound-dataset",
197
+ "train",
198
+ "label",
199
+ dataset_is_split=True
200
+ )
201
+ all_results["BTX24_test_split"] = metrics_btx24
202
+ except Exception as e:
203
+ print(f"BTX24 evaluation failed: {e}")
204
+ traceback.print_exc()
205
+ all_results["BTX24_test_split"] = {"error": str(e)}
206
+
207
+ # 2. Evaluate on joooy94/thyroid_data (cross-dataset)
208
+ try:
209
+ metrics_cross = evaluate_dataset(
210
+ "joooy94/thyroid_data",
211
+ "train",
212
+ "label",
213
+ dataset_is_split=True
214
+ )
215
+ all_results["joooy94_thyroid_data"] = metrics_cross
216
+ except Exception as e:
217
+ print(f"joooy94 evaluation failed: {e}")
218
+ traceback.print_exc()
219
+ all_results["joooy94_thyroid_data"] = {"error": str(e)}
220
+
221
+ # Save results
222
+ print(f"\n{'='*60}")
223
+ print("SAVING RESULTS")
224
+ print(f"{'='*60}")
225
+
226
+ results_json = json.dumps(all_results, indent=2)
227
+ print(results_json)
228
+
229
+ # Write to local file
230
+ output_path = "/tmp/cross_dataset_metrics.json"
231
+ with open(output_path, "w") as f:
232
+ f.write(results_json)
233
+ print(f"\nSaved to {output_path}")
234
+
235
+ # Upload to Hub
236
+ try:
237
+ api = HfApi()
238
+ api.upload_file(
239
+ path_or_fileobj=output_path,
240
+ path_in_file="cross_dataset_metrics.json",
241
+ repo_id=REPO_ID,
242
+ repo_type="model"
243
+ )
244
+ print(f"Uploaded to https://huggingface.co/{REPO_ID}/blob/main/cross_dataset_metrics.json")
245
+ except Exception as e:
246
+ print(f"Upload failed: {e}")
247
+ traceback.print_exc()
248
+
249
+ print("\nDone!")
250
+
251
+ if __name__ == "__main__":
252
+ main()