File size: 5,614 Bytes
db82745 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """Pre-train Bee from scratch on a text corpus (e.g. TinyStories, OpenWebText)."""
import argparse
import logging
import os
import sys
from pathlib import Path
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
set_seed,
)
# Ensure bee is discoverable
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.register import register
from bee.config import BeeConfig
from bee.modeling_bee import BeeForCausalLM
register()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger("bee.pretrain")
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Pre-train Bee from scratch")
parser.add_argument("--dataset", type=str, default="roneneldan/TinyStories", help="HF dataset name")
parser.add_argument("--dataset_text_field", type=str, default="text", help="Text column name")
parser.add_argument("--output_dir", type=str, required=True, help="Where to save checkpoints")
parser.add_argument("--tokenizer_name", type=str, default="HuggingFaceTB/SmolLM2-135M", help="Tokenizer to use")
parser.add_argument("--vocab_size", type=int, default=49152)
parser.add_argument("--hidden_size", type=int, default=768)
parser.add_argument("--num_layers", type=int, default=12)
parser.add_argument("--num_heads", type=int, default=12)
parser.add_argument("--intermediate_size", type=int, default=1536)
parser.add_argument("--max_seq_length", type=int, default=2048)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--num_train_epochs", type=int, default=3)
parser.add_argument("--warmup_steps", type=int, default=1000)
parser.add_argument("--save_steps", type=int, default=2000)
parser.add_argument("--eval_steps", type=int, default=2000)
parser.add_argument("--logging_steps", type=int, default=100)
parser.add_argument("--bf16", action="store_true", default=True)
parser.add_argument("--fp16", action="store_true", default=False)
parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--push_to_hub", action="store_true", default=False)
parser.add_argument("--hub_model_id", type=str, default=None)
return parser.parse_args()
def main():
args = get_args()
set_seed(args.seed)
config = BeeConfig(
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_heads,
intermediate_size=args.intermediate_size,
max_position_embeddings=args.max_seq_length,
tie_word_embeddings=False,
)
logger.info("Initializing model with config: %s", config.to_dict())
model = BeeForCausalLM(config)
n_params = sum(p.numel() for p in model.parameters())
logger.info("Model parameters: %.2fM", n_params / 1e6)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Loading dataset: %s", args.dataset)
ds = load_dataset(args.dataset, split="train", streaming=True)
eval_ds = load_dataset(args.dataset, split="validation", streaming=True) if "validation" in load_dataset(args.dataset).keys() else None
def tokenize_function(examples):
return tokenizer(examples[args.dataset_text_field], truncation=True, max_length=args.max_seq_length)
ds = ds.map(tokenize_function, batched=True, remove_columns=[args.dataset_text_field])
if eval_ds is not None:
eval_ds = eval_ds.map(tokenize_function, batched=True, remove_columns=[args.dataset_text_field])
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=args.output_dir,
overwrite_output_dir=True,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
warmup_steps=args.warmup_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
logging_steps=args.logging_steps,
evaluation_strategy="steps" if eval_ds is not None else "no",
save_strategy="steps",
bf16=args.bf16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
fp16=args.fp16,
gradient_checkpointing=args.gradient_checkpointing,
report_to=["tensorboard"],
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id,
dataloader_num_workers=4,
remove_unused_columns=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=ds,
eval_dataset=eval_ds,
data_collator=data_collator,
tokenizer=tokenizer,
)
logger.info("Starting training...")
trainer.train()
logger.info("Training complete. Saving final model to %s", args.output_dir)
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
if __name__ == "__main__":
main()
|