bee / scripts /train_sft.py
ceocxx's picture
chore: deploy Bee API backend (bee/, Dockerfile, requirements)
db82745 verified
"""Supervised Fine-Tuning (SFT) for Bee using TRL + Accelerate."""
import argparse
import logging
import sys
from pathlib import Path
from datasets import load_dataset
from transformers import AutoTokenizer, TrainingArguments, set_seed
from trl import SFTTrainer, SFTConfig
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.register import register
from bee.config import BeeConfig
from bee.modeling_bee import BeeForCausalLM
register()
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger("bee.sft")
def get_args():
parser = argparse.ArgumentParser(description="SFT train Bee")
parser.add_argument("--model_path", type=str, required=True, help="Path to pretrained Bee checkpoint")
parser.add_argument("--dataset", type=str, default="tatsu-lab/alpaca", help="HF dataset for SFT")
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--max_seq_length", type=int, default=2048)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--num_train_epochs", type=int, default=3)
parser.add_argument("--warmup_ratio", type=float, default=0.03)
parser.add_argument("--save_steps", type=int, default=500)
parser.add_argument("--logging_steps", type=int, default=50)
parser.add_argument("--bf16", action="store_true", default=True)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--push_to_hub", action="store_true", default=False)
parser.add_argument("--hub_model_id", type=str, default=None)
return parser.parse_args()
def formatting_alpaca(examples):
texts = []
for instruction, input_text, output in zip(examples["instruction"], examples.get("input", []), examples["output"]):
if input_text:
text = f"### Instruction:\n{instruction}\n### Input:\n{input_text}\n### Response:\n{output}"
else:
text = f"### Instruction:\n{instruction}\n### Response:\n{output}"
texts.append(text)
return {"text": texts}
def main():
args = get_args()
set_seed(args.seed)
logger.info("Loading model from %s", args.model_path)
model = BeeForCausalLM.from_pretrained(args.model_path)
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Loading SFT dataset: %s", args.dataset)
ds = load_dataset(args.dataset, split="train")
if "alpaca" in args.dataset.lower():
ds = ds.map(formatting_alpaca, batched=True)
training_args = SFTConfig(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
warmup_ratio=args.warmup_ratio,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
save_strategy="steps",
bf16=args.bf16,
max_seq_length=args.max_seq_length,
dataset_text_field="text",
report_to=["tensorboard"],
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id,
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=ds,
args=training_args,
)
logger.info("Starting SFT training...")
trainer.train()
logger.info("SFT complete. Saving to %s", args.output_dir)
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
if __name__ == "__main__":
main()