| |
| """Run linear probes on all trained checkpoints and write results to JSON.""" |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| import torch |
|
|
| from pawn.config import CLMConfig |
| from pawn.model import PAWNCLM |
| from pawn.eval_suite.probes import extract_probe_data, train_all_probes |
|
|
|
|
| def load_model_from_checkpoint(checkpoint_path: str, device: str) -> PAWNCLM: |
| from pawn.checkpoint import load_backbone_weights |
| state_dict, model_config = load_backbone_weights(checkpoint_path, device) |
| if model_config: |
| cfg = CLMConfig(**model_config) |
| else: |
| |
| d_model = state_dict["embed.src_embed.weight"].shape[1] |
| n_layers = max(int(k.split(".")[1]) for k in state_dict if k.startswith("layers.")) + 1 |
| if d_model == 256 and n_layers == 8: |
| cfg = CLMConfig.small() |
| elif d_model == 512 and n_layers == 8: |
| cfg = CLMConfig.base() |
| elif d_model == 640 and n_layers == 10: |
| cfg = CLMConfig.large() |
| else: |
| cfg = CLMConfig(d_model=d_model, n_layers=n_layers) |
| model = PAWNCLM(cfg).to(device) |
| model.load_state_dict(state_dict) |
| model.eval() |
| return model |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Run linear probes on checkpoints") |
| parser.add_argument("--log-dir", type=str, default="logs", help="Log directory containing run dirs") |
| parser.add_argument("--n-games", type=int, default=4096, help="Games for probe train set") |
| parser.add_argument("--n-val-games", type=int, default=1024, help="Games for probe val set") |
| parser.add_argument("--n-epochs", type=int, default=20, help="Probe training epochs") |
| parser.add_argument("--device", type=str, default=None) |
| parser.add_argument("--run", type=str, default=None, help="Only evaluate this run dir name") |
| args = parser.parse_args() |
|
|
| device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") |
| if device == "cuda": |
| from pawn.gpu import configure_gpu |
| gpu_cfg = configure_gpu() |
| import pawn.model as model_module |
| model_module.SDPA_BACKEND = gpu_cfg.get("sdpa_backend") |
|
|
| log_dir = Path(args.log_dir) |
|
|
| |
| runs = [] |
| for config_path in sorted(log_dir.glob("run_*/config.json")): |
| run_dir = config_path.parent |
| if args.run and run_dir.name != args.run: |
| continue |
| |
| checkpoints = sorted( |
| [d for d in run_dir.glob("checkpoints/step_*") if d.is_dir()] |
| or list(run_dir.glob("checkpoints/step_*.pt")) |
| ) |
| if not checkpoints: |
| continue |
| latest = checkpoints[-1] |
| with open(config_path) as f: |
| cfg = json.load(f) |
| runs.append((run_dir, latest, cfg)) |
|
|
| if not runs: |
| print("No runs with checkpoints found.") |
| sys.exit(1) |
|
|
| print(f"Found {len(runs)} runs to evaluate") |
|
|
| |
| max_ply = 256 |
| print(f"\nGenerating probe data: {args.n_games} train + {args.n_val_games} val games...") |
| train_data = extract_probe_data(args.n_games, max_ply, seed=12345) |
| val_data = extract_probe_data(args.n_val_games, max_ply, seed=54321) |
| print("Done.") |
|
|
| for run_dir, ckpt_path, run_cfg in runs: |
| model_cfg = run_cfg.get("model", {}) |
| train_cfg = run_cfg.get("training", {}) |
| variant = f"{model_cfg.get('d_model', '?')}d/{model_cfg.get('n_layers', '?')}L" |
| discard = train_cfg.get("discard_ply_limit", False) |
| step = ckpt_path.stem.replace("step_", "") |
|
|
| print(f"\n{'='*60}") |
| print(f"Run: {run_dir.name} ({variant}, discard_ply={discard}, step={step})") |
| print(f"Checkpoint: {ckpt_path}") |
| print(f"{'='*60}") |
|
|
| model = load_model_from_checkpoint(str(ckpt_path), device) |
|
|
| results = train_all_probes( |
| model, train_data, val_data, device, |
| per_layer=True, n_epochs=args.n_epochs, verbose=True, |
| ) |
|
|
| |
| output = { |
| "run": run_dir.name, |
| "checkpoint": str(ckpt_path), |
| "step": int(step), |
| "variant": variant, |
| "discard_ply_limit": discard, |
| "model_config": model_cfg, |
| "probes": { |
| pname: { |
| lname: {k: round(v, 6) if isinstance(v, float) else v for k, v in metrics.items()} |
| for lname, metrics in layer_results.items() |
| } |
| for pname, layer_results in results.items() |
| }, |
| } |
|
|
| out_path = run_dir / "probe_results.json" |
| with open(out_path, "w") as f: |
| json.dump(output, f, indent=2) |
| print(f"\nSaved: {out_path}") |
|
|
| del model |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|