PAWN / scripts /eval_probes.py
thomas-schweich's picture
Add post-training evals, /dev/shm checkpoints, async HF push, and _orig_mod fix
87b2fa6
#!/usr/bin/env python3
"""Run linear probes on all trained checkpoints and write results to JSON."""
import argparse
import json
import sys
from pathlib import Path
import torch
from pawn.config import CLMConfig
from pawn.model import PAWNCLM
from pawn.eval_suite.probes import extract_probe_data, train_all_probes
def load_model_from_checkpoint(checkpoint_path: str, device: str) -> PAWNCLM:
from pawn.checkpoint import load_backbone_weights
state_dict, model_config = load_backbone_weights(checkpoint_path, device)
if model_config:
cfg = CLMConfig(**model_config)
else:
# Fallback: infer from state dict shapes
d_model = state_dict["embed.src_embed.weight"].shape[1]
n_layers = max(int(k.split(".")[1]) for k in state_dict if k.startswith("layers.")) + 1
if d_model == 256 and n_layers == 8:
cfg = CLMConfig.small()
elif d_model == 512 and n_layers == 8:
cfg = CLMConfig.base()
elif d_model == 640 and n_layers == 10:
cfg = CLMConfig.large()
else:
cfg = CLMConfig(d_model=d_model, n_layers=n_layers)
model = PAWNCLM(cfg).to(device)
model.load_state_dict(state_dict)
model.eval()
return model
def main():
parser = argparse.ArgumentParser(description="Run linear probes on checkpoints")
parser.add_argument("--log-dir", type=str, default="logs", help="Log directory containing run dirs")
parser.add_argument("--n-games", type=int, default=4096, help="Games for probe train set")
parser.add_argument("--n-val-games", type=int, default=1024, help="Games for probe val set")
parser.add_argument("--n-epochs", type=int, default=20, help="Probe training epochs")
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--run", type=str, default=None, help="Only evaluate this run dir name")
args = parser.parse_args()
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
from pawn.gpu import configure_gpu
gpu_cfg = configure_gpu()
import pawn.model as model_module
model_module.SDPA_BACKEND = gpu_cfg.get("sdpa_backend")
log_dir = Path(args.log_dir)
# Find all runs with checkpoints
runs = []
for config_path in sorted(log_dir.glob("run_*/config.json")):
run_dir = config_path.parent
if args.run and run_dir.name != args.run:
continue
# Find checkpoints: directory-based (safetensors) or legacy .pt
checkpoints = sorted(
[d for d in run_dir.glob("checkpoints/step_*") if d.is_dir()]
or list(run_dir.glob("checkpoints/step_*.pt"))
)
if not checkpoints:
continue
latest = checkpoints[-1]
with open(config_path) as f:
cfg = json.load(f)
runs.append((run_dir, latest, cfg))
if not runs:
print("No runs with checkpoints found.")
sys.exit(1)
print(f"Found {len(runs)} runs to evaluate")
# Generate probe data once (shared across all models with same max_ply)
max_ply = 256
print(f"\nGenerating probe data: {args.n_games} train + {args.n_val_games} val games...")
train_data = extract_probe_data(args.n_games, max_ply, seed=12345)
val_data = extract_probe_data(args.n_val_games, max_ply, seed=54321)
print("Done.")
for run_dir, ckpt_path, run_cfg in runs:
model_cfg = run_cfg.get("model", {})
train_cfg = run_cfg.get("training", {})
variant = f"{model_cfg.get('d_model', '?')}d/{model_cfg.get('n_layers', '?')}L"
discard = train_cfg.get("discard_ply_limit", False)
step = ckpt_path.stem.replace("step_", "")
print(f"\n{'='*60}")
print(f"Run: {run_dir.name} ({variant}, discard_ply={discard}, step={step})")
print(f"Checkpoint: {ckpt_path}")
print(f"{'='*60}")
model = load_model_from_checkpoint(str(ckpt_path), device)
results = train_all_probes(
model, train_data, val_data, device,
per_layer=True, n_epochs=args.n_epochs, verbose=True,
)
# Save results
output = {
"run": run_dir.name,
"checkpoint": str(ckpt_path),
"step": int(step),
"variant": variant,
"discard_ply_limit": discard,
"model_config": model_cfg,
"probes": {
pname: {
lname: {k: round(v, 6) if isinstance(v, float) else v for k, v in metrics.items()}
for lname, metrics in layer_results.items()
}
for pname, layer_results in results.items()
},
}
out_path = run_dir / "probe_results.json"
with open(out_path, "w") as f:
json.dump(output, f, indent=2)
print(f"\nSaved: {out_path}")
del model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
if __name__ == "__main__":
main()