"""Supervised Fine-Tuning (SFT) for Bee using TRL + Accelerate.""" import argparse import logging import sys from pathlib import Path from datasets import load_dataset from transformers import AutoTokenizer, TrainingArguments, set_seed from trl import SFTTrainer, SFTConfig sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from bee.register import register from bee.config import BeeConfig from bee.modeling_bee import BeeForCausalLM register() logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") logger = logging.getLogger("bee.sft") def get_args(): parser = argparse.ArgumentParser(description="SFT train Bee") parser.add_argument("--model_path", type=str, required=True, help="Path to pretrained Bee checkpoint") parser.add_argument("--dataset", type=str, default="tatsu-lab/alpaca", help="HF dataset for SFT") parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--max_seq_length", type=int, default=2048) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--gradient_accumulation_steps", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=2e-5) parser.add_argument("--num_train_epochs", type=int, default=3) parser.add_argument("--warmup_ratio", type=float, default=0.03) parser.add_argument("--save_steps", type=int, default=500) parser.add_argument("--logging_steps", type=int, default=50) parser.add_argument("--bf16", action="store_true", default=True) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--push_to_hub", action="store_true", default=False) parser.add_argument("--hub_model_id", type=str, default=None) return parser.parse_args() def formatting_alpaca(examples): texts = [] for instruction, input_text, output in zip(examples["instruction"], examples.get("input", []), examples["output"]): if input_text: text = f"### Instruction:\n{instruction}\n### Input:\n{input_text}\n### Response:\n{output}" else: text = f"### Instruction:\n{instruction}\n### Response:\n{output}" texts.append(text) return {"text": texts} def main(): args = get_args() set_seed(args.seed) logger.info("Loading model from %s", args.model_path) model = BeeForCausalLM.from_pretrained(args.model_path) tokenizer = AutoTokenizer.from_pretrained(args.model_path) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info("Loading SFT dataset: %s", args.dataset) ds = load_dataset(args.dataset, split="train") if "alpaca" in args.dataset.lower(): ds = ds.map(formatting_alpaca, batched=True) training_args = SFTConfig( output_dir=args.output_dir, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, warmup_ratio=args.warmup_ratio, logging_steps=args.logging_steps, save_steps=args.save_steps, save_strategy="steps", bf16=args.bf16, max_seq_length=args.max_seq_length, dataset_text_field="text", report_to=["tensorboard"], push_to_hub=args.push_to_hub, hub_model_id=args.hub_model_id, ) trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=ds, args=training_args, ) logger.info("Starting SFT training...") trainer.train() logger.info("SFT complete. Saving to %s", args.output_dir) trainer.save_model(args.output_dir) tokenizer.save_pretrained(args.output_dir) if __name__ == "__main__": main()