| """Pre-train Bee from scratch on a text corpus (e.g. TinyStories, OpenWebText).""" |
|
|
| import argparse |
| import logging |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| from datasets import load_dataset |
| from transformers import ( |
| AutoTokenizer, |
| TrainingArguments, |
| Trainer, |
| DataCollatorForLanguageModeling, |
| set_seed, |
| ) |
|
|
| |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from bee.register import register |
| from bee.config import BeeConfig |
| from bee.modeling_bee import BeeForCausalLM |
|
|
| register() |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| ) |
| logger = logging.getLogger("bee.pretrain") |
|
|
|
|
| def get_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Pre-train Bee from scratch") |
| parser.add_argument("--dataset", type=str, default="roneneldan/TinyStories", help="HF dataset name") |
| parser.add_argument("--dataset_text_field", type=str, default="text", help="Text column name") |
| parser.add_argument("--output_dir", type=str, required=True, help="Where to save checkpoints") |
| parser.add_argument("--tokenizer_name", type=str, default="HuggingFaceTB/SmolLM2-135M", help="Tokenizer to use") |
| parser.add_argument("--vocab_size", type=int, default=49152) |
| parser.add_argument("--hidden_size", type=int, default=768) |
| parser.add_argument("--num_layers", type=int, default=12) |
| parser.add_argument("--num_heads", type=int, default=12) |
| parser.add_argument("--intermediate_size", type=int, default=1536) |
| parser.add_argument("--max_seq_length", type=int, default=2048) |
| parser.add_argument("--batch_size", type=int, default=8) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=4) |
| parser.add_argument("--learning_rate", type=float, default=5e-4) |
| parser.add_argument("--num_train_epochs", type=int, default=3) |
| parser.add_argument("--warmup_steps", type=int, default=1000) |
| parser.add_argument("--save_steps", type=int, default=2000) |
| parser.add_argument("--eval_steps", type=int, default=2000) |
| parser.add_argument("--logging_steps", type=int, default=100) |
| parser.add_argument("--bf16", action="store_true", default=True) |
| parser.add_argument("--fp16", action="store_true", default=False) |
| parser.add_argument("--gradient_checkpointing", action="store_true", default=True) |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--push_to_hub", action="store_true", default=False) |
| parser.add_argument("--hub_model_id", type=str, default=None) |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = get_args() |
| set_seed(args.seed) |
|
|
| config = BeeConfig( |
| vocab_size=args.vocab_size, |
| hidden_size=args.hidden_size, |
| num_hidden_layers=args.num_layers, |
| num_attention_heads=args.num_heads, |
| intermediate_size=args.intermediate_size, |
| max_position_embeddings=args.max_seq_length, |
| tie_word_embeddings=False, |
| ) |
|
|
| logger.info("Initializing model with config: %s", config.to_dict()) |
| model = BeeForCausalLM(config) |
| n_params = sum(p.numel() for p in model.parameters()) |
| logger.info("Model parameters: %.2fM", n_params / 1e6) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| logger.info("Loading dataset: %s", args.dataset) |
| ds = load_dataset(args.dataset, split="train", streaming=True) |
| eval_ds = load_dataset(args.dataset, split="validation", streaming=True) if "validation" in load_dataset(args.dataset).keys() else None |
|
|
| def tokenize_function(examples): |
| return tokenizer(examples[args.dataset_text_field], truncation=True, max_length=args.max_seq_length) |
|
|
| ds = ds.map(tokenize_function, batched=True, remove_columns=[args.dataset_text_field]) |
| if eval_ds is not None: |
| eval_ds = eval_ds.map(tokenize_function, batched=True, remove_columns=[args.dataset_text_field]) |
|
|
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
| training_args = TrainingArguments( |
| output_dir=args.output_dir, |
| overwrite_output_dir=True, |
| num_train_epochs=args.num_train_epochs, |
| per_device_train_batch_size=args.batch_size, |
| per_device_eval_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| learning_rate=args.learning_rate, |
| warmup_steps=args.warmup_steps, |
| save_steps=args.save_steps, |
| eval_steps=args.eval_steps, |
| logging_steps=args.logging_steps, |
| evaluation_strategy="steps" if eval_ds is not None else "no", |
| save_strategy="steps", |
| bf16=args.bf16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported(), |
| fp16=args.fp16, |
| gradient_checkpointing=args.gradient_checkpointing, |
| report_to=["tensorboard"], |
| push_to_hub=args.push_to_hub, |
| hub_model_id=args.hub_model_id, |
| dataloader_num_workers=4, |
| remove_unused_columns=False, |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=ds, |
| eval_dataset=eval_ds, |
| data_collator=data_collator, |
| tokenizer=tokenizer, |
| ) |
|
|
| logger.info("Starting training...") |
| trainer.train() |
| logger.info("Training complete. Saving final model to %s", args.output_dir) |
| trainer.save_model(args.output_dir) |
| tokenizer.save_pretrained(args.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|