File size: 5,614 Bytes
db82745
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""Pre-train Bee from scratch on a text corpus (e.g. TinyStories, OpenWebText)."""

import argparse
import logging
import os
import sys
from pathlib import Path

import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    set_seed,
)

# Ensure bee is discoverable
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.register import register
from bee.config import BeeConfig
from bee.modeling_bee import BeeForCausalLM

register()

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger("bee.pretrain")


def get_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Pre-train Bee from scratch")
    parser.add_argument("--dataset", type=str, default="roneneldan/TinyStories", help="HF dataset name")
    parser.add_argument("--dataset_text_field", type=str, default="text", help="Text column name")
    parser.add_argument("--output_dir", type=str, required=True, help="Where to save checkpoints")
    parser.add_argument("--tokenizer_name", type=str, default="HuggingFaceTB/SmolLM2-135M", help="Tokenizer to use")
    parser.add_argument("--vocab_size", type=int, default=49152)
    parser.add_argument("--hidden_size", type=int, default=768)
    parser.add_argument("--num_layers", type=int, default=12)
    parser.add_argument("--num_heads", type=int, default=12)
    parser.add_argument("--intermediate_size", type=int, default=1536)
    parser.add_argument("--max_seq_length", type=int, default=2048)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
    parser.add_argument("--learning_rate", type=float, default=5e-4)
    parser.add_argument("--num_train_epochs", type=int, default=3)
    parser.add_argument("--warmup_steps", type=int, default=1000)
    parser.add_argument("--save_steps", type=int, default=2000)
    parser.add_argument("--eval_steps", type=int, default=2000)
    parser.add_argument("--logging_steps", type=int, default=100)
    parser.add_argument("--bf16", action="store_true", default=True)
    parser.add_argument("--fp16", action="store_true", default=False)
    parser.add_argument("--gradient_checkpointing", action="store_true", default=True)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--push_to_hub", action="store_true", default=False)
    parser.add_argument("--hub_model_id", type=str, default=None)
    return parser.parse_args()


def main():
    args = get_args()
    set_seed(args.seed)

    config = BeeConfig(
        vocab_size=args.vocab_size,
        hidden_size=args.hidden_size,
        num_hidden_layers=args.num_layers,
        num_attention_heads=args.num_heads,
        intermediate_size=args.intermediate_size,
        max_position_embeddings=args.max_seq_length,
        tie_word_embeddings=False,
    )

    logger.info("Initializing model with config: %s", config.to_dict())
    model = BeeForCausalLM(config)
    n_params = sum(p.numel() for p in model.parameters())
    logger.info("Model parameters: %.2fM", n_params / 1e6)

    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    logger.info("Loading dataset: %s", args.dataset)
    ds = load_dataset(args.dataset, split="train", streaming=True)
    eval_ds = load_dataset(args.dataset, split="validation", streaming=True) if "validation" in load_dataset(args.dataset).keys() else None

    def tokenize_function(examples):
        return tokenizer(examples[args.dataset_text_field], truncation=True, max_length=args.max_seq_length)

    ds = ds.map(tokenize_function, batched=True, remove_columns=[args.dataset_text_field])
    if eval_ds is not None:
        eval_ds = eval_ds.map(tokenize_function, batched=True, remove_columns=[args.dataset_text_field])

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    training_args = TrainingArguments(
        output_dir=args.output_dir,
        overwrite_output_dir=True,
        num_train_epochs=args.num_train_epochs,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.learning_rate,
        warmup_steps=args.warmup_steps,
        save_steps=args.save_steps,
        eval_steps=args.eval_steps,
        logging_steps=args.logging_steps,
        evaluation_strategy="steps" if eval_ds is not None else "no",
        save_strategy="steps",
        bf16=args.bf16 and torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
        fp16=args.fp16,
        gradient_checkpointing=args.gradient_checkpointing,
        report_to=["tensorboard"],
        push_to_hub=args.push_to_hub,
        hub_model_id=args.hub_model_id,
        dataloader_num_workers=4,
        remove_unused_columns=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=ds,
        eval_dataset=eval_ds,
        data_collator=data_collator,
        tokenizer=tokenizer,
    )

    logger.info("Starting training...")
    trainer.train()
    logger.info("Training complete. Saving final model to %s", args.output_dir)
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)


if __name__ == "__main__":
    main()