bee / scripts /train_pretrain.py
ceocxx's picture
chore: deploy Bee API backend (bee/, Dockerfile, requirements)
db82745 verified
"""Pre-train Bee from scratch on a text corpus (e.g. TinyStories, OpenWebText)."""
import argparse
import logging
import os
import sys
from pathlib import Path
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
set_seed,
)
# Ensure bee is discoverable
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.register import register
from bee.config import BeeConfig
from bee.modeling_bee import BeeForCausalLM
register()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger("bee.pretrain")
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Pre-train Bee from scratch")
parser.add_argument("--dataset", type=str, default="roneneldan/TinyStories", help="HF dataset name")
parser.add_argument("--dataset_text_field", type=str, default="text", help="Text column name")
parser.add_argument("--output_dir", type=str, required=True, help="Where to save checkpoints")
parser.add_argument("--tokenizer_name", type=str, default="HuggingFaceTB/SmolLM2-135M", help="Tokenizer to use")
parser.add_argument("--vocab_size", type=int, default=49152)
parser.add_argument("--hidden_size", type=int, default=768)
parser.add_argument("--num_layers", type=int, default=12)
parser.add_argument("--num_heads", type=int, default=12)
parser.add_argument("--intermediate_size", type=int, default=1536)
parser.add_argument("--max_seq_length", type=int, default=2048)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--num_train_epochs", type=int, default=3)
parser.add_argument("--warmup_steps", type=int, default=1000)
parser.add_argument("--save_steps", type=int, default=2000)
parser.add_argument("--eval_steps", type=int, default=2000)
parser.add_argument("--logging_steps", type=int, default=100)
parser.add_argument("--bf16", action="store_true", default=True)
parser.add_argument("--fp16", action="store_true", default=False)
parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--push_to_hub", action="store_true", default=False)
parser.add_argument("--hub_model_id", type=str, default=None)
return parser.parse_args()
def main():
args = get_args()
set_seed(args.seed)
config = BeeConfig(
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_heads,
intermediate_size=args.intermediate_size,
max_position_embeddings=args.max_seq_length,
tie_word_embeddings=False,
)
logger.info("Initializing model with config: %s", config.to_dict())
model = BeeForCausalLM(config)
n_params = sum(p.numel() for p in model.parameters())
logger.info("Model parameters: %.2fM", n_params / 1e6)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Loading dataset: %s", args.dataset)
ds = load_dataset(args.dataset, split="train", streaming=True)
eval_ds = load_dataset(args.dataset, split="validation", streaming=True) if "validation" in load_dataset(args.dataset).keys() else None
def tokenize_function(examples):
return tokenizer(examples[args.dataset_text_field], truncation=True, max_length=args.max_seq_length)
ds = ds.map(tokenize_function, batched=True, remove_columns=[args.dataset_text_field])
if eval_ds is not None:
eval_ds = eval_ds.map(tokenize_function, batched=True, remove_columns=[args.dataset_text_field])
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=args.output_dir,
overwrite_output_dir=True,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
warmup_steps=args.warmup_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
logging_steps=args.logging_steps,
evaluation_strategy="steps" if eval_ds is not None else "no",
save_strategy="steps",
bf16=args.bf16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
fp16=args.fp16,
gradient_checkpointing=args.gradient_checkpointing,
report_to=["tensorboard"],
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id,
dataloader_num_workers=4,
remove_unused_columns=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=ds,
eval_dataset=eval_ds,
data_collator=data_collator,
tokenizer=tokenizer,
)
logger.info("Starting training...")
trainer.train()
logger.info("Training complete. Saving final model to %s", args.output_dir)
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
if __name__ == "__main__":
main()