| """Direct Preference Optimization (DPO) for Bee using TRL.""" |
|
|
| import argparse |
| import logging |
| import sys |
| from pathlib import Path |
|
|
| from datasets import load_dataset |
| from transformers import AutoTokenizer, TrainingArguments, set_seed |
| from trl import DPOTrainer, DPOConfig |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from bee.register import register |
| from bee.modeling_bee import BeeForCausalLM |
|
|
| register() |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s") |
| logger = logging.getLogger("bee.dpo") |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(description="DPO train Bee") |
| parser.add_argument("--model_path", type=str, required=True, help="SFT checkpoint to align") |
| parser.add_argument("--dataset", type=str, default="trl-lib/ultrafeedback_binarized", help="HF preference dataset") |
| parser.add_argument("--output_dir", type=str, required=True) |
| parser.add_argument("--max_length", type=int, default=2048) |
| parser.add_argument("--batch_size", type=int, default=2) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=8) |
| parser.add_argument("--learning_rate", type=float, default=5e-7) |
| parser.add_argument("--num_train_epochs", type=int, default=1) |
| parser.add_argument("--beta", type=float, default=0.1) |
| parser.add_argument("--save_steps", type=int, default=500) |
| parser.add_argument("--logging_steps", type=int, default=50) |
| parser.add_argument("--bf16", action="store_true", default=True) |
| parser.add_argument("--seed", type=int, default=42) |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = get_args() |
| set_seed(args.seed) |
|
|
| logger.info("Loading model from %s", args.model_path) |
| model = BeeForCausalLM.from_pretrained(args.model_path) |
| ref_model = BeeForCausalLM.from_pretrained(args.model_path) |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| logger.info("Loading preference dataset: %s", args.dataset) |
| ds = load_dataset(args.dataset, split="train") |
|
|
| training_args = DPOConfig( |
| output_dir=args.output_dir, |
| num_train_epochs=args.num_train_epochs, |
| per_device_train_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| learning_rate=args.learning_rate, |
| beta=args.beta, |
| logging_steps=args.logging_steps, |
| save_steps=args.save_steps, |
| save_strategy="steps", |
| bf16=args.bf16, |
| max_length=args.max_length, |
| report_to=["tensorboard"], |
| ) |
|
|
| trainer = DPOTrainer( |
| model=model, |
| ref_model=ref_model, |
| args=training_args, |
| train_dataset=ds, |
| tokenizer=tokenizer, |
| ) |
|
|
| logger.info("Starting DPO training...") |
| trainer.train() |
| logger.info("DPO complete. Saving to %s", args.output_dir) |
| trainer.save_model(args.output_dir) |
| tokenizer.save_pretrained(args.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|