| |
| """ |
| Comprehensive Cross-Dataset Evaluation for Thyroid Ultrasound Model |
| Computes: Accuracy, Sensitivity, Specificity, PPV, NPV, AUC-ROC, F1 |
| Evaluates on: |
| 1. BTX24 test split (same-dataset validation) |
| 2. joooy94/thyroid_data (cross-dataset validation) |
| Results pushed to Hugging Face Hub. |
| """ |
| import os, sys, json, warnings, traceback |
| warnings.filterwarnings("ignore") |
|
|
| import numpy as np |
| from datasets import load_dataset |
| from transformers import AutoImageProcessor, AutoModelForImageClassification |
| from sklearn.metrics import ( |
| accuracy_score, precision_score, recall_score, f1_score, |
| roc_auc_score, confusion_matrix, precision_recall_fscore_support |
| ) |
| import torch |
| import torch.nn.functional as F |
| from huggingface_hub import HfApi |
|
|
| HF_USERNAME = "Johnyquest7" |
| MODEL_NAME = f"{HF_USERNAME}/ML-Inter_thyroid" |
| REPO_ID = f"{HF_USERNAME}/thyroid-training-scripts" |
| SEED = 42 |
| BATCH_SIZE = 8 |
|
|
| np.random.seed(SEED) |
| torch.manual_seed(SEED) |
|
|
| def evaluate_dataset(dataset_name, split_name, label_column, dataset_is_split=True): |
| """Evaluate model on a dataset. Returns metrics dict.""" |
| print(f"\n{'='*60}") |
| print(f"Evaluating on: {dataset_name} | split: {split_name}") |
| print(f"{'='*60}") |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
|
|
| |
| print(f"Loading model: {MODEL_NAME}") |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device).eval() |
| id2label = model.config.id2label |
| print(f"Model classes: {id2label}") |
|
|
| |
| print(f"Loading dataset: {dataset_name}") |
| try: |
| if dataset_is_split: |
| ds = load_dataset(dataset_name, split=split_name) |
| else: |
| ds = load_dataset(dataset_name) |
| if split_name in ds: |
| ds = ds[split_name] |
| else: |
| ds = ds[list(ds.keys())[0]] |
| except Exception as e: |
| print(f"ERROR loading dataset: {e}") |
| return {"error": str(e)} |
|
|
| print(f"Total samples: {len(ds)}") |
|
|
| |
| if "image" not in ds.column_names: |
| print(f"ERROR: Dataset missing 'image' column. Available: {ds.column_names}") |
| return {"error": "Missing image column"} |
| if label_column not in ds.column_names: |
| print(f"ERROR: Dataset missing '{label_column}' column. Available: {ds.column_names}") |
| return {"error": f"Missing {label_column} column"} |
|
|
| |
| labels = [ds[i][label_column] for i in range(min(100, len(ds)))] |
| unique_labels = sorted(set(labels)) |
| print(f"Label values (first 100): {unique_labels}") |
|
|
| |
| |
| |
|
|
| all_logits, all_labels = [], [] |
| for i in range(0, len(ds), BATCH_SIZE): |
| batch_items = [ds[j] for j in range(i, min(i+BATCH_SIZE, len(ds)))] |
| try: |
| images = [] |
| valid_labels = [] |
| for item in batch_items: |
| img = item["image"] |
| if hasattr(img, 'mode'): |
| img = img.convert("RGB") if img.mode != "RGB" else img |
| elif hasattr(img, 'convert'): |
| img = img.convert("RGB") |
| images.append(img) |
| valid_labels.append(item[label_column]) |
|
|
| inputs = processor(images, return_tensors="pt") |
| with torch.no_grad(): |
| outputs = model(pixel_values=inputs["pixel_values"].to(device)) |
| all_logits.extend(outputs.logits.cpu().numpy()) |
| all_labels.extend(valid_labels) |
| except Exception as e: |
| print(f" Error in batch {i//BATCH_SIZE}: {e}") |
| continue |
|
|
| if (i // BATCH_SIZE) % 10 == 0: |
| print(f" Processed {i}/{len(ds)} samples") |
|
|
| print(f"\nTotal evaluated: {len(all_labels)}") |
| if len(all_labels) == 0: |
| return {"error": "No samples evaluated"} |
|
|
| y_true = np.array(all_labels) |
| y_logits = np.array(all_logits) |
| y_pred = np.argmax(y_logits, axis=1) |
| probs = F.softmax(torch.from_numpy(y_logits), dim=1).numpy() |
|
|
| |
| acc = accuracy_score(y_true, y_pred) |
| prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted", zero_division=0) |
|
|
| |
| cm = confusion_matrix(y_true, y_pred) |
| print(f"\nConfusion Matrix:\n{cm}") |
|
|
| |
| |
| tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0) |
|
|
| sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0 |
| ppv = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0 |
|
|
| |
| try: |
| if probs.shape[1] >= 2: |
| auc = roc_auc_score(y_true, probs[:, 1]) |
| else: |
| auc = roc_auc_score(y_true, probs[:, 0]) |
| except Exception as e: |
| print(f"AUC calculation failed: {e}") |
| auc = 0.0 |
|
|
| |
| prec_macro = precision_score(y_true, y_pred, average="macro", zero_division=0) |
| rec_macro = recall_score(y_true, y_pred, average="macro", zero_division=0) |
| f1_macro = f1_score(y_true, y_pred, average="macro", zero_division=0) |
|
|
| metrics = { |
| "dataset": dataset_name, |
| "split": split_name, |
| "n_samples": int(len(y_true)), |
| "accuracy": float(acc), |
| "weighted_precision": float(prec), |
| "weighted_recall": float(rec), |
| "weighted_f1": float(f1), |
| "macro_precision": float(prec_macro), |
| "macro_recall": float(rec_macro), |
| "macro_f1": float(f1_macro), |
| "sensitivity": float(sensitivity), |
| "specificity": float(specificity), |
| "ppv": float(ppv), |
| "npv": float(npv), |
| "auc_roc": float(auc), |
| "confusion_matrix": cm.tolist(), |
| } |
|
|
| print(f"\n{'='*60}") |
| print("RESULTS") |
| print(f"{'='*60}") |
| for k, v in metrics.items(): |
| if k != "confusion_matrix": |
| print(f" {k}: {v}") |
|
|
| return metrics |
|
|
| def main(): |
| print("=" * 60) |
| print("Cross-Dataset Thyroid Model Evaluation") |
| print("=" * 60) |
|
|
| all_results = {} |
|
|
| |
| try: |
| ds_full = load_dataset("BTX24/thyroid-cancer-classification-ultrasound-dataset", split="train") |
| ds_full = ds_full.shuffle(seed=SEED) |
| train_test = ds_full.train_test_split(test_size=0.2, stratify_by_column="label", seed=SEED) |
| test_ds = train_test["test"] |
|
|
| |
| print(f"\nBTX24 Test Split: {len(test_ds)} samples") |
| metrics_btx24 = evaluate_dataset( |
| "BTX24/thyroid-cancer-classification-ultrasound-dataset", |
| "train", |
| "label", |
| dataset_is_split=True |
| ) |
| all_results["BTX24_test_split"] = metrics_btx24 |
| except Exception as e: |
| print(f"BTX24 evaluation failed: {e}") |
| traceback.print_exc() |
| all_results["BTX24_test_split"] = {"error": str(e)} |
|
|
| |
| try: |
| metrics_cross = evaluate_dataset( |
| "joooy94/thyroid_data", |
| "train", |
| "label", |
| dataset_is_split=True |
| ) |
| all_results["joooy94_thyroid_data"] = metrics_cross |
| except Exception as e: |
| print(f"joooy94 evaluation failed: {e}") |
| traceback.print_exc() |
| all_results["joooy94_thyroid_data"] = {"error": str(e)} |
|
|
| |
| print(f"\n{'='*60}") |
| print("SAVING RESULTS") |
| print(f"{'='*60}") |
|
|
| results_json = json.dumps(all_results, indent=2) |
| print(results_json) |
|
|
| |
| output_path = "/tmp/cross_dataset_metrics.json" |
| with open(output_path, "w") as f: |
| f.write(results_json) |
| print(f"\nSaved to {output_path}") |
|
|
| |
| try: |
| api = HfApi() |
| api.upload_file( |
| path_or_fileobj=output_path, |
| path_in_file="cross_dataset_metrics.json", |
| repo_id=REPO_ID, |
| repo_type="model" |
| ) |
| print(f"Uploaded to https://huggingface.co/{REPO_ID}/blob/main/cross_dataset_metrics.json") |
| except Exception as e: |
| print(f"Upload failed: {e}") |
| traceback.print_exc() |
|
|
| print("\nDone!") |
|
|
| if __name__ == "__main__": |
| main() |
|
|