bee / scripts /train_dpo.py
ceocxx's picture
chore: deploy Bee API backend (bee/, Dockerfile, requirements)
db82745 verified
"""Direct Preference Optimization (DPO) for Bee using TRL."""
import argparse
import logging
import sys
from pathlib import Path
from datasets import load_dataset
from transformers import AutoTokenizer, TrainingArguments, set_seed
from trl import DPOTrainer, DPOConfig
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from bee.register import register
from bee.modeling_bee import BeeForCausalLM
register()
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger("bee.dpo")
def get_args():
parser = argparse.ArgumentParser(description="DPO train Bee")
parser.add_argument("--model_path", type=str, required=True, help="SFT checkpoint to align")
parser.add_argument("--dataset", type=str, default="trl-lib/ultrafeedback_binarized", help="HF preference dataset")
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=5e-7)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument("--beta", type=float, default=0.1)
parser.add_argument("--save_steps", type=int, default=500)
parser.add_argument("--logging_steps", type=int, default=50)
parser.add_argument("--bf16", action="store_true", default=True)
parser.add_argument("--seed", type=int, default=42)
return parser.parse_args()
def main():
args = get_args()
set_seed(args.seed)
logger.info("Loading model from %s", args.model_path)
model = BeeForCausalLM.from_pretrained(args.model_path)
ref_model = BeeForCausalLM.from_pretrained(args.model_path)
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Loading preference dataset: %s", args.dataset)
ds = load_dataset(args.dataset, split="train")
training_args = DPOConfig(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
beta=args.beta,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
save_strategy="steps",
bf16=args.bf16,
max_length=args.max_length,
report_to=["tensorboard"],
)
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=ds,
tokenizer=tokenizer,
)
logger.info("Starting DPO training...")
trainer.train()
logger.info("DPO complete. Saving to %s", args.output_dir)
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
if __name__ == "__main__":
main()