File size: 2,952 Bytes
31e2456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Evaluate a trained checkpoint on PTB-XL AF + downstream probes.

Loads the model from `--ckpt`, fetches PTB-XL via HF, extracts pooled latents
from the ECG encoder, runs a logistic-regression linear probe, and writes
results JSON.

Used at epoch 25 (K-gate eval) and epoch 100 (final eval).
"""
from __future__ import annotations

import argparse
import json
import os
from pathlib import Path

import numpy as np
import torch
from dotenv import load_dotenv

load_dotenv()
os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", ""))

import sys
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))

from physiojepa.models import MODEL_REGISTRY, ModelConfig
from physiojepa.probe import linear_probe_auroc, pooled_features


def get_ecg_encoder(model_letter: str, model: torch.nn.Module) -> torch.nn.Module:
    if model_letter == "A":
        return model.ecg
    if model_letter == "C":
        return model.ecg
    return model.bb.ecg


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", required=True)
    ap.add_argument("--model", required=True, choices=["A", "B", "C", "F"])
    ap.add_argument("--ptbxl_npz", default="/workspace/cache/ptbxl_af.npz")
    ap.add_argument("--out", required=True)
    args = ap.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sd = torch.load(args.ckpt, map_location=device, weights_only=False)
    saved_cfg = sd.get("cfg", {})
    # Respect ablation knobs saved in the TrainConfig
    cfg = ModelConfig(
        pred_depth=saved_cfg.get("pred_depth", 4),
        query_mode=saved_cfg.get("query_mode", "learned"),
        mask_ratio=saved_cfg.get("mask_ratio", 0.50),
    )
    print(f"[eval] model cfg: pred_depth={cfg.pred_depth} query_mode={cfg.query_mode} mask_ratio={cfg.mask_ratio}")
    model = MODEL_REGISTRY[args.model](cfg)
    model.load_state_dict(sd["model"])
    model.to(device)
    model.train(False)
    enc = get_ecg_encoder(args.model, model)

    print(f"[eval] loading PTB-XL cache from {args.ptbxl_npz}")
    arr = np.load(args.ptbxl_npz)
    X, y = arr["X"], arr["y"]
    print(f"[eval] X={X.shape} y_pos={int(y.sum())} y_neg={int((1 - y).sum())}")
    X_t = torch.from_numpy(X)
    feats = pooled_features(enc, X_t, device=device, batch_size=64)

    rng = np.random.default_rng(0)
    idx = rng.permutation(len(y))
    cut = int(len(idx) * 0.8)
    train_idx, test_idx = idx[:cut], idx[cut:]
    auroc = linear_probe_auroc(feats[train_idx], y[train_idx], feats[test_idx], y[test_idx])
    print(f"[eval] AF AUROC = {auroc:.4f}")
    Path(args.out).parent.mkdir(parents=True, exist_ok=True)
    Path(args.out).write_text(json.dumps({
        "ckpt": args.ckpt, "model": args.model, "auroc": auroc,
        "n_train": int(cut), "n_test": int(len(idx) - cut),
        "n_pos": int(y.sum()), "n_neg": int((1 - y).sum()),
    }, indent=2))


if __name__ == "__main__":
    main()