File size: 2,952 Bytes
31e2456 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | """Evaluate a trained checkpoint on PTB-XL AF + downstream probes.
Loads the model from `--ckpt`, fetches PTB-XL via HF, extracts pooled latents
from the ECG encoder, runs a logistic-regression linear probe, and writes
results JSON.
Used at epoch 25 (K-gate eval) and epoch 100 (final eval).
"""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
import numpy as np
import torch
from dotenv import load_dotenv
load_dotenv()
os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", ""))
import sys
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
from physiojepa.models import MODEL_REGISTRY, ModelConfig
from physiojepa.probe import linear_probe_auroc, pooled_features
def get_ecg_encoder(model_letter: str, model: torch.nn.Module) -> torch.nn.Module:
if model_letter == "A":
return model.ecg
if model_letter == "C":
return model.ecg
return model.bb.ecg
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", required=True)
ap.add_argument("--model", required=True, choices=["A", "B", "C", "F"])
ap.add_argument("--ptbxl_npz", default="/workspace/cache/ptbxl_af.npz")
ap.add_argument("--out", required=True)
args = ap.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sd = torch.load(args.ckpt, map_location=device, weights_only=False)
saved_cfg = sd.get("cfg", {})
# Respect ablation knobs saved in the TrainConfig
cfg = ModelConfig(
pred_depth=saved_cfg.get("pred_depth", 4),
query_mode=saved_cfg.get("query_mode", "learned"),
mask_ratio=saved_cfg.get("mask_ratio", 0.50),
)
print(f"[eval] model cfg: pred_depth={cfg.pred_depth} query_mode={cfg.query_mode} mask_ratio={cfg.mask_ratio}")
model = MODEL_REGISTRY[args.model](cfg)
model.load_state_dict(sd["model"])
model.to(device)
model.train(False)
enc = get_ecg_encoder(args.model, model)
print(f"[eval] loading PTB-XL cache from {args.ptbxl_npz}")
arr = np.load(args.ptbxl_npz)
X, y = arr["X"], arr["y"]
print(f"[eval] X={X.shape} y_pos={int(y.sum())} y_neg={int((1 - y).sum())}")
X_t = torch.from_numpy(X)
feats = pooled_features(enc, X_t, device=device, batch_size=64)
rng = np.random.default_rng(0)
idx = rng.permutation(len(y))
cut = int(len(idx) * 0.8)
train_idx, test_idx = idx[:cut], idx[cut:]
auroc = linear_probe_auroc(feats[train_idx], y[train_idx], feats[test_idx], y[test_idx])
print(f"[eval] AF AUROC = {auroc:.4f}")
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
Path(args.out).write_text(json.dumps({
"ckpt": args.ckpt, "model": args.model, "auroc": auroc,
"n_train": int(cut), "n_test": int(len(idx) - cut),
"n_pos": int(y.sum()), "n_neg": int((1 - y).sum()),
}, indent=2))
if __name__ == "__main__":
main()
|