File size: 5,454 Bytes
d7ecc62 a188746 230508d 0fe8a5f 230508d d7ecc62 a188746 d7ecc62 0fe8a5f d7ecc62 230508d d7ecc62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | #!/usr/bin/env python3
"""PAWN training entry point."""
import argparse
import sys
import torch
import torch.multiprocessing as mp
from pawn.config import CLMConfig, TrainingConfig
from pawn.trainer import CLMTrainer
def parse_args():
parser = argparse.ArgumentParser(description="Train PAWN model")
parser.add_argument("--variant", type=str, default="base",
choices=["small", "base", "large", "toy"],
help="Model variant: small (~10M), base (~36M), large (~68M), toy (testing)")
parser.add_argument("--toy", action="store_true", help="Alias for --variant toy")
parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume from")
parser.add_argument("--device", type=str, default=None, help="Device (cuda/cpu)")
parser.add_argument("--total-steps", type=int, default=None, help="Override total training steps")
parser.add_argument("--batch-size", type=int, default=None, help="Override batch size")
parser.add_argument("--num-workers", type=int, default=None, help="Override num data workers")
parser.add_argument("--accumulation-steps", type=int, default=None, help="Gradient accumulation steps")
parser.add_argument("--resume-logs", action="append", default=[], metavar="RUN_DIR",
help="Run directory whose metrics.jsonl should be spliced into the new log")
parser.add_argument("--wandb", action="store_true", help="Enable W&B logging")
parser.add_argument("--checkpoint-dir", type=str, default=None, help="Override checkpoint dir")
parser.add_argument("--log-dir", type=str, default=None, help="Override log directory")
parser.add_argument("--discard-ply-limit", action="store_true",
help="Only train on games that ended naturally (no ply limit truncation)")
# Architecture overrides (for sweeps — override the named variant)
parser.add_argument("--d-model", type=int, default=None, help="Override d_model")
parser.add_argument("--n-layers", type=int, default=None, help="Override n_layers")
parser.add_argument("--n-heads", type=int, default=None, help="Override n_heads")
parser.add_argument("--d-ff", type=int, default=None, help="Override d_ff")
parser.add_argument("--lr", type=float, default=None, help="Override learning rate")
parser.add_argument("--weight-decay", type=float, default=None, help="Override weight decay")
parser.add_argument("--warmup-steps", type=int, default=None, help="Override warmup steps")
ckpt_group = parser.add_mutually_exclusive_group(required=True)
ckpt_group.add_argument("--hf-repo", type=str, default=None,
help="Push checkpoints to this HuggingFace repo (requires HF_TOKEN)")
ckpt_group.add_argument("--local-checkpoints", action="store_true",
help="Save checkpoints locally only (no HuggingFace push)")
return parser.parse_args()
def main():
args = parse_args()
variant = "toy" if args.toy else args.variant
variant_factory = {"small": CLMConfig.small, "base": CLMConfig.base,
"large": CLMConfig.large, "toy": CLMConfig.toy}
model_cfg = variant_factory[variant]()
train_cfg = TrainingConfig.toy() if variant == "toy" else TrainingConfig()
if args.device:
train_cfg.device = args.device
elif not torch.cuda.is_available():
if variant == "toy":
train_cfg.device = "cpu"
print("CUDA not available, falling back to CPU (toy mode)")
else:
print("ERROR: CUDA is required for full model training.")
print("Use --toy for CPU-based testing, or --device cpu to force CPU.")
sys.exit(1)
if args.total_steps is not None:
train_cfg.total_steps = args.total_steps
if args.batch_size is not None:
train_cfg.batch_size = args.batch_size
if args.num_workers is not None:
train_cfg.num_workers = args.num_workers
if args.accumulation_steps is not None:
train_cfg.accumulation_steps = args.accumulation_steps
if args.wandb:
train_cfg.use_wandb = True
if args.checkpoint_dir:
train_cfg.checkpoint_dir = args.checkpoint_dir
if args.log_dir:
train_cfg.log_dir = args.log_dir
if args.discard_ply_limit:
train_cfg.discard_ply_limit = True
# Architecture overrides
if args.d_model is not None:
model_cfg.d_model = args.d_model
if args.n_layers is not None:
model_cfg.n_layers = args.n_layers
if args.n_heads is not None:
model_cfg.n_heads = args.n_heads
if args.d_ff is not None:
model_cfg.d_ff = args.d_ff
if args.lr is not None:
train_cfg.lr = args.lr
if args.weight_decay is not None:
train_cfg.weight_decay = args.weight_decay
if args.warmup_steps is not None:
train_cfg.warmup_steps = args.warmup_steps
print(f"Model config: {model_cfg}")
print(f"Training config: {train_cfg}")
trainer = CLMTrainer(train_cfg, model_cfg, hf_repo=args.hf_repo)
if args.resume:
trainer.load_checkpoint(args.resume)
if args.resume_logs:
trainer.seed_logs(args.resume_logs, trainer.global_step)
trainer.train()
if __name__ == "__main__":
try:
mp.set_start_method("forkserver", force=True)
except ValueError:
mp.set_start_method("spawn", force=True)
main()
|