"""Direct Preference Optimization (DPO) for Bee using TRL.""" import argparse import logging import sys from pathlib import Path from datasets import load_dataset from transformers import AutoTokenizer, TrainingArguments, set_seed from trl import DPOTrainer, DPOConfig sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from bee.register import register from bee.modeling_bee import BeeForCausalLM register() logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") logger = logging.getLogger("bee.dpo") def get_args(): parser = argparse.ArgumentParser(description="DPO train Bee") parser.add_argument("--model_path", type=str, required=True, help="SFT checkpoint to align") parser.add_argument("--dataset", type=str, default="trl-lib/ultrafeedback_binarized", help="HF preference dataset") parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--max_length", type=int, default=2048) parser.add_argument("--batch_size", type=int, default=2) parser.add_argument("--gradient_accumulation_steps", type=int, default=8) parser.add_argument("--learning_rate", type=float, default=5e-7) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument("--beta", type=float, default=0.1) parser.add_argument("--save_steps", type=int, default=500) parser.add_argument("--logging_steps", type=int, default=50) parser.add_argument("--bf16", action="store_true", default=True) parser.add_argument("--seed", type=int, default=42) return parser.parse_args() def main(): args = get_args() set_seed(args.seed) logger.info("Loading model from %s", args.model_path) model = BeeForCausalLM.from_pretrained(args.model_path) ref_model = BeeForCausalLM.from_pretrained(args.model_path) tokenizer = AutoTokenizer.from_pretrained(args.model_path) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info("Loading preference dataset: %s", args.dataset) ds = load_dataset(args.dataset, split="train") training_args = DPOConfig( output_dir=args.output_dir, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, beta=args.beta, logging_steps=args.logging_steps, save_steps=args.save_steps, save_strategy="steps", bf16=args.bf16, max_length=args.max_length, report_to=["tensorboard"], ) trainer = DPOTrainer( model=model, ref_model=ref_model, args=training_args, train_dataset=ds, tokenizer=tokenizer, ) logger.info("Starting DPO training...") trainer.train() logger.info("DPO complete. Saving to %s", args.output_dir) trainer.save_model(args.output_dir) tokenizer.save_pretrained(args.output_dir) if __name__ == "__main__": main()