| """Supervised Fine-Tuning (SFT) for Bee using TRL + Accelerate.""" |
|
|
| import argparse |
| import logging |
| import sys |
| from pathlib import Path |
|
|
| from datasets import load_dataset |
| from transformers import AutoTokenizer, TrainingArguments, set_seed |
| from trl import SFTTrainer, SFTConfig |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from bee.register import register |
| from bee.config import BeeConfig |
| from bee.modeling_bee import BeeForCausalLM |
|
|
| register() |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") |
| logger = logging.getLogger("bee.sft") |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(description="SFT train Bee") |
| parser.add_argument("--model_path", type=str, required=True, help="Path to pretrained Bee checkpoint") |
| parser.add_argument("--dataset", type=str, default="tatsu-lab/alpaca", help="HF dataset for SFT") |
| parser.add_argument("--output_dir", type=str, required=True) |
| parser.add_argument("--max_seq_length", type=int, default=2048) |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=4) |
| parser.add_argument("--learning_rate", type=float, default=2e-5) |
| parser.add_argument("--num_train_epochs", type=int, default=3) |
| parser.add_argument("--warmup_ratio", type=float, default=0.03) |
| parser.add_argument("--save_steps", type=int, default=500) |
| parser.add_argument("--logging_steps", type=int, default=50) |
| parser.add_argument("--bf16", action="store_true", default=True) |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--push_to_hub", action="store_true", default=False) |
| parser.add_argument("--hub_model_id", type=str, default=None) |
| return parser.parse_args() |
|
|
|
|
| def formatting_alpaca(examples): |
| texts = [] |
| for instruction, input_text, output in zip(examples["instruction"], examples.get("input", []), examples["output"]): |
| if input_text: |
| text = f"### Instruction:\n{instruction}\n### Input:\n{input_text}\n### Response:\n{output}" |
| else: |
| text = f"### Instruction:\n{instruction}\n### Response:\n{output}" |
| texts.append(text) |
| return {"text": texts} |
|
|
|
|
| def main(): |
| args = get_args() |
| set_seed(args.seed) |
|
|
| logger.info("Loading model from %s", args.model_path) |
| model = BeeForCausalLM.from_pretrained(args.model_path) |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| logger.info("Loading SFT dataset: %s", args.dataset) |
| ds = load_dataset(args.dataset, split="train") |
| if "alpaca" in args.dataset.lower(): |
| ds = ds.map(formatting_alpaca, batched=True) |
|
|
| training_args = SFTConfig( |
| output_dir=args.output_dir, |
| num_train_epochs=args.num_train_epochs, |
| per_device_train_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| learning_rate=args.learning_rate, |
| warmup_ratio=args.warmup_ratio, |
| logging_steps=args.logging_steps, |
| save_steps=args.save_steps, |
| save_strategy="steps", |
| bf16=args.bf16, |
| max_seq_length=args.max_seq_length, |
| dataset_text_field="text", |
| report_to=["tensorboard"], |
| push_to_hub=args.push_to_hub, |
| hub_model_id=args.hub_model_id, |
| ) |
|
|
| trainer = SFTTrainer( |
| model=model, |
| tokenizer=tokenizer, |
| train_dataset=ds, |
| args=training_args, |
| ) |
|
|
| logger.info("Starting SFT training...") |
| trainer.train() |
| logger.info("SFT complete. Saving to %s", args.output_dir) |
| trainer.save_model(args.output_dir) |
| tokenizer.save_pretrained(args.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|