| """Evaluate a trained checkpoint on PTB-XL AF + downstream probes. |
| |
| Loads the model from `--ckpt`, fetches PTB-XL via HF, extracts pooled latents |
| from the ECG encoder, runs a logistic-regression linear probe, and writes |
| results JSON. |
| |
| Used at epoch 25 (K-gate eval) and epoch 100 (final eval). |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
| os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) |
|
|
| import sys |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) |
|
|
| from physiojepa.models import MODEL_REGISTRY, ModelConfig |
| from physiojepa.probe import linear_probe_auroc, pooled_features |
|
|
|
|
| def get_ecg_encoder(model_letter: str, model: torch.nn.Module) -> torch.nn.Module: |
| if model_letter == "A": |
| return model.ecg |
| if model_letter == "C": |
| return model.ecg |
| return model.bb.ecg |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--ckpt", required=True) |
| ap.add_argument("--model", required=True, choices=["A", "B", "C", "F"]) |
| ap.add_argument("--ptbxl_npz", default="/workspace/cache/ptbxl_af.npz") |
| ap.add_argument("--out", required=True) |
| args = ap.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| sd = torch.load(args.ckpt, map_location=device, weights_only=False) |
| saved_cfg = sd.get("cfg", {}) |
| |
| cfg = ModelConfig( |
| pred_depth=saved_cfg.get("pred_depth", 4), |
| query_mode=saved_cfg.get("query_mode", "learned"), |
| mask_ratio=saved_cfg.get("mask_ratio", 0.50), |
| ) |
| print(f"[eval] model cfg: pred_depth={cfg.pred_depth} query_mode={cfg.query_mode} mask_ratio={cfg.mask_ratio}") |
| model = MODEL_REGISTRY[args.model](cfg) |
| model.load_state_dict(sd["model"]) |
| model.to(device) |
| model.train(False) |
| enc = get_ecg_encoder(args.model, model) |
|
|
| print(f"[eval] loading PTB-XL cache from {args.ptbxl_npz}") |
| arr = np.load(args.ptbxl_npz) |
| X, y = arr["X"], arr["y"] |
| print(f"[eval] X={X.shape} y_pos={int(y.sum())} y_neg={int((1 - y).sum())}") |
| X_t = torch.from_numpy(X) |
| feats = pooled_features(enc, X_t, device=device, batch_size=64) |
|
|
| rng = np.random.default_rng(0) |
| idx = rng.permutation(len(y)) |
| cut = int(len(idx) * 0.8) |
| train_idx, test_idx = idx[:cut], idx[cut:] |
| auroc = linear_probe_auroc(feats[train_idx], y[train_idx], feats[test_idx], y[test_idx]) |
| print(f"[eval] AF AUROC = {auroc:.4f}") |
| Path(args.out).parent.mkdir(parents=True, exist_ok=True) |
| Path(args.out).write_text(json.dumps({ |
| "ckpt": args.ckpt, "model": args.model, "auroc": auroc, |
| "n_train": int(cut), "n_test": int(len(idx) - cut), |
| "n_pos": int(y.sum()), "n_neg": int((1 - y).sum()), |
| }, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|