|
|
| """
|
| Main training entry point for Vortex models.
|
| """
|
|
|
| import argparse
|
| import sys
|
| from pathlib import Path
|
|
|
| import torch
|
|
|
| from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| from configs.vortex_13b_config import VORTEX_13B_CONFIG
|
| from configs.training_config import TRAINING_CONFIG, TRAINING_CONFIG_7B_CUDA, TRAINING_CONFIG_13B_CUDA, TRAINING_CONFIG_MPS
|
|
|
| from models.vortex_model import VortexModel
|
| from tokenizer.vortex_tokenizer import VortexScienceTokenizer
|
| from training.trainer import VortexTrainer, VortexDataset
|
|
|
|
|
| def parse_args():
|
| parser = argparse.ArgumentParser(description="Train Vortex scientific language model")
|
| parser.add_argument("--model_size", type=str, choices=["7b", "13b"], default="7b",
|
| help="Model size to train")
|
| parser.add_argument("--device", type=str, default="cuda",
|
| choices=["cuda", "mps", "cpu"],
|
| help="Device to train on")
|
| parser.add_argument("--use_mps", action="store_true",
|
| help="Use MPS backend (Apple Silicon)")
|
| parser.add_argument("--data_dir", type=str, default="./data/processed",
|
| help="Directory with processed data shards")
|
| parser.add_argument("--tokenizer_path", type=str, default=None,
|
| help="Path to pretrained tokenizer")
|
| parser.add_argument("--resume_from_checkpoint", type=str, default=None,
|
| help="Resume training from checkpoint")
|
| parser.add_argument("--output_dir", type=str, default="./checkpoints",
|
| help="Output directory for checkpoints")
|
| parser.add_argument("--max_steps", type=int, default=None,
|
| help="Override max training steps")
|
| parser.add_argument("--micro_batch_size", type=int, default=None,
|
| help="Override micro batch size")
|
| parser.add_argument("--quantization", type=str, choices=[None, "int8", "int4"], default=None,
|
| help="Quantization for 13B on 8GB")
|
| return parser.parse_args()
|
|
|
|
|
| def main():
|
| args = parse_args()
|
|
|
|
|
| if args.model_size == "7b":
|
| model_config = VORTEX_7B_CONFIG.copy()
|
| train_config = TRAINING_CONFIG_7B_CUDA.copy()
|
| else:
|
| model_config = VORTEX_13B_CONFIG.copy()
|
| train_config = TRAINING_CONFIG_13B_CUDA.copy()
|
|
|
|
|
| if args.use_mps or args.device == "mps":
|
| train_config = TRAINING_CONFIG_MPS.copy()
|
| train_config["use_mps"] = True
|
|
|
|
|
| if args.max_steps:
|
| train_config["max_steps"] = args.max_steps
|
| if args.micro_batch_size:
|
| train_config["micro_batch_size"] = args.micro_batch_size
|
| if args.quantization:
|
| train_config["quantization"] = args.quantization
|
|
|
|
|
| device = torch.device(args.device)
|
| train_config["device"] = args.device
|
|
|
| print(f"Training Vortex-{args.model_size.upper()}")
|
| print(f"Device: {device}")
|
| print(f"Max steps: {train_config['max_steps']}")
|
| print(f"Micro batch size: {train_config['micro_batch_size']}")
|
|
|
|
|
| print("Loading tokenizer...")
|
| tokenizer = VortexScienceTokenizer(
|
| model_config,
|
| tokenizer_path=args.tokenizer_path,
|
| )
|
| print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
|
|
|
|
|
| print("Creating model...")
|
| model = VortexModel(model_config)
|
| print(f"Model parameters: {model.get_num_params():,}")
|
|
|
|
|
| mem = model.estimate_memory_usage(
|
| train_config["micro_batch_size"],
|
| model_config["max_seq_len"],
|
| )
|
| print("Memory estimate:")
|
| for k, v in mem.items():
|
| print(f" {k}: {v:.2f} GB")
|
|
|
|
|
| print("Loading dataset...")
|
| data_dir = Path(args.data_dir)
|
| shard_files = sorted(list(data_dir.glob("train_*.parquet")))
|
| if not shard_files:
|
| print(f"No training shards found in {data_dir}")
|
| print("Please run data pipeline first.")
|
| sys.exit(1)
|
|
|
| train_dataset = VortexDataset(
|
| shard_files,
|
| tokenizer,
|
| max_seq_len=model_config["max_seq_len"],
|
| )
|
| print(f"Training dataset size: {len(train_dataset)} samples")
|
|
|
|
|
| eval_shard_files = shard_files[:1]
|
| eval_dataset = VortexDataset(
|
| eval_shard_files,
|
| tokenizer,
|
| max_seq_len=model_config["max_seq_len"],
|
| )
|
|
|
|
|
| trainer = VortexTrainer(
|
| model=model,
|
| tokenizer=tokenizer,
|
| train_dataset=train_dataset,
|
| config=train_config,
|
| eval_dataset=eval_dataset,
|
| )
|
|
|
|
|
| if args.resume_from_checkpoint:
|
| trainer.load_checkpoint(args.resume_from_checkpoint)
|
|
|
|
|
| trainer.train()
|
|
|
| print("Training complete!")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|