thyroid-training-scripts / cross_dataset_evaluation.py
Johnyquest7's picture
Upload cross_dataset_evaluation.py
1393c62 verified
#!/usr/bin/env python3
"""
Comprehensive Cross-Dataset Evaluation for Thyroid Ultrasound Model
Computes: Accuracy, Sensitivity, Specificity, PPV, NPV, AUC-ROC, F1
Evaluates on:
1. BTX24 test split (same-dataset validation)
2. joooy94/thyroid_data (cross-dataset validation)
Results pushed to Hugging Face Hub.
"""
import os, sys, json, warnings, traceback
warnings.filterwarnings("ignore")
import numpy as np
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
roc_auc_score, confusion_matrix, precision_recall_fscore_support
)
import torch
import torch.nn.functional as F
from huggingface_hub import HfApi
HF_USERNAME = "Johnyquest7"
MODEL_NAME = f"{HF_USERNAME}/ML-Inter_thyroid"
REPO_ID = f"{HF_USERNAME}/thyroid-training-scripts"
SEED = 42
BATCH_SIZE = 8 # Smaller for CPU compatibility
np.random.seed(SEED)
torch.manual_seed(SEED)
def evaluate_dataset(dataset_name, split_name, label_column, dataset_is_split=True):
"""Evaluate model on a dataset. Returns metrics dict."""
print(f"\n{'='*60}")
print(f"Evaluating on: {dataset_name} | split: {split_name}")
print(f"{'='*60}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Load model once
print(f"Loading model: {MODEL_NAME}")
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device).eval()
id2label = model.config.id2label
print(f"Model classes: {id2label}")
# Load dataset
print(f"Loading dataset: {dataset_name}")
try:
if dataset_is_split:
ds = load_dataset(dataset_name, split=split_name)
else:
ds = load_dataset(dataset_name)
if split_name in ds:
ds = ds[split_name]
else:
ds = ds[list(ds.keys())[0]]
except Exception as e:
print(f"ERROR loading dataset: {e}")
return {"error": str(e)}
print(f"Total samples: {len(ds)}")
# Check if dataset has required columns
if "image" not in ds.column_names:
print(f"ERROR: Dataset missing 'image' column. Available: {ds.column_names}")
return {"error": "Missing image column"}
if label_column not in ds.column_names:
print(f"ERROR: Dataset missing '{label_column}' column. Available: {ds.column_names}")
return {"error": f"Missing {label_column} column"}
# Count labels
labels = [ds[i][label_column] for i in range(min(100, len(ds)))]
unique_labels = sorted(set(labels))
print(f"Label values (first 100): {unique_labels}")
# Map dataset labels to model labels if needed
# Assume 0 = benign, 1 = malignant (standard convention)
# If labels are different, we may need mapping
all_logits, all_labels = [], []
for i in range(0, len(ds), BATCH_SIZE):
batch_items = [ds[j] for j in range(i, min(i+BATCH_SIZE, len(ds)))]
try:
images = []
valid_labels = []
for item in batch_items:
img = item["image"]
if hasattr(img, 'mode'):
img = img.convert("RGB") if img.mode != "RGB" else img
elif hasattr(img, 'convert'):
img = img.convert("RGB")
images.append(img)
valid_labels.append(item[label_column])
inputs = processor(images, return_tensors="pt")
with torch.no_grad():
outputs = model(pixel_values=inputs["pixel_values"].to(device))
all_logits.extend(outputs.logits.cpu().numpy())
all_labels.extend(valid_labels)
except Exception as e:
print(f" Error in batch {i//BATCH_SIZE}: {e}")
continue
if (i // BATCH_SIZE) % 10 == 0:
print(f" Processed {i}/{len(ds)} samples")
print(f"\nTotal evaluated: {len(all_labels)}")
if len(all_labels) == 0:
return {"error": "No samples evaluated"}
y_true = np.array(all_labels)
y_logits = np.array(all_logits)
y_pred = np.argmax(y_logits, axis=1)
probs = F.softmax(torch.from_numpy(y_logits), dim=1).numpy()
# Compute all metrics
acc = accuracy_score(y_true, y_pred)
prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted", zero_division=0)
# Binary metrics
cm = confusion_matrix(y_true, y_pred)
print(f"\nConfusion Matrix:\n{cm}")
# Handle different label conventions
# If dataset uses 0=benign, 1=malignant (same as model)
tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
ppv = tp / (tp + fp) if (tp + fp) > 0 else 0.0
npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
# AUC-ROC
try:
if probs.shape[1] >= 2:
auc = roc_auc_score(y_true, probs[:, 1])
else:
auc = roc_auc_score(y_true, probs[:, 0])
except Exception as e:
print(f"AUC calculation failed: {e}")
auc = 0.0
# Per-class metrics
prec_macro = precision_score(y_true, y_pred, average="macro", zero_division=0)
rec_macro = recall_score(y_true, y_pred, average="macro", zero_division=0)
f1_macro = f1_score(y_true, y_pred, average="macro", zero_division=0)
metrics = {
"dataset": dataset_name,
"split": split_name,
"n_samples": int(len(y_true)),
"accuracy": float(acc),
"weighted_precision": float(prec),
"weighted_recall": float(rec),
"weighted_f1": float(f1),
"macro_precision": float(prec_macro),
"macro_recall": float(rec_macro),
"macro_f1": float(f1_macro),
"sensitivity": float(sensitivity),
"specificity": float(specificity),
"ppv": float(ppv),
"npv": float(npv),
"auc_roc": float(auc),
"confusion_matrix": cm.tolist(),
}
print(f"\n{'='*60}")
print("RESULTS")
print(f"{'='*60}")
for k, v in metrics.items():
if k != "confusion_matrix":
print(f" {k}: {v}")
return metrics
def main():
print("=" * 60)
print("Cross-Dataset Thyroid Model Evaluation")
print("=" * 60)
all_results = {}
# 1. Evaluate on BTX24 test split (our own held-out data)
try:
ds_full = load_dataset("BTX24/thyroid-cancer-classification-ultrasound-dataset", split="train")
ds_full = ds_full.shuffle(seed=SEED)
train_test = ds_full.train_test_split(test_size=0.2, stratify_by_column="label", seed=SEED)
test_ds = train_test["test"]
# Save test_ds as temporary and evaluate
print(f"\nBTX24 Test Split: {len(test_ds)} samples")
metrics_btx24 = evaluate_dataset(
"BTX24/thyroid-cancer-classification-ultrasound-dataset",
"train",
"label",
dataset_is_split=True
)
all_results["BTX24_test_split"] = metrics_btx24
except Exception as e:
print(f"BTX24 evaluation failed: {e}")
traceback.print_exc()
all_results["BTX24_test_split"] = {"error": str(e)}
# 2. Evaluate on joooy94/thyroid_data (cross-dataset)
try:
metrics_cross = evaluate_dataset(
"joooy94/thyroid_data",
"train",
"label",
dataset_is_split=True
)
all_results["joooy94_thyroid_data"] = metrics_cross
except Exception as e:
print(f"joooy94 evaluation failed: {e}")
traceback.print_exc()
all_results["joooy94_thyroid_data"] = {"error": str(e)}
# Save results
print(f"\n{'='*60}")
print("SAVING RESULTS")
print(f"{'='*60}")
results_json = json.dumps(all_results, indent=2)
print(results_json)
# Write to local file
output_path = "/tmp/cross_dataset_metrics.json"
with open(output_path, "w") as f:
f.write(results_json)
print(f"\nSaved to {output_path}")
# Upload to Hub
try:
api = HfApi()
api.upload_file(
path_or_fileobj=output_path,
path_in_file="cross_dataset_metrics.json",
repo_id=REPO_ID,
repo_type="model"
)
print(f"Uploaded to https://huggingface.co/{REPO_ID}/blob/main/cross_dataset_metrics.json")
except Exception as e:
print(f"Upload failed: {e}")
traceback.print_exc()
print("\nDone!")
if __name__ == "__main__":
main()