File size: 5,454 Bytes
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a188746
 
230508d
0fe8a5f
 
 
 
 
 
 
 
 
230508d
 
 
 
 
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a188746
 
d7ecc62
0fe8a5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7ecc62
 
 
230508d
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
"""PAWN training entry point."""

import argparse
import sys

import torch
import torch.multiprocessing as mp

from pawn.config import CLMConfig, TrainingConfig
from pawn.trainer import CLMTrainer


def parse_args():
    parser = argparse.ArgumentParser(description="Train PAWN model")
    parser.add_argument("--variant", type=str, default="base",
                        choices=["small", "base", "large", "toy"],
                        help="Model variant: small (~10M), base (~36M), large (~68M), toy (testing)")
    parser.add_argument("--toy", action="store_true", help="Alias for --variant toy")
    parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume from")
    parser.add_argument("--device", type=str, default=None, help="Device (cuda/cpu)")
    parser.add_argument("--total-steps", type=int, default=None, help="Override total training steps")
    parser.add_argument("--batch-size", type=int, default=None, help="Override batch size")
    parser.add_argument("--num-workers", type=int, default=None, help="Override num data workers")
    parser.add_argument("--accumulation-steps", type=int, default=None, help="Gradient accumulation steps")
    parser.add_argument("--resume-logs", action="append", default=[], metavar="RUN_DIR",
                        help="Run directory whose metrics.jsonl should be spliced into the new log")
    parser.add_argument("--wandb", action="store_true", help="Enable W&B logging")
    parser.add_argument("--checkpoint-dir", type=str, default=None, help="Override checkpoint dir")
    parser.add_argument("--log-dir", type=str, default=None, help="Override log directory")
    parser.add_argument("--discard-ply-limit", action="store_true",
                        help="Only train on games that ended naturally (no ply limit truncation)")

    # Architecture overrides (for sweeps — override the named variant)
    parser.add_argument("--d-model", type=int, default=None, help="Override d_model")
    parser.add_argument("--n-layers", type=int, default=None, help="Override n_layers")
    parser.add_argument("--n-heads", type=int, default=None, help="Override n_heads")
    parser.add_argument("--d-ff", type=int, default=None, help="Override d_ff")
    parser.add_argument("--lr", type=float, default=None, help="Override learning rate")
    parser.add_argument("--weight-decay", type=float, default=None, help="Override weight decay")
    parser.add_argument("--warmup-steps", type=int, default=None, help="Override warmup steps")

    ckpt_group = parser.add_mutually_exclusive_group(required=True)
    ckpt_group.add_argument("--hf-repo", type=str, default=None,
                            help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
    ckpt_group.add_argument("--local-checkpoints", action="store_true",
                            help="Save checkpoints locally only (no HuggingFace push)")
    return parser.parse_args()


def main():
    args = parse_args()

    variant = "toy" if args.toy else args.variant
    variant_factory = {"small": CLMConfig.small, "base": CLMConfig.base,
                       "large": CLMConfig.large, "toy": CLMConfig.toy}
    model_cfg = variant_factory[variant]()
    train_cfg = TrainingConfig.toy() if variant == "toy" else TrainingConfig()

    if args.device:
        train_cfg.device = args.device
    elif not torch.cuda.is_available():
        if variant == "toy":
            train_cfg.device = "cpu"
            print("CUDA not available, falling back to CPU (toy mode)")
        else:
            print("ERROR: CUDA is required for full model training.")
            print("Use --toy for CPU-based testing, or --device cpu to force CPU.")
            sys.exit(1)

    if args.total_steps is not None:
        train_cfg.total_steps = args.total_steps
    if args.batch_size is not None:
        train_cfg.batch_size = args.batch_size
    if args.num_workers is not None:
        train_cfg.num_workers = args.num_workers
    if args.accumulation_steps is not None:
        train_cfg.accumulation_steps = args.accumulation_steps
    if args.wandb:
        train_cfg.use_wandb = True
    if args.checkpoint_dir:
        train_cfg.checkpoint_dir = args.checkpoint_dir
    if args.log_dir:
        train_cfg.log_dir = args.log_dir
    if args.discard_ply_limit:
        train_cfg.discard_ply_limit = True

    # Architecture overrides
    if args.d_model is not None:
        model_cfg.d_model = args.d_model
    if args.n_layers is not None:
        model_cfg.n_layers = args.n_layers
    if args.n_heads is not None:
        model_cfg.n_heads = args.n_heads
    if args.d_ff is not None:
        model_cfg.d_ff = args.d_ff
    if args.lr is not None:
        train_cfg.lr = args.lr
    if args.weight_decay is not None:
        train_cfg.weight_decay = args.weight_decay
    if args.warmup_steps is not None:
        train_cfg.warmup_steps = args.warmup_steps

    print(f"Model config: {model_cfg}")
    print(f"Training config: {train_cfg}")

    trainer = CLMTrainer(train_cfg, model_cfg, hf_repo=args.hf_repo)

    if args.resume:
        trainer.load_checkpoint(args.resume)

    if args.resume_logs:
        trainer.seed_logs(args.resume_logs, trainer.global_step)

    trainer.train()


if __name__ == "__main__":
    try:
        mp.set_start_method("forkserver", force=True)
    except ValueError:
        mp.set_start_method("spawn", force=True)
    main()