"""Pre-train Bee from scratch on a text corpus (e.g. TinyStories, OpenWebText).""" import argparse import logging import os import sys from pathlib import Path import torch from datasets import load_dataset from transformers import ( AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling, set_seed, ) # Ensure bee is discoverable sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from bee.register import register from bee.config import BeeConfig from bee.modeling_bee import BeeForCausalLM register() logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", ) logger = logging.getLogger("bee.pretrain") def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Pre-train Bee from scratch") parser.add_argument("--dataset", type=str, default="roneneldan/TinyStories", help="HF dataset name") parser.add_argument("--dataset_text_field", type=str, default="text", help="Text column name") parser.add_argument("--output_dir", type=str, required=True, help="Where to save checkpoints") parser.add_argument("--tokenizer_name", type=str, default="HuggingFaceTB/SmolLM2-135M", help="Tokenizer to use") parser.add_argument("--vocab_size", type=int, default=49152) parser.add_argument("--hidden_size", type=int, default=768) parser.add_argument("--num_layers", type=int, default=12) parser.add_argument("--num_heads", type=int, default=12) parser.add_argument("--intermediate_size", type=int, default=1536) parser.add_argument("--max_seq_length", type=int, default=2048) parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--gradient_accumulation_steps", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--num_train_epochs", type=int, default=3) parser.add_argument("--warmup_steps", type=int, default=1000) parser.add_argument("--save_steps", type=int, default=2000) parser.add_argument("--eval_steps", type=int, default=2000) parser.add_argument("--logging_steps", type=int, default=100) parser.add_argument("--bf16", action="store_true", default=True) parser.add_argument("--fp16", action="store_true", default=False) parser.add_argument("--gradient_checkpointing", action="store_true", default=True) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--push_to_hub", action="store_true", default=False) parser.add_argument("--hub_model_id", type=str, default=None) return parser.parse_args() def main(): args = get_args() set_seed(args.seed) config = BeeConfig( vocab_size=args.vocab_size, hidden_size=args.hidden_size, num_hidden_layers=args.num_layers, num_attention_heads=args.num_heads, intermediate_size=args.intermediate_size, max_position_embeddings=args.max_seq_length, tie_word_embeddings=False, ) logger.info("Initializing model with config: %s", config.to_dict()) model = BeeForCausalLM(config) n_params = sum(p.numel() for p in model.parameters()) logger.info("Model parameters: %.2fM", n_params / 1e6) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info("Loading dataset: %s", args.dataset) ds = load_dataset(args.dataset, split="train", streaming=True) eval_ds = load_dataset(args.dataset, split="validation", streaming=True) if "validation" in load_dataset(args.dataset).keys() else None def tokenize_function(examples): return tokenizer(examples[args.dataset_text_field], truncation=True, max_length=args.max_seq_length) ds = ds.map(tokenize_function, batched=True, remove_columns=[args.dataset_text_field]) if eval_ds is not None: eval_ds = eval_ds.map(tokenize_function, batched=True, remove_columns=[args.dataset_text_field]) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) training_args = TrainingArguments( output_dir=args.output_dir, overwrite_output_dir=True, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, warmup_steps=args.warmup_steps, save_steps=args.save_steps, eval_steps=args.eval_steps, logging_steps=args.logging_steps, evaluation_strategy="steps" if eval_ds is not None else "no", save_strategy="steps", bf16=args.bf16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported(), fp16=args.fp16, gradient_checkpointing=args.gradient_checkpointing, report_to=["tensorboard"], push_to_hub=args.push_to_hub, hub_model_id=args.hub_model_id, dataloader_num_workers=4, remove_unused_columns=False, ) trainer = Trainer( model=model, args=training_args, train_dataset=ds, eval_dataset=eval_ds, data_collator=data_collator, tokenizer=tokenizer, ) logger.info("Starting training...") trainer.train() logger.info("Training complete. Saving final model to %s", args.output_dir) trainer.save_model(args.output_dir) tokenizer.save_pretrained(args.output_dir) if __name__ == "__main__": main()