diff --git a/ICL/RL/trl_source/.github/workflows/tests-experimental.yml b/ICL/RL/trl_source/.github/workflows/tests-experimental.yml new file mode 100644 index 0000000000000000000000000000000000000000..41f99855da8f415aee65d39ab126ce9f29517d85 --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/tests-experimental.yml @@ -0,0 +1,70 @@ +name: Tests (experimental) + +on: + pull_request: + paths: + # Run only when relevant files are modified + - "trl/experimental/**" + - "tests/experimental/**" + +env: + TQDM_DISABLE: 1 + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" + TRL_EXPERIMENTAL_SILENCE: 1 + +jobs: + check_code_quality: + name: Check code quality + runs-on: ubuntu-latest + if: github.event.pull_request.draft == false + steps: + - uses: actions/checkout@v6 + - name: Set up Python 3.13 + uses: actions/setup-python@v6 + with: + python-version: 3.13 + - uses: pre-commit/action@v3.0.1 + with: + extra_args: --all-files + + tests: + name: Tests (experimental) + runs-on: + group: aws-g4dn-2xlarge + container: + image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + steps: + - name: Git checkout + uses: actions/checkout@v6 + + - name: Set up Python 3.13 + uses: actions/setup-python@v6 + with: + python-version: 3.13 + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + + - name: Test with pytest + run: | + source .venv/bin/activate + make test_experimental diff --git a/ICL/RL/trl_source/.github/workflows/tests_transformers_branch.yml b/ICL/RL/trl_source/.github/workflows/tests_transformers_branch.yml new file mode 100644 index 0000000000000000000000000000000000000000..df5b21ac919bdb785d103ee06683dc7c60bcb153 --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/tests_transformers_branch.yml @@ -0,0 +1,121 @@ +name: Tests against Transformers branch + +on: + workflow_dispatch: + inputs: + transformers_ref: + description: "Transformers git ref (branch, tag, or commit SHA)" + required: true + default: "main" + +env: + TQDM_DISABLE: 1 + CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }} + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" + +jobs: + tests_transformers_branch: + name: Tests with Transformers ${{ inputs.transformers_ref }} + runs-on: + group: aws-g4dn-2xlarge + container: + image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + steps: + - name: Git checkout + uses: actions/checkout@v6 + + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: '3.12' + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + uv pip install -U git+https://github.com/huggingface/transformers.git@${{ inputs.transformers_ref }} + + - name: Test with pytest + run: | + source .venv/bin/activate + make test + + - name: Post to Slack + if: github.ref == 'refs/heads/main' && always() + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: Results with Transformers ${{ inputs.transformers_ref }} + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + + distributed_smoke: + name: Distributed smoke tests with Transformers ${{ inputs.transformers_ref }} + runs-on: + group: aws-g5-12xlarge-cache + container: + image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + env: + CUDA_VISIBLE_DEVICES: "0,1" + steps: + - name: Git checkout + uses: actions/checkout@v6 + + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: '3.12' + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + uv pip install -U git+https://github.com/huggingface/transformers.git@${{ inputs.transformers_ref }} + + - name: Run distributed smoke tests + run: | + source .venv/bin/activate + pytest -v tests/distributed/test_distributed.py + + - name: Post to Slack + if: github.ref == 'refs/heads/main' && always() + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: Results of distributed smoke tests with Transformers ${{ inputs.transformers_ref }} + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} diff --git a/ICL/RL/trl_source/examples/scripts/evals/judge_tldr.py b/ICL/RL/trl_source/examples/scripts/evals/judge_tldr.py new file mode 100644 index 0000000000000000000000000000000000000000..25bee0fae6a0a2c2d3f65c37965605662290b49a --- /dev/null +++ b/ICL/RL/trl_source/examples/scripts/evals/judge_tldr.py @@ -0,0 +1,108 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl[vllm]", +# ] +# /// + +from dataclasses import dataclass, field + +from datasets import load_dataset +from transformers import HfArgumentParser +from vllm import LLM, SamplingParams + +from trl.experimental.judges import HfPairwiseJudge, OpenAIPairwiseJudge + + +""" +Examples: + +python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --num_examples 1000 +Model win rate: 31.40% + +python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-3.5-turbo-0125 --num_examples 1000 +Model win rate: 51.60% + +python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 51.20% + +python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --num_examples 1000 +Model win rate: 46.30% + +python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-3.5-turbo-0125 --num_examples 1000 +Model win rate: 52.50% + +python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 63.00% +""" + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + model_name_or_path (`str`): + Model name or path to the model to evaluate. + judge_model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3-70B-Instruct"`): + Model name or path to the model to use as a judge. E.g., 'gpt-3.5-turbo-0125' or + 'meta-llama/Meta-Llama-3-70B-Instruct'. + num_examples (`int`, *optional*): + Number of examples to evaluate. + """ + + model_name_or_path: str = field(metadata={"help": "Model name or path to the model to evaluate."}) + judge_model: str = field( + default="meta-llama/Meta-Llama-3-70B-Instruct", + metadata={ + "help": "Model name or path to the model to use as a judge. E.g., 'gpt-3.5-turbo-0125' or " + "'meta-llama/Meta-Llama-3-70B-Instruct'." + }, + ) + num_examples: int | None = field(default=None, metadata={"help": "Number of examples to evaluate."}) + + +if __name__ == "__main__": + # Parse the arguments + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + # Load the dataset + dataset = load_dataset("trl-lib/tldr", split="validation") + if script_args.num_examples is not None: + dataset = dataset.select(range(script_args.num_examples)) + + # Extract the prompts and reference completions + prompts = dataset["prompt"] + reference_completions = dataset["completion"] + + # Generate the model completions + sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=200) # very generous max token length + llm = LLM(model=script_args.model_name_or_path, tensor_parallel_size=1) + outputs = llm.generate(prompts, sampling_params) + model_completions = [output.outputs[0].text.strip() for output in outputs] + + # Judge the outputs + if "gpt" in script_args.judge_model: + judge = OpenAIPairwiseJudge(script_args.judge_model) + else: + judge = HfPairwiseJudge(script_args.judge_model) + + completions = [[c0, c1] for c0, c1 in zip(reference_completions, model_completions, strict=True)] + best_idxs = judge.judge(prompts, completions) + model_win_rate = best_idxs.count(1) / len(best_idxs) + print(f"Model win rate: {model_win_rate * 100:.2f}%") diff --git a/ICL/RL/trl_source/examples/scripts/nemo_gym/deepspeed_zero3.yaml b/ICL/RL/trl_source/examples/scripts/nemo_gym/deepspeed_zero3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ac6ad51adb0fe9b190d88cdd8588d5ededd6540a --- /dev/null +++ b/ICL/RL/trl_source/examples/scripts/nemo_gym/deepspeed_zero3.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 4 +num_processes: 32 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/ICL/RL/trl_source/examples/scripts/nemo_gym/submit.sh b/ICL/RL/trl_source/examples/scripts/nemo_gym/submit.sh new file mode 100644 index 0000000000000000000000000000000000000000..c819c0fa45dc68dcc336ea622e4928c995cad8ba --- /dev/null +++ b/ICL/RL/trl_source/examples/scripts/nemo_gym/submit.sh @@ -0,0 +1,112 @@ +#!/bin/bash +#SBATCH -A account +#SBATCH -p partition +#SBATCH -N 5 +#SBATCH --gres gpu:8 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --time=4:00:00 +#SBATCH --job-name=trl_nemo_gym +#SBATCH --output=logs/%j/slurm.out +#SBATCH --error=logs/%j/slurm.err + +CONTAINER_IMAGE="nvcr.io/nvidia/pytorch:25.12-py3" +MOUNTS="/path/to/mounts:/path/to/mounts" + +NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) + +TRAIN_NODE_0="${NODELIST[0]}" +TRAIN_NODE_1="${NODELIST[1]}" +TRAIN_NODE_2="${NODELIST[2]}" +TRAIN_NODE_3="${NODELIST[3]}" +VLLM_NODE="${NODELIST[4]}" + +echo "Training Nodes: $TRAIN_NODE_0, $TRAIN_NODE_1, $TRAIN_NODE_2, $TRAIN_NODE_3" +echo "vLLM Node: $VLLM_NODE" +echo "Main process IP: $TRAIN_NODE_0" + +LOG_DIR="logs/${SLURM_JOB_ID}" +mkdir -p ${LOG_DIR} + +echo "Starting ng_run and vLLM on ${VLLM_NODE}..." +echo "Logs will be saved to: ${LOG_DIR}" + +# NOTE: If you have already set up your TRL venv, you can remove all of the pip installs and uv venv related commands below! + +srun --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${MOUNTS}" \ + --container-mount-home \ + bash -c " + LOG_DIR=/path/to/logs + mkdir -p \${LOG_DIR} + + # Install uv if not already installed + curl -LsSf https://astral.sh/uv/install.sh | sh + source \$HOME/.local/bin/env + + # Start nemo gym servers + (set -x && \ + export HOME=/path/to/user && \ + export PATH=\$HOME/.local/bin:\$PATH && \ + cd /path/to/user/Gym && \ + uv venv --python 3.12 && \ + source .venv/bin/activate && \ + uv sync && \ + ray stop --force && \ + ng_run +config_paths=[responses_api_models/vllm_model/configs/vllm_model.yaml,resources_servers/workplace_assistant/configs/workplace_assistant.yaml] +head_server.host=0.0.0.0 +head_server.port=11000) > \${LOG_DIR}/ng_run.log 2>&1 & + + sleep 10 + + # Start trl vllm server + (set -x && \ + export HOME=/path/to/user && \ + export HF_HOME=/path/to/user/hf_home && \ + cd /path/to/user/trl && \ + rm -rf .venv && uv venv && source .venv/bin/activate && uv sync && uv pip install -e .[vllm] && uv pip install fastapi uvicorn && \ + python -m trl.scripts.vllm_serve \ + --model Qwen/Qwen3-4B-Instruct-2507 \ + --host 0.0.0.0 \ + --tensor-parallel-size 8 \ + --data-parallel-size 1 \ + --max-model-len 16384 \ + --gpu-memory-utilization 0.7 \ + --port 8000) > \${LOG_DIR}/vllm_serve.log 2>&1 & + + wait +" & + +echo "Waiting for nemo gym and vllm to start..." +sleep 120 + +echo "Launching training on 4 nodes..." + +TRAIN_NODES_LIST="${TRAIN_NODE_0},${TRAIN_NODE_1},${TRAIN_NODE_2},${TRAIN_NODE_3}" + +srun --nodes=4 --ntasks=4 --nodelist="${TRAIN_NODES_LIST}" \ + --container-image="${CONTAINER_IMAGE}" \ + --container-mounts="${MOUNTS}" \ + --container-mount-home \ + bash -c " + set -x && \ + export HOME=/path/to/user && \ + export HF_HOME=/path/to/user/hf_home && \ + cd /path/to/user/trl && \ + source .venv/bin/activate && uv pip install accelerate deepspeed wandb omegaconf && \ + cd examples/scripts/nemo_gym && \ + export WANDB_API_KEY= && \ + accelerate launch \ + --config_file deepspeed_zero3.yaml \ + --num_processes 32 \ + --num_machines 4 \ + --machine_rank \$SLURM_PROCID \ + --main_process_ip ${TRAIN_NODE_0} \ + --main_process_port 29500 \ + --rdzv_backend c10d \ + train_multi_environment.py \ + --config config.yaml \ + --vllm_server_host ${VLLM_NODE} \ + --head_server_host ${VLLM_NODE}" & + +wait + diff --git a/ICL/RL/trl_source/examples/scripts/online_dpo.py b/ICL/RL/trl_source/examples/scripts/online_dpo.py new file mode 100644 index 0000000000000000000000000000000000000000..75fd964f22f5216c972793d63c61df418809d40b --- /dev/null +++ b/ICL/RL/trl_source/examples/scripts/online_dpo.py @@ -0,0 +1,159 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +""" +Usage: + +python examples/scripts/online_dpo.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-1b-tldr-online-dpo \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 16 \ + --warmup_steps 0.1 \ + --missing_eos_penalty 1.0 + +With LoRA: +python examples/scripts/online_dpo.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-6 \ + --output_dir pythia-1b-tldr-online-dpo \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 8 \ + --warmup_steps 0.1 \ + --missing_eos_penalty 1.0 \ + --use_peft +""" + +import os + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig + +from trl import ( + LogCompletionsCallback, + ModelConfig, + ScriptArguments, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.experimental.judges import HfPairwiseJudge, OpenAIPairwiseJudge, PairRMJudge +from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer + + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +JUDGES = {"pair_rm": PairRMJudge, "openai": OpenAIPairwiseJudge, "hf": HfPairwiseJudge} + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} + + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + use_cache=False if training_args.gradient_checkpointing else True, + ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + + if training_args.reward_model_path is not None: + reward_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, + num_labels=1, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + reward_tokenizer = AutoTokenizer.from_pretrained( + training_args.reward_model_path, + trust_remote_code=model_args.trust_remote_code, + truncation=True, + truncation_side="left", # since we judge the completion, truncating left is more appropriate + ) + if reward_tokenizer.pad_token_id is None: + reward_tokenizer.pad_token = reward_tokenizer.eos_token + else: + reward_model = None + reward_tokenizer = None + + if training_args.judge is not None: + judge_cls = JUDGES[training_args.judge] + judge = judge_cls() + else: + judge = None + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + padding_side="left", + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + trainer = OnlineDPOTrainer( + model=model, + reward_funcs=reward_model, + judge=judge, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + reward_processing_classes=reward_tokenizer, + peft_config=get_peft_config(model_args), + ) + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/ICL/RL/trl_source/examples/scripts/openenv/browsergym_llm.py b/ICL/RL/trl_source/examples/scripts/openenv/browsergym_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..1431681423465685aa385952d9512278bddf5bb8 --- /dev/null +++ b/ICL/RL/trl_source/examples/scripts/openenv/browsergym_llm.py @@ -0,0 +1,506 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl[vllm]", +# "peft", +# "trackio", +# "kernels", +# "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env", +# ] +# /// + +""" +Simple script to run GRPO training with OpenEnv's BrowserGym environment and vLLM for LLMs. + +This script is optimized for text-only Language Models (LLMs). It uses the accessibility +tree text from BrowserGym, making it memory-efficient. + +The environment runs on a Hugging Face Space by default. + +Setup (Option A - Install from HF Space, recommended): + +```sh +uv pip install git+https://huggingface.co/spaces/openenv/browsergym_env +``` + +Setup (Option B - Clone OpenEnv repo, for development): + +```sh +git clone https://github.com/meta-pytorch/OpenEnv.git +cd OpenEnv/envs/browsergym_env +uv pip install -e . +``` + +# Option 1: HF Spaces + Colocated vLLM (1 GPU required) +```sh +python examples/scripts/openenv/browsergym_llm.py --vllm-mode colocate +``` + +# Option 2: HF Spaces + Separate vLLM server (2 GPUs required) + +# Spin up vLLM server (Terminal 1) +```sh +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-0.6B --host 0.0.0.0 --port 8001 +``` + +# Run training (Terminal 2) +```sh +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/browsergym_llm.py --vllm-mode server --vllm-server-url http://localhost:8001 +``` +""" + +from __future__ import annotations + +import argparse +from datetime import datetime +from pathlib import Path + +from browsergym_env import BrowserGymAction, BrowserGymEnv +from datasets import Dataset +from transformers import AutoTokenizer + +from trl import GRPOConfig, GRPOTrainer +from trl.experimental.openenv import generate_rollout_completions + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run GRPO training for BrowserGym MiniWoB using OpenEnv environment.") + parser.add_argument( + "--model-id", + default="Qwen/Qwen3-0.6B", + help="Model identifier passed to GRPOTrainer for fine-tuning.", + ) + parser.add_argument( + "--space-url", + type=str, + default="https://openenv-browsergym-env.hf.space", + help="URL for the Hugging Face Space running the BrowserGym environment.", + ) + parser.add_argument( + "--benchmark", + default="miniwob", + help="BrowserGym benchmark to use (miniwob, webarena, etc.).", + ) + parser.add_argument( + "--task-name", + default="click-test", + help="Specific task within the benchmark (e.g., click-test, click-button).", + ) + parser.add_argument( + "--dataset-prompt", + default="Complete the web task successfully.", + help="Prompt text used to seed the training dataset.", + ) + parser.add_argument( + "--dataset-size", + type=int, + default=1000, + help="Number of entries to include in the synthetic training dataset.", + ) + parser.add_argument( + "--max-steps", + type=int, + default=10, + help="Maximum number of steps per episode.", + ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=32, + help="Maximum number of new tokens to request from vLLM for each action.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Sampling temperature used during rollout generation.", + ) + parser.add_argument( + "--top-k", + type=int, + default=50, + help="Top-k sampling parameter forwarded to vLLM.", + ) + parser.add_argument( + "--top-p", + type=float, + default=None, + help="Optional top-p sampling parameter forwarded to vLLM.", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=5e-6, + help="Learning rate for GRPO training.", + ) + parser.add_argument( + "--weight-decay", + type=float, + default=0.0, + help="Weight decay applied during optimization.", + ) + parser.add_argument( + "--gradient-accumulation-steps", + type=int, + default=32, + help="Gradient accumulation steps for GRPO training.", + ) + parser.add_argument( + "--warmup-steps", + type=int, + default=10, + help="Warmup steps for the scheduler.", + ) + parser.add_argument( + "--per-device-batch-size", + type=int, + default=1, + help="Per-device train batch size.", + ) + parser.add_argument( + "--num-generations", + type=int, + default=4, + help="Number of rollout generations per dataset prompt.", + ) + parser.add_argument( + "--num-epochs", + type=int, + default=1, + help="Number of training epochs.", + ) + parser.add_argument( + "--save-interval", + type=int, + default=50, + help="Interval (in steps) between checkpoint saves.", + ) + parser.add_argument( + "--save-total-limit", + type=int, + default=None, + help="Maximum number of checkpoints to keep.", + ) + parser.add_argument( + "--output-dir", + default=None, + help="Directory where training outputs and checkpoints are stored.", + ) + parser.add_argument( + "--run-name", + default=None, + help="Optional run name for logging systems.", + ) + parser.add_argument( + "--project", + default=None, + help="Optional project identifier for logging systems.", + ) + parser.add_argument( + "--vllm-mode", + choices=("colocate", "server"), + default="colocate", + help="vLLM execution mode: 'colocate' or 'server'.", + ) + parser.add_argument( + "--vllm-server-url", + type=str, + default="http://localhost:8001", + help="URL for the vLLM server (only used when --vllm-mode=server).", + ) + parser.add_argument( + "--logging-steps", + type=int, + default=1, + help="Frequency of logging steps for GRPO training.", + ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Enable verbose debugging output during rollouts.", + ) + return parser.parse_args() + + +def sanitize_name(name: str) -> str: + return name.replace("/", "-") + + +# --------------------------------------------------------------------------- +# System Prompt +# --------------------------------------------------------------------------- + +SYSTEM_PROMPT = """You control a web browser through BrowserGym actions. +You must complete the given web task by interacting with the page. + +Available actions: +- noop() - Do nothing +- click(bid) - Click element with BrowserGym ID (the number in brackets) +- fill(bid, text) - Fill input field with text +- send_keys(text) - Send keyboard input +- scroll(direction) - Scroll up/down + +The page structure shows elements as: [bid] element_type 'element_text' +For example: [13] button 'Click Me!' means bid='13' + +Reply with exactly ONE action on a single line, e.g.: +click('13') +fill('42', 'hello world') +noop() + +Do not include explanations or multiple actions.""" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_user_prompt(goal: str, step_num: int, axtree: str, error: str = "") -> str: + """Create user prompt from observation.""" + prompt_parts = [f"Step {step_num + 1}"] + + if goal: + prompt_parts.append(f"Goal: {goal}") + + if error: + prompt_parts.append(f"Previous action error: {error}") + + # Include accessibility tree (truncated for context) + if axtree: + max_len = 2000 + axtree_truncated = axtree[:max_len] + "..." if len(axtree) > max_len else axtree + prompt_parts.append(f"Page structure:\n{axtree_truncated}") + + prompt_parts.append("What action do you take?") + + return "\n\n".join(prompt_parts) + + +def parse_action(response_text: str) -> str: + """Parse BrowserGym action from model response.""" + # Extract first line that looks like an action + for line in response_text.strip().split("\n"): + line = line.strip() + if "(" in line and ")" in line: + return line + + # Fallback to noop if no valid action found + return "noop()" + + +def rollout_once( + trainer: GRPOTrainer, + env: BrowserGymEnv, + tokenizer: AutoTokenizer, + dataset_prompt: str, + max_steps: int, + debug: bool = False, +) -> dict[str, list]: + """Run one episode and collect training data (text-only, no screenshots).""" + result = env.reset() + observation = result.observation + + prompt_ids: list[int] = [] + completion_ids: list[int] = [] + logprobs: list[float] = [] + step_rewards: list[float] = [] + completion_rewards: list[float] = [] + + for step_num in range(max_steps): + if result.done: + break + + # Create prompt from observation (text-only using accessibility tree) + goal = observation.goal or dataset_prompt + axtree = observation.axtree_txt or "" + error = observation.error if observation.last_action_error else "" + + user_prompt = make_user_prompt(goal, step_num, axtree, error) + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + prompt_text = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=False, + ) + + # Generate action with vLLM + rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0] + prompt_ids.extend(rollout_outputs["prompt_ids"]) + completion_ids.extend(rollout_outputs["completion_ids"]) + logprobs.extend(rollout_outputs["logprobs"]) + + completion_text = rollout_outputs.get("text") or tokenizer.decode( + rollout_outputs["completion_ids"], skip_special_tokens=True + ) + + # Parse and execute action + action_str = parse_action(completion_text) + + if debug: + print(f"Step {step_num + 1}: {action_str}") + + # Take action in environment + result = env.step(BrowserGymAction(action_str=action_str)) + observation = result.observation + + # Track rewards + step_reward = float(result.reward or 0.0) + step_rewards.append(step_reward) + + # Reward shaping: success is most important + if result.done and step_reward > 0: + completion_rewards.append(1.0) # Task completed successfully + elif result.done and step_reward == 0: + completion_rewards.append(0.0) # Task failed + else: + completion_rewards.append(step_reward) # Intermediate reward + + # Final reward is based on task completion + final_reward = completion_rewards[-1] if completion_rewards else 0.0 + + return { + "prompt_ids": prompt_ids, + "completion_ids": completion_ids, + "logprobs": logprobs, + "step_rewards": step_rewards, + "completion_reward": final_reward, + } + + +# --------------------------------------------------------------------------- +# Rewards +# --------------------------------------------------------------------------- + + +def reward_completion(completions: list[str], **kwargs) -> list[float]: + """Reward for task completion.""" + rewards = kwargs.get("completion_reward") if kwargs else None + if rewards is None: + return [0.0 for _ in completions] + return [float(r) for r in rewards] + + +# --------------------------------------------------------------------------- +# Main entrypoint +# --------------------------------------------------------------------------- + + +def main() -> None: + args = parse_args() + + # Connect to BrowserGym environment via Hugging Face Space + client = BrowserGymEnv(base_url=args.space_url) + print(f"๐ŸŒ Using Hugging Face Space environment at: {args.space_url}") + + dataset = Dataset.from_dict({"prompt": [args.dataset_prompt] * args.dataset_size}) + + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + default_output_dir = Path("outputs") / f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}" + output_dir = Path(args.output_dir or default_output_dir) + + grpo_config = GRPOConfig( + use_vllm=True, + vllm_mode=args.vllm_mode, + vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None, + vllm_gpu_memory_utilization=0.4, + output_dir=str(output_dir), + num_train_epochs=args.num_epochs, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + gradient_accumulation_steps=args.gradient_accumulation_steps, + per_device_train_batch_size=args.per_device_batch_size, + warmup_steps=args.warmup_steps, + num_generations=args.num_generations, + generation_batch_size=args.num_generations, # Must be divisible by num_generations + max_completion_length=args.max_new_tokens, + logging_steps=args.logging_steps, + report_to="trackio", + trackio_space_id=f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}", + save_strategy="steps", + save_steps=args.save_interval, + save_total_limit=args.save_total_limit, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + ) + + grpo_config.run_name = args.run_name or f"run-{timestamp}" + grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}" + + def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: + episode_prompt_ids: list[list[int]] = [] + episode_completion_ids: list[list[int]] = [] + episode_logprobs: list[list[float]] = [] + completion_rewards: list[float] = [] + + if args.debug: + print(f"\n[DEBUG] rollout_func called with {len(prompts)} prompts (LLM mode, text-only)") + + for i, prompt_text in enumerate(prompts): + if args.debug: + print(f"[DEBUG] Processing prompt {i + 1}/{len(prompts)}") + episode = rollout_once( + trainer=trainer, + env=client, + tokenizer=trainer.processing_class, + dataset_prompt=prompt_text, + max_steps=args.max_steps, + debug=args.debug, + ) + episode_prompt_ids.append(episode["prompt_ids"]) + episode_completion_ids.append(episode["completion_ids"]) + episode_logprobs.append(episode["logprobs"]) + completion_rewards.append(episode["completion_reward"]) + + return { + "prompt_ids": episode_prompt_ids, + "completion_ids": episode_completion_ids, + "logprobs": episode_logprobs, + "completion_reward": completion_rewards, + } + + trainer = GRPOTrainer( + model=args.model_id, + reward_funcs=[reward_completion], + train_dataset=dataset, + args=grpo_config, + rollout_func=rollout_func, + ) + + print("=" * 80) + print("Starting GRPO training with BrowserGym environment (LLM mode)") + print(f"Benchmark: {args.benchmark}") + print(f"Task: {args.task_name}") + print(f"Model: {args.model_id}") + print("Mode: LLM (text-only, using accessibility tree)") + print(f"Using {args.num_generations} rollouts per dataset prompt") + print(f"Output directory: {output_dir}") + print("=" * 80) + + try: + trainer.train() + print("\nTraining completed successfully!") + finally: + client.close() + + +if __name__ == "__main__": + main() diff --git a/ICL/RL/trl_source/examples/scripts/openenv/echo.py b/ICL/RL/trl_source/examples/scripts/openenv/echo.py new file mode 100644 index 0000000000000000000000000000000000000000..f52a7a1850d8e7750e1f1d46c14793b3c898059b --- /dev/null +++ b/ICL/RL/trl_source/examples/scripts/openenv/echo.py @@ -0,0 +1,248 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl[vllm]", +# "peft", +# "trackio", +# "kernels", +# "openenv-echo-env @ git+https://huggingface.co/spaces/openenv/echo_env", +# ] +# /// + + +""" +Simple script to run GRPO training with OpenEnv's Echo environment and vLLM. The reward function encourages +longer completions. + +Setup (Option A - Install from HF Space, recommended): + +```sh +uv pip install git+https://huggingface.co/spaces/openenv/echo_env +``` + +Setup (Option B - Clone OpenEnv repo, for development): + +```sh +git clone https://github.com/meta-pytorch/OpenEnv.git +cd OpenEnv/envs/echo_env +uv pip install -e . +``` + +# Option 1: HF Spaces + Colocated vLLM (1 GPU required) +```sh +python examples/scripts/openenv/echo.py --env-mode space --env-host https://openenv-echo-env.hf.space --vllm-mode colocate +``` + +# Option 2: HF Spaces + Separate vLLM server (2 GPUs required) + +# Spin up vLLM server (Terminal 1) +```sh +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 +``` + +# Run training (Terminal 2) +```sh +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py --env-mode space --env-host https://openenv-echo-env.hf.space --vllm-mode server --vllm-server-url http://localhost:8000 +``` + +# Option 3: Local + Colocated vLLM (1 GPU required) + +# Start the environment only if using --env-mode docker-local +```sh +docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest +``` + +```sh +python examples/scripts/openenv/echo.py --env-mode docker-local --vllm-mode colocate +``` +""" + +# ruff: noqa: T201 +import argparse +import os +import subprocess +import sys +import time +from pathlib import Path + +import requests +from datasets import load_dataset +from echo_env import EchoEnv +from echo_env.models import EchoAction + +from trl import GRPOConfig, GRPOTrainer, RichProgressCallback +from trl.experimental.openenv import generate_rollout_completions + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run GRPO training with Echo environment and vLLM.") + + parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the Echo environment.") + parser.add_argument("--env-port", type=int, default=8001, help="Port for the Echo environment.") + parser.add_argument( + "--env-mode", + choices=["local", "docker-local", "docker-image", "docker-hub", "space"], + default="docker-image", + help="Where to run the Echo environment: 'local' to launch it, 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.", + ) + parser.add_argument( + "--model", + type=str, + default="Qwen/Qwen2.5-0.5B-Instruct", + help="Model to use for training.", + ) + parser.add_argument( + "--dataset", + type=str, + default="trl-lib/ultrafeedback-prompt", + help="Dataset to use for training.", + ) + parser.add_argument( + "--env-image", type=str, default="echo-env:latest", help="Docker image for the Echo environment." + ) + parser.add_argument( + "--vllm-mode", + choices=["colocate", "server"], + default="colocate", + help="vLLM execution mode: 'colocate' or 'server'.", + ) + parser.add_argument( + "--vllm-server-url", + type=str, + default="http://localhost:8000", + help="URL for the vLLM server (only used when --vllm-mode=server).", + ) + + return parser.parse_args() + + +def start_env_server(env_host: str, env_port: int): + """Launch the Echo environment server locally.""" + env_url = f"http://{env_host}:{env_port}" + print(f"โšก Starting FastAPI server for Echo Environment on {env_url}...") + + work_dir = str(Path.cwd().parent.absolute()) + process = subprocess.Popen( + [sys.executable, "-m", "uvicorn", "echo_env.server.app:app", "--host", env_host, "--port", str(env_port)], + env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=work_dir, + ) + + print("โณ Waiting for server to start...") + time.sleep(5) + + try: + requests.get(f"{env_url}/health", timeout=2) + print("\nโœ… Echo Environment server is running!") + except Exception as e: + print(f"\nโŒ Server failed to start: {e}") + if process.stderr: + print(process.stderr.read()) + raise + + return process + + +def reward_from_env(completions, **kwargs): + """Extract environment rewards for training.""" + env_rewards = kwargs.get("env_reward", []) + return [float(r) for r in env_rewards] if env_rewards else [0.0] * len(completions) + + +def main(): + args = parse_args() + + # Select environment mode + if args.env_mode == "local": + env_url = f"http://{args.env_host}:{args.env_port}" + server_process = start_env_server(args.env_host, args.env_port) + elif args.env_mode == "docker-local": + env_url = f"http://{args.env_host}:{args.env_port}" + server_process = None + print(f"๐ŸŒ Using existing Echo Environment (Docker) at: {env_url}") + elif args.env_mode == "docker-image": + client = EchoEnv.from_docker_image(args.env_image) + server_process = None + print("๐ŸŒ Using Echo Environment (Docker) from local Image") + elif args.env_mode == "docker-hub": + client = EchoEnv.from_hub(args.env_image) + server_process = None + print("๐ŸŒ Using existing Echo Environment (Docker) from Hub Image") + elif args.env_mode == "space": + env_url = args.env_host + server_process = None + print(f"๐ŸŒ Using Hugging Face Space environment at: {env_url}") + else: + raise ValueError(f"Unknown environment mode: {args.env_mode}") + + if args.env_mode != "docker-hub" and args.env_mode != "docker-image": + client = EchoEnv(base_url=env_url) + dataset = load_dataset(args.dataset, split="train[:1000]") + + training_args = GRPOConfig( + output_dir=f"{args.model.split('/')[-1]}-GRPO-Rollout", + use_vllm=True, + vllm_mode=args.vllm_mode, + vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None, + logging_steps=1, + report_to="trackio", + trackio_space_id=f"{args.model.split('/')[-1]}-GRPO-Rollout", + num_train_epochs=1, + max_completion_length=2048, + gradient_accumulation_steps=4, + ) + + def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: + outputs = generate_rollout_completions(trainer, prompts) + tokenizer = trainer.processing_class + + completions_text = [tokenizer.decode(output["completion_ids"], skip_special_tokens=True) for output in outputs] + + env_result = client.reset() + env_rewards: list[float] = [] + for message in completions_text: + env_result = client.step(EchoAction(message=message)) + env_rewards.append(env_result.reward) + + return { + "prompt_ids": [output["prompt_ids"] for output in outputs], + "completion_ids": [output["completion_ids"] for output in outputs], + "logprobs": [output["logprobs"] for output in outputs], + "env_reward": env_rewards, + } + + trainer = GRPOTrainer( + model=args.model, + reward_funcs=reward_from_env, + args=training_args, + train_dataset=dataset, + rollout_func=rollout_func, + callbacks=[RichProgressCallback()], + ) + + trainer.train() + time.sleep(5) + + if server_process: + print("๐Ÿ›‘ Terminating Echo Environment server...") + server_process.terminate() + + +if __name__ == "__main__": + main() diff --git a/ICL/RL/trl_source/examples/scripts/openenv/wordle.py b/ICL/RL/trl_source/examples/scripts/openenv/wordle.py new file mode 100644 index 0000000000000000000000000000000000000000..b6fbabc968948f81d4a7e367dac22a1a9b9a555b --- /dev/null +++ b/ICL/RL/trl_source/examples/scripts/openenv/wordle.py @@ -0,0 +1,607 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl[vllm]", +# "peft", +# "trackio", +# "kernels", +# "openenv-textarena @ git+https://huggingface.co/spaces/sergiopaniego/wordle", +# ] +# /// + + +""" +Simple script to run GRPO training with OpenEnv's Wordle environment and vLLM. + +Setup (Option A - Install from HF Space, recommended): + +```sh +uv pip install git+https://huggingface.co/spaces/sergiopaniego/wordle +``` + +# Option 1: HF Spaces + Colocated vLLM (1 GPU required) +```sh +python examples/scripts/openenv/wordle.py --vllm-mode colocate +``` + +# Option 2: HF Spaces + Separate vLLM server (2 GPUs required) + +# Spin up vLLM server (Terminal 1) +```sh +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000 +``` + +# Run training (Terminal 2) +```sh +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py --vllm-mode server --vllm-server-url http://localhost:8000 +``` + +# Option 3: Local Environment + Colocated vLLM (1 GPU required) + +To run the Wordle environment locally, you have several options: + +## Option 3a: Using Docker Image (Recommended) + +First, build the Docker image from the textarena_env directory: +```sh +cd 3rd_party/OpenEnv/envs/textarena_env +docker build -t textarena-env:latest -f server/Dockerfile . +``` + +Then run the environment server: +```sh +docker run -d -p 8001:8001 textarena-env:latest +``` + +Finally, run training pointing to local server: +```sh +python examples/scripts/openenv/wordle.py --vllm-mode colocate --env-url http://localhost:8001 +``` + +## Option 3b: Running Server Directly + +From the textarena_env directory: +```sh +cd 3rd_party/OpenEnv/envs/textarena_env +uv venv && source .venv/bin/activate +uv pip install -e . +python -m uvicorn server.app:app --reload --port 8001 +``` + +Then in another terminal, run training: +```sh +python examples/scripts/openenv/wordle.py --vllm-mode colocate --env-url http://localhost:8001 +``` + +## Option 3c: Using Pre-built HF Space Image + +```sh +docker run -d -p 8001:8001 registry.hf.space/burtenshaw-wordle:latest +python examples/scripts/openenv/wordle.py --vllm-mode colocate --env-url http://localhost:8001 +``` +""" + +import argparse +import re +import sys +from collections.abc import Iterable +from datetime import datetime +from pathlib import Path + +from datasets import Dataset +from transformers import AutoTokenizer + +from trl import GRPOConfig, GRPOTrainer +from trl.experimental.openenv import generate_rollout_completions + + +# Ensure src/ is on the path +sys.path.insert(0, str(Path(__file__).parent / "src")) + +from textarena_env import TextArenaAction, TextArenaEnv +from textarena_env.models import TextArenaMessage +from textarena_env.rewards import extract_feedback_counts, extract_guess, extract_wordle_feedback + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run GRPO training for Wordle using the TextArena OpenEnv environment." + ) + parser.add_argument( + "--tokenizer-id", + default="Qwen/Qwen3-1.7B", + help="Model identifier used to load the tokenizer.", + ) + parser.add_argument( + "--model-id", + default="Qwen/Qwen3-1.7B", + help="Model identifier passed to GRPOTrainer for fine-tuning.", + ) + parser.add_argument( + "--env-url", type=str, default="https://sergiopaniego-wordle.hf.space", help="URL for the environment server." + ) + parser.add_argument( + "--system-prompt-path", + default="wordle_prompt.txt", + help="Path to the file containing the system prompt.", + ) + parser.add_argument( + "--dataset-prompt", + default="Play Wordle like an expert.", + help="Prompt text used to seed the training dataset.", + ) + parser.add_argument( + "--dataset-size", + type=int, + default=3000, + help="Number of entries to include in the synthetic training dataset.", + ) + parser.add_argument( + "--max-turns", + type=int, + default=6, + help="Maximum number of turns to play in the Wordle environment per episode.", + ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=8, + help="Maximum number of new tokens to request from vLLM for each guess.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="Sampling temperature used during rollout generation.", + ) + parser.add_argument( + "--top-k", + type=int, + default=10, + help="Top-k sampling parameter forwarded to vLLM.", + ) + parser.add_argument( + "--top-p", + type=float, + default=None, + help="Optional top-p sampling parameter forwarded to vLLM.", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=1e-6, + help="Learning rate for GRPO training.", + ) + parser.add_argument( + "--weight-decay", + type=float, + default=0.0, + help="Weight decay applied during optimization.", + ) + parser.add_argument( + "--gradient-accumulation-steps", + type=int, + default=64, + help="Gradient accumulation steps for GRPO training.", + ) + parser.add_argument( + "--warmup-steps", + type=int, + default=10, + help="Warmup steps for the scheduler.", + ) + parser.add_argument( + "--per-device-batch-size", + type=int, + default=1, + help="Per-device train batch size.", + ) + parser.add_argument( + "--num-generations", + type=int, + default=4, + help="Number of rollout generations per dataset prompt.", + ) + parser.add_argument( + "--num-epochs", + type=int, + default=1, + help="Number of training epochs.", + ) + parser.add_argument( + "--save-interval", + type=int, + default=10, + help="Interval (in steps) between checkpoint saves.", + ) + parser.add_argument( + "--save-total-limit", + type=int, + default=None, + help="Maximum number of checkpoints to keep.", + ) + parser.add_argument( + "--output-dir", + default=None, + help="Directory where training outputs and checkpoints are stored.", + ) + parser.add_argument( + "--run-name", + default=None, + help="Optional run name for logging systems.", + ) + parser.add_argument( + "--project", + default=None, + help="Optional project identifier for logging systems.", + ) + parser.add_argument( + "--trackio-space-id", + default="Wordle-GRPO", + help="TrackIO space identifier.", + ) + parser.add_argument( + "--vllm-mode", + choices=("colocate", "server"), + default="colocate", + help="vLLM execution mode: 'colocate' or 'server'.", + ) + parser.add_argument( + "--vllm-server-url", + type=str, + default="http://localhost:8000", + help="URL for the vLLM server (only used when --vllm-mode=server).", + ) + parser.add_argument( + "--logging-steps", + type=int, + default=1, + help="Frequency of logging steps for GRPO training.", + ) + return parser.parse_args() + + +def resolve_system_prompt(path: str) -> str: + prompt_path = Path(path) + if not prompt_path.is_file(): + prompt_path = Path(__file__).parent / path + return prompt_path.read_text() + + +def sanitize_name(name: str) -> str: + return name.replace("/", "-") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def format_history(messages: Iterable[TextArenaMessage]) -> str: + lines: list[str] = [] + for message in messages: + tag = message.category or "MESSAGE" + content = message.content.strip() + if not content: + continue + lines.append(f"[{tag}] {content}") + return "\n".join(lines) + + +def make_user_prompt(prompt_text: str, messages: Iterable[TextArenaMessage]) -> str: + history = format_history(messages) + # Only use messages for conversation history - the prompt is already included as the first message + history_section = history if history else "[PROMPT] Awaiting first feedback." + return f"Conversation so far:\n{history_section}\n\nReply with your next guess enclosed in square brackets." + + +def rollout_once( + trainer: GRPOTrainer, + env: TextArenaEnv, + tokenizer: AutoTokenizer, + dataset_prompt: str, + system_prompt: str, + max_turns: int, + max_new_tokens: int = 16, +) -> dict[str, list]: + result = env.reset() + observation = result.observation + + prompt_ids: list[int] = [] + completion_ids: list[int] = [] + logprobs: list[float] = [] + env_mask: list[int] = [] # 1 for model-generated tokens, 0 for environment tokens + model_outputs: list[str] = [] + raw_rewards: list[float] = [] + position_scores: list[float] = [] + correct_scores: list[float] = [] + prev_env_output_len: int = 0 # Track length to only add NEW portion each turn + + accumulated_messages: list[dict[str, str]] = [{"role": "system", "content": system_prompt}] + # Build initial prompt (only once, at the start) + # The initial env messages are included in the prompt, not completion + base_prompt = observation.prompt or dataset_prompt + initial_user_prompt = make_user_prompt(base_prompt, observation.messages) + # Track initial env output length so we don't add it again + initial_env_output = format_history(observation.messages) if observation.messages else "" + prev_env_output_len = len(initial_env_output) + initial_messages = accumulated_messages + [{"role": "user", "content": initial_user_prompt}] + initial_prompt_text = tokenizer.apply_chat_template( + initial_messages, + add_generation_prompt=True, + tokenize=False, + enable_thinking=False, + ) + # Tokenize initial prompt once - this is the base prompt for the entire episode. + # GRPO expects one prompt-completion pair per episode, where: + # - prompt_ids = the initial/base prompt (what the model sees at episode start) + # - completion_ids = all model responses + env feedback from all turns concatenated + # Note: The actual prompts used for generation in each turn are longer (include conversation history), + # but we only count the initial prompt tokens here. + initial_prompt_ids = tokenizer.encode(initial_prompt_text, add_special_tokens=False) + prompt_ids.extend(initial_prompt_ids) + + for _turn in range(max_turns): + if result.done: + break + + base_prompt = observation.prompt or dataset_prompt + user_prompt = make_user_prompt(base_prompt, observation.messages) + messages = accumulated_messages + [{"role": "user", "content": user_prompt}] + prompt_text = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=False, + enable_thinking=False, + ) + + rollout_outputs = generate_rollout_completions( + trainer, [prompt_text], generation_overrides={"max_tokens": max_new_tokens} + )[0] + # Add model-generated completion tokens and logprobs with newlines for readability + newline_tokens = tokenizer.encode("\n", add_special_tokens=False) + completion_ids.extend(newline_tokens) # newline before guess + logprobs.extend([0.0] * len(newline_tokens)) + env_mask.extend([1] * len(newline_tokens)) # newlines are part of model output format + + completion_ids.extend(rollout_outputs["completion_ids"]) + logprobs.extend(rollout_outputs["logprobs"]) + env_mask.extend([1] * len(rollout_outputs["completion_ids"])) # model-generated tokens + + completion_ids.extend(newline_tokens) # newline after guess + logprobs.extend([0.0] * len(newline_tokens)) + env_mask.extend([1] * len(newline_tokens)) # newlines are part of model output format + completion_text = rollout_outputs.get("text") or tokenizer.decode( + rollout_outputs["completion_ids"], skip_special_tokens=True + ) + guess = extract_guess(completion_text) + model_outputs.append(completion_text.strip()) # Store raw model output for format reward + + result = env.step(TextArenaAction(message=guess)) + + raw_rewards.append(float(result.reward or 0.0)) + observation = result.observation + correct_score = float(result.reward or 0.0) + feedback = extract_wordle_feedback(observation) + + full_env_output = format_history(observation.messages) if observation.messages else "" + new_env_output = full_env_output[prev_env_output_len:].lstrip("\n") + prev_env_output_len = len(full_env_output) + + if new_env_output: + env_output_tokens = tokenizer.encode(new_env_output, add_special_tokens=False) + completion_ids.extend(env_output_tokens) # Add to completion_ids + logprobs.extend([0.0] * len(env_output_tokens)) # Placeholder (ignored via env_mask=0) + env_mask.extend([0] * len(env_output_tokens)) # Environment tokens - mask out from loss + completion_with_env = completion_text + "\n" + new_env_output + else: + completion_with_env = completion_text + + accumulated_messages.append({"role": "user", "content": user_prompt}) + accumulated_messages.append({"role": "assistant", "content": completion_with_env}) + + if not feedback: + position_score = 0.0 + else: + green_count, yellow_count = extract_feedback_counts(feedback) + position_score = (green_count + 0.5 * yellow_count) / 5.0 + + position_scores.append(position_score) + correct_scores.append(correct_score) + + # Use the final correct reward (win/lose is binary at end) + correct_reward_value = correct_scores[-1] if correct_scores else (raw_rewards[-1] if raw_rewards else 0.0) + + # Position reward as shaping signal: + # - If model WINS: position_reward = 1.0 (no penalty for winning fast) + # - If model LOSES: position_reward = last attempt (where it ended up) + if correct_reward_value >= 1.0: + final_position_reward = 1.0 + else: + final_position_reward = position_scores[-1] if position_scores else 0.0 + + return { + "prompt_ids": prompt_ids, + "completion_ids": completion_ids, + "logprobs": logprobs, + "env_mask": env_mask, + "raw_rewards": raw_rewards, + "correct_reward": correct_reward_value, + "position_reward": final_position_reward, + "model_outputs": model_outputs, + } + + +# --------------------------------------------------------------------------- +# Rewards +# --------------------------------------------------------------------------- + + +def reward_correct(completions: list[str], **kwargs) -> list[float]: + """Reward from environment (correct answer).""" + rewards = kwargs.get("correct_reward") if kwargs else None + if rewards is None: + return [0.0 for _ in completions] + return [float(r) for r in rewards] + + +def reward_position(completions: list[str], **kwargs) -> list[float]: + """Position reward: green worth 1.0, yellow worth 0.5, normalized by 5.""" + rewards = kwargs.get("position_reward") if kwargs else None + if rewards is None: + return [0.0 for _ in completions] + return [float(r) for r in rewards] + + +def compute_format_reward(model_outputs: list[str]) -> float: + """Compute format reward from a list of model outputs (one per turn). + + Each output should be exactly [5 letters] with optional whitespace. + Returns proportion of correctly formatted outputs. + """ + if not model_outputs: + return 0.0 + + exact_pattern = re.compile(r"^\s*\[[A-Za-z]{5}\]\s*$") + correct_count = sum(1 for output in model_outputs if exact_pattern.match(output)) + + return correct_count / len(model_outputs) + + +def reward_format_strict(completions: list[str], **kwargs) -> list[float]: + """Format reward - pre-computed in rollout_func.""" + rewards = kwargs.get("format_reward") if kwargs else None + if rewards is None: + return [0.0 for _ in completions] + return [float(r) for r in rewards] + + +# --------------------------------------------------------------------------- +# Main entrypoint +# --------------------------------------------------------------------------- + + +def main() -> None: + args = parse_args() + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id) + tokenizer.pad_token = tokenizer.eos_token + + client = TextArenaEnv(base_url=args.env_url) + + system_prompt = resolve_system_prompt(args.system_prompt_path) + + dataset = Dataset.from_dict({"prompt": [args.dataset_prompt] * args.dataset_size}) + + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + default_output_dir = Path("outputs") / f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}" + output_dir = Path(args.output_dir or default_output_dir) + + grpo_config = GRPOConfig( + use_vllm=True, + vllm_mode=args.vllm_mode, + vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None, + output_dir=str(output_dir), + num_train_epochs=args.num_epochs, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + gradient_accumulation_steps=args.gradient_accumulation_steps, + per_device_train_batch_size=args.per_device_batch_size, + warmup_steps=args.warmup_steps, + num_generations=args.num_generations, + max_completion_length=1024, # Full episode length, not per-turn + logging_steps=args.logging_steps, + log_completions=True, + report_to="trackio", + trackio_space_id=f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}", + save_strategy="steps", + save_steps=args.save_interval, + save_total_limit=args.save_total_limit, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + vllm_gpu_memory_utilization=0.25, + vllm_max_model_length=8192, + vllm_importance_sampling_mode="token_truncate", # Less aggressive than default sequence_mask + optim="adamw_torch", + max_grad_norm=1.0, # Clip gradients to prevent explosion + ) + + grpo_config.run_name = args.run_name or f"run-{timestamp}" + grpo_config.project = args.project or f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}" + grpo_config.trackio_space_id = args.trackio_space_id + + def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: + episode_prompt_ids: list[list[int]] = [] + episode_completion_ids: list[list[int]] = [] + episode_logprobs: list[list[float]] = [] + episode_env_masks: list[list[int]] = [] + correctness_rewards: list[float] = [] + position_rewards: list[float] = [] + format_rewards: list[float] = [] + + for prompt_text in prompts: + episode = rollout_once( + trainer=trainer, + env=client, + tokenizer=tokenizer, + dataset_prompt=prompt_text, + system_prompt=system_prompt, + max_turns=args.max_turns, + max_new_tokens=args.max_new_tokens, + ) + episode_prompt_ids.append(episode["prompt_ids"]) + episode_completion_ids.append(episode["completion_ids"]) + episode_logprobs.append(episode["logprobs"]) + episode_env_masks.append(episode["env_mask"]) + correctness_rewards.append(episode["correct_reward"]) + position_rewards.append(episode["position_reward"]) + format_rewards.append(compute_format_reward(episode["model_outputs"])) + + return { + "prompt_ids": episode_prompt_ids, + "completion_ids": episode_completion_ids, + "logprobs": episode_logprobs, + "env_mask": episode_env_masks, + "correct_reward": correctness_rewards, + "position_reward": position_rewards, + "format_reward": format_rewards, + } + + trainer = GRPOTrainer( + model=args.model_id, + processing_class=tokenizer, + reward_funcs=[ + reward_correct, + reward_position, + reward_format_strict, + ], + train_dataset=dataset, + args=grpo_config, + rollout_func=rollout_func, + ) + + print("Starting GRPO training with Wordle environment...") + print(f"Using {args.num_generations} rollouts per dataset prompt") + + try: + trainer.train() + finally: + client.close() + + +if __name__ == "__main__": + main() diff --git a/ICL/RL/trl_source/examples/scripts/openenv/wordle_prompt.txt b/ICL/RL/trl_source/examples/scripts/openenv/wordle_prompt.txt new file mode 100644 index 0000000000000000000000000000000000000000..af8001c77f954750a8be679662c4bcea44ff7e90 --- /dev/null +++ b/ICL/RL/trl_source/examples/scripts/openenv/wordle_prompt.txt @@ -0,0 +1,105 @@ +You are an expert Wordle solver with deep knowledge of English vocabulary, letter frequency patterns, and optimal guessing strategies. + +## GAME RULES + +1. The target is a 5-letter English word +2. You have 6 attempts to guess the correct word +3. After each guess, you receive color-coded feedback: + - GREEN: Letter is correct and in the correct position + - YELLOW: Letter is in the word but in the wrong position + - GRAY: Letter is not in the word at all +4. All guesses must be valid 5-letter English words +5. You cannot reuse a word you've already guessed + +## RESPONSE FORMAT + +Only respond with your next guess in square brackets, e.g., [crane]. + +Format: +``` +[guess] +``` + + +## STRATEGIC APPROACH + +Do not repeat the same guess twice. + +### Opening Strategy +- Start with words rich in common vowels (A, E, I, O, U) and consonants (R, S, T, L, N) +- Optimal starters: CRANE, SLATE, STARE, AROSE, IRATE +- Prioritize words that test the most common letters in different positions + +### Mid-Game Strategy +- Use confirmed GREEN letters in their correct positions +- Place YELLOW letters in different positions than where they appeared +- Eliminate GRAY letters entirely from consideration +- If multiple letters are unknown, prioritize common letter combinations (TH, CH, ST, ER, etc.) +- Consider letter frequency: E is most common, followed by A, R, I, O, T, N, S + +### Vowel Placement +- Most 5-letter words have 2 vowels +- Common patterns: vowel-consonant-vowel (like CRANE) or consonant-vowel-vowel-consonant-vowel (like QUEUE) +- If you have 1-2 vowels confirmed, consider where the others might be + +### Advanced Tactics +- Use "sacrificial" guesses to test multiple new letters if you have attempts to spare +- Avoid repeating letter patterns unless you're certain (e.g., SPEED has two E's) +- Think about word endings: -ER, -LY, -ED, -ING are common but may not fit the 5-letter constraint +- Consider less common letters (Q, X, Z, J) only when you've eliminated most common options + +### Common Pitfalls to Avoid +- Don't reuse X letters +- Don't place Y letters in the same position they appeared +- Don't ignore confirmed G letters +- Don't guess words that contradict known information + +## EXAMPLES + +### Example 1: Opening Guess +"Starting with a word that tests common vowels and consonants in varied positions." +[crane] + +### Example 2: After Receiving Feedback +Previous guess: CRANE +Feedback: C=gray, R=yellow, A=green, N=gray, E=yellow + +"A is confirmed in position 2. R and E are in the word but need different positions. C and N are eliminated. I'll try a word with A in position 2, and test R and E in new positions along with common letters like S and T." +[spare] + +### Example 3: Narrowing Down +Previous guesses: CRANE (C=gray, R=yellow, A=green, N=gray, E=yellow), SPARE (S=gray, P=gray, A=green, R=green, E=green) +Feedback summary: _ARE_ with R in position 4, A in position 2, E in position 5 + +"I have _AR E_ confirmed. Position 1 and 3 are unknown. Common letters to try: T, L, D, B, F, G. Testing with TARED." +[tared] + +### Example 4: Final Deduction +Previous feedback shows: _ARED with position 1 unknown and all common consonants tested + +"Only position 1 remains. I've eliminated S, P, C, N. Common starting consonants left are B, F, G, H. BARED is a common word." +[bared] + +## LETTER FREQUENCY REFERENCE + +Most common letters in 5-letter words (in order): +S, E, A, O, R, I, L, T, N, U, D, Y, C, P, M, H, G, B, K, F + +Most common starting letters: +S, C, B, T, P, A, F, G, D, M + +Most common ending letters: +E, Y, T, S, R, L, N, D + +## IMPORTANT CONSTRAINTS + +- Use lowercase only +- One guess per response +- Must be exactly 5 letters +- Must be a real English word from standard dictionaries +- Never repeat a previous guess +- Always include brief reasoning before your guess + +## YOUR GOAL + +Solve the Wordle in as few guesses as possible by strategically using feedback to eliminate impossible words and narrow down the solution space efficiently. \ No newline at end of file diff --git a/ICL/RL/trl_source/examples/scripts/ppo/ppo.py b/ICL/RL/trl_source/examples/scripts/ppo/ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..d60b5688b319f294b198636cedfc73cd10fc3dc3 --- /dev/null +++ b/ICL/RL/trl_source/examples/scripts/ppo/ppo.py @@ -0,0 +1,180 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +import os +import shutil + +import torch +from accelerate import PartialState +from datasets import load_dataset +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + HfArgumentParser, +) + +from trl import ModelConfig, ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config +from trl.experimental.ppo import PPOConfig, PPOTrainer + + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +""" +python -i examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ + --learning_rate 3e-6 \ + --output_dir pythia-1b-deduped-descriptiveness-sentiment-trl-style-ppo \ + --per_device_train_batch_size 64 \ + --gradient_accumulation_steps 1 \ + --total_episodes 10000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --missing_eos_penalty 1.0 + +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ + --output_dir pythia-1b-deduped-descriptiveness-sentiment-trl-style-ppo \ + --num_ppo_epochs 1 \ + --num_mini_batches 1 \ + --learning_rate 3e-6 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --total_episodes 10000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path EleutherAI/pythia-1b-deduped \ + --reward_model_path EleutherAI/pythia-1b-deduped \ + --local_rollout_forward_batch_size 1 \ + --missing_eos_penalty 1.0 +""" + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_into_dataclasses() + # remove output_dir if exists + shutil.rmtree(training_args.output_dir, ignore_errors=True) + + ################ + # Model & Tokenizer + ################ + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code + ) + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + value_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, + trust_remote_code=model_args.trust_remote_code, + num_labels=1, + **model_kwargs, + ) + reward_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, + trust_remote_code=model_args.trust_remote_code, + num_labels=1, + **model_kwargs, + ) + policy = AutoModelForCausalLM.from_pretrained( + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + + peft_config = get_peft_config(model_args) + if peft_config is None: + ref_policy = AutoModelForCausalLM.from_pretrained( + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + else: + ref_policy = None + + ################ + # Dataset + ################ + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_train_split + ) + eval_samples = 100 + train_dataset = dataset.select(range(len(dataset) - eval_samples)) + eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) + dataset_text_field = "prompt" + + def prepare_dataset(dataset, tokenizer): + """pre-tokenize the dataset before training; only collate during training""" + + def tokenize(element): + outputs = tokenizer( + element[dataset_text_field], + padding=False, + ) + return {"input_ids": outputs["input_ids"]} + + return dataset.map( + tokenize, + batched=True, + remove_columns=dataset.column_names, + num_proc=training_args.dataset_num_proc, + ) + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + train_dataset = prepare_dataset(train_dataset, tokenizer) + eval_dataset = prepare_dataset(eval_dataset, tokenizer) + + ################ + # Training + ################ + trainer = PPOTrainer( + args=training_args, + processing_class=tokenizer, + model=policy, + ref_model=ref_policy, + reward_model=reward_model, + value_model=value_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + ) + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + trainer.generate_completions() diff --git a/ICL/RL/trl_source/examples/scripts/reward_modeling.py b/ICL/RL/trl_source/examples/scripts/reward_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..4b860aff9db13be7806e4accba7e146867d410e6 --- /dev/null +++ b/ICL/RL/trl_source/examples/scripts/reward_modeling.py @@ -0,0 +1,136 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "trackio", +# "kernels", +# ] +# /// + +""" +Full training: +python examples/scripts/reward_modeling.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --output_dir Qwen2-0.5B-Reward \ + --per_device_train_batch_size 8 \ + --num_train_epochs 1 \ + --learning_rate 1.0e-5 \ + --eval_strategy steps \ + --eval_steps 50 \ + --max_length 2048 + +LoRA: +python examples/scripts/reward_modeling.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --output_dir Qwen2-0.5B-Reward-LoRA \ + --per_device_train_batch_size 8 \ + --num_train_epochs 1 \ + --learning_rate 1.0e-4 \ + --eval_strategy steps \ + --eval_steps 50 \ + --max_length 2048 \ + --use_peft \ + --lora_task_type SEQ_CLS \ + --lora_r 32 \ + --lora_alpha 16 +""" + +import os + +import torch +from accelerate import logging +from datasets import load_dataset +from transformers import AutoModelForSequenceClassification, HfArgumentParser + +from trl import ( + ModelConfig, + RewardConfig, + RewardTrainer, + ScriptArguments, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +logger = logging.get_logger(__name__) + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, RewardConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_into_dataclasses() + + ################ + # Model & Tokenizer + ################ + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + use_cache=False if training_args.gradient_checkpointing else True, + dtype=dtype, + ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + model = AutoModelForSequenceClassification.from_pretrained( + model_args.model_name_or_path, num_labels=1, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + + if model_args.use_peft and model_args.lora_task_type != "SEQ_CLS": + logger.warning( + "You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs" + " Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT.", + ) + + ############## + # Load dataset + ############## + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + ########## + # Training + ########## + trainer = RewardTrainer( + model=model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + peft_config=get_peft_config(model_args), + ) + trainer.train() + + ############################ + # Save model and push to Hub + ############################ + trainer.save_model(training_args.output_dir) + + if training_args.eval_strategy != "no": + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/ICL/RL/trl_source/examples/scripts/sft_vlm_gemma3.py b/ICL/RL/trl_source/examples/scripts/sft_vlm_gemma3.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7467ee4c5bed679df5a8495fd30dfe566e5fae --- /dev/null +++ b/ICL/RL/trl_source/examples/scripts/sft_vlm_gemma3.py @@ -0,0 +1,194 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "Pillow>=9.4.0", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +""" +Train Gemma 3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image). + +accelerate launch \ + --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/sft_vlm_gemma3.py \ + --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \ + --model_name_or_path google/gemma-3-4b-it \ + --per_device_train_batch_size 1 \ + --output_dir Gemma-3-4B-SFT-MMIU \ + --dtype bfloat16 \ + --use_peft \ + --lora_target_modules all-linear \ + --attn_implementation eager + +Train Gemma 3 on the FanqingM/MMIU-Benchmark dataset (multi-image). + +accelerate launch \ + --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/sft_vlm_gemma3.py \ + --dataset_name FanqingM/MMIU-Benchmark \ + --dataset_train_split test \ + --model_name_or_path google/gemma-3-4b-it \ + --per_device_train_batch_size 1 \ + --output_dir Gemma-3-4B-SFT-MMIU \ + --dtype bfloat16 \ + --use_peft \ + --lora_target_modules all-linear \ + --attn_implementation eager +""" + +import io +import os +import zipfile + +import torch +from datasets import DatasetDict, load_dataset +from huggingface_hub import hf_hub_download, list_repo_files +from PIL import Image +from transformers import AutoModelForImageTextToText + +from trl import ( + ModelConfig, + ScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +# For multi-image example +def process_vision_info(messages: list[dict]) -> list[Image.Image]: + image_inputs = [] + for msg in messages: + content = msg.get("content", []) + if not isinstance(content, list): + content = [content] + + for element in content: + if isinstance(element, dict) and ("image" in element or element.get("type") == "image"): + if "image" in element: + image = element["image"] + else: + image = element + if image is not None: + image = Image.open(io.BytesIO(image["bytes"])) + image_inputs.append(image.convert("RGB")) + return image_inputs + + +def format_data(samples: dict[str, any]) -> dict[str, list]: + formatted_samples = {"messages": []} + for cont in range(len(samples["question"])): + images = [] + for img_path in samples["input_image_path"][cont]: + try: + with open(img_path, "rb") as f: + img_bytes = f.read() + image = Image.open(io.BytesIO(img_bytes)).convert("RGB") + images.append({"type": "image", "image": image}) + except Exception as e: + print(f"Error processing image {img_path}: {e}") + continue + + formatted_samples["messages"].append( + [ + {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]}, + {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]}, + {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]}, + ] + ) + return formatted_samples + + +# For multi-image example +def prepare_dataset(dataset: DatasetDict, dataset_name: str) -> DatasetDict: + all_files = list_repo_files(dataset_name, repo_type="dataset") + zip_files = [f for f in all_files if f.endswith(".zip")] + + for zip_filename in zip_files: + zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset") + extract_folder = zip_filename.replace(".zip", "") + os.makedirs(extract_folder, exist_ok=True) + + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(extract_folder) + + dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16) + return dataset + + +def main(): + parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + training_args.max_length = None + + ################ + # Model + ################ + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + model = AutoModelForImageTextToText.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + if script_args.dataset_name == "FanqingM/MMIU-Benchmark": + dataset = prepare_dataset(dataset, script_args.dataset_name) + + ################ + # Training + ################ + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + peft_config=get_peft_config(model_args), + ) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +if __name__ == "__main__": + main() diff --git a/ICL/RL/trl_source/trl/__pycache__/__init__.cpython-313.pyc b/ICL/RL/trl_source/trl/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..624dbb8197f15436676f0615d7d6ff1bad9ddea2 Binary files /dev/null and b/ICL/RL/trl_source/trl/__pycache__/__init__.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/__pycache__/_compat.cpython-313.pyc b/ICL/RL/trl_source/trl/__pycache__/_compat.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..842deab5e702a73aacb6beff1aecf98b2b79d365 Binary files /dev/null and b/ICL/RL/trl_source/trl/__pycache__/_compat.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/__pycache__/chat_template_utils.cpython-313.pyc b/ICL/RL/trl_source/trl/__pycache__/chat_template_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9563bf2163db8ac93f63c58f1e2d452d01d90e0d Binary files /dev/null and b/ICL/RL/trl_source/trl/__pycache__/chat_template_utils.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/__pycache__/data_utils.cpython-313.pyc b/ICL/RL/trl_source/trl/__pycache__/data_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cca5bdb96c690260b450b71375599b3cd5ca680f Binary files /dev/null and b/ICL/RL/trl_source/trl/__pycache__/data_utils.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/__pycache__/import_utils.cpython-313.pyc b/ICL/RL/trl_source/trl/__pycache__/import_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93e5ddb5076b12866224a1ee05010711f8658715 Binary files /dev/null and b/ICL/RL/trl_source/trl/__pycache__/import_utils.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/accelerate_configs/fsdp1.yaml b/ICL/RL/trl_source/trl/accelerate_configs/fsdp1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c01b0b567bc93bf87ec136ea975b3793d273a45c --- /dev/null +++ b/ICL/RL/trl_source/trl/accelerate_configs/fsdp1.yaml @@ -0,0 +1,28 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: true + fsdp_offload_params: false + fsdp_reshard_after_forward: FULL_SHARD + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: true + fsdp_version: 1 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/ICL/RL/trl_source/trl/accelerate_configs/fsdp2.yaml b/ICL/RL/trl_source/trl/accelerate_configs/fsdp2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af498f3eced9c2434b80113f2f22d40395e0ab8a --- /dev/null +++ b/ICL/RL/trl_source/trl/accelerate_configs/fsdp2.yaml @@ -0,0 +1,25 @@ +# Requires accelerate 1.7.0 or higher +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_version: 2 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/ICL/RL/trl_source/trl/accelerate_configs/multi_gpu.yaml b/ICL/RL/trl_source/trl/accelerate_configs/multi_gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15dad9be3ba44f7c934e1ecab98a93cb83cbc79a --- /dev/null +++ b/ICL/RL/trl_source/trl/accelerate_configs/multi_gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/ICL/RL/trl_source/trl/accelerate_configs/single_gpu.yaml b/ICL/RL/trl_source/trl/accelerate_configs/single_gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ebd00a067118e56f3d63ab0f24827cfea21b24b9 --- /dev/null +++ b/ICL/RL/trl_source/trl/accelerate_configs/single_gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: "NO" +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/ICL/RL/trl_source/trl/accelerate_configs/zero1.yaml b/ICL/RL/trl_source/trl/accelerate_configs/zero1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5b5f782fb30f9fcbcc8fc58262f09eaf2e10368 --- /dev/null +++ b/ICL/RL/trl_source/trl/accelerate_configs/zero1.yaml @@ -0,0 +1,20 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + zero3_init_flag: false + zero_stage: 1 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/ICL/RL/trl_source/trl/accelerate_configs/zero2.yaml b/ICL/RL/trl_source/trl/accelerate_configs/zero2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..239b14ac3a9ae8de73122d1154bf0d71903dc15f --- /dev/null +++ b/ICL/RL/trl_source/trl/accelerate_configs/zero2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/ICL/RL/trl_source/trl/accelerate_configs/zero3.yaml b/ICL/RL/trl_source/trl/accelerate_configs/zero3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5a1201f8a2ee8706b63f0f80c664a1fc61a7d9d --- /dev/null +++ b/ICL/RL/trl_source/trl/accelerate_configs/zero3.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/ICL/RL/trl_source/trl/experimental/__init__.py b/ICL/RL/trl_source/trl/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..041a38542a6fbaba81fbdb3d278e737109c95f83 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Experimental submodule for TRL. + +This submodule contains unstable or incubating features. Anything here may change (or be removed) in any release +without deprecation. Use at your own risk. + +To silence this notice set environment variable TRL_EXPERIMENTAL_SILENCE=1. +""" + +import os +import warnings + +from ..import_utils import TRLExperimentalWarning + + +if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): + warnings.warn( + "You are importing from 'trl.experimental'. APIs here are unstable and may change or be removed without " + "notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.", + TRLExperimentalWarning, + stacklevel=2, + ) diff --git a/ICL/RL/trl_source/trl/experimental/bco/__init__.py b/ICL/RL/trl_source/trl/experimental/bco/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a6871ca60e002b21ac74a69e4113ded3391f622 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/bco/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .bco_config import BCOConfig +from .bco_trainer import BCOTrainer diff --git a/ICL/RL/trl_source/trl/experimental/bema_for_ref_model/__init__.py b/ICL/RL/trl_source/trl/experimental/bema_for_ref_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8d7de0bee7977851b77a803f297a91ed6d170e27 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/bema_for_ref_model/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .callback import BEMACallback +from .dpo_trainer import DPOTrainer diff --git a/ICL/RL/trl_source/trl/experimental/bema_for_ref_model/dpo_trainer.py b/ICL/RL/trl_source/trl/experimental/bema_for_ref_model/dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8c50a232dc38c236dc64d79d32714b799ceeea8d --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/bema_for_ref_model/dpo_trainer.py @@ -0,0 +1,30 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...trainer.dpo_trainer import DPOTrainer as _DPOTrainer +from .callback import CallbackHandlerWithRefModel + + +class DPOTrainer(_DPOTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Replace with a new one that calls the events with the reference model + self.callback_handler = CallbackHandlerWithRefModel( + self.callback_handler.callbacks, + self.model, + self.ref_model, + self.processing_class, + self.optimizer, + self.lr_scheduler, + ) diff --git a/ICL/RL/trl_source/trl/experimental/cpo/__init__.py b/ICL/RL/trl_source/trl/experimental/cpo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57e151f53f63b9af52f3b7ba9355e1b3e223fc93 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/cpo/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .cpo_config import CPOConfig +from .cpo_trainer import CPOTrainer + + +__all__ = ["CPOConfig", "CPOTrainer"] diff --git a/ICL/RL/trl_source/trl/experimental/cpo/cpo_config.py b/ICL/RL/trl_source/trl/experimental/cpo/cpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..dad89dc2556f52d8bad266e83d043a954079ebdb --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/cpo/cpo_config.py @@ -0,0 +1,207 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from transformers import TrainingArguments + + +@dataclass +class CPOConfig(TrainingArguments): + r""" + Configuration class for the [`experimental.cpo.CPOTrainer`]. + + This class includes only the parameters that are specific to CPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want + to use the default data collator. + max_completion_length (`int`, *optional*): + Maximum length of the completion. This argument is required if you want to use the default data collator + and your model is an encoder-decoder. + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher ฮฒ means less deviation from the + reference model. For the IPO loss (`loss_type="ipo"`), ฮฒ is the regularization parameter denoted by ฯ„ in + the [paper](https://huggingface.co/papers/2310.12036). + label_smoothing (`float`, *optional*, defaults to `0.0`): + Label smoothing factor. This argument is required if you want to use the default data collator. + loss_type (`str`, *optional*, defaults to `"sigmoid"`): + Type of loss to use. Possible values are: + + - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. + - `"hinge"`: hinge loss on the normalized likelihood from the + [SLiC](https://huggingface.co/papers/2305.10425) paper. + - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. + - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper. + - `"alphapo"`: AlphaPO loss from the [AlphaPO](https://huggingface.co/papers/2501.03884) paper. This + automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. + + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + cpo_alpha (`float`, *optional*, defaults to `1.0`): + Weight of the BC regularizer in CPO training. + simpo_gamma (`float`, *optional*, defaults to `0.5`): + Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`. + alpha (`float`, *optional*, defaults to `0.0`): + Alpha parameter that controls reward function shape across all loss types. When alpha=0 (default), uses + standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: `r = (1 - p^(-alpha)) + / alpha` from the [AlphaPO paper](https://huggingface.co/papers/2501.03884). This parameter works with all + loss types. + truncation_mode (`str`,*optional*, defaults to `"keep_end"`): + Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. + This argument is required if you want to use the default data collator. + generate_during_eval (`bool`, *optional*, defaults to `False`): + If `True`, generates and logs completions from the model to W&B or Comet during evaluation. + is_encoder_decoder (`bool`, *optional*): + When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, + you need to specify if the model returned by the callable is an encoder-decoder model. + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a + string. + dataset_num_proc (`int`, *optional*): + Number of processes to use for processing the dataset. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue + # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary + # workaround here, which can be removed once we drop support for versions older than 4.57.5. + lr_scheduler_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " + "restarts." + }, + ) + + max_length: int | None = field( + default=1024, + metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, + ) + max_completion_length: int | None = field( + default=None, + metadata={ + "help": "Maximum length of the completion. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the deviation from the reference model. Higher ฮฒ means less deviation from " + "the reference model." + }, + ) + label_smoothing: float = field( + default=0.0, + metadata={"help": "Label smoothing factor."}, + ) + loss_type: str = field( + default="sigmoid", + metadata={ + "help": "Type of loss to use.", + "choices": ["sigmoid", "hinge", "ipo", "simpo", "alphapo"], + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) + cpo_alpha: float = field( + default=1.0, + metadata={"help": "Weight of the BC regularizer in CPO training."}, + ) + simpo_gamma: float = field( + default=0.5, + metadata={"help": "Target reward margin for the SimPO loss, used only when the `loss_type='simpo'`."}, + ) + alpha: float = field( + default=0.0, + metadata={ + "help": "Alpha parameter that controls reward function shape across all loss types. When alpha=0 " + "(default), uses standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: " + "`r = (1 - p^(-alpha)) / alpha` from the AlphaPO paper. This parameter works with all loss types." + }, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the prompt is too long.", + "choices": ["keep_end", "keep_start"], + }, + ) + generate_during_eval: bool = field( + default=False, + metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."}, + ) + is_encoder_decoder: bool | None = field( + default=None, + metadata={"help": "Whether the model is an encoder-decoder model."}, + ) + model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " + "from a string." + }, + ) + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + # Syntactic sugar for AlphaPO: set loss_type to "simpo" and cpo_alpha to 0.0 + if self.loss_type == "alphapo": + self.loss_type = "simpo" + self.cpo_alpha = 0.0 + + super().__post_init__() diff --git a/ICL/RL/trl_source/trl/experimental/cpo/cpo_trainer.py b/ICL/RL/trl_source/trl/experimental/cpo/cpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8cae8405c611d3034781bb5f951c01378e9bc632 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/cpo/cpo_trainer.py @@ -0,0 +1,1057 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import random +import textwrap +from collections import defaultdict +from collections.abc import Callable +from contextlib import nullcontext +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformers +from accelerate import PartialState, logging +from datasets import Dataset +from packaging.version import Version +from torch import autocast +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, + is_comet_available, + is_wandb_available, +) +from transformers.trainer_utils import EvalLoopOutput +from transformers.utils import is_peft_available, is_torch_fx_proxy + +from ...data_utils import maybe_apply_chat_template, maybe_extract_prompt +from ...models.utils import peft_module_casting_to_bf16 +from ...trainer.base_trainer import BaseTrainer +from ...trainer.utils import ( + disable_dropout_in_model, + log_table_to_comet_experiment, + pad_to_length, + selective_log_softmax, +) +from ..utils import DPODataCollatorWithPadding, add_bos_token_if_needed, add_eos_token_if_needed +from .cpo_config import CPOConfig + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +if is_wandb_available(): + import wandb + + +logger = logging.get_logger(__name__) + + +class CPOTrainer(BaseTrainer): + r""" + Initialize CPOTrainer. + + Args: + model ([`~transformers.PreTrainedModel`]): + The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`]. + args ([`experimental.cpo.CPOConfig`]): + The CPO config arguments to use for training. + data_collator ([`~transformers.DataCollator`]): + The data collator to use for training. If None is specified, the default data collator + ([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the + maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset ([`~datasets.Dataset`]): + The dataset to use for training. + eval_dataset ([`~datasets.Dataset`]): + The dataset to use for evaluation. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be + used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in + a PEFT model. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to + metric values. + """ + + _tag_names = ["trl", "cpo"] + _name = "CPO" + _paper = { + "title": "Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation", + "id": "2401.08417", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @inproceedings{xu2024contrastive, + title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}}, + author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim}, + year = 2024, + booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=51iwkioZpn} + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module | str | None = None, + args: CPOConfig | None = None, + data_collator: DataCollator | None = None, + train_dataset: Dataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + model_init: Callable[[], PreTrainedModel] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: dict | None = None, + compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, + ): + if args.model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.") + else: + model_init_kwargs = args.model_init_kwargs + dtype = model_init_kwargs.get("dtype", "auto") + if dtype is not None: + # Convert to `torch.dtype` if an str is passed + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(torch, dtype) + if dtype != "auto" and not isinstance(dtype, torch.dtype): + raise ValueError( + f"Invalid `dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." + ) + model_init_kwargs["dtype"] = dtype + model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto") + + if isinstance(model, str): + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + if isinstance(model, PeftModel): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " + "merge and unload the existing adapter, save the resulting base model, and then pass that base " + "model along with the new `peft_config` to the trainer." + ) + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True + + # For models that use gradient_checkpointing, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif args.is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = args.is_encoder_decoder + + if self.is_encoder_decoder: + self.decoder_start_token_id = model.config.decoder_start_token_id + self.pad_token_id = model.config.pad_token_id + + if processing_class is None: + raise ValueError("processing_class must be specified to tokenize a CPO dataset.") + if args.max_length is None: + logger.warning( + "`max_length` is not set in the CPOConfig's init" + " it will default to `512` by default, but you should do it yourself in the future.", + ) + max_length = 512 + else: + max_length = args.max_length + + if args.max_completion_length is None and self.is_encoder_decoder: + logger.warning( + "When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init" + " it will default to `128` by default, but you should do it yourself in the future.", + ) + max_completion_length = 128 + else: + max_completion_length = args.max_completion_length + + if data_collator is None: + data_collator = DPODataCollatorWithPadding( + pad_token_id=processing_class.pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + logger.warning( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" + " we have set it for you, but you should do it yourself in the future.", + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(model) + + self.max_length = max_length + self.generate_during_eval = args.generate_during_eval + self.truncation_mode = args.truncation_mode + self.max_completion_length = max_completion_length + self.processing_class = processing_class + + if processing_class.pad_token is None: + processing_class.pad_token = processing_class.eos_token + self.pad_token_id = processing_class.pad_token_id + + if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0: + logger.warning( + f"You are using the {args.loss_type} loss type that does not support label smoothing. The " + "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.", + ) + if args.loss_type == "kto_pair": + raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.") + + self.beta = args.beta + self.label_smoothing = args.label_smoothing + self.loss_type = args.loss_type + self.cpo_alpha = args.cpo_alpha + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) + if self.aux_loss_enabled and self.aux_loss_coef == 0.0: + logger.warning( + "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " + "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " + "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " + "loss.", + ) + + if args.loss_type == "simpo": + self.simpo_gamma = args.simpo_gamma + + # AlphaPO parameter for reward shaping + self.alpha = args.alpha + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().main_process_first(): + # Extract the prompt if needed, and apply the chat template if needed + train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) + train_dataset = train_dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc + ) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) + eval_dataset = eval_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + ) + + # tokenize the dataset + train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + def build_tokenized_answer(self, prompt, answer): + """ + Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a + + b)[len(enc(a)):]`. Reference: + https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + """ + + full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False) + prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"] + + answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] + answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] + + # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` + full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) + + # Prepare input tokens for token by token comparison + full_input_ids = np.array(full_tokenized["input_ids"]) + + if len(full_input_ids) != len(full_concat_input_ids): + raise ValueError("Prompt input ids and answer input ids should have the same length.") + + # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens + # can be merged together when tokenizing prompt+answer. This could result + # on the last token from the prompt being different when tokenized on its own + # vs when done as prompt+answer. + response_token_ids_start_idx = len(prompt_input_ids) + + # If tokenized prompt is different than both prompt+answer, then it means the + # last token has changed due to merging. + if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: + response_token_ids_start_idx -= 1 + + prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] + prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] + + if len(prompt_input_ids) != len(prompt_attention_mask): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] + answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] + + return dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + input_ids=answer_input_ids, + attention_mask=answer_attention_mask, + ) + + def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None) -> dict: + """Tokenize a single row from a CPO specific dataset. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + + chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long, + we truncate the chosen/rejected. + + We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length + of the prompt and the chosen/rejected response, with `-100` for the prompt tokens. + """ + batch = {} + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + if not self.is_encoder_decoder: + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + prompt_tokens = self.processing_class(prompt, add_special_tokens=False) + prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} + + if not isinstance(chosen, str): + raise ValueError(f"chosen should be an str but got {type(chosen)}") + chosen_tokens = self.build_tokenized_answer(prompt, chosen) + + if not isinstance(rejected, str): + raise ValueError(f"rejected should be an str but got {type(rejected)}") + rejected_tokens = self.build_tokenized_answer(prompt, rejected) + + # Last prompt token might get merged by tokenizer and + # it should not be included for generation if that happens + prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) + + chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) + rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) + prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) + + for k, v in prompt_tokens.items(): + prompt_tokens[k] = v[:prompt_len_input_ids] + + # Make sure prompts only have one different token at most an + # and length only differs by 1 at most + num_diff_tokens = sum( + a != b + for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=True) + ) + num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) + if num_diff_tokens > 1 or num_diff_len > 1: + raise ValueError( + "Chosen and rejected prompt_input_ids might only differ on the " + "last token due to tokenizer merge ops." + ) + + # add BOS token to head of prompt. Avoid adding if it's already there + prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed( + self.processing_class.bos_token_id, + prompt_len_input_ids, + prompt_tokens, + chosen_prompt_len_input_ids, + chosen_tokens, + rejected_prompt_len_input_ids, + rejected_tokens, + ) + + # add EOS token to end of answer. Avoid adding if it's already there + chosen_tokens, rejected_tokens = add_eos_token_if_needed( + self.processing_class.eos_token_id, chosen_tokens, rejected_tokens + ) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + # if combined sequence is too long, truncate the response + for answer_tokens in [chosen_tokens, rejected_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + for k in ["input_ids", "attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_length - longer_response_length] + + # Create labels + chosen_sequence_tokens = { + k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] + } + rejected_sequence_tokens = { + k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [-100] * len( + chosen_tokens["prompt_input_ids"] + ) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [-100] * len( + rejected_tokens["prompt_input_ids"] + ) + + for k, toks in { + "chosen_": chosen_sequence_tokens, + "rejected_": rejected_sequence_tokens, + "": prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}{type_key}"] = tokens + + else: + chosen_tokens = self.processing_class( + chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + rejected_tokens = self.processing_class( + rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + prompt_tokens = self.processing_class(prompt, add_special_tokens=True) + + batch["chosen_labels"] = chosen_tokens["input_ids"] + batch["rejected_labels"] = rejected_tokens["input_ids"] + batch["prompt_input_ids"] = prompt_tokens["input_ids"] + batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + + if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): + batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["rejected_labels"]) + ) + batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["chosen_labels"]) + ) + + return batch + + @staticmethod + def concatenated_inputs( + batch: dict[str, list | torch.LongTensor], + is_encoder_decoder: bool = False, + padding_value: int = 0, + device: torch.device | None = None, + ) -> dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: + A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors + of shape (batch_size, sequence_length). + is_encoder_decoder: + Whether the model is an encoder-decoder model. + padding_value: + The padding value to use for the concatenated inputs_ids. + device: + The device for the concatenated inputs. + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + concatenated_batch = {} + + if is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = -100 + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = -100 + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(device=device) + + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1).to(device=device) + ) + + return concatenated_batch + + def cpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the CPO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: + Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: + Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the CPO + loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for + the chosen and rejected responses, respectively. + """ + # Apply AlphaPO reward transformation if alpha != 0 + if self.alpha != 0.0: + # Compute probabilities + chosen_probs = torch.exp(policy_chosen_logps) + rejected_probs = torch.exp(policy_rejected_logps) + + # Apply AlphaPO transformation: r = (1 - p^(-alpha)) / alpha + policy_chosen_rewards = (1 - chosen_probs.pow(-self.alpha)) / self.alpha + policy_rejected_rewards = (1 - rejected_probs.pow(-self.alpha)) / self.alpha + + logits = (policy_chosen_rewards - policy_rejected_rewards).to(self.accelerator.device) + else: + # Standard log probability rewards when alpha = 0 + logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device) + + # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5. + # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and + # calculates a conservative CPO loss. + + if self.loss_type == "simpo": + gamma_logratios = self.simpo_gamma / self.beta + logits = logits - gamma_logratios + # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0. + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + elif self.loss_type == "sigmoid": + # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0. + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + elif self.loss_type == "hinge": + losses = torch.relu(1 - self.beta * logits) + elif self.loss_type == "ipo": + # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. + losses = (logits - 1 / (2 * self.beta)) ** 2 + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']" + ) + + # Calculate rewards for logging + if self.alpha != 0.0: + # When using AlphaPO transformation, use the transformed rewards + chosen_rewards = self.beta * policy_chosen_rewards.to(self.accelerator.device).detach() + rejected_rewards = self.beta * policy_rejected_rewards.to(self.accelerator.device).detach() + else: + # Standard log probability rewards + chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach() + rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() + + return losses, chosen_rewards, rejected_rewards + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + is_encoder_decoder: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: + Labels for which to compute the log probabilities. Label tokens with a value of `-100` are ignored. + Shape: (batch_size, sequence_length) + average_log_prob: + If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the + log probabilities of the (non-masked) tokens. + is_encoder_decoder: Whether the model is an encoder-decoder model. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the + given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != -100 + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == -100] = 0 + + per_token_logps = selective_log_softmax(logits, labels) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def concatenated_forward( + self, model: nn.Module, batch: dict[str, list | torch.LongTensor] + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + padding_value=self.pad_token_id, + device=self.accelerator.device, + ) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), + } + if self.is_encoder_decoder + else {} + ) + + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + outputs = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) + all_logits = outputs.logits + + def cross_entropy_loss(logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = concatenated_batch["concatenated_labels"].clone() + + if self.cpo_alpha == 0: + nll_loss = torch.tensor(0.0).to(self.accelerator.device) + else: + nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) + + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=self.loss_type in ["ipo", "simpo"], + is_encoder_decoder=self.is_encoder_decoder, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss) + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss) + + def get_batch_loss_metrics( + self, + model, + batch: dict[str, list | torch.LongTensor], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the CPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + forward_output = self.concatenated_forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + losses, chosen_rewards, rejected_rewards = self.cpo_loss( + policy_chosen_logps, + policy_rejected_logps, + ) + + loss = losses.mean() + self.cpo_alpha * policy_nll_loss + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() + metrics[f"{prefix}rewards/margins"] = ( + self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() + ) + metrics[f"{prefix}logps/rejected"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item() + ) + metrics[f"{prefix}logps/chosen"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item() + ) + metrics[f"{prefix}logits/rejected"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean()).mean().item() + ) + metrics[f"{prefix}logits/chosen"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean()).mean().item() + ) + metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item() + + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss + + return loss, metrics + + def compute_loss( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + return_outputs=False, + num_items_in_batch=None, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: + compute_loss_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with compute_loss_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + + # force log the metrics + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with generate_context_manager: + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + return policy_output_decoded + + def prediction_step( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + prediction_loss_only: bool, + ignore_keys: list[str] | None = None, + ): + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with torch.no_grad(), prediction_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") + + # force log the metrics + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + logits = [v for k, v in logits_dict.items() if k not in ignore_keys] + logits = torch.tensor(logits, device=self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: bool | None = None, + ignore_keys: list[str] | None = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by + `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded = self.generate_from_model(self.model, random_batch) + + table = pd.DataFrame( + columns=["Prompt", "Policy"], + data=[ + [prompt, pol[len(prompt) :]] + for prompt, pol in zip(random_batch["prompt"], policy_output_decoded, strict=True) + ], + ) + if "wandb" in self.args.report_to: + wandb.log({"game_log": wandb.Table(data=table)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="game_log.csv", + table=table, + ) + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`dict[str, float]`): + The values to log. + start_time (`float`, *optional*): + Start time of the training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + + def _shift_right(self, input_ids): + if self.decoder_start_token_id is None: + raise ValueError( + "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = self.decoder_start_token_id + + if self.pad_token_id is None: + raise ValueError("model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id) + + return shifted_input_ids + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/ICL/RL/trl_source/trl/experimental/gfpo/gfpo_config.py b/ICL/RL/trl_source/trl/experimental/gfpo/gfpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ae529987306eadc276115e425f4509e6dd7f1e1c --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/gfpo/gfpo_config.py @@ -0,0 +1,35 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ...trainer.grpo_config import GRPOConfig as _GRPOConfig + + +@dataclass +class GFPOConfig(_GRPOConfig): + num_remains_in_group: int | None = field( + default=None, + metadata={ + "help": "number inputs remains after group filter function, `'num_remains_in_group'` must be >=2 if given." + }, + ) + + def __post_init__(self): + super().__post_init__() + + if self.num_remains_in_group is not None and self.num_remains_in_group >= self.num_generations: + raise ValueError( + f"Number remains in Group {self.num_remains_in_group} must be less than num_generations : {self.num_generations}." + ) diff --git a/ICL/RL/trl_source/trl/experimental/gkd/__init__.py b/ICL/RL/trl_source/trl/experimental/gkd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d49fe68a623fbca5fdc6cd4581464ac65d892fa9 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/gkd/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .gkd_config import GKDConfig +from .gkd_trainer import GKDTrainer + + +__all__ = ["GKDConfig", "GKDTrainer"] diff --git a/ICL/RL/trl_source/trl/experimental/gkd/gkd_config.py b/ICL/RL/trl_source/trl/experimental/gkd/gkd_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ad0e854ba6c2c8ecc5824051a0cab6f3dae6b729 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/gkd/gkd_config.py @@ -0,0 +1,112 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from transformers import TrainingArguments + +from ...trainer.sft_config import SFTConfig + + +@dataclass +class GKDConfig(SFTConfig): + """ + Configuration class for [`experimental.gkd.GKDTrainer`]. + + This class includes only the parameters that are specific to GKD training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation. + + Args: + temperature (`float`, *optional*, defaults to `0.9`): + Temperature for sampling. The higher the temperature, the more random the completions. + lmbda (`float`, *optional*, defaults to `0.5`): + Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy + student-generated outputs). + beta (`float`, *optional*, defaults to `0.5`): + Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When + beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence. + max_new_tokens (`int`, *optional*, defaults to `128`): + Maximum number of tokens to generate per completion. + teacher_model_name_or_path (`str`, *optional*): + Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being + trained. + teacher_model_init_kwargs (`dict[str, Any]]`, *optional*): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model + from a string. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + seq_kd (`bool`, *optional*, defaults to `False`): + Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on + teacher-generated output). + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"] + + temperature: float = field( + default=0.9, + metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, + ) + lmbda: float = field( + default=0.5, + metadata={ + "help": "Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy " + "student-generated outputs)." + }, + ) + beta: float = field( + default=0.5, + metadata={ + "help": "Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence " + "loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL " + "Divergence." + }, + ) + max_new_tokens: int = field( + default=128, + metadata={"help": "Maximum number of tokens to generate per completion."}, + ) + teacher_model_name_or_path: str | None = field( + default=None, + metadata={ + "help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the " + "model being trained." + }, + ) + teacher_model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " + "teacher model from a string." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropouts in `model`."}, + ) + seq_kd: bool = field( + default=False, + metadata={ + "help": "Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised " + "FT on teacher-generated output)." + }, + ) + + def __post_init__(self): + super().__post_init__() + # check lmbda and beta are in the range [0, 1] + if self.lmbda < 0.0 or self.lmbda > 1.0: + raise ValueError("lmbda must be in the range [0.0, 1.0].") + if self.beta < 0.0 or self.beta > 1.0: + raise ValueError("beta must be in the range [0.0, 1.0].") diff --git a/ICL/RL/trl_source/trl/experimental/gold/__init__.py b/ICL/RL/trl_source/trl/experimental/gold/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c2a8c1149130a77081ed794d02c7d7629d265bb --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/gold/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .gold_config import GOLDConfig +from .gold_trainer import GOLDTrainer + + +__all__ = ["GOLDConfig", "GOLDTrainer"] diff --git a/ICL/RL/trl_source/trl/experimental/gold/gold.py b/ICL/RL/trl_source/trl/experimental/gold/gold.py new file mode 100644 index 0000000000000000000000000000000000000000..81954b5c753451b746a2ed09cb2d576a98930321 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/gold/gold.py @@ -0,0 +1,155 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl @ git+https://github.com/huggingface/trl.git", +# "peft", +# "trackio", +# ] +# /// + +# docstyle-ignore +""" +# Full training: +python trl/experimental/gold/gold.py \ + --model_name_or_path meta-llama/Llama-3.2-1B-Instruct \ + --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \ + --dataset_name trl-lib/chatbot_arena_completions \ + --learning_rate 2e-5 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --output_dir gold-model \ + --num_train_epochs 1 \ + --push_to_hub + +# LoRA: +python trl/experimental/gold/gold.py \ + --model_name_or_path meta-llama/Llama-3.2-1B-Instruct \ + --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \ + --dataset_name trl-lib/chatbot_arena_completions \ + --learning_rate 2e-4 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --output_dir gold-model \ + --num_train_epochs 1 \ + --push_to_hub \ + --use_peft \ + --lora_r 64 \ + --lora_alpha 16 +""" + +import logging + +from datasets import load_dataset +from transformers import AutoTokenizer, GenerationConfig + +from trl import ( + LogCompletionsCallback, + ModelConfig, + ScriptArguments, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.experimental.gold.gold_config import GOLDConfig +from trl.experimental.gold.gold_trainer import GOLDTrainer + + +logger = logging.getLogger(__name__) + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, GOLDConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + + ################ + # Model & Tokenizer + ################ + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=training_args.student_model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + training_args.model_init_kwargs = model_kwargs + + if training_args.teacher_tokenizer_name_or_path is None and training_args.use_uld_loss: + training_args.teacher_tokenizer_name_or_path = training_args.teacher_model_name_or_path + teacher_model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.dtype, + use_cache=True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + training_args.teacher_model_init_kwargs = teacher_model_kwargs + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + padding_side="left", + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + ################ + # Training + ################ + # Handle eval dataset - check if test split exists, fallback to validation or None + eval_dataset = None + if training_args.eval_strategy != "no": + if script_args.dataset_test_split in dataset: + eval_dataset = dataset[script_args.dataset_test_split] + elif "validation" in dataset: + eval_dataset = dataset["validation"] + elif "dev" in dataset: + eval_dataset = dataset["dev"] + + trainer = GOLDTrainer( + model=model_args.model_name_or_path, + teacher_model=training_args.teacher_model_name_or_path, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=eval_dataset, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_completion_length, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/ICL/RL/trl_source/trl/experimental/gold/gold_config.py b/ICL/RL/trl_source/trl/experimental/gold/gold_config.py new file mode 100644 index 0000000000000000000000000000000000000000..827b639dec844e5f9a4f53b1cdac461070cb7ab1 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/gold/gold_config.py @@ -0,0 +1,419 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from transformers import TrainingArguments + +from ...trainer.sft_config import SFTConfig + + +@dataclass +class GOLDConfig(SFTConfig): + r""" + Configuration class for [`GOLDTrainer`]. + + This class includes only the parameters that are specific to GOLD training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation. + + Args: + temperature (`float`, *optional*, defaults to `0.9`): + Temperature for sampling. The higher the temperature, the more random the completions. + lmbda (`float`, *optional*, defaults to `0.5`): + Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy + student-generated outputs). + beta (`float`, *optional*, defaults to `0.5`): + Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When + beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence. + max_completion_length (`int`, *optional*, defaults to `128`): + Maximum number of tokens to generate per completion. + teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`): + Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being + trained. + teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model + from a string. + teacher_tokenizer_name_or_path (`str` or `None`, *optional*, defaults to `None`): + Tokenizer name or path for the teacher model. If None when using ULD loss, will use the same tokenizer as + the student model (not recommended for cross-tokenizer distillation). + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + seq_kd (`bool`, *optional*, defaults to `False`): + Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on + teacher-generated output). + use_uld_loss (`bool`, *optional*, defaults to `False`): + Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence + loss. + uld_crossentropy_weight (`float`, *optional*, defaults to `0.0`): + Weight for the cross-entropy loss component in ULD loss. If 0, only ULD distillation loss is used. + uld_distillation_weight (`float`, *optional*, defaults to `1.0`): + Weight for the distillation loss component in ULD loss. + uld_student_temperature (`float`, *optional*, defaults to `1.0`): + Temperature for student logits in ULD loss computation. + uld_teacher_temperature (`float`, *optional*, defaults to `1.0`): + Temperature for teacher logits in ULD loss computation. + uld_skip_student_eos (`bool`, *optional*, defaults to `True`): + Whether to skip EOS token for student in ULD loss computation. + uld_skip_teacher_eos (`bool`, *optional*, defaults to `True`): + Whether to skip EOS token for teacher in ULD loss computation. + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generating completions from the student model. Requires `vllm` to be installed. + vllm_mode (`str`, *optional*, defaults to `"server"`): + Mode for student vLLM integration. Either `"server"` (connect to a running TRL vLLM server) or `"colocate"` + (run vLLM in the same process). + vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host of the vLLM server for the student model (if `vllm_mode="server"`). + vllm_server_port (`int`, *optional*, defaults to `8001`): + Port of the vLLM server for the student model (if `vllm_mode="server"`). + vllm_server_timeout (`float`, *optional*, defaults to `240.0`): + Timeout for connecting to the student vLLM server (if `vllm_mode="server"`). + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): + GPU memory utilization for the colocated student vLLM engine (if `vllm_mode="colocate"`). It is recommended + to set this to a low value if the student and teacher models share the same GPU. + vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): + Tensor parallel size for the colocated student vLLM engine (if `vllm_mode="colocate"`). + vllm_structured_outputs_regex (`str` or `None`, *optional*, defaults to `None`): + Regex for vLLM structured outputs for the student model. + vllm_sync_frequency (`int`, *optional*, defaults to `1`): + Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after + every step. + vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): + Enable vLLM sleep mode to offload student weights/cache during the optimizer step. Keeps GPU memory usage + low, but waking the engine adds hostโ€“device transfer latency. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-7, + metadata={"help": "The initial learning rate for AdamW."}, + ) + + # GOLD-specific parameters + temperature: float = field( + default=0.9, + metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, + ) + top_p: float = field( + default=0.95, + metadata={ + "help": "If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to " + "`top_p` or higher are kept for generation." + }, + ) + top_k: int = field( + default=0, + metadata={ + "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, " + "top-k-filtering is disabled and all tokens are considered." + }, + ) + lmbda: float = field( + default=0.5, + metadata={ + "help": "Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy " + "student-generated outputs)." + }, + ) + beta: float = field( + default=0.5, + metadata={ + "help": "Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence " + "loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL " + "Divergence." + }, + ) + max_completion_length: int = field( + default=128, + metadata={"help": "Maximum number of tokens to generate per completion."}, + ) + student_model_revision: str = field( + default="main", + metadata={ + "help": "Revision of the student model to use. If not specified, the default revision of the model will be used." + }, + ) + teacher_model_name_or_path: str | None = field( + default=None, + metadata={ + "help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the " + "model being trained." + }, + ) + teacher_model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " + "teacher model from a string." + }, + ) + teacher_tokenizer_name_or_path: str | None = field( + default=None, + metadata={ + "help": "Tokenizer name or path for the teacher model. If None when using ULD loss, will use the same " + "tokenizer as the student model (not recommended for cross-tokenizer distillation)." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropouts in `model`."}, + ) + seq_kd: bool = field( + default=False, + metadata={ + "help": "Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised " + "FT on teacher-generated output)." + }, + ) + steps_per_generation: int | None = field( + default=None, + metadata={ + "help": "Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps." + }, + ) + + # ULD Loss parameters + use_uld_loss: bool = field( + default=False, + metadata={ + "help": "Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence loss." + }, + ) + use_extended_uld: bool = field( + default=True, + metadata={ + "help": ( + "Whether to enable extended ULD alignment that uses tokenizers to align and merge token " + "probabilities across student and teacher tokenizations. When True, the trainer will compute " + "token mappings and merge probabilities for split tokens; when False, ULD will use simple " + "positional truncation like in the original ULD paper." + ) + }, + ) + uld_use_hybrid_loss: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use a hybrid loss that combines ULD loss and JSD loss. When True, the final loss is a " + "a combination of JSD for known token mappings and ULD for unknown token mappings." + ) + }, + ) + uld_hybrid_matched_weight: float | None = field( + default=None, + metadata={ + "help": ( + "Weight for the matched token loss component when using hybrid ULD + JSD loss. This weight scales " + "the JSD loss computed over tokens that have a direct mapping between student and teacher " + "tokenizations. If None, uses adaptive weighting based on vocabulary overlap. Must be set together " + "with uld_hybrid_unmatched_weight (both None or both float)." + ) + }, + ) + uld_hybrid_unmatched_weight: float | None = field( + default=None, + metadata={ + "help": ( + "Weight for the unmatched token loss component when using hybrid ULD + JSD loss. This weight scales " + "the ULD loss computed over tokens that do not have a direct mapping between student and teacher " + "tokenizations. If None, uses adaptive weighting based on vocabulary overlap. Must be set together " + "with uld_hybrid_matched_weight (both None or both float)." + ) + }, + ) + uld_crossentropy_weight: float = field( + default=0.0, + metadata={"help": "Weight for the cross-entropy loss component in ULD loss."}, + ) + uld_distillation_weight: float = field( + default=1.0, + metadata={"help": "Weight for the distillation loss component in ULD loss."}, + ) + uld_student_temperature: float = field( + default=1.0, + metadata={"help": "Temperature for student logits in ULD loss computation."}, + ) + uld_teacher_temperature: float = field( + default=1.0, + metadata={"help": "Temperature for teacher logits in ULD loss computation."}, + ) + + uld_skip_student_eos: bool = field( + default=True, + metadata={"help": "Whether to skip EOS token for student in ULD loss computation."}, + ) + uld_skip_teacher_eos: bool = field( + default=True, + metadata={"help": "Whether to skip EOS token for teacher in ULD loss computation."}, + ) + + # transformers paged attention + use_transformers_paged: bool = field( + default=False, + metadata={ + "help": "Whether to use the `transformers` paged implementation for generation. If set to `True`, the " + "`transformers` paged implementation will be used for generation instead of the default padded " + "implementation." + }, + ) + + # vLLM parameters + use_vllm: bool = field( + default=False, + metadata={"help": "Whether to use vLLM for generating completions. Requires `vllm` to be installed."}, + ) + vllm_mode: str = field( + default="server", + metadata={ + "help": 'Mode for vLLM integration. Either "server" (connect to a running TRL vLLM server) or "colocate" (run vLLM in the same process).' + }, + ) + vllm_server_host: str = field( + default="0.0.0.0", + metadata={"help": 'Host of the vLLM server when `vllm_mode="server"`.'}, + ) + vllm_server_port: int = field( + default=8001, + metadata={"help": 'Port of the vLLM server when `vllm_mode="server"`.'}, + ) + vllm_server_timeout: float = field( + default=240.0, + metadata={"help": 'Timeout (in seconds) for connecting to the vLLM server when `vllm_mode="server"`.'}, + ) + vllm_gpu_memory_utilization: float = field( + default=0.9, + metadata={ + "help": 'GPU memory utilization for the colocated vLLM engine when `vllm_mode="colocate"`. Lower values reduce contention when sharing a device with the student/teacher models.' + }, + ) + vllm_tensor_parallel_size: int = field( + default=1, + metadata={"help": 'Tensor parallel size for the colocated vLLM engine when `vllm_mode="colocate"`.'}, + ) + vllm_structured_outputs_regex: str | None = field( + default=None, + metadata={"help": "Regex pattern used for vLLM structured outputs (optional)."}, + ) + vllm_sync_frequency: int = field( + default=1, + metadata={ + "help": "Frequency (in training steps) to synchronize model weights to the vLLM engine. Set to 1 to sync after every step." + }, + ) + vllm_enable_sleep_mode: bool = field( + default=False, + metadata={ + "help": "Enable vLLM sleep mode to offload student weights/cache during the optimizer step. Keeps GPU " + "memory usage low, but waking the engine adds hostโ€“device transfer latency." + }, + ) + # Parameters that control the logging + log_completions: bool = field( + default=False, + metadata={ + "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is " + "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`." + }, + ) + log_completions_steps: int = field( + default=100, + metadata={ + "help": "Number of steps between logging (prompt, completion) pairs. Only used if `log_completions` is " + "set to `True`." + }, + ) + num_completions_to_print: int | None = field( + default=None, + metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."}, + ) + wandb_entity: str | None = field( + default=None, + metadata={"help": ("The entity to store runs under.")}, + ) + wandb_project: str | None = field( + default=None, + metadata={"help": ("The project to store runs under.")}, + ) + wandb_run_group: str | None = field( + default=None, + metadata={"help": ("The group to store runs under.")}, + ) + wandb_log_unique_prompts: bool = field( + default=True, + metadata={ + "help": ("Whether to log the unique prompts to wandb. This will create a new run for each unique prompt.") + }, + ) + callbacks: list[str] = field( + default_factory=lambda: [], + metadata={"help": "The callbacks to run during training."}, + ) + hub_model_revision: str | None = field( + default="main", metadata={"help": "The Hub model branch to push the model to."} + ) + num_completions_to_print: int = field(default=5, metadata={"help": "Number of completions to print."}) + overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) + push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) + trl_project: str = field( + default="smollm3", + metadata={ + "help": "The TRL project to use for evaluation. This is used to determine the path to the evaluation script." + }, + ) + + def __post_init__(self): + super().__post_init__() + # check lmbda and beta are in the range [0, 1] + if self.lmbda < 0.0 or self.lmbda > 1.0: + raise ValueError("lmbda must be in the range [0.0, 1.0].") + if self.beta < 0.0 or self.beta > 1.0: + raise ValueError("beta must be in the range [0.0, 1.0].") + + # Validate that max_length is sufficient for max_completion_length + if self.max_length is not None and self.max_completion_length >= self.max_length: + raise ValueError( + f"max_completion_length ({self.max_completion_length}) must be smaller than max_length ({self.max_length}) " + f"to leave room for the prompt. Consider increasing max_length or reducing max_completion_length." + ) + + if self.steps_per_generation is None: + self.steps_per_generation = self.gradient_accumulation_steps + + # Validate ULD parameters + if self.use_uld_loss: + if self.uld_crossentropy_weight < 0.0: + raise ValueError("uld_crossentropy_weight must be non-negative.") + if self.uld_distillation_weight < 0.0: + raise ValueError("uld_distillation_weight must be non-negative.") + if self.uld_student_temperature <= 0.0: + raise ValueError("uld_student_temperature must be positive.") + if self.uld_teacher_temperature <= 0.0: + raise ValueError("uld_teacher_temperature must be positive.") + + # Validate hybrid loss weights - both must be None or both must be set + if self.uld_use_hybrid_loss: + if (self.uld_hybrid_matched_weight is None) != (self.uld_hybrid_unmatched_weight is None): + raise ValueError( + "uld_hybrid_matched_weight and uld_hybrid_unmatched_weight must both be None (for adaptive " + "weighting) or both be set to numeric values. Got uld_hybrid_matched_weight=" + f"{self.uld_hybrid_matched_weight} and uld_hybrid_unmatched_weight=" + f"{self.uld_hybrid_unmatched_weight}." + ) + if self.uld_hybrid_matched_weight is not None: + if self.uld_hybrid_matched_weight < 0.0: + raise ValueError("uld_hybrid_matched_weight must be non-negative.") + if self.uld_hybrid_unmatched_weight < 0.0: + raise ValueError("uld_hybrid_unmatched_weight must be non-negative.") diff --git a/ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/__init__.py b/ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..708ea62d6740e6ca43c7de21bb716abce1e7c284 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig +from .grpo_with_replay_buffer_trainer import GRPOWithReplayBufferTrainer, ReplayBuffer diff --git a/ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py b/ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f271a40d5acae80f0cb00ba10501bfa6865b2653 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py @@ -0,0 +1,34 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ...trainer.grpo_config import GRPOConfig + + +@dataclass +class GRPOWithReplayBufferConfig(GRPOConfig): + """ + New Parameters: + replay_buffer_size (`int`, *optional*, defaults to `0`): + A cache that stores the rollouts with the highest advantage scores and variance per group. If a new + group has 0 variance, it is replaced with a group sampled from the replay buffer. + """ + + replay_buffer_size: int = field( + default=64, + metadata={ + "help": "A cache that stores the rollouts with the highest advantage scores and variance per group. If a new group has 0 variance, it is replaced with a group sampled from the replay buffer." + }, + ) diff --git a/ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f3aba7a0edbe3c527e0305ba9811450f43ad5925 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -0,0 +1,731 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import heapq +from typing import Any + +import torch +from accelerate.utils import gather_object + +from ...data_utils import apply_chat_template, prepare_multimodal_messages +from ...models.utils import disable_gradient_checkpointing +from ...trainer.grpo_trainer import GRPOTrainer +from ...trainer.utils import nanmax, nanmin, nanstd, pad +from .grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig + + +class ReplayBuffer: + """ + A simple replay buffer to store and sample previously seen rollouts. + """ + + def __init__(self, max_size: int): + self.max_size = max_size + self.heap = [] # Min-heap of (score, data) tuples + + def add(self, scores: list[float], data: list[dict]): + for score, datum in zip(scores, data, strict=True): + if len(self.heap) < self.max_size: + heapq.heappush(self.heap, (score, datum)) + else: + # Only add if score is better than worst (minimum) item + if score > self.heap[0][0]: + heapq.heapreplace(self.heap, (score, datum)) + + def sample(self, num_samples: int) -> list[dict[str, torch.Tensor]]: + if not self.heap: + return None + + # Sample by normalized scores + scores = torch.tensor([item[0] for item in self.heap], dtype=torch.float32) + probabilities = scores / scores.sum() + replacement = False + if num_samples > len(self.heap): + replacement = True + chosen_indices = torch.multinomial(probabilities, num_samples, replacement=replacement).tolist() + return [self.heap[i][1] for i in chosen_indices] + + +class GRPOWithReplayBufferTrainer(GRPOTrainer): + def __init__(self, args: GRPOWithReplayBufferConfig | None = None, **kwargs): + super().__init__(args=args, **kwargs) + self.replay_buffer = ReplayBuffer(args.replay_buffer_size) if args.replay_buffer_size > 0 else None + + def _generate_and_score_completions( + self, inputs: list[dict[str, torch.Tensor | Any]] + ) -> dict[str, torch.Tensor | Any]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + # Transformers requires at least one image in the batch, otherwise it throws an error + if images is not None and all(img_list == [] for img_list in images): + images = None + + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What color is the sky?"}]}] + if images is not None: + prompts = [ + prepare_multimodal_messages(prompt, image_list) + for prompt, image_list in zip(prompts, images, strict=True) + ] + + ( + prompt_ids_list, + completion_ids_list, + tool_mask_list, + completions, + num_items_in_batch, + sampling_per_token_logps_list, + extra_fields, + ) = self._generate(prompts) + + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + if sampling_per_token_logps_list is not None: + sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] + sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") + else: + sampling_per_token_logps = None + if self.tools: + tool_mask = [torch.tensor(mask, device=device) for mask in tool_mask_list] + tool_mask = pad(tool_mask, padding_value=1, padding_side="right") # 0 for tool result tokens, 1 elsewhere + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() + + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + num_images = [len(img_list) for img_list in images] if images is not None else None + + # Get forward_kwargs for models with multimodal inputs + if images is not None: + prompts_text = [ + apply_chat_template( + {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs + )["prompt"] + for prompt in prompts + ] + prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + + # If token_type_ids are used, extend them with zeros for the completion part + if "token_type_ids" in forward_kwargs: + token_type_ids = forward_kwargs["token_type_ids"] + forward_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 + ) + + # When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a + # torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True"). + # Temporarily disable checkpointing to avoid this warning during inference. + with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): + # If the generation and optimization steps are misalignedโ€”i.e., if generation does not occur at the end of + # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)โ€”then the + # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps + # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set + # old_per_token_logps to None. + # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the + # distribution mismatch between vLLM and the training model can be large and harm the training. + generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency + if self.args.gradient_accumulation_steps % generate_every != 0 or ( + self.use_vllm and self.vllm_importance_sampling_correction + ): + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + old_per_token_logps = None + + # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch + if self.use_vllm and self.vllm_importance_sampling_correction: + importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) + importance_sampling_ratio = torch.clamp( + importance_sampling_ratio, max=self.vllm_importance_sampling_cap + ) + + # Compute the per-token log probabilities for the reference model + if self.beta != 0.0: + if self.ref_model is not None: + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + ref_per_token_logps = None + + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + + # Merge extra_fields from rollout_func into inputs for reward functions + if extra_fields: + for i, inp in enumerate(inputs): + for key, values in extra_fields.items(): + if isinstance(values, list) and i < len(values): + inp[key] = values[i] + elif not isinstance(values, list): + inp[key] = values + + # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is + # important because rewards will be normalized per group, and completions are distributed. We will later slice + # rewards_per_func to extract each process's subset. + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + + # Apply weights to each reward function's output and sum + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + # Compute grouped-wise rewards + mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) + + # Normalize the rewards to compute the advantages + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + advantages = rewards - mean_grouped_rewards + + grouped_std_rewards = rewards.view(-1, self.num_generations).std(dim=1) + grouped_std_rewards = grouped_std_rewards.repeat_interleave(self.num_generations, dim=0) + + if self.scale_rewards in ["group", "none"]: + # If self.scale_rewards = "none", we'll still log group level std + std_rewards = grouped_std_rewards.clone() + elif self.scale_rewards == "batch": + # Compute global std + std_rewards = rewards.std().expand_as(rewards) + else: + raise ValueError( + f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'." + ) + + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) + if self.scale_rewards != "none": + advantages = advantages / (std_rewards + 1e-4) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + all_process_advantages = advantages.clone() # keep the aggregated advantages for logging + advantages = advantages[process_slice] + grouped_std_rewards = grouped_std_rewards[process_slice] + + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) + std_func_rewards = nanstd(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards) + self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item()) + self._metrics[mode]["reward_std"].append(std_rewards.mean().item()) + self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) + + # Log prompt and completion texts + self._logs["prompt"].extend(gather_object(prompts_text)) + self._logs["completion"].extend(gather_object(completions_text)) + for i, name in enumerate(self.reward_func_names): + self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) + self._logs["advantages"].extend(all_process_advantages.tolist()) + + if images is not None: + self._logs["images"].extend(gather_object(images)) + + if self.use_vllm and self.vllm_importance_sampling_correction: + delta = torch.abs(old_per_token_logps - sampling_per_token_logps) + mask = completion_mask.bool() if not self.tools else (completion_mask * tool_mask).bool() + delta = delta[mask] + mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( + self.accelerator.gather(mean_delta).mean().item() + ) + self._metrics[mode]["sampling/sampling_logp_difference/max"].append( + self.accelerator.gather(max_delta).max().item() + ) + + flat_is_ratio = importance_sampling_ratio[mask] + min_importance_sampling_ratio = ( + torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + mean_importance_sampling_ratio = ( + torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + max_importance_sampling_ratio = ( + torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( + nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( + self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( + nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() + ) + outputs_after_sampling_buffer = self.update_with_replay_buffer( + advantages, + grouped_std_rewards, + prompt_ids, + prompt_mask, + completion_ids, + completion_mask, + forward_kwargs, + num_items_in_batch, + old_per_token_logps, + ref_per_token_logps, + importance_sampling_ratio if self.use_vllm and self.vllm_importance_sampling_correction else None, + ) + if outputs_after_sampling_buffer is not None: + return outputs_after_sampling_buffer + else: + output = { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "advantages": advantages, + "num_items_in_batch": num_items_in_batch, + } + if old_per_token_logps is not None: + output["old_per_token_logps"] = old_per_token_logps + if self.use_vllm and self.vllm_importance_sampling_correction: + output["importance_sampling_ratio"] = importance_sampling_ratio + if ref_per_token_logps is not None: + output["ref_per_token_logps"] = ref_per_token_logps + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] + if "token_type_ids" in forward_kwargs: + output["token_type_ids"] = forward_kwargs["token_type_ids"] + if images is not None: + output["num_images"] = num_images + if self.tools: + output["tool_mask"] = tool_mask + return output + + def slice_group_data( + self, data: torch.Tensor, mask: torch.Tensor, group_idx: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Slices the input data and mask tensors for a specific group index. Also trims the sequence length to the + maximum length in the group based on the mask. + + Args: + data: Tensor of shape (num_groups * num_generations, seq_length) + mask: Tensor of shape (num_groups * num_generations, seq_length) + group_idx: Index of the group to slice + Returns: + Tuple of (sliced_data, sliced_mask) for the specified group, with sequence length trimmed to the maximum + length in the group. + """ + start_idx = group_idx * self.num_generations + end_idx = (group_idx + 1) * self.num_generations + group_data = data[start_idx:end_idx] + group_mask = mask[start_idx:end_idx] + group_max_len = group_mask.sum(dim=1).max().item() + return group_data[:, :group_max_len], group_mask[:, :group_max_len] + + def update_replay_buffer( + self, + groups_with_variance: torch.Tensor, + group_advantages: torch.Tensor, + group_std_rewards: torch.Tensor, + prompt_ids: torch.Tensor, + prompt_mask: torch.Tensor, + completion_ids: torch.Tensor, + completion_mask: torch.Tensor, + forward_kwargs: dict, + optional_vision_fields: list[str] = None, + old_per_token_logps: torch.Tensor | None = None, + ref_per_token_logps: torch.Tensor | None = None, + importance_sampling_ratio: float | None = None, + ) -> None: + """ + Update the replay buffer with groups that have reward variance (std > 0). + + Args: + groups_with_variance: Boolean tensor indicating which groups have reward variance + group_advantages: Tensor of shape (num_groups, num_generations) containing advantage values + std_rewards: Tensor of shape (num_groups, num_generations) containing std of rewards per group + prompt_ids: Tensor containing prompt token IDs + prompt_mask: Tensor containing prompt attention masks + completion_ids: Tensor containing completion token IDs + completion_mask: Tensor containing completion attention masks + forward_kwargs: Dictionary containing additional prompt inputs (vision data, etc.) + optional_vision_fields: List of optional vision-related fields to include if present in forward_kwargs + old_per_token_logps: Optional tensor of old per-token log probabilities + ref_per_token_logps: Optional tensor of reference per-token log probabilities + importance_sampling_ratio: Optional importance sampling correction ratio + """ + # Prepare buffered outputs for groups with variance + buffered_outputs = [] + for _, group_idx in enumerate(groups_with_variance.nonzero(as_tuple=True)[0].unique().tolist()): + group_prompt_ids, group_prompt_mask = self.slice_group_data(prompt_ids, prompt_mask, group_idx) + group_completion_ids, group_completion_mask = self.slice_group_data( + completion_ids, completion_mask, group_idx + ) + + # Store unpadded data in the buffer + buffered_output = { + "prompt_ids": group_prompt_ids, + "completion_ids": group_completion_ids, + "advantages": group_advantages[group_idx].tolist(), + "prompt_mask": group_prompt_mask, + "completion_mask": group_completion_mask, + } + + # Add optional fields if they exist + optional_fields = { + "old_per_token_logps": old_per_token_logps if old_per_token_logps is not None else None, + "ref_per_token_logps": ref_per_token_logps if ref_per_token_logps is not None else None, + } + + for field_name, field_data in optional_fields.items(): + if field_data is not None: + buffered_output[field_name] = self.slice_group_data(field_data, completion_mask, group_idx)[0] + + # Add importance sampling if needed + if self.use_vllm and self.vllm_importance_sampling_correction: + buffered_output["importance_sampling_ratio"] = importance_sampling_ratio + + if optional_vision_fields: + # Add vision-related fields if they exist + for field_name in optional_vision_fields: + if field_name in forward_kwargs: + buffered_output[field_name] = self.slice_group_data( + forward_kwargs[field_name], prompt_mask, group_idx + )[0] + + buffered_outputs.append(buffered_output) + + if groups_with_variance.any(): + # Calculate replay buffer scores for groups with variance + replay_buffer_scores = (group_advantages.abs() * group_std_rewards).sum(dim=-1)[groups_with_variance] + # Add all groups to replay buffer at once (batch operation) + self.replay_buffer.add(replay_buffer_scores.tolist(), buffered_outputs) + + def sample_from_replay_buffer( + self, num_samples: int, optional_vision_fields: list[str] = None, optional_tensor_fields: list[str] = None + ) -> list[dict]: + """ + Sample groups from the replay buffer. + + Args: + num_samples: Number of samples to draw from the replay buffer + optional_vision_fields: List of optional vision-related fields to include if present in sampled data + optional_tensor_fields: List of optional tensor fields to include if present in sampled data + Returns: + List of sampled data dictionaries from the replay buffer + """ + sampled = self.replay_buffer.sample(num_samples=num_samples) + + # Extract and concatenate sampled data + sampled_data = { + "prompt_ids": [], + "prompt_mask": [], + "completion_ids": [], + "completion_mask": [], + "advantages": [], + } + + all_optional_fields = (optional_tensor_fields or []) + (optional_vision_fields or []) + # Initialize containers for optional fields if they exist in sampled data + for field in all_optional_fields: + if sampled and field in sampled[0]: + sampled_data[field] = [] + + # Extract data from each sampled item + for item in sampled: + # Handle core fields + for key in ["prompt_ids", "prompt_mask", "completion_ids", "completion_mask"]: + sampled_data[key].append(item[key]) + + # Handle advantages (list, not tensor) + sampled_data["advantages"].append(item["advantages"]) + + # Handle optional fields + for field in all_optional_fields: + if field in item: + sampled_data[field].append(item[field]) + + return sampled_data + + def update_with_replay_buffer( + self, + group_advantages: torch.Tensor, + group_std_rewards: torch.Tensor, + prompt_ids: torch.Tensor, + prompt_mask: torch.Tensor, + completion_ids: torch.Tensor, + completion_mask: torch.Tensor, + forward_kwargs: dict, + num_items_in_batch: int, + old_per_token_logps: torch.Tensor | None = None, + ref_per_token_logps: torch.Tensor | None = None, + importance_sampling_ratio: float | None = None, + ) -> None: + """ + Update current batch data with samples from replay buffer. + + Groups with reward variance (std > 0) are added to the replay buffer and then replaced with samples from the + buffer to improve training stability. + + Args: + group_advantages: Tensor of shape (num_groups, num_generations) containing advantage values + std_rewards: Tensor of shape (num_groups, num_generations) containing std of rewards per group + prompt_ids: Tensor containing prompt token IDs + prompt_mask: Tensor containing prompt attention masks + completion_ids: Tensor containing completion token IDs + completion_mask: Tensor containing completion attention masks + forward_kwargs: Dictionary containing additional prompt inputs (vision data, etc.) + num_items_in_batch: Number of items in the current batch + old_per_token_logps: Optional tensor of old per-token log probabilities + ref_per_token_logps: Optional tensor of reference per-token log probabilities + importance_sampling_ratio: Optional importance sampling correction ratio + """ + if self.replay_buffer.max_size <= 0: + return + + # Groups to consider for adding to the replay buffer + groups_with_variance = group_std_rewards.max(dim=0).values > 0 + # Groups to replace from the replay buffer + groups_without_variance = ~groups_with_variance + + # Track which optional fields are present in sampled data + optional_tensor_fields = ["old_per_token_logps", "ref_per_token_logps"] + vision_fields = ["pixel_values", "image_grid_thw", "pixel_attention_mask", "image_sizes"] + + self.update_replay_buffer( + groups_with_variance, + group_advantages, + group_std_rewards, + prompt_ids, + prompt_mask, + completion_ids, + completion_mask, + forward_kwargs, + vision_fields, + old_per_token_logps, + ref_per_token_logps, + importance_sampling_ratio, + ) + + # Sample from replay buffer to replace groups with variance + num_groups_to_replace = groups_without_variance.sum().item() + if not num_groups_to_replace: + return + + sampled_data = self.sample_from_replay_buffer( + num_samples=num_groups_to_replace, + optional_vision_fields=vision_fields, + optional_tensor_fields=optional_tensor_fields, + ) + + # Pad sampled data if they are shorter than the current batch sequences + # Or pad the current batch if sampled are longer + current_batch_prompt_seq_len = prompt_ids.size(1) + current_batch_completion_seq_len = completion_ids.size(1) + + groups_to_replace_idxs = groups_with_variance.logical_not().nonzero(as_tuple=True)[0].unique().tolist() + + # Determine target (max) sequence lengths once + sampled_prompt_lengths = [t.size(1) for t in sampled_data["prompt_ids"]] + sampled_completion_lengths = [t.size(1) for t in sampled_data["completion_ids"]] + target_prompt_len = max([current_batch_prompt_seq_len] + sampled_prompt_lengths) + target_completion_len = max([current_batch_completion_seq_len] + sampled_completion_lengths) + + # If any sampled prompt is longer, pad the whole batch prompt tensors once (left padding) + if target_prompt_len > current_batch_prompt_seq_len: + prompt_ids = pad( + list(prompt_ids.unbind(0)), + padding_value=self.pad_token_id, + pad_to_multiple_of=target_prompt_len, + padding_side="left", + ) + prompt_mask = pad( + list(prompt_mask.unbind(0)), padding_value=0, pad_to_multiple_of=target_prompt_len, padding_side="left" + ) + # If any sampled completion is longer, pad the whole batch completion tensors once (right padding) + if target_completion_len > current_batch_completion_seq_len: + completion_ids = pad( + list(completion_ids.unbind(0)), + padding_value=self.pad_token_id, + pad_to_multiple_of=target_completion_len, + padding_side="right", + ) + completion_mask = pad( + list(completion_mask.unbind(0)), + padding_value=0, + pad_to_multiple_of=target_completion_len, + padding_side="right", + ) + if old_per_token_logps is not None: + old_per_token_logps = pad( + list(old_per_token_logps.unbind(0)), + padding_value=0.0, + pad_to_multiple_of=target_completion_len, + padding_side="right", + ) + if ref_per_token_logps is not None: + ref_per_token_logps = pad( + list(ref_per_token_logps.unbind(0)), + padding_value=0.0, + pad_to_multiple_of=target_completion_len, + padding_side="right", + ) + + # Replace per-group data, padding only sampled groups that are shorter than the target + for i, group_idx in enumerate(groups_to_replace_idxs): + start_idx = group_idx * self.num_generations + end_idx = (group_idx + 1) * self.num_generations + idx_range = slice(start_idx, end_idx) + + # Pad sampled prompt to target length if needed + if sampled_data["prompt_ids"][i].size(1) < target_prompt_len: + sampled_data["prompt_ids"][i] = pad( + sampled_data["prompt_ids"][i], + padding_value=self.pad_token_id, + pad_to_multiple_of=target_prompt_len, + padding_side="left", + ) + sampled_data["prompt_mask"][i] = pad( + sampled_data["prompt_mask"][i], + padding_value=0, + pad_to_multiple_of=target_prompt_len, + padding_side="left", + ) + + # Pad sampled completion to target length if needed + if sampled_data["completion_ids"][i].size(1) < target_completion_len: + sampled_data["completion_ids"][i] = pad( + sampled_data["completion_ids"][i], + padding_value=self.pad_token_id, + pad_to_multiple_of=target_completion_len, + padding_side="right", + ) + sampled_data["completion_mask"][i] = pad( + sampled_data["completion_mask"][i], + padding_value=0, + pad_to_multiple_of=target_completion_len, + padding_side="right", + ) + if "old_per_token_logps" in sampled_data: + sampled_data["old_per_token_logps"][i] = pad( + sampled_data["old_per_token_logps"][i], + padding_value=0.0, + pad_to_multiple_of=target_completion_len, + padding_side="right", + ) + if "ref_per_token_logps" in sampled_data: + sampled_data["ref_per_token_logps"][i] = pad( + sampled_data["ref_per_token_logps"][i], + padding_value=0.0, + pad_to_multiple_of=target_completion_len, + padding_side="right", + ) + + # Assign (replace) group slice + prompt_ids[idx_range] = sampled_data["prompt_ids"][i] + prompt_mask[idx_range] = sampled_data["prompt_mask"][i] + completion_ids[idx_range] = sampled_data["completion_ids"][i] + completion_mask[idx_range] = sampled_data["completion_mask"][i] + group_advantages[group_idx] = sampled_data["advantages"][i] + + if "old_per_token_logps" in sampled_data: + old_per_token_logps[idx_range] = sampled_data["old_per_token_logps"][i] + if "ref_per_token_logps" in sampled_data: + ref_per_token_logps[idx_range] = sampled_data["ref_per_token_logps"][i] + + for field in vision_fields: + if field in sampled_data and field in forward_kwargs: + forward_kwargs[field][idx_range] = sampled_data[field][i] + + # Prepare final outputs after sampling and replacement + outputs_after_sampling_buffer = { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "advantages": group_advantages, + } + + # Replace optional tensor fields if they exist + for field in optional_tensor_fields: + if field in sampled_data: + outputs_after_sampling_buffer[field] = ( + old_per_token_logps if field == "old_per_token_logps" else ref_per_token_logps + ) + + # Replace vision fields if they exist + for field in vision_fields: + if field in sampled_data and field in forward_kwargs: + outputs_after_sampling_buffer[field] = forward_kwargs[field] + + outputs_after_sampling_buffer["num_items_in_batch"] = num_items_in_batch + if self.use_vllm and self.vllm_importance_sampling_correction: + outputs_after_sampling_buffer["importance_sampling_ratio"] = importance_sampling_ratio + + return outputs_after_sampling_buffer diff --git a/ICL/RL/trl_source/trl/experimental/gspo_token/__init__.py b/ICL/RL/trl_source/trl/experimental/gspo_token/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c9814fddca65cfe8d5ccabccc9d646141e19364 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/gspo_token/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .grpo_trainer import GRPOTrainer diff --git a/ICL/RL/trl_source/trl/experimental/gspo_token/grpo_trainer.py b/ICL/RL/trl_source/trl/experimental/gspo_token/grpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0522058bc6eeda33779b6182130478545d1ff9b6 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/gspo_token/grpo_trainer.py @@ -0,0 +1,157 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer +from ...trainer.utils import nanmax, nanmin + + +class GRPOTrainer(_GRPOTrainer): + def _compute_loss(self, model, inputs): + # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + # Compute the per_token_logps and the entropy at each position in the completion + per_token_logps, entropies = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=True, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), + pixel_attention_mask=inputs.get("pixel_attention_mask"), + image_sizes=inputs.get("image_sizes"), + token_type_ids=inputs.get("token_type_ids"), + ) + + if self.top_entropy_quantile < 1.0: + entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile) + else: + entropy_mask = None + + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) + + # Compute the loss + advantages = inputs["advantages"] + # When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps, + # old_per_token_logps == per_token_logps. In this case we can skip its computation + # (see _generate_and_score_completions) and instead use per_token_logps.detach(). + # The exception is when using vLLM, where we always compute old_per_token_logps + # for importance sampling + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + + log_ratio = per_token_logps - old_per_token_logps + if self.importance_sampling_level == "token": + log_importance_weights = log_ratio + elif self.importance_sampling_level == "sequence": + log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + log_importance_weights = log_importance_weights.unsqueeze(-1) + elif self.importance_sampling_level == "sequence_token": + # GSPO-token: sg[si(ฮธ)] * ฯ€ฮธ(yi,t)/sg[ฯ€ฮธ(yi,t)] + seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_level_log_weight = seq_level_log_weight.detach().unsqueeze(-1) # Stop gradient + log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight + else: + raise ValueError( + f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " + "and 'sequence'." + ) + # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on + # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) + + coef_1 = torch.exp(log_importance_weights) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + + # Two-sided clipping + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) + + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if entropy_mask is not None: + per_token_loss = per_token_loss * entropy_mask + + if self.use_vllm and self.vllm_importance_sampling_correction: + per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] + + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + mode = "train" if self.model.training else "eval" + if self.loss_type == "grpo": + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + loss = loss / normalizer + elif self.loss_type == "bnpo": + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + loss = loss / normalizer + elif self.loss_type == "dr_grpo": + loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + loss = loss / normalizer + elif self.loss_type == "dapo": + normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes + loss = (per_token_loss * completion_mask).sum() / normalizer + else: + raise ValueError(f"Unknown loss type: {self.loss_type}") + + # Log the metrics + completion_token_count = completion_mask.sum().clamp(min=1.0) + + def masked_batch_mean(x): + if x.shape[1] == 1: # when importance_sampling_level == "sequence" + return x.mean() + else: + return (x * completion_mask).sum() / completion_token_count + + if self.beta != 0.0: + mean_kl = masked_batch_mean(per_token_kl) + self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) + + mean_entropy = masked_batch_mean(entropies) + self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) + + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = masked_batch_mean(is_low_clipped.float()) + high_clip = masked_batch_mean(is_high_clipped.float()) + clip_ratio = masked_batch_mean(is_region_clipped.float()) + + gathered_low_clip = self.accelerator.gather(low_clip) + self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) + gathered_high_clip = self.accelerator.gather(high_clip) + self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) + gathered_clip_ratio = self.accelerator.gather(clip_ratio) + self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + return loss diff --git a/ICL/RL/trl_source/trl/experimental/judges/__init__.py b/ICL/RL/trl_source/trl/experimental/judges/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e99527ed868d90d73cd5b19b04d097c00607c7c3 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/judges/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .judges import ( + AllTrueJudge, + BaseBinaryJudge, + BaseJudge, + BasePairwiseJudge, + BaseRankJudge, + HfPairwiseJudge, + OpenAIPairwiseJudge, + PairRMJudge, +) + + +__all__ = [ + "AllTrueJudge", + "BaseBinaryJudge", + "BaseJudge", + "BasePairwiseJudge", + "BaseRankJudge", + "HfPairwiseJudge", + "OpenAIPairwiseJudge", + "PairRMJudge", +] diff --git a/ICL/RL/trl_source/trl/experimental/judges/judges.py b/ICL/RL/trl_source/trl/experimental/judges/judges.py new file mode 100644 index 0000000000000000000000000000000000000000..cbf5d14414b781454319fc8d8cac7fc3c7a51e3d --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/judges/judges.py @@ -0,0 +1,482 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import concurrent.futures +import logging +from abc import ABC, abstractmethod + +import numpy as np +from accelerate import Accelerator +from huggingface_hub import InferenceClient +from packaging.version import Version +from transformers.utils import is_openai_available + +from ...import_utils import is_llm_blender_available + + +DEFAULT_PAIRWISE_SYSTEM_PROMPT = '''I require a leaderboard for various large language models. I'll provide you with prompts given to these models and their corresponding outputs. Your task is to assess these responses, and select the model that produces the best output from a human perspective. + +## Instruction + +{{ + "instruction": """{prompt}""", +}} + +## Model Outputs + +Here are the unordered outputs from the models. Each output is associated with a specific model, identified by a unique model identifier. + +{{ + {{ + "model_identifier": "0", + "output": """{response0}""" + }}, + {{ + "model_identifier": "1", + "output": """{response1}""" + }} +}} + +## Task + +Evaluate the models on the basis of the quality and relevance of their results, and select the model that generated the best result. Reply with the identifier of the best model. Our evaluation will only take into account the first character of your answer, so make sure it contains only one of the identifiers and nothing else (no quotation marks, no spaces, no new lines, ...). +''' + + +def _ensure_llm_blender_importable() -> None: + """ + Pre-import shim to work around a known `llm-blender` issue. + + As of `llm-blender` v0.0.2 (see upstream issue: https://github.com/yuchenlin/LLM-Blender/issues/33), importing + `llm_blender` may fail on `transformers` >= 5.0.0 because it unconditionally accesses + `transformers.utils.hub.TRANSFORMERS_CACHE`. + + We set this attribute to a dummy value before importing `llm_blender` so that the import succeeds. This helper is + intentionally a no-op on older `transformers` versions. + + This shim can be removed once the upstream issue is fixed and the minimum required `llm-blender` version includes + that fix. + """ + import transformers.utils.hub + + if Version(transformers.__version__) >= Version("5.0.0"): + transformers.utils.hub.TRANSFORMERS_CACHE = None # unused; just needs to exist + + +class BaseJudge(ABC): + """ + Base class for judges. The subclasses of this class should implement the `judge` method. + """ + + @abstractmethod + def judge(self, prompts: list[str], completions: list[str], shuffle_order: bool = True) -> list: + raise NotImplementedError("Judge subclasses must implement the `judge` method.") + + +class BaseRankJudge(ABC): + """ + Base class for LLM ranking judges. + + **Example**: + ```python + class MyRankJudge(BaseRankJudge): + def judge(self, prompts, completions, shuffle_order=True): + return ... # Your ranking logic here + + + judge = MyRankJudge() + judge.judge( + prompts=["The capital of France is", "The capital of Germany is"], + completions=[[" Paris", " Marseille", "Lyon"], [" Munich", " Berlin"]], + ) # [[0, 1, 2], [1, 0]] + ``` + """ + + @abstractmethod + def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[list[int]]: + """ + Judge the completion for the given prompts and return the ranks of each completion. + + Args: + prompts (`list[str]`): + List of prompts. + completions (`list[list[str]]`): + List of completions list, where each element is a list of completions for the corresponding prompt. + shuffle_order (`bool`, *optional*, defaults to `True`): + Whether to shuffle the order of the completions to avoid positional bias. + + Returns: + `list[list[int]]`: + List of lists of idxs, where each list contains the ranks of the completions for the corresponding + prompt. E.g., `[1, 2, 0]` means that the second completion (`idx=1`) is the best, followed by the + third, and then the first. + """ + raise NotImplementedError("Judge subclasses must implement the `judge` method.") + + +class BasePairwiseJudge(BaseJudge): + """ + Base class for pairwise judges. + """ + + @abstractmethod + def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[int]: + """ + Judge the completion pairs for the given prompts. + + Args: + prompts (`list[str]`): + List of prompts. + completions (`list[list[str]]`): + List of completions pairs, where each element is a pair of completions for the corresponding prompt. + shuffle_order (`bool`, *optional*, defaults to `True`): + Whether to shuffle the order of the completions to avoid positional bias. + + Returns: + `list[int]`: + List of idxs, where each idx is the rank of the best completion for the corresponding prompt. E.g., `1` + means that the second completion (`idx=1`) is the best. + + Note: + If the judge returns `-1` for any prompt, it indicates that the inner process used to compute the + preference has failed. For instance, this could occur if the underlying language model returned an invalid + answer. In such cases, the caller should handle these invalid indices appropriately, possibly by + implementing fallback logic or error handling. + """ + raise NotImplementedError("Judge subclasses must implement the `judge` method.") + + +class BaseBinaryJudge(BaseJudge): + """ + Base class for binary judges. + """ + + @abstractmethod + def judge( + self, + prompts: list[str], + completions: list[str], + gold_completions: list[str] | None = None, + shuffle_order: bool = True, + ) -> list[int]: + """ + Judge the completion for a given prompt. Used to assess if a completion satisfies a constraint. + + This base class should be used to implement binary evaluations as done in section 4.1.4 of the [CGPO + paper](https://huggingface.co/papers/2409.20370). It is relevant for assessing whether a prompt-completion pair + satisfies a specific constraint. + + Args: + prompts (`list[str]`): List of prompts. + completions (`list[str]`): List of completions. + gold_completions (`list[str]`, `optional`): List of gold completions if it exists. + shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias. + + Returns: + list[int]: A list of binary labels: + - 1 indicates that the completion satisfies the evaluated constraint. + - 0 indicates that the completion does not satisfy the evaluated constraint. + + Note: + If the judge returns -1 for any prompt, it indicates that the inner process used to compute the preference + has failed. For instance, this could occur if the underlying language model or rule based constraint + returned an invalid answer. In such cases, the caller should handle these invalid indices appropriately, + possibly by implementing fallback logic or error handling. + """ + raise NotImplementedError("Judge subclasses must implement the `judge` method.") + + +class PairRMJudge(BasePairwiseJudge): + # docstyle-ignore + """ + LLM judge based on the PairRM model from AllenAI. + + This judge uses the PairRM model to rank pairs of completions for given prompts. It's designed for pairwise + comparison of language model outputs. The PairRM model is loaded using the llm-blender library and runs on the + default Accelerator device. + + **Attributes**: + + blender (`llm_blender.Blender`): + An instance of the Blender class from llm-blender. + + **Example**: + ```python + >>> pairrm_judge = PairRMJudge() + >>> prompts = ["Translate 'hello' to French", "What's the capital of Japan?"] + >>> completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]] + >>> results = pairrm_judge.judge(prompts, completions) + >>> print(results) # [0, 1] (indicating the first completion is preferred for the first prompt and the second) + ``` + + > [!TIP] + > This class requires the llm-blender library to be installed. Install it with: `pip install llm-blender`. + """ + + def __init__(self): + if not is_llm_blender_available(): + raise ValueError("llm-blender is not installed. Please install it with `pip install llm-blender`.") + import transformers + + if Version(transformers.__version__) >= Version("5.0.0"): + raise RuntimeError( + "llm-blender currently supports transformers < 5.0.0. Please install a compatible version: `pip install 'transformers<5.0.0'`. Check the issue tracker for updates: https://github.com/huggingface/trl/issues/4918" + ) + _ensure_llm_blender_importable() + import llm_blender + + self.blender = llm_blender.Blender() + self.blender.loadranker("llm-blender/PairRM", device=Accelerator().device) + + def judge( + self, + prompts: list[str], + completions: list[list[str]], + shuffle_order: bool = True, + return_scores: bool = False, + temperature: float = 1.0, + ) -> list[int | float]: + """ + Judge the completion pairs for the given prompts using the PairRM model. + + Args: + prompts (`list[str]`): + List of prompts to judge. + completions (`list[list[str]]`): + List of completion pairs for each prompt. + shuffle_order (`bool`, *optional*, defaults to `True`): + Whether to shuffle the order of the completions to avoid positional bias. + return_scores (`bool`, *optional*, defaults to `False`): + If `True`, return probability scores of the first completion instead of ranks (i.e. a *soft-judge*). + temperature (`float`, *optional*, defaults to `1.0`): + Temperature for scaling logits if `return_scores` is True. + + Returns: + `list[int | float]`: + If `return_scores` is `False`, returns a list of ranks (`0` or `1`) for each prompt, indicating which + completion is preferred. If `return_scores` is `True`, returns softmax probabilities for the first + completion. + + Raises: + `ValueError`: + If the number of completions per prompt is not exactly 2. + + Note: + Unlike llm-blender, ranks are 0-indexed (`0` means the first completion is preferred). + """ + + if len(completions[0]) != 2: + raise ValueError("PairRM judge requires exactly 2 completions per prompt.") + + # Shuffle the order of the completions to avoid positional bias + if shuffle_order: + flip_mask = np.random.choice([True, False], size=len(prompts)) + completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)] + + # Rank the completions + ranks = self.blender.rank(prompts, completions, return_scores=return_scores, disable_tqdm=True) + if not return_scores: + ranks -= 1 # PairRM rank is 1-indexed, so we subtract 1 to make it 0-indexed + else: + # scale the logits by temperature + ranks /= temperature + + # Flip back the ranks or scores to the original order if needed + if shuffle_order: + ranks[flip_mask] = ranks[flip_mask][:, ::-1] + + # Return the ranks or score probability + if return_scores: + logit_max = np.amax(ranks, axis=-1, keepdims=True) + exp_logit_shifted = np.exp(ranks - logit_max) + probs = exp_logit_shifted / np.sum(exp_logit_shifted, axis=-1, keepdims=True) + return probs[:, 0].tolist() + else: + return ranks[:, 0].tolist() + + +class HfPairwiseJudge(BasePairwiseJudge): + """ + Pairwise judge based on the Hugging Face API with chat completion. + + This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt. + + Args: + model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3-70B-Instruct"`): + Model to use for the judge. + token (`str`, *optional*): + Hugging Face API token to use for the [`huggingface_hub.InferenceClient`]. + system_prompt (`str`, *optional*): + The system prompt to be used for the judge. If not provided, a default prompt is used. Note that the system + prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the + inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token + response. + """ + + def __init__( + self, + model="meta-llama/Meta-Llama-3-70B-Instruct", + token: str | None = None, + system_prompt: str | None = None, + ): + self.client = InferenceClient(model=model, token=token) + self.system_prompt = system_prompt or DEFAULT_PAIRWISE_SYSTEM_PROMPT + + def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[int]: + # Shuffle the order of the completions to avoid positional bias + if shuffle_order: + flip_mask = np.random.choice([True, False], size=len(prompts)) + completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)] + + # Define a function to get the rank for a single prompt, will be called concurrently + def get_rank(prompt, candidates): + content = self.system_prompt.format(prompt=prompt, response0=candidates[0], response1=candidates[1]) + completion = self.client.chat_completion(messages=[{"role": "user", "content": content}], max_tokens=1) + response = completion.choices[0].message.content + if response in ["0", "1"]: + return int(response) + else: + logging.debug(f"Invalid response from the judge model: '{response}'. Returning -1.") + return -1 + + # Call the completions concurrently + with concurrent.futures.ThreadPoolExecutor() as executor: + ranks = list(executor.map(get_rank, prompts, completions)) + + # Flip back the ranks to the original order if needed + if shuffle_order: + ranks = [ranks[i] if not flip else 1 - ranks[i] for i, flip in enumerate(flip_mask)] + + # Return the ranks + return ranks + + +class OpenAIPairwiseJudge(BasePairwiseJudge): + """ + Judge based on the OpenAI API. + + This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt. + + Args: + model (`str`, *optional*, defaults to `"gpt-4-turbo-preview"`): + Model to use for the judge. + system_prompt (`str`, *optional*): + System prompt to be used for the judge. If not provided, a default prompt is used. Note that the system + prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the + inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token + response. + max_requests (`int` or `None`, *optional*, defaults to `1000`): + Maximum number of requests to make to the OpenAI API. If set to `None`, there is no limit. + """ + + def __init__( + self, model="gpt-4-turbo-preview", system_prompt: str | None = None, max_requests: int | None = 1_000 + ): + if not is_openai_available(): + raise ValueError("OpenAI client is not installed. Please install it with 'pip install openai'.") + from openai import OpenAI + + self.client = OpenAI() + self.model = model + self.system_prompt = system_prompt or DEFAULT_PAIRWISE_SYSTEM_PROMPT + self.max_requests = max_requests + self.num_requests = 0 + self._warned = False + + def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[int]: + # Check if the limit of requests is reached, if so, use random choice instead + if self.max_requests is not None and self.num_requests >= self.max_requests: + if not self._warned: # Print the warning only once + logging.warning( + f"Reached the maximum number of requests ({self.max_requests}). From now on, returning -1 instead. " + " To increase the limit, set `max_requests` to a higher value, or to `None` for no limit." + ) + self._warned = True + return [-1] * len(prompts) + + # Shuffle the order of the completions to avoid positional bias + if shuffle_order: + flip_mask = np.random.choice([True, False], size=len(prompts)) + completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)] + + # Define a function to get the rank for a single prompt, will be called concurrently + def get_rank(prompt, candidates): + content = self.system_prompt.format(prompt=prompt, response0=candidates[0], response1=candidates[1]) + messages = [{"role": "user", "content": content}] + completion = self.client.chat.completions.create(model=self.model, messages=messages, max_tokens=1) + response = completion.choices[0].message.content + if response in ["0", "1"]: + return int(response) + else: + logging.debug(f"Invalid response from the judge model: '{response}'. Returning -1.") + return -1 + + # Call the completions concurrently + with concurrent.futures.ThreadPoolExecutor() as executor: + ranks = list(executor.map(get_rank, prompts, completions)) + + # Flip back the ranks to the original order if needed + if shuffle_order: + ranks = [ranks[i] if not flip else 1 - ranks[i] for i, flip in enumerate(flip_mask)] + + # Update the number of requests + self.num_requests += len(prompts) + + # Return the ranks + return ranks + + +class AllTrueJudge(BaseBinaryJudge): + """ + Unify the decision of multiple [`experimental.judges.BaseBinaryJudge`] instances. + + Returns `1` only if all inner binary judges return `1`. If any judge returns `0`, it returns `0`. If any judge + returns `-1`, indicating a failure in its process, this judge will also return `-1`. + + Implements the Mixture of Judges as described in the [CGPO paper](https://huggingface.co/papers/2409.20370). + + Args: + judges (`list` of [`experimental.judges.BaseBinaryJudge`]): + A list of [`experimental.judges.BaseBinaryJudge`] instances whose decisions will be unified. + """ + + def __init__(self, judges: list[BaseBinaryJudge]): + self.judges = judges + + def judge( + self, + prompts: list[str], + completions: list[str], + gold_completions: list[str] | None = None, + shuffle_order: bool = True, + ) -> list[int]: + all_binary_judgments = [ + judge.judge(prompts, completions, gold_completions, shuffle_order) for judge in self.judges + ] + output = [] + for binary_judgments in zip(*all_binary_judgments, strict=True): + # Check that all values are in {0, 1, -1} + if any(binary_judgment not in {0, 1, -1} for binary_judgment in binary_judgments): + raise ValueError( + f"Invalid binary judgment: {binary_judgments}, expected list of values in {{0, 1, -1}}." + ) + + # Unify the decision + if -1 in binary_judgments: + output.append(-1) + elif all(binary_judgment == 1 for binary_judgment in binary_judgments): + output.append(1) + else: + output.append(0) + return output diff --git a/ICL/RL/trl_source/trl/experimental/kto/__init__.py b/ICL/RL/trl_source/trl/experimental/kto/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c21e2316bff7d621faa06622584271ee176d0e --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/kto/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .kto_config import KTOConfig +from .kto_trainer import KTOTrainer + + +__all__ = ["KTOConfig", "KTOTrainer"] diff --git a/ICL/RL/trl_source/trl/experimental/kto/kto_config.py b/ICL/RL/trl_source/trl/experimental/kto/kto_config.py new file mode 100644 index 0000000000000000000000000000000000000000..c69830669250008079ba142f45e6b8d29725468f --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/kto/kto_config.py @@ -0,0 +1,171 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from transformers import TrainingArguments + + +@dataclass +class KTOConfig(TrainingArguments): + r""" + Configuration class for the [`experimental.kto.KTOTrainer`]. + + This class includes only the parameters that are specific to KTO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want + to use the default data collator. + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher ฮฒ means less deviation from the + reference model. + loss_type (`str`, *optional*, defaults to `"kto"`): + Type of loss to use. Possible values are: + + - `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper. + - `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the + [APO](https://huggingface.co/papers/2408.06266) paper. + + desirable_weight (`float`, *optional*, defaults to `1.0`): + Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris. + undesirable_weight (`float`, *optional*, defaults to `1.0`): + Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs. + generate_during_eval (`bool`, *optional*, defaults to `False`): + If `True`, generates and logs completions from both the model and the reference model to W&B or Comet + during evaluation. + precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): + Whether to precompute reference model log probabilities for training and evaluation datasets. This is + useful when training without the reference model to reduce the total GPU memory needed. + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a + string. + dataset_num_proc: (`int`, *optional*): + Number of processes to use for processing the dataset. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model and reference model. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue + # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary + # workaround here, which can be removed once we drop support for versions older than 4.57.5. + lr_scheduler_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " + "restarts." + }, + ) + + max_length: int | None = field( + default=1024, + metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the deviation from the reference model. Higher ฮฒ means less deviation from " + "the reference model." + }, + ) + loss_type: str = field( + default="kto", + metadata={ + "help": "Type of loss to use.", + "choices": ["kto", "apo_zero_unpaired"], + }, + ) + desirable_weight: float = field( + default=1.0, + metadata={ + "help": "Desirable losses are weighed by this factor to counter unequal number of desirable and " + "undesirable pairs.", + }, + ) + undesirable_weight: float = field( + default=1.0, + metadata={ + "help": "Undesirable losses are weighed by this factor to counter unequal number of desirable and " + "undesirable pairs.", + }, + ) + generate_during_eval: bool = field( + default=False, + metadata={ + "help": "If `True`, generates and logs completions from both the model and the reference model to W&B " + "during evaluation." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) + precompute_ref_log_probs: bool = field( + default=False, + metadata={ + "help": "Whether to precompute reference model log probabilities for training and evaluation datasets. " + "This is useful when training without the reference model to reduce the total GPU memory needed." + }, + ) + model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " + "from a string." + }, + ) + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + super().__post_init__() diff --git a/ICL/RL/trl_source/trl/experimental/kto/kto_trainer.py b/ICL/RL/trl_source/trl/experimental/kto/kto_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..78fea3b38b44514fd9634671b74402389e0766a7 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/kto/kto_trainer.py @@ -0,0 +1,1511 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import random +import textwrap +from collections import defaultdict +from collections.abc import Callable +from contextlib import contextmanager, nullcontext +from operator import itemgetter +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformers +from accelerate import PartialState, logging +from accelerate.utils import tqdm +from datasets import Dataset, concatenate_datasets +from packaging.version import Version +from torch import autocast +from torch.utils.data import DataLoader, SequentialSampler +from transformers import ( + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, + TrainingArguments, + is_comet_available, + is_wandb_available, +) +from transformers.trainer_utils import EvalLoopOutput, has_length +from transformers.utils import is_peft_available + +from ...data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset +from ...import_utils import is_liger_kernel_available +from ...models.utils import create_reference_model, peft_module_casting_to_bf16, prepare_deepspeed +from ...trainer.base_trainer import BaseTrainer +from ...trainer.utils import ( + create_model_from_path, + disable_dropout_in_model, + log_table_to_comet_experiment, + pad_to_length, + selective_log_softmax, +) +from ..utils import DPODataCollatorWithPadding +from .kto_config import KTOConfig + + +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + +if is_wandb_available(): + import wandb + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + + +logger = logging.get_logger(__name__) + +RUNNING_NAME = "running.pt" + + +def _get_kl_dataset(batch: dict[str, list[Any]]) -> dict[str, list[Any]]: + """ + Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of + completions. For best results, the mismatched outputs y' used to estimate the KL term for a batch should be the + same set as the matched outputs y used to estimate the rewards in that batch, just paired with different x. + """ + batch["answer_input_ids"] = [batch["answer_input_ids"][-1]] + batch["answer_input_ids"][:-1] + batch["answer_attention_mask"] = [batch["answer_attention_mask"][-1]] + batch["answer_attention_mask"][:-1] + return batch + + +def _tokenize( + batch: dict[str, list[Any]], + tokenizer: "PreTrainedTokenizer", +) -> dict[str, list[Any]]: + """Tokenize a batch from a KTO specific dataset.""" + prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False) + prompt_input_ids = prompt_tokenized["input_ids"] + prompt_attention_mask = prompt_tokenized["attention_mask"] + prompt_and_completion = [ + prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"], strict=True) + ] + full_tokenized = tokenizer(prompt_and_completion, add_special_tokens=False) + full_input_ids = full_tokenized["input_ids"] + full_attention_mask = full_tokenized["attention_mask"] + + answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids, strict=True)] + answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask, strict=True)] + + # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` + full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids, strict=True)] + # Prepare input tokens for token by token comparison + full_input_ids = [np.array(f) for f in full_input_ids] + for full, concat in zip(full_input_ids, full_concat_input_ids, strict=True): + if len(full) != len(concat): + raise ValueError( + "The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length." + ) + + # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens + # can be merged together when tokenizing prompt+answer. This could result + # on the last token from the prompt being different when tokenized on its own + # vs when done as prompt+answer. + response_token_ids_start_idx = [len(p) for p in prompt_input_ids] + + # If tokenized prompt is different than both prompt+answer, then it means the + # last token has changed due to merging. + for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx, strict=True)): + if not np.array_equal(p, f[:r]): + response_token_ids_start_idx[idx] -= 1 + + prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)] + prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)] + + for p, m in zip(prompt_input_ids, prompt_attention_mask, strict=True): + if len(p) != len(m): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)] + answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)] + + output = dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + answer_input_ids=answer_input_ids, + answer_attention_mask=answer_attention_mask, + ) + + return output + + +def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **kwargs) -> dict: + """Process tokens of a KTO specific dataset. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + + completion responses is/are too long. We truncate from the end (completion) to fit within max_length. + + We also create the labels for the completion responses, which are of length equal to the sum of the length of the + prompt and the completion response, with `-100` for the prompt tokens. + """ + prompt = example["prompt"] + completion = example["completion"] + + batch = { + f"{kwargs['prefix']}prompt": prompt, + f"{kwargs['prefix']}completion": completion, + f"{kwargs['prefix']}label": example["label"], + } + + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + + if not isinstance(completion, str): + raise ValueError(f"completion should be an str but got {type(completion)}") + + # keys of format prompt_* refers to just the prompt and answer_* refers to just the answer + all_tokens = { + "prompt_input_ids": example["prompt_input_ids"], + "prompt_attention_mask": example["prompt_attention_mask"], + "answer_input_ids": example["answer_input_ids"], + "answer_attention_mask": example["answer_attention_mask"], + } + + # calculate max length by checking if BOS/EOS is already there + max_length = kwargs["max_length"] + bos_token_id = kwargs["tokenizer"].bos_token_id + eos_token_id = kwargs["tokenizer"].eos_token_id + if len(all_tokens["prompt_input_ids"]) > 0 and bos_token_id != all_tokens["prompt_input_ids"][0]: + max_length -= 1 + if len(all_tokens["answer_input_ids"]) > 0 and eos_token_id != all_tokens["answer_input_ids"][-1]: + max_length -= 1 + + # if combined sequence is too long, truncate the completion (answer) from the end + prompt_length = len(all_tokens["prompt_input_ids"]) + completion_length = len(all_tokens["answer_input_ids"]) + if prompt_length + completion_length > max_length: + max_completion_length = max_length - prompt_length + for k in ["answer_input_ids", "answer_attention_mask"]: + all_tokens[k] = all_tokens[k][:max_completion_length] + + # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens + batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"] + batch[f"{kwargs['prefix']}completion_input_ids"] = all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"] + batch[f"{kwargs['prefix']}completion_attention_mask"] = ( + all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"] + ) + + # add BOS, which affects both prompt and the full completion + if bos_token_id is not None: + if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: + batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}prompt_input_ids" + ] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"] + batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}completion_input_ids" + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] + # add EOS, which affects only the full completion + if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: + batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ + eos_token_id + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] + [1] + + batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:] + batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [-100] * len( + batch[f"{kwargs['prefix']}prompt_input_ids"] + ) + + return batch + + +class KTOTrainer(BaseTrainer): + r""" + Initialize KTOTrainer. + + Args: + model ([`~transformers.PreTrainedModel`]): + The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`]. + ref_model ([`~transformers.PreTrainedModel`]): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation + and loss. If no reference model is provided, the trainer will create a reference model with the same + architecture as the model to be optimized. + args ([`experimental.kto.KTOConfig`]): + The arguments to use for training. + train_dataset ([`~datasets.Dataset`]): + The dataset to use for training. + eval_dataset ([`~datasets.Dataset`]): + The dataset to use for evaluation. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + data_collator ([`~transformers.DataCollator`], *optional*): + The data collator to use for training. If None is specified, the default data collator + ([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the + maximum length of the sequences in the batch, given a dataset of paired sequences. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be + used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in + a PEFT model. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to + metric values. + model_adapter_name (`str`, defaults to `None`): + Name of the train target PEFT adapter, when using LoRA with multiple adapters. + ref_adapter_name (`str`, defaults to `None`): + Name of the reference PEFT adapter, when using LoRA with multiple adapters. + """ + + _tag_names = ["trl", "kto"] + _name = "KTO" + _paper = { + "title": "KTO: Model Alignment as Prospect Theoretic Optimization", + "id": "2402.01306", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @article{ethayarajh2024kto, + title = {{KTO: Model Alignment as Prospect Theoretic Optimization}}, + author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela}, + year = 2024, + eprint = {arXiv:2402.01306}, + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module | str = None, + ref_model: PreTrainedModel | nn.Module | str | None = None, + args: KTOConfig = None, + train_dataset: Dataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + data_collator: DataCollator | None = None, + model_init: Callable[[], PreTrainedModel] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: dict | None = None, + compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, + model_adapter_name: str | None = None, + ref_adapter_name: str | None = None, + ): + if type(args) is TrainingArguments: + raise ValueError("Please use `KTOConfig` instead TrainingArguments.") + + if not isinstance(model, str) and ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, you must mass a copy of it, or `None` if you use peft." + ) + + # Model initialization + if isinstance(model, str): + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) + else: + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the KTOConfig, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + + # Reference model initialization + if isinstance(ref_model, str): + ref_model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + ref_model_init_kwargs["device_map"] = None + ref_model = create_model_from_path(ref_model, **ref_model_init_kwargs) + + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + if isinstance(model, PeftModel): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " + "merge and unload the existing adapter, save the resulting base model, and then pass that base " + "model along with the new `peft_config` to the trainer." + ) + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True + + # For models that use gradient_checkpointing, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve." + ) + + # KTO only supports causal language models, not encoder-decoder models + if model is not None and hasattr(model.config, "is_encoder_decoder") and model.config.is_encoder_decoder: + raise ValueError( + "KTO only supports causal language models. Encoder-decoder models are not supported. " + "Please use a causal LM (e.g., GPT, Llama, Mistral) instead of an encoder-decoder model (e.g., T5, BART)." + ) + + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + self.model_adapter_name = model_adapter_name + self.ref_adapter_name = ref_adapter_name + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model or args.precompute_ref_log_probs: + # The `model` with adapters turned off will be used as the reference model + self.ref_model = None + else: + self.ref_model = create_reference_model(model) + + if processing_class is None: + raise ValueError( + "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding" + ) + if args.max_length is None: + logger.warning( + "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init" + " it will be set to `512` by default, but you should do it yourself in the future.", + ) + max_length = 512 + if args.max_length is not None: + max_length = args.max_length + + if data_collator is None: + data_collator = DPODataCollatorWithPadding( + pad_token_id=processing_class.pad_token_id, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + logger.warning( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig" + " we have set it for you, but you should do it yourself in the future.", + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + # Disable dropout in the model and reference model + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.loss_type = args.loss_type + self.max_length = max_length + self.generate_during_eval = args.generate_during_eval + self.processing_class = processing_class + self.precompute_ref_log_probs = args.precompute_ref_log_probs + + # Not all losses require a KL calculation + self.calculate_KL = True + if self.loss_type in ["apo_zero_unpaired"]: + self.calculate_KL = False + + # Since ref_logs are precomputed on the first call to get_train/eval_dataloader + # keep track of first called to avoid computation of future calls + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + + # metric + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # KTO parameter + self.beta = args.beta + self.desirable_weight = args.desirable_weight + self.undesirable_weight = args.undesirable_weight + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) + if self.aux_loss_enabled and self.aux_loss_coef == 0.0: + logger.warning( + "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " + "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " + "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " + "loss.", + ) + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().main_process_first(): + # Extract the prompt if needed + train_dataset = train_dataset.map( + maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset" + ) + # Unpair the dataset if needed + train_dataset = maybe_unpair_preference_dataset( + train_dataset, args.dataset_num_proc, desc="Unpairing train dataset" + ) + # Apply the chat template if needed + train_dataset = train_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + desc="Applying chat template to train dataset", + ) + if eval_dataset is not None: + eval_dataset = eval_dataset.map( + maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset" + ) + eval_dataset = maybe_unpair_preference_dataset( + eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset" + ) + eval_dataset = eval_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + desc="Applying chat template to eval dataset", + ) + + # Tokenize and prepare the training datasets + train_dataset = train_dataset.map( + _tokenize, + batched=True, + fn_kwargs={"tokenizer": self.processing_class}, + num_proc=args.dataset_num_proc, + desc="Tokenizing train dataset", + ) + + fn_kwargs = { + "prefix": "", + "tokenizer": self.processing_class, + "max_length": self.max_length, + } + + train_dataset = train_dataset.map( + _process_tokens, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + desc="Processing tokenized train dataset", + ) + + # Tokenize and prepare the eval datasets + if eval_dataset is not None: + eval_dataset = eval_dataset.map( + _tokenize, + fn_kwargs={"tokenizer": self.processing_class}, + batched=True, + num_proc=args.dataset_num_proc, + desc="Tokenizing eval dataset", + ) + + eval_dataset = eval_dataset.map( + _process_tokens, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + desc="Processing tokenized eval dataset", + ) + + # Get KL datasets if needed + if self.calculate_KL: + if args.per_device_train_batch_size <= 1: + raise ValueError( + "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward." + ) + + # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size + # i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n) + train_kl_dataset = train_dataset.map( + _get_kl_dataset, + batched=True, + batch_size=args.per_device_train_batch_size, + num_proc=args.dataset_num_proc, + desc="Extracting KL train dataset", + ) + + fn_kwargs["prefix"] = "KL_" + train_kl_dataset = train_kl_dataset.map( + _process_tokens, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names], + desc="Processing tokenized train KL dataset", + ) + + # merge the datasets + train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1) + + if eval_dataset is not None: + # Get KL dataset + eval_kl_dataset = eval_dataset.map( + _get_kl_dataset, + batched=True, + batch_size=args.per_device_train_batch_size, + num_proc=args.dataset_num_proc, + desc="Extracting eval KL dataset", + ) + + eval_kl_dataset = eval_kl_dataset.map( + _process_tokens, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names], + desc="Processing tokenized eval KL dataset", + ) + + # merge the datasets + eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1) + + # calculate dataset desirability balance + num_desirable = max(sum(train_dataset["label"]), 1) + num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary + + if num_desirable != num_undesirable: + # The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306 + des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2) + des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2) + und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2) + und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2) + + des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound + und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound + + if not (des_weight_in_range or und_weight_in_range): + logger.warning( + "You have different amounts of desirable/positive and undesirable/negative examples but the " + "weights on the desirable and undesirable losses don't seem to be in an ideal range. Based " + f"on your data, we recommend EITHER " + f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or " + f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). " + "See the documentation on how to optimally set these weights.", + ) + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + # Deepspeed Zero-3 does not support precompute_ref_log_probs + if self.is_deepspeed_enabled: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." + ) + + if self.ref_model is None: + if not (self.is_peft_model or self.precompute_ref_log_probs): + raise ValueError( + "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" + ) + else: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + # Import Liger kernel if enabled + if self.args.use_liger_kernel: + if not is_liger_kernel_available(): + raise ImportError( + "You set `use_liger_kernel=True` but the liger kernel is not available. " + "Please install liger-kernel first: `pip install liger-kernel`" + ) + if self.loss_type in ["apo_zero_unpaired"]: + raise ValueError( + "You cannot set `loss_type='apo_zero_unpaired'` with liger-kernel." + "Only KTO loss is supported with liger-kernel." + ) + if self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with liger kernel. Please set " + "`precompute_ref_log_probs=False`." + ) + if self.is_peft_model or self.ref_adapter_name is not None: + raise ValueError( + "You cannot use `use_liger_kernel=True` with Peft models. Please set `use_liger_kernel=False`." + ) + self.kto_loss_fn = LigerFusedLinearKTOLoss(beta=self.beta, use_ref_model=(self.ref_model is not None)) + + @contextmanager + def null_ref_context(self): + """Context manager for handling null reference model (that is, peft adapter manipulation).""" + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): + if self.ref_adapter_name: + self.model.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.set_adapter(self.model_adapter_name or "default") + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. + """ + + if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_train_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) + reference_completion_logps = [] + reference_KL_logps = [] + + for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): + reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch) + + reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) + reference_completion_logps.append(reference_completion_logp.cpu()) + + if self.calculate_KL: + reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp) + reference_KL_logps.append(reference_KL_logp.cpu()) + + self.train_dataset = self.train_dataset.add_column( + name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() + ) + + if self.calculate_KL: + self.train_dataset = self.train_dataset.add_column( + name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy() + ) + + self._precomputed_train_ref_log_probs = True + + return super().get_train_dataloader() + + def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_eval_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) + + reference_completion_logps = [] + reference_KL_logps = [] + + for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): + reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch) + + reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) + reference_completion_logps.append(reference_completion_logp.cpu()) + + if self.calculate_KL: + reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp) + reference_KL_logps.append(reference_KL_logp.cpu()) + + eval_dataset = eval_dataset.add_column( + name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() + ) + if self.calculate_KL: + eval_dataset = eval_dataset.add_column( + name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy() + ) + + # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs + if self.eval_dataset is not None: + self.eval_dataset = eval_dataset + self._precomputed_eval_ref_log_probs = True + + return super().get_eval_dataloader(eval_dataset=eval_dataset) + + def compute_reference_log_probs(self, padded_batch: dict) -> dict: + """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset.""" + with torch.no_grad(): + if self.ref_model is None: + with self.null_ref_context(): + completion_logits = self.model( + padded_batch["completion_input_ids"], + attention_mask=padded_batch["completion_attention_mask"], + ).logits + + if self.calculate_KL: + KL_logits = self.model( + padded_batch["KL_completion_input_ids"], + attention_mask=padded_batch["KL_completion_attention_mask"], + ).logits + else: + completion_logits = self.ref_model( + padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"] + ).logits + + if self.calculate_KL: + KL_logits = self.ref_model( + padded_batch["KL_completion_input_ids"], + attention_mask=padded_batch["KL_completion_attention_mask"], + ).logits + + completion_logps = self.get_batch_logps( + completion_logits, + padded_batch["completion_labels"], + average_log_prob=False, + ) + + if self.calculate_KL: + KL_logps = self.get_batch_logps( + KL_logits, + padded_batch["KL_completion_labels"], + average_log_prob=False, + ) + else: + KL_logps = None + + return completion_logps, KL_logps + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: + Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: + Labels for which to compute the log probabilities. Label tokens with a value of `-100` are ignored. + Shape: (batch_size, sequence_length) + average_log_prob: + If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the + log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the + given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + # For causal LM, shift labels and logits by one position + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + + loss_mask = labels != -100 + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == -100] = 0 + + per_token_logps = selective_log_softmax(logits, labels) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def forward( + self, model: nn.Module, batch: dict[str, list | torch.LongTensor] + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + KL_logps = self._compute_kl_logps(model, batch) + + model_kwargs = {} + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + outputs = model( + batch["completion_input_ids"], + attention_mask=batch["completion_attention_mask"], + **model_kwargs, + ) + completion_logits = outputs.logits + + completion_logps = self.get_batch_logps( + completion_logits, + batch["completion_labels"], + average_log_prob=False, + ) + + if completion_logps.shape[0] != len(batch["label"]): + raise ValueError( + "There is a mismatch between the number of examples in this batch and the number of " + "examples for which an output sequence was predicted." + ) + + # Use torch.nonzero for efficient tensor index selection + device = completion_logits.device + labels = torch.as_tensor(batch["label"], dtype=torch.bool, device=device) + chosen_idx = torch.nonzero(labels, as_tuple=False).view(-1) + rejected_idx = torch.nonzero(~labels, as_tuple=False).view(-1) + + # Use index_select for efficient CUDA operations + chosen_logps = completion_logps.index_select(0, chosen_idx) + rejected_logps = completion_logps.index_select(0, rejected_idx) + + chosen_logits = completion_logits.index_select(0, chosen_idx) + rejected_logits = completion_logits.index_select(0, rejected_idx) + + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss) + else: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps) + + def kto_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + policy_KL_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + reference_KL_logps: torch.FloatTensor, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the KTO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: + Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,) + policy_rejected_logps: + Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,) + policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,) + reference_chosen_logps: + Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,) + reference_rejected_logps: + Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in + batch_size,) + reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,) + + Returns: + A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL). The losses tensor contains the KTO + loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for + the chosen and rejected responses, respectively. The KL tensor contains the detached KL divergence estimate + between the policy and reference models. + """ + if self.calculate_KL: + kl = (policy_KL_logps - reference_KL_logps).mean().detach() + kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0) + else: + kl = torch.zeros(1).to(policy_chosen_logps.device) + + # Chosen losses + if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0: + chosen_logratios = policy_chosen_logps - reference_chosen_logps + + if self.loss_type == "kto": + # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306) + chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) + elif self.loss_type == "apo_zero_unpaired": + # Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) + # Use this loss when you believe the chosen outputs are better than your model's default output + chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios) + + chosen_rewards = self.beta * chosen_logratios.detach() + + else: + # lists can't be empty -- if they are, then accelerate.gather will hang + chosen_losses = torch.Tensor([]).to(self.accelerator.device) + chosen_rewards = torch.Tensor([]).to(self.accelerator.device) + + # Rejected losses + if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0: + rejected_logratios = policy_rejected_logps - reference_rejected_logps + + if self.loss_type == "kto": + rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) + elif self.loss_type == "apo_zero_unpaired": + rejected_losses = F.sigmoid(self.beta * rejected_logratios) + + rejected_rewards = self.beta * rejected_logratios.detach() + else: + # lists can't be empty -- if they are, then accelerate.gather will hang + rejected_losses = torch.Tensor([]).to(self.accelerator.device) + rejected_rewards = torch.Tensor([]).to(self.accelerator.device) + + losses = torch.cat( + (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), + 0, + ) + + return losses, chosen_rewards, rejected_rewards, kl + + def _compute_kl_logps(self, model, batch): + """Compute KL log probabilities for a given batch.""" + KL_logps = None + if self.calculate_KL: + KL_model_kwargs = { + "input_ids": batch["KL_completion_input_ids"], + "attention_mask": batch["KL_completion_attention_mask"], + } + + with torch.no_grad(): + KL_logits = model(**KL_model_kwargs).logits + + KL_logps = self.get_batch_logps( + KL_logits, + batch["KL_completion_labels"], + average_log_prob=False, + ) + return KL_logps + + def _compute_loss_liger(self, model, batch): + """ + Compute the KTO loss using the Liger-Kernel's LigerFusedLinearKTOLoss. + + Args: + model: + The policy model used for generating log probabilities and outputs. It could be an encoder-decoder + model or a regular language model. + batch: A dictionary containing the input data and labels for the batch. + + Returns: + A dictionary containing the following keys: + - "loss": The computed KTO loss for the batch. + - "chosen_logits_sum": Sum of the logits for the chosen responses from the policy model. + - "rejected_logits_sum": Sum of the logits for the rejected responses from the policy model. + - "chosen_logps": Log probabilities of the chosen responses from the policy model. + - "rejected_logps": Log probabilities of the rejected responses from the policy model. + - "chosen_rewards": Rewards for the chosen responses. + - "rejected_rewards": Rewards for the rejected responses. + - "kl": The KL divergence between the policy and reference models (detached). + + If auxiliary loss is enabled, the dictionary will also include: + - "aux_loss": The auxiliary loss from the model outputs. + """ + policy_KL_logps = self._compute_kl_logps(model, batch) + reference_KL_logps = self._compute_kl_logps(self.ref_model, batch) + if self.calculate_KL: + kl = (policy_KL_logps - reference_KL_logps).mean().detach() + kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0) + else: + kl = torch.zeros(1).to(self.accelerator.device) + + model_kwargs = {} + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + # skip the lm head and get the last hidden state + base_model = model.get_decoder() + outputs = base_model( + batch["completion_input_ids"], + attention_mask=batch["completion_attention_mask"], + use_cache=False, + **model_kwargs, + ) + + # reference model + ref_base_model = self.ref_model.get_decoder() + ref_outputs = ref_base_model( + batch["completion_input_ids"], + attention_mask=batch["completion_attention_mask"], + use_cache=False, + **model_kwargs, + ) + lm_head = model.get_output_embeddings() + ref_lm_head = self.ref_model.get_output_embeddings() + + ( + loss, + ( + chosen_logps_sum, + rejected_logps_sum, + chosen_logits_sum, + rejected_logits_sum, + chosen_rewards_sum, + rejected_rewards_sum, + ), + ) = self.kto_loss_fn( + _input=outputs.last_hidden_state[:, :-1], + lin_weight=lm_head.weight, + target=batch["completion_labels"][:, 1:], + bias=lm_head.bias if hasattr(lm_head, "bias") else None, + preference_labels=torch.tensor(batch["label"], dtype=torch.bool).to(self.accelerator.device), + ref_input=ref_outputs.last_hidden_state[:, :-1], + ref_weight=ref_lm_head.weight, + ref_bias=ref_lm_head.bias if hasattr(lm_head, "bias") else None, + kl=kl, + ) + + output = { + "loss": loss, + "chosen_logits_sum": chosen_logits_sum, + "rejected_logits_sum": rejected_logits_sum, + "chosen_logps_sum": chosen_logps_sum, + "rejected_logps_sum": rejected_logps_sum, + "chosen_rewards_sum": chosen_rewards_sum, + "rejected_rewards_sum": rejected_rewards_sum, + "kl": kl, + } + if self.aux_loss_enabled: + output["aux_loss"] = outputs.aux_loss + + return output + + def get_batch_loss_metrics( + self, + model, + batch: dict[str, list | torch.LongTensor], + ): + """Compute the KTO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + + labels = torch.tensor(batch["label"]) + num_chosen = labels.sum().to(self.accelerator.device) + num_rejected = (len(labels) - num_chosen).to(self.accelerator.device) + + if self.args.use_liger_kernel: + model_output = self._compute_loss_liger(model, batch) + losses = model_output["loss"] + policy_chosen_logits = model_output["chosen_logits_sum"] + policy_rejected_logits = model_output["rejected_logits_sum"] + policy_chosen_logps = model_output["chosen_logps_sum"] + policy_rejected_logps = model_output["rejected_logps_sum"] + chosen_rewards = model_output["chosen_rewards_sum"] + rejected_rewards = model_output["rejected_rewards_sum"] + kl = model_output["kl"] + if self.aux_loss_enabled: + aux_loss = model_output["aux_loss"] + else: + forward_output = self.forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_KL_logps, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + # if reference_logps in batch use them, otherwise use the reference model + if "reference_logps" in batch: + # Convert Python lists to tensor indices for efficient CUDA operations + device = batch["reference_logps"].device + labels = torch.as_tensor(batch["label"], dtype=torch.bool, device=device) + chosen_idx = torch.nonzero(labels, as_tuple=False).view(-1) + rejected_idx = torch.nonzero(~labels, as_tuple=False).view(-1) + + # Use index_select for efficient CUDA operations + reference_chosen_logps = batch["reference_logps"].index_select(0, chosen_idx) + reference_rejected_logps = batch["reference_logps"].index_select(0, rejected_idx) + if self.calculate_KL: + reference_KL_logps = batch["reference_KL_logps"] + else: + reference_KL_logps = None + else: + with torch.no_grad(): + if self.ref_model is None: + with self.null_ref_context(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + reference_KL_logps, + ) = self.forward(self.model, batch)[:5] + else: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + reference_KL_logps, + ) = self.forward(self.ref_model, batch)[:5] + + losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( + policy_chosen_logps, + policy_rejected_logps, + policy_KL_logps, + reference_chosen_logps, + reference_rejected_logps, + reference_KL_logps, + ) + + metrics["kl"] = kl.item() + + all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item() + all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item() + + if all_num_chosen > 0: + metrics["rewards/chosen_sum"] = ( + self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item() + ) + metrics["logps/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item() + ) + metrics["logits/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item() + ) + metrics["count/chosen"] = all_num_chosen + + if all_num_rejected > 0: + metrics["rewards/rejected_sum"] = ( + self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item() + ) + metrics["logps/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item() + ) + metrics["logits/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item() + ) + metrics["count/rejected"] = all_num_rejected + + loss = losses.nanmean() + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss + + return loss, metrics + + def compute_loss( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + return_outputs=False, + num_items_in_batch=None, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: + compute_loss_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with compute_loss_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs) + + # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: + loss = loss.to(self.args.device) + # force log the metrics + if self.accelerator.is_main_process: + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def _get_train_sampler(self, dataset: Dataset | None = None) -> torch.utils.data.Sampler | None: + if dataset is None: + dataset = self.train_dataset + if dataset is None or not has_length(dataset): + return None + return SequentialSampler(dataset) + + def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with generate_context_manager: + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + # if reference_output in batch use that otherwise use the reference model + if "reference_output" in batch: + reference_output = batch["reference_output"] + else: + if self.ref_model is None: + with self.null_ref_context(): + reference_output = self.model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + else: + reference_output = self.ref_model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id) + reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True) + + return policy_output_decoded, reference_output_decoded + + def prediction_step( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + prediction_loss_only: bool, + ignore_keys: list[str] | None = None, + ): + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + with torch.no_grad(), prediction_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs) + + # force log the metrics + if self.accelerator.is_main_process: + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = {} + if "logits/chosen_sum" in metrics: + logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"] + if "logits/rejected_sum" in metrics: + logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"] + logits = [v for k, v in logits_dict.items() if k not in ignore_keys] + logits = torch.tensor(logits, device=self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: bool | None = None, + ignore_keys: list[str] | None = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by + `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + target_labels = torch.tensor(random_batch["label"], dtype=torch.bool, device=self.accelerator.device) + target_indices = torch.where(~target_labels)[0] + target_batch = { + "prompt_input_ids": random_batch["prompt_input_ids"][target_indices], + "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indices], + "prompt": itemgetter(*target_indices)(random_batch["prompt"]), + } + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) + + table = pd.DataFrame( + columns=["Prompt", "Policy", "Ref Model"], + data=[ + [prompt, pol[len(prompt) :], ref[len(prompt) :]] + for prompt, pol, ref in zip( + target_batch["prompt"], policy_output_decoded, ref_output_decoded, strict=True + ) + ], + ) + if "wandb" in self.args.report_to: + wandb.log({"game_log": wandb.Table(data=table)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="game_log.csv", + table=table, + ) + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`dict[str, float]`): + The values to log. + start_time (`float`, *optional*): + Start time of the training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # train metrics should have no prefix, eval should have 'eval_' + prefix = "eval_" if train_eval == "eval" else "" + # accumulate average metrics from sums and lengths + for split in ["chosen", "rejected"]: + if f"count/{split}" in self._stored_metrics[train_eval]: + count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item() + for metric in ["rewards", "logps", "logits"]: + logs[f"{prefix}{metric}/{split}"] = ( + torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item() + / count_sum + ) + # delete obsolete metric + del self._stored_metrics[train_eval][f"{metric}/{split}_sum"] + del self._stored_metrics[train_eval][f"count/{split}"] + # calculate reward margin + if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: + logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/ICL/RL/trl_source/trl/experimental/merge_model_callback.py b/ICL/RL/trl_source/trl/experimental/merge_model_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..2f8478759204105aa02eea715b884c755bfc9757 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/merge_model_callback.py @@ -0,0 +1,352 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import torch +from huggingface_hub import HfApi +from transformers import TrainerCallback + +from ..import_utils import is_mergekit_available +from ..trainer.utils import get_config_model_id + + +if is_mergekit_available(): + from mergekit.config import MergeConfiguration + from mergekit.merge import MergeOptions, run_merge + + +# Logger for module-level logging +logger = logging.getLogger(__name__) + + +def upload_model_to_hf(folder_path: str, repo_id: str): + api = HfApi() + # Create the repository if it doesn't exist + repo = api.create_repo(repo_id, repo_type="model") + + # Upload the folder to the specified repository + api.upload_folder( + folder_path=folder_path, + repo_id=repo.repo_id, + repo_type=repo.repo_type, + ) + + +class MergeConfig: + r""" + Configuration class for merging two models using `mergekit`. + + This class provides a structured way to configure and generate merge configurations for various merge methods, such + as `linear`, `ties`, `dare_ties`, and `slerp`. + + Args: + method (`str`, *optional*, defaults to `"linear"`): + Merge method to use. Supported methods include: + + - `"linear"`: Linearly combines two models with specified weights. + - `"ties"`: Combines two models using the TIES method with density parameters. + - `"dare_ties"`: A variant of TIES for domain adaptation. + - `"slerp"`: Combines models using spherical linear interpolation. + + Note: + + For more details about the merge methods and how they are implemented, see the [MergeKit GitHub + repository](https://github.com/arcee-ai/mergekit?tab=readme-ov-file#merge-methods). + + Attributes: + method (`str`): The merge method to use. + policy_model_path (`str` or `None`): Path to the policy model. + target_model_path (`str` or `None`): Path to the target model. + policy_model_weight (`float`): Weight for the policy model (for `linear` and `ties` methods). + target_model_weight (`float`): Weight for the target model (for `linear` and `ties` methods). + policy_model_density (`list[float]`): Density parameters for the policy model (for `ties` and `dare_ties`). + target_model_density (`list[float]`): Density parameters for the target model (for `ties` and `dare_ties`). + normalize (`float` or `None`): Normalization factor for the TIES method. + t_values (`float` or `None`): Interpolation factor for the SLERP method. + dtype (`str`): Data type to use for merging, e.g., `"float16"`. + """ + + def __init__(self, method: str = "linear"): + if not is_mergekit_available(): + raise ImportError("MergeConfig requires the `mergekit` extra. To install, run `pip install mergekit`.") + self.method = method + self.policy_model_path = None + self.target_model_path = None + + # Initialize relevant parameters based on the method + if method == "linear": + self.policy_model_weight = 0.5 + self.target_model_weight = 0.5 + self.dtype = "float16" + elif method == "ties": + self.policy_model_weight = 1.0 + self.policy_model_density = [1.0, 0.7, 0.1] + self.target_model_weight = 1.0 + self.target_model_density = [1.0] + self.normalize = 1.0 + self.dtype = "float16" + elif method == "dare_ties": + self.policy_model_weight = 1.0 + self.policy_model_density = [1.0, 0.7, 0.1] + self.target_model_weight = 1.0 + self.target_model_density = [1.0] + self.normalize = 1.0 + self.dtype = "float16" + elif method == "slerp": + self.t_values = 0.5 + self.dtype = "float16" + else: + raise ValueError(f"Unsupported merge method: {method}") + + def create_merge_config_linear(self) -> "MergeConfiguration": + """ + Creates a merge configuration for a linear merge of two models with specified weights. + """ + # Create the merge configuration dictionary + merge_config_dict = { + "dtype": self.dtype, + "merge_method": "linear", + "models": [ + {"model": self.policy_model_path, "parameters": {"weight": self.policy_model_weight}}, + {"model": self.target_model_path, "parameters": {"weight": self.target_model_weight}}, + ], + } + + # Create the MergeConfiguration from the dictionary + merge_config = MergeConfiguration.model_validate(merge_config_dict) + + return merge_config + + def create_merge_config_ties(self) -> "MergeConfiguration": + """ + Creates a merge configuration for a TIES merge of two models, with specified weights and densities. + """ + # Create the TIES merge configuration dictionary + merge_config_dict = { + "merge_method": "ties", + "slices": None, # Optional slices if needed + "models": [ + { + "model": { + "model": {"path": self.target_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "parameters": {"density": self.target_model_density, "weight": self.target_model_weight}, + }, + { + "model": { + "model": {"path": self.policy_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "parameters": {"density": self.policy_model_density, "weight": self.policy_model_weight}, + }, + ], + "parameters": {"normalize": self.normalize}, + "base_model": { + "model": {"path": self.policy_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "dtype": self.dtype, + "tokenizer_source": None, + "tokenizer": None, + "chat_template": None, + "out_dtype": None, + } + + # Create the MergeConfiguration from the dictionary + merge_config = MergeConfiguration.model_validate(merge_config_dict) + + return merge_config + + def create_merge_config_dare_ties(self) -> "MergeConfiguration": + """ + Creates a merge configuration for a DARE TIES merge of two models, with specified weights and densities. + """ + # Create the DARE TIES merge configuration dictionary + merge_config_dict = { + "merge_method": "dare_ties", + "slices": None, # Optional slices if needed + "models": [ + { + "model": { + "model": {"path": self.target_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "parameters": {"density": self.target_model_density, "weight": self.target_model_weight}, + }, + { + "model": { + "model": {"path": self.policy_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "parameters": {"density": self.policy_model_density, "weight": self.policy_model_weight}, + }, + ], + "parameters": {"normalize": self.normalize}, + "base_model": { + "model": {"path": self.policy_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "dtype": self.dtype, + "tokenizer_source": None, + "tokenizer": None, + "chat_template": None, + "out_dtype": None, + } + + # Create the MergeConfiguration from the dictionary + merge_config = MergeConfiguration.model_validate(merge_config_dict) + + return merge_config + + def create_merge_config_slerp(self) -> "MergeConfiguration": + """ + Creates a merge configuration for a SLERP merge of a model with a base model. + """ + + # Create the SLERP merge configuration dictionary + merge_config_dict = { + "merge_method": "slerp", + "slices": None, # Optional slices if needed + "models": [ + { + "model": { + "model": {"path": self.target_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "parameters": None, # No specific parameters for SLERP model + } + ], + "parameters": { + "t": self.t_values # Set the t values for SLERP + }, + "base_model": { + "model": {"path": self.policy_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "dtype": self.dtype, + "tokenizer_source": None, + "tokenizer": None, + "chat_template": None, + "out_dtype": None, + } + + # Create the MergeConfiguration from the dictionary + merge_config = MergeConfiguration.model_validate(merge_config_dict) + + return merge_config + + def create(self) -> "MergeConfiguration": + if self.method == "linear": + return self.create_merge_config_linear() + elif self.method == "ties": + return self.create_merge_config_ties() + elif self.method == "dare_ties": + return self.create_merge_config_dare_ties() + elif self.method == "slerp": + return self.create_merge_config_slerp() + + +def merge_models(config: "MergeConfiguration", out_path: str): + """ + Merge two models using mergekit + + Args: + config (`MergeConfiguration`): The merge configuration. + out_path (`str`): The output path for the merged model. + """ + if not is_mergekit_available(): + raise ImportError("merge_models requires the `mergekit` extra. To install, run `pip install mergekit`.") + run_merge( + config, + out_path=out_path, + options=MergeOptions( + device="auto", + cuda=torch.cuda.is_available(), + copy_tokenizer=True, + lazy_unpickle=False, + low_cpu_memory=False, + ), + ) + + +class MergeModelCallback(TrainerCallback): + r""" + A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based + on a merge configuration. + + Args: + merge_config ([`experimental.merge_model_callback.MergeConfig`], *optional*): + Configuration used for the merging process. If not provided, the default + [`~experimental.merge_model_callback.MergeConfig`] is used. + merge_at_every_checkpoint (`bool`, *optional*, defaults to `False`): + Whether to merge the model at every checkpoint. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the merged model to the Hub after merging. + + Example: + + ```python + from trl.experimental.merge_model_callback import MergeConfig, MergeModelCallback + + config = MergeConfig() + merge_callback = MergeModelCallback(config) + trainer = DPOTrainer(..., callbacks=[merge_callback]) + ``` + """ + + def __init__( + self, + merge_config: "MergeConfig | None" = None, + merge_at_every_checkpoint: bool = False, + push_to_hub: bool = False, + ): + if not is_mergekit_available(): + raise ImportError( + "MergeModelCallback requires the `mergekit` extra. To install, run `pip install mergekit`." + ) + self.merge_config = merge_config or MergeConfig() + self.merge_at_every_checkpoint = merge_at_every_checkpoint + self.push_to_hub = push_to_hub + + def _merge_and_maybe_push(self, output_dir, global_step, model): + checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}") + self.merge_config.policy_model_path = checkpoint_path + if self.merge_config.target_model_path is None: + self.merge_config.target_model_path = get_config_model_id(model.config) + merge_path = os.path.join(checkpoint_path, "merged") + + merge_models(self.merge_config.create(), merge_path) + + if self.push_to_hub: + repo_name = f"{output_dir}_checkpoint-{global_step}_merged" + upload_model_to_hf(merge_path, repo_name) + + def on_save(self, args, state, control, model=None, **kwargs): + if self.merge_at_every_checkpoint: + self._merge_and_maybe_push(args.output_dir, state.global_step, model) + + def on_train_end(self, args, state, control, model=None, **kwargs): + if not self.merge_at_every_checkpoint: + self._merge_and_maybe_push(args.output_dir, state.global_step, model) diff --git a/ICL/RL/trl_source/trl/experimental/minillm/__init__.py b/ICL/RL/trl_source/trl/experimental/minillm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07bc38f0ef9973010a3eed3c9adf4ef5137ec6b1 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/minillm/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .minillm_config import MiniLLMConfig +from .minillm_trainer import MiniLLMTrainer + + +__all__ = ["MiniLLMConfig", "MiniLLMTrainer"] diff --git a/ICL/RL/trl_source/trl/experimental/minillm/minillm_config.py b/ICL/RL/trl_source/trl/experimental/minillm/minillm_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e66afb118637b610f3a7182b559a753ed5484e4d --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/minillm/minillm_config.py @@ -0,0 +1,138 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from transformers import TrainingArguments + +from ...trainer.grpo_config import GRPOConfig + + +@dataclass +class MiniLLMConfig(GRPOConfig): + """ + Configuration class for [`MiniLLMTrainer`]. + + This class includes only the parameters that are specific to MiniLLM training. For a full list of training + arguments, please refer to the [`~transformers.TrainingArguments`] and [`GRPOConfig`] documentation. + + Args: + teacher_model_init_kwargs (`dict[str, Any]]`, *optional*): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model + from a string. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + rkl_advantage (`bool`, *optional*, defaults to `True`): + Whether to add the reverse KL advantage to the reward advantage. + single_step_decomposition (`bool`, *optional*, defaults to `True`): + Whether to use single-step decomposition for the KL divergence computation. + kd_temperature (`float`, *optional*, defaults to `1.0`): + Temperature for knowledge distillation. Higher temperatures produce softer probability distributions over + classes. + gamma (`float`, *optional*, defaults to `0.0`): + Discount factor for future rewards in reinforcement learning. + length_normalization (`bool`, *optional*, defaults to `True`): + Whether to apply length normalization to the rewards. + """ + + teacher_model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " + "teacher model from a string." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropouts in `model`."}, + ) + rkl_advantage: bool = field( + default=True, + metadata={"help": "Whether to add the reverse KL advantage to the reward advantage."}, + ) + single_step_decomposition: bool = field( + default=True, + metadata={"help": "Whether to use single-step decomposition for the KL divergence computation."}, + ) + kd_temperature: float = field( + default=1.0, + metadata={ + "help": "Temperature for knowledge distillation. Higher temperatures produce softer probability " + "distributions over classes." + }, + ) + gamma: float = field( + default=0.0, + metadata={"help": "Discount factor for future rewards in reinforcement learning."}, + ) + length_normalization: bool = field( + default=True, + metadata={"help": "Whether to apply length normalization to the rewards."}, + ) + + def __post_init__(self): + # We do not use the post_init of GRPOConfig because: + # 1. num_generations can be < 2 in MiniLLMConfig. Scale_rewards must be set to "none" to avoid nan. + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + TrainingArguments.__post_init__(self) + + self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) + if self.num_generations == 1: + self.scale_rewards = "none" + + num_processes = self.world_size + # The current default effective batch size + if self.generation_batch_size is None and self.steps_per_generation is None: + self.steps_per_generation = self.gradient_accumulation_steps + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + elif self.generation_batch_size is not None and self.steps_per_generation is None: + # Just ensure the value is divisible by the global batch size + if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size " + f"({self.per_device_train_batch_size * num_processes})." + ) + self.steps_per_generation = self.generation_batch_size // ( + self.per_device_train_batch_size * num_processes + ) + elif self.generation_batch_size is None and self.steps_per_generation is not None: + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + else: + raise ValueError( + "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time" + ) + + if self.do_eval and self.eval_strategy != "no": + # Determine the number of generations to use for evaluation + num_generations = self.num_generations_eval or self.num_generations + + # Just ensure the value is divisible by the global batch size + if (self.per_device_eval_batch_size * num_processes) % num_generations != 0: + raise ValueError( + f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be " + f"divisible by the number of generations used for evaluation ({num_generations})." + ) + + # The generation batch must contain full prompt groups (no partials), so it must be divisible by + # num_generations. + if self.generation_batch_size % self.num_generations != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations " + f"({self.num_generations})." + ) + + if self.delta is not None and self.use_liger_kernel: + raise ValueError("Liger kernel does not support two-sided GRPO loss yet.") diff --git a/ICL/RL/trl_source/trl/experimental/minillm/minillm_trainer.py b/ICL/RL/trl_source/trl/experimental/minillm/minillm_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..4b8d34c02d288930dc903d50299b7dc8fd19d85f --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/minillm/minillm_trainer.py @@ -0,0 +1,411 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformers +from datasets import Dataset, IterableDataset +from packaging.version import Version +from transformers import ( + AutoModelForCausalLM, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.utils import is_peft_available + +from ...models import prepare_deepspeed +from ...trainer.grpo_trainer import GRPOTrainer, RewardFunc, RolloutFunc +from ...trainer.utils import disable_dropout_in_model, empty_cache, get_config_model_id +from .minillm_config import MiniLLMConfig + + +if is_peft_available(): + from peft import PeftConfig + + +def dummy_reward_func(completions: list, **kwargs): + # placeholder reward function when no reward function is provided + return [1.0 for _ in completions] + + +class MiniLLMTrainer(GRPOTrainer): + """ + Trainer for the Knowledge Distillation of Language Models (MiniLLM) method. This algorithm was initially proposed + in the paper [Knowledge Distillation of Large Language Models](https://huggingface.co/papers/2306.08543). + + Example: + + ```python + from datasets import load_dataset + from trl.experimental.minillm import MiniLLMTrainer + + dataset = load_dataset("trl-lib/tldr", split="train") + + trainer = MiniLLMTrainer( + model="Qwen/Qwen3-0.6B", + teacher_model="Qwen/Qwen3-1.7B", + train_dataset=dataset, + ) + trainer.train() + ``` + + Args: + model (`str | PreTrainedModel`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in + `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + teacher_model (`PreTrainedModel | nn.Module | str`): + Teacher model used for knowledge distillation. Instantiated similarly to `model`. + reward_funcs (`RewardFunc | list[RewardFunc]`, *optional*): + Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward + functions with the prompts and completions and sum the rewards. Can be either: + + - A single reward function, such as: + - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the + keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. + - A custom reward function: The function is provided with the prompts and the generated completions, + plus any additional columns in the dataset. It should return a list of rewards. Custom reward + functions can also return `None` when the reward is not applicable to those samples. This is useful + for multi-task training where different reward functions apply to different types of samples. When a + reward function returns `None` for a sample, that reward function is excluded from the reward + calculation for that sample. For more details, see [Using a custom reward + function](#using-a-custom-reward-function). + + The trainer's state is also passed to the reward function. The trainer's state is an instance of + [`~transformers.TrainerState`] and can be accessed by accessing the `trainer_state` argument to the + reward function's signature. + - A list of reward functions, where each item can independently be any of the above types. Mixing different + types within the list (e.g., a string model ID and a custom reward function) is allowed. + args ([`experimental.minillm.MiniLLMConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is + ignored. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. The padding side must be set to "left". If `None`, the + processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A + padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token, + `tokenizer.eos_token` will be used as the default. + reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*): + Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: + + - A single processing class: Used when `reward_funcs` contains only one reward function. + - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. + If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is + `None`, the tokenizer for the model is automatically loaded using + [`~transformers.AutoTokenizer.from_pretrained`]. For elements in `reward_funcs` that are custom reward + functions (not [`~transformers.PreTrainedModel`]), the corresponding entries in `reward_processing_classes` + are ignored. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your + model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + rollout_func (`RolloutFunc`, *optional*): + Function to use for generating completions. It must take prompts, args, and processing_class as parameters + and return a dict with `"prompt_ids"`, `"completion_ids"`, and `"logprobs"` fields. Any other fields that + are forwarded to the reward functions. This feature is experimental and may change or be removed at any + time without prior notice. + """ + + _tag_names = ["trl", "minillm"] + _name = "MiniLLM" + _paper = { + "title": "MiniLLM: Knowledge Distillation of Large Language Models", + "id": "2306.08543", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @inproceedings{ + gu2024minillm, + title={{MiniLLM: Knowledge Distillation of Large Language Models}}, + author={Yuxian Gu and Li Dong and Furu Wei and Minlie Huang}, + booktitle={The Twelfth International Conference on Learning Representations}, + year={2024}, + url={https://openreview.net/forum?id=5h0qf7IBZZ} + }"""), + } + + def __init__( + self, + model: str | PreTrainedModel, + teacher_model: PreTrainedModel | nn.Module | str, + reward_funcs: RewardFunc | list[RewardFunc] | None = None, + args: MiniLLMConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + peft_config: "PeftConfig | None" = None, + rollout_func: RolloutFunc | None = None, + ): + if reward_funcs is None: + reward_funcs = [dummy_reward_func] + + # Args + if args is None: + model_name = model if isinstance(model, str) else get_config_model_id(model.config) + model_name = model_name.split("/")[-1] + args = MiniLLMConfig(f"{model_name}-MiniLLM") + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model, + reward_funcs, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + reward_processing_classes=reward_processing_classes, + callbacks=callbacks, + optimizers=optimizers, + peft_config=peft_config, + rollout_func=rollout_func, + ) + + if args.teacher_model_init_kwargs is None: + teacher_model_init_kwargs = {} + elif not isinstance(teacher_model, str): + raise ValueError( + "You passed teacher_model_init_kwargs to the MiniLLMConfig, but your teacher_model is already instantiated." + ) + else: + teacher_model_init_kwargs = args.teacher_model_init_kwargs + teacher_model_init_kwargs["dtype"] = ( + teacher_model_init_kwargs["dtype"] + if teacher_model_init_kwargs["dtype"] in ["auto", None] + else getattr(torch, teacher_model_init_kwargs["dtype"]) + ) + + if isinstance(teacher_model, str): + teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs) + + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(self.model) + + if self.is_deepspeed_enabled: + self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator) + else: + self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) + + self.temperature = args.temperature + self.kd_temperature = args.kd_temperature + self.single_step_decomposition = args.single_step_decomposition + self.rkl_advantage = args.rkl_advantage + self.gamma = args.gamma + self.length_normalization = args.length_normalization + + def _single_step_decomposition_loss( + self, + student_log_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, + mask: torch.Tensor | None = None, + reduction: str = "batchmean", + ): + """ + Compute the MiniLLM loss for knowledge distillation using F.kl_div. See Eq. (1) of + https://huggingface.co/papers/2306.08543 for the definition. + + Args: + student_logits: + Tensor of shape (batch_size, sequence_length, vocab_size) + teacher_logits: + Tensor of shape (batch_size, sequence_length, vocab_size) + labels: + Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing + loss + beta: + Interpolation coefficient between 0 and 1 (default: 0.5) + temperature: + Softmax temperature (default: 1.0) + reduction: + Specifies the reduction to apply to the output (default: 'batchmean') + + Returns: + loss: Scalar tensor with the generalized JSD loss + """ + reg_loss = F.kl_div( + teacher_log_probs, student_log_probs, reduction="none", log_target=True + ) # (batch_size, sequence_length) + + # Masking + if mask is not None: + reg_loss = reg_loss[mask] + + # Apply reduction + if reduction == "batchmean": + return reg_loss.sum() / mask.sum() if mask is not None else reg_loss.sum() / reg_loss.size(0) + elif reduction == "sum": + return reg_loss.sum() + elif reduction == "mean": + return reg_loss.mean() + else: + return reg_loss + + def _compute_advantage( + self, + student_log_probs_on_labels: torch.Tensor, + teacher_log_probs_on_labels: torch.Tensor, + mask: torch.Tensor | None = None, + ) -> torch.Tensor: + r"""Compute the advantage for Reverse KL Divergence. + + Mostly following [this + implementation](https://github.com/microsoft/LMOps/blob/e210d2c026b9958617887762400778ace81172e6/minillm/minillm/losses.py#L37-L49). + + $$ \text{rewards}_t = \text{teacher\_log\_probs\_on\_labels}_t - \text{student\_log\_probs\_on\_labels}_t $$ + + If length normalization is enabled: + + $$ \text{lengths}_t = \sum_{i=t}^{T} \gamma^{i-t} $$ + + $$ \text{advantages}_t = \frac{\sum_{i=t}^{T} \gamma^{i-t} R_i}{\text{lengths}_t} $$ + + Otherwise: + + $$ \text{advantages}_t = \sum_{i=t}^{T} \gamma^{i-t} R_i $$ + + Args: + student_log_probs_on_labels: Log probabilities of the student model on the labels. + Shape: (batch_size, sequence_length) + teacher_log_probs_on_labels: Log probabilities of the teacher model on the labels. + Shape: (batch_size, sequence_length) + mask: Optional mask to apply to the log probabilities. Shape: (batch_size, sequence_length) + Returns: + advantage: Computed advantage. Shape: (batch_size, sequence_length) + """ + response_length = student_log_probs_on_labels.size(1) + if mask is None: + mask = torch.ones_like(student_log_probs_on_labels) + mask = mask.float() + student_log_probs_on_labels = student_log_probs_on_labels * mask + teacher_log_probs_on_labels = teacher_log_probs_on_labels * mask + + rewards = teacher_log_probs_on_labels - student_log_probs_on_labels # (batch_size, sequence_length) + + if self.gamma > 0.0: + gamma_pow = torch.pow(self.gamma, torch.arange(response_length, device=rewards.device)) + + advantages = rewards * gamma_pow + advantages = advantages.flip(1).cumsum(dim=1).flip(1) + + if self.length_normalization: + mask = torch.where(mask < 0.5, 1e-4, mask) + lengths = mask * gamma_pow + lengths = lengths.flip(1).cumsum(dim=1).flip(1) + advantages = advantages / lengths + else: + advantages = rewards + + return advantages + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + input_ids = torch.cat([inputs["prompt_ids"], inputs["completion_ids"]], dim=1) + attention_mask = torch.cat([inputs["prompt_mask"], inputs["completion_mask"]], dim=1) + labels = input_ids.clone() + labels[attention_mask == 0] = -100 + + # Compute student output + student_outputs = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) + + # Compute teacher output in eval mode + self.teacher_model.eval() + with torch.no_grad(): + teacher_outputs = self.teacher_model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) + + # Slice the logits for the generated tokens using the inputs["prompts"] lengths + prompt_lengths = inputs["prompt_ids"].shape[1] + student_logits = student_outputs.logits[:, prompt_lengths - 1 : -1, :] + teacher_logits = teacher_outputs.logits[:, prompt_lengths - 1 : -1, :] + shifted_labels = input_ids[:, prompt_lengths:] + + # Apply temperature scaling + student_logits = student_logits / self.kd_temperature + teacher_logits = teacher_logits / self.kd_temperature + + # Compute log probabilities for student and probabilities for teacher + student_log_probs = F.log_softmax(student_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + + student_log_probs_on_labels = torch.gather( + student_log_probs, dim=-1, index=shifted_labels.unsqueeze(-1) + ).squeeze(-1) + teacher_log_probs_on_labels = torch.gather( + teacher_log_probs, dim=-1, index=shifted_labels.unsqueeze(-1) + ).squeeze(-1) + + mask = shifted_labels != -100 + + if self.rkl_advantage: + reverse_kl_advantage = self._compute_advantage( + student_log_probs_on_labels=student_log_probs_on_labels, + teacher_log_probs_on_labels=teacher_log_probs_on_labels, + mask=mask, + ) + + inputs["advantages"] = inputs["advantages"].unsqueeze(1) + reverse_kl_advantage + + # Compute GRPO loss on verifiable reward + loss = self._compute_loss(model, inputs) + + # Compute loss + if self.single_step_decomposition: + single_step_decomposition_loss = self._single_step_decomposition_loss( + student_log_probs=student_log_probs, + teacher_log_probs=teacher_log_probs, + mask=mask, + ) + + loss += single_step_decomposition_loss + + # Empty cache + empty_cache() + + # Return loss + return (loss, student_outputs) if return_outputs else loss diff --git a/ICL/RL/trl_source/trl/experimental/nash_md/__init__.py b/ICL/RL/trl_source/trl/experimental/nash_md/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b80c1515bed13913e9497e35e6f79552deb40dc --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/nash_md/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .nash_md_config import NashMDConfig +from .nash_md_trainer import NashMDTrainer + + +__all__ = ["NashMDConfig", "NashMDTrainer"] diff --git a/ICL/RL/trl_source/trl/experimental/nash_md/nash_md_config.py b/ICL/RL/trl_source/trl/experimental/nash_md/nash_md_config.py new file mode 100644 index 0000000000000000000000000000000000000000..2d52d08ecbe7bde752f9b1ae3455c56bf31e445d --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/nash_md/nash_md_config.py @@ -0,0 +1,46 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ..online_dpo import OnlineDPOConfig + + +@dataclass +class NashMDConfig(OnlineDPOConfig): + r""" + Configuration class for the [`experimental.nash_md.NashMDTrainer`]. + + Subclass of [`experimental.online_dpo.OnlineDPOConfig`] we can use all its arguments and add the following: + + Parameters: + mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`): + Logit mixture coefficient for the model and reference model. If a list of floats is provided then the + mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the + epochs. + """ + + mixture_coef: list[float] = field( + default_factory=lambda: [0.5], + metadata={ + "help": "Logit mixture coefficient for the model and reference model. If a list of floats is provided " + "then the mixture coefficient is selected for each new epoch and the last coefficient is used for the " + "rest of the epochs." + }, + ) + + def __post_init__(self): + super().__post_init__() + if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1: + self.mixture_coef = self.mixture_coef[0] diff --git a/ICL/RL/trl_source/trl/experimental/nash_md/nash_md_trainer.py b/ICL/RL/trl_source/trl/experimental/nash_md/nash_md_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..b89021873c71b7b15f100dbe25aa176b73be79c7 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/nash_md/nash_md_trainer.py @@ -0,0 +1,555 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap +from collections.abc import Callable +from typing import Any + +import jinja2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import Dataset, IterableDataset +from transformers import ( + BaseImageProcessor, + FeatureExtractionMixin, + GenerationMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.trainer_utils import EvalPrediction +from transformers.training_args import OptimizerNames +from transformers.utils import is_peft_available + +from ...data_utils import is_conversational, maybe_apply_chat_template +from ...models.utils import unwrap_model_for_generation +from ...trainer.utils import empty_cache, selective_log_softmax +from ..judges import BasePairwiseJudge +from ..online_dpo import OnlineDPOTrainer +from ..utils import SIMPLE_CHAT_TEMPLATE, get_reward, truncate_right +from .nash_md_config import NashMDConfig + + +if is_peft_available(): + from peft import PeftModel + + +class GeometricMixtureWrapper(GenerationMixin): + """ + Geometric Mixture generation wrapper that samples from the logits of two model's geometric mixture. + + Args: + model ([`~transformers.PreTrainedModel`]): The model to be wrapped. + ref_model ([`~transformers.PreTrainedModel`]): The reference model. + generation_config ([`~transformers.GenerationConfig`]): The generation config. + mixture_coef (`float`, *optional* - default: 0.5): The mixture coefficient. + """ + + main_input_name = "input_ids" + _supports_cache_class = False + _supports_static_cache = False + _is_stateful = False + + def __init__(self, model, ref_model, generation_config, mixture_coef=0.5, device=None): + super().__init__() + + self.model = model + self.config = model.config + self.ref_model = ref_model + self.generation_config = generation_config + self.mixture_coef = mixture_coef + self.device = device + if hasattr(self.model, "_is_stateful"): + self._is_stateful = self.model._is_stateful + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @torch.inference_mode() + def forward(self, *args, **kwargs): + model_outputs = self.model(*args, **kwargs) + model_logits = model_outputs.logits + ref_model_logits = self.ref_model(*args, **kwargs).logits + + model_outputs.logits = torch.nn.functional.log_softmax( + self.mixture_coef * ref_model_logits + (1 - self.mixture_coef) * model_logits, dim=-1 + ) + + return model_outputs + + def prepare_inputs_for_generation(self, *args, **kwargs): + # turn off cache in the generation config + kwargs["use_cache"] = False + model_inputs = self.model.prepare_inputs_for_generation(*args, **kwargs) + _ = self.ref_model.prepare_inputs_for_generation(*args, **kwargs) + + return model_inputs + + def _validate_model_class(self): + self.model._validate_model_class() + + def _validate_model_kwargs(self, model_kwargs): + return self.model._validate_model_kwargs(model_kwargs) + + +class NashMDTrainer(OnlineDPOTrainer): + """ + Trainer for the Nash-MD method. + + It is implemented as a subclass of [`experimental.online_dpo.OnlineDPOTrainer`]. + + Args: + model ([`~transformers.PreTrainedModel`]): + The model to train, preferably an `AutoModelForCausalLM`. + ref_model ([`~transformers.PreTrainedModel`]): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation + and loss. If no reference model is provided, the trainer will create a reference model with the same + architecture as the model to be optimized. + reward_funcs ([`~transformers.PreTrainedModel`]): + The reward model to score completions with, preferably an + [`~transformers.AutoModelForSequenceClassification`]. + judge ([`experimental.judges.BasePairwiseJudge`]): + The judge to use for pairwise comparison of model completions. + args ([`experimental.nash_md.NashMDConfig`]): + The NashMD config arguments to use for training. + data_collator ([`~transformers.DataCollator`]): + The data collator to use for training. If None is specified, the default data collator + ([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the + maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset ([`~datasets.Dataset`]): + The dataset to use for training. + eval_dataset ([`~datasets.Dataset`]): + The dataset to use for evaluation. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + peft_config (`dict`): + The peft config to use for training. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to + metric values. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + """ + + _tag_names = ["trl", "nash-md"] + _name = "Nash-MD" + _paper = { + "title": "Nash Learning from Human Feedback", + "id": "2312.00886", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @inproceedings{munos2024nash, + title = {{Nash Learning from Human Feedback}}, + author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot}, + year = 2024, + booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=Y5AmNYiyCQ} + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module = None, + ref_model: PreTrainedModel | nn.Module = None, + reward_funcs: PreTrainedModel | nn.Module | None = None, + judge: BasePairwiseJudge | None = None, + args: NashMDConfig | None = None, + data_collator: Callable | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + peft_config: dict | None = None, + compute_metrics: Callable[[EvalPrediction], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + ) -> None: + super().__init__( + model=model, + ref_model=ref_model, + reward_funcs=reward_funcs, + judge=judge, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + reward_processing_classes=processing_class, + peft_config=peft_config, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + self._mixture_coef = self.args.mixture_coef + + # Overwrite the stats dictionary to include NashMD specific statistics + self.stats = { + # Remove "non_score_reward", "rlhf_reward", "scores_margin" + # Add "mixture_coef" + "loss/kl": [], + "objective/entropy": [], + "loss/score": [], + "rewards/probabilities": [], + "rewards/accuracies": [], + "rewards/margins": [], + "logps/chosen": [], + "logps/rejected": [], + "val/model_contain_eos_token": [], + "val/ref_contain_eos_token": [], + "beta": [], + "mixture_coef": [], + } + if self.reward_funcs is not None: + if len(self.reward_funcs) != 1: + raise ValueError("NashMDTrainer only supports one reward function/model.") + self.reward_funcs = self.reward_funcs[0] + self.stats["rewards/chosen"] = [] + self.stats["rewards/rejected"] = [] + + @property + def mixture_coef(self): + if isinstance(self._mixture_coef, list): + epoch = self.state.epoch + return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1] + else: + return self._mixture_coef + + def _generate_completions(self, model, prompts): + # Generate completions from the policy model. + with ( + unwrap_model_for_generation( + model, + self.accelerator, + generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 + ) as unwrapped_policy_for_gen_ctx + ): + model_output = unwrapped_policy_for_gen_ctx.generate( + input_ids=prompts["input_ids"], + attention_mask=prompts["attention_mask"], + generation_config=self.generation_config, + ) + + # Get the DDP/FSDP unwrapped version of the main model. + # This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used). + policy_model_for_gmw = self.accelerator.unwrap_model(model) + + # Determine the correct reference model for GeometricMixtureWrapper. + # This also needs to be DDP/FSDP unwrapped. + ref_model_for_gmw: torch.nn.Module + if self.ref_model is None: + # No explicit ref_model is provided. + # Use the base of the main `model` if it's a PEFT model. + # policy_model_for_gmw is already DDP-unwrapped. + if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel): + ref_model_for_gmw = policy_model_for_gmw.get_base_model() + else: + # Not a PEFT model (or PEFT not available), or already a base model. + # Use the DDP-unwrapped policy model itself as the reference. + ref_model_for_gmw = policy_model_for_gmw + else: + # An explicit ref_model is provided. Unwrap it for DDP/FSDP. + ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model) + + # Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped. + with torch.no_grad(): # Ensure no_grad context for mixture model generation + mixture_model = GeometricMixtureWrapper( + model=policy_model_for_gmw, + ref_model=ref_model_for_gmw, + generation_config=self.generation_config, + mixture_coef=self.mixture_coef, + device=self.accelerator.device, + ) + + # TODO: use self._override_model_generation_config for both models? + mixture_output = mixture_model.generate( + input_ids=prompts["input_ids"], + attention_mask=prompts["attention_mask"], + generation_config=self.generation_config, + ) + + return model_output, mixture_output + + def _process_completions(self, model_output, mixture_output, prompts): + context_length = prompts["input_ids"].shape[1] + + # Process model completions + model_completion_ids = model_output[:, context_length:] + model_completion_ids, model_completion_mask = truncate_right( + model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id + ) + model_data = { + "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1), + "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1), + "raw": prompts["raw"], + } + + # Process reference model completions + mixture_completion_ids = mixture_output[:, context_length:] + mixture_completion_ids, mixture_completion_mask = truncate_right( + mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id + ) + mixture_data = { + "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1), + "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1), + "raw": prompts["raw"], + } + + return model_data, mixture_data + + def _compute_rewards(self, model_data, mixture_data, context_length): + with torch.no_grad(): + _, model_scores, _ = get_reward( + self.reward_funcs, model_data["input_ids"], self.processing_class.pad_token_id, context_length + ) + _, mixture_scores, _ = get_reward( + self.reward_funcs, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length + ) + + # Apply EOS penalty if needed + if self.args.missing_eos_penalty is not None: + model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) + mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) + model_scores[~model_contain_eos] -= self.args.missing_eos_penalty + mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty + + return model_scores, mixture_scores + + def _compute_judge(self, model_data, mixture_data, context_length): + prompts = model_data["raw"] + model_data_completions = self.processing_class.batch_decode( + model_data["input_ids"][:, context_length:], skip_special_tokens=True + ) + model_data_completions = [completion.strip() for completion in model_data_completions] + + mixture_data_completions = self.processing_class.batch_decode( + mixture_data["input_ids"][:, context_length:], skip_special_tokens=True + ) + mixture_data_completions = [completion.strip() for completion in mixture_data_completions] + if is_conversational({"prompt": prompts[0]}): + model_data_completions = [ + [{"role": "assistant", "content": completion}] for completion in model_data_completions + ] + environment = jinja2.Environment() + template = environment.from_string(SIMPLE_CHAT_TEMPLATE) + prompts = [template.render(messages=message) for message in prompts] + model_data_completions = [template.render(messages=completion) for completion in model_data_completions] + + mixture_data_completions = [ + [{"role": "assistant", "content": completion}] for completion in mixture_data_completions + ] + mixture_data_completions = [ + template.render(messages=completion) for completion in mixture_data_completions + ] + + probability = self.judge.judge( + prompts, + list(zip(model_data_completions, mixture_data_completions, strict=True)), + return_scores=True, + ) + return torch.tensor(probability, device=model_data["input_ids"].device) + + def _compute_logprobs(self, model, model_data, context_length): + def compute_logprobs_for_data(m, data): + output = m(data["input_ids"], attention_mask=data["attention_mask"]) + logits = output.logits[:, context_length - 1 : -1] + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) + return token_logprobs + + # Compute logprobs for model completions under the model + model_logprobs_model_data = compute_logprobs_for_data(model, model_data) + + # Compute logprobs of model completions under the reference model + with torch.no_grad(): + if self.ref_model is None: + with model.disable_adapter(): + ref_logprobs_model_data = compute_logprobs_for_data(model, model_data) + else: + ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data) + + # Mask padding tokens + model_padding_mask = model_data["attention_mask"][:, context_length:] == 0 + model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0) + ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0) + + return (model_logprobs_model_data, ref_logprobs_model_data) + + def _compute_losses( + self, + model_logprobs_model_data, + ref_logprobs_model_data, + probability, + ): + # reinforce score where 0.5 is a control variate + score = (probability - 0.5) * model_logprobs_model_data.sum(1) + + # kl divergence via reinforce + with torch.no_grad(): + log_ratio = model_logprobs_model_data - ref_logprobs_model_data + kl_div_log = log_ratio.sum(1) + kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1) + + # final loss + loss = self.beta * kl_div_loss - score + + return loss.mean(), score, kl_div_log + + def _log_statistics( + self, + model_data, + mixture_data, + model_logprobs_model_data, + ref_logprobs_model_data, + probability, + score, + kl_div, + context_length, + model_scores=None, + mixture_scores=None, + ): + # Helper function to gather and compute mean + def gather_mean(tensor): + return self.accelerator.gather_for_metrics(tensor).mean().item() + + # Log score + self.stats["loss/score"].append(gather_mean(score)) + # Log KL divergence + self.stats["loss/kl"].append(gather_mean(kl_div)) + + # Log logprobs + model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) + ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) + + self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum)) + self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum)) + + # Log rewards + if self.reward_funcs is not None: + self.stats["rewards/chosen"].append(gather_mean(model_scores)) + self.stats["rewards/rejected"].append(gather_mean(mixture_scores)) + + # Log probabilities + self.stats["rewards/probabilities"].append(gather_mean(probability)) + + # Calculate entropy for model data + entropy_model_data = -model_logprobs_model_data.sum(1) + self.stats["objective/entropy"].append(gather_mean(entropy_model_data)) + + # Calculate margins + margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum + self.stats["rewards/margins"].append(gather_mean(margin)) + + # Calculate accuracy + accuracy = (margin > 0).float() + self.stats["rewards/accuracies"].append(gather_mean(accuracy)) + + # Log EOS token statistics + model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) + mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) + self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float())) + self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float())) + + # Log beta and mixture coef + self.stats["beta"].append(self.beta) + self.stats["mixture_coef"].append(self.mixture_coef) + + def training_step( + self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None + ) -> torch.Tensor: + model.train() + + # Apply chat template and tokenize the input + batch_size = len(next(iter(inputs.values()))) + prompts = inputs["prompt"] + inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)] + inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs] + inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs] + inputs = self.data_collator(inputs) + + # need the prompt_ only + inputs = self._prepare_inputs(inputs) + context_length = inputs["prompt_input_ids"].shape[1] + prompts = { + "input_ids": inputs["prompt_input_ids"], + "attention_mask": inputs["prompt_attention_mask"], + "raw": prompts, + } + del inputs + + # Sample completions from both the model and the reference model + model_output, mixture_output = self._generate_completions(model, prompts) + + # Process model completions + model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts) + + # Compute rewards + if self.reward_funcs is not None: + model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length) + # probability of the model data vs the mixture data + probability = F.sigmoid(model_scores - mixture_scores) + else: + model_scores, mixture_scores = None, None + probability = self._compute_judge(model_data, mixture_data, context_length) + + # Compute logprobs + model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length) + + # Compute loss + loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability) + + # Log everything + self._log_statistics( + model_data, + mixture_data, + model_logprobs_model_data.detach(), + ref_logprobs_model_data, + probability, + score.detach(), + kl_div.detach(), + context_length, + model_scores, + mixture_scores, + ) + + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + empty_cache() + + kwargs = {} + # For LOMO optimizers you need to explicitly use the learning rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + self.accelerator.backward(loss, **kwargs) + + return loss.detach() / self.args.gradient_accumulation_steps diff --git a/ICL/RL/trl_source/trl/experimental/online_dpo/__init__.py b/ICL/RL/trl_source/trl/experimental/online_dpo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a109cabc240f871544e5ccc56ee5616794c0ba9 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/online_dpo/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .online_dpo_config import OnlineDPOConfig +from .online_dpo_trainer import OnlineDPOTrainer + + +__all__ = ["OnlineDPOConfig", "OnlineDPOTrainer"] diff --git a/ICL/RL/trl_source/trl/experimental/online_dpo/online_dpo_config.py b/ICL/RL/trl_source/trl/experimental/online_dpo/online_dpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..31418b2bdf082175f92d2db7cca69bc9265bddd5 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/online_dpo/online_dpo_config.py @@ -0,0 +1,416 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from typing import Any + +from transformers import TrainingArguments + + +@dataclass +class OnlineDPOConfig(TrainingArguments): + r""" + Configuration class for the [`experimental.online_dpo.OnlineDPOTrainer`]. + + This class includes only the parameters that are specific to Online DPO training. For a full list of training + arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this + class may differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + reward_model_path (`str`, *optional*): + Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both. + judge (`str`, *optional*): + Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both. + max_new_tokens (`int`, *optional*, defaults to `64`): + Maximum number of tokens to generate per completion. + max_length (`int`, *optional*, defaults to `256`): + Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the + sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as + possible. + temperature (`float`, *optional*, defaults to `0.9`): + Temperature for sampling. The higher the temperature, the more random the completions. + missing_eos_penalty (`float`, *optional*): + Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage to + generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive + value. This parameter only works when using `reward_funcs` and not when using `judge`. + beta (`float` or `list[float]`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher ฮฒ means less deviation from the + reference model. For the IPO loss (`loss_type="ipo"`), ฮฒ is the regularization parameter denoted by ฯ„ in + the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the ฮฒ is + selected for each new epoch and the last ฮฒ is used for the rest of the epochs. + loss_type (`str`, *optional*, defaults to `"sigmoid"`): + Type of loss to use. Possible values are: + + - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. + - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model and reference model. + + > Parameters that control generation + + top_p (`float`, *optional*, defaults to `1.0`): + Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to + `1.0` to consider all tokens. + top_k (`int`, *optional*, defaults to `0`): + Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, top-k-filtering is + disabled and all tokens are considered. + min_p (`float`, *optional*): + Minimum token probability, which will be scaled by the probability of the most likely token. It must be a + value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. + Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat + tokens. + use_transformers_paged (`bool`, *optional*, defaults to `False`): + Whether to use the `transformers` paged implementation for generation. If set to `True`, the `transformers` + paged implementation will be used for generation instead of the default padded implementation. This + parameter is only effective when `use_vllm` is set to `False`. + cache_implementation (`str`, *optional*): + Implementation of the cache method for faster generation when `use_vllm` is set to `False`. + generation_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or + `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the + generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict + with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. + + > Parameters that control generation acceleration powered by vLLM + + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation + instead of the default model.generate(). Requires `vllm` to be installed. + vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): + Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use + the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model + implementation. + vllm_mode (`str`, *optional*, defaults to `"server"`): + Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or + `"colocate"`. + + - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM + server is running (start with `trl vllm-serve`). + - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a + separate server but may cause resource contention with training. + vllm_structured_outputs_regex (`str`, *optional*): + Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled. + + > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) + + vllm_server_base_url (`str`, *optional*): + Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and + `vllm_server_port` are ignored. + vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. + vllm_server_port (`int`, *optional*, defaults to `8000`): + Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. + vllm_server_timeout (`float`, *optional*, defaults to `240.0`): + Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the + timeout, a `ConnectionError` is raised. + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port + is occupied, there is no need to change it. + + > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) + + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.55`): + Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to + `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when + launching the vLLM server via the `--vllm_gpu_memory_utilization` flag. + vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): + Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to + `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when + launching the vLLM server via the `--vllm_tensor_parallel_size` flag. + vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): + Enable vLLM sleep mode to offload weights/cache during the optimizer step. Keeps GPU memory usage low, but + waking the engine adds hostโ€“device transfer latency. + + > Other parameters + + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible + with vLLM generation. + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a + string. + """ + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=5e-7, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue + # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary + # workaround here, which can be removed once we drop support for versions older than 4.57.5. + lr_scheduler_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " + "restarts." + }, + ) + + reward_model_path: str | None = field( + default=None, + metadata={ + "help": "Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both." + }, + ) + judge: str | None = field( + default=None, + metadata={ + "help": "Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both." + }, + ) + max_new_tokens: int = field( + default=64, + metadata={"help": "Maximum number of tokens to generate per completion."}, + ) + max_length: int = field( + default=512, + metadata={ + "help": "Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If " + "the sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the " + "completion as possible." + }, + ) + temperature: float = field( + default=0.9, + metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, + ) + top_p: float = field( + default=1.0, + metadata={ + "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. " + "Set to 1.0 to consider all tokens." + }, + ) + top_k: int = field( + default=0, + metadata={ + "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, " + "top-k-filtering is disabled and all tokens are considered." + }, + ) + min_p: float | None = field( + default=None, + metadata={ + "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It " + "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range." + }, + ) + repetition_penalty: float = field( + default=1.0, + metadata={ + "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated " + "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model " + "to repeat tokens." + }, + ) + generation_kwargs: dict | None = field( + default=None, + metadata={ + "help": "Additional keyword arguments to pass to `GenerationConfig` (if using transformers) or " + "`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the " + "generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that " + "conflict with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them." + }, + ) + use_transformers_paged: bool = field( + default=False, + metadata={ + "help": "Whether to use the `transformers` paged implementation for generation. If set to `True`, the " + "`transformers` paged implementation will be used for generation instead of the default padded " + "implementation. This parameter is only effective when `use_vllm` is set to `False`." + }, + ) + cache_implementation: str | None = field( + default=None, + metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."}, + ) + missing_eos_penalty: float | None = field( + default=None, + metadata={ + "help": "Penalty applied to the score when the model fails to generate an EOS token. This is useful to " + "encourage to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be " + "a positive value." + }, + ) + beta: list[float] = field( + default_factory=lambda: [0.1], + metadata={ + "help": "Parameter controlling the deviation from the reference model. Higher ฮฒ means less deviation from " + "the reference model. For the IPO loss (`loss_type='ipo'`), ฮฒ is the regularization parameter denoted by " + "ฯ„ in the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the ฮฒ " + "is selected for each new epoch and the last ฮฒ is used for the rest of the epochs." + }, + ) + loss_type: str = field( + default="sigmoid", + metadata={ + "help": "Type of loss to use.", + "choices": ["sigmoid", "ipo"], + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) + use_vllm: bool = field( + default=False, + metadata={ + "help": "Whether to use vLLM for generating completions. Requires vLLM to be installed " + "(`pip install trl[vllm]`)." + }, + ) + vllm_model_impl: str = field( + default="vllm", + metadata={ + "help": "Model implementation to use for vLLM. Must be one of `transformers` or `vllm`. `transformers`: " + "Use the `transformers` backend for model implementation. `vllm`: Use the `vllm` library for " + "model implementation." + }, + ) + vllm_structured_outputs_regex: str | None = field( + default=None, + metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."}, + ) + vllm_gpu_memory_utilization: float | None = field( + default=0.55, + metadata={ + "help": "Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set " + "to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when " + "launching the vLLM server via the `--vllm_gpu_memory_utilization` flag.", + }, + ) + vllm_mode: str = field( + default="server", + metadata={ + "help": "Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `'server'` or " + "`'colocate'`. `'server'`: The trainer will send generation requests to a separate vLLM server. Make sure " + "a TRL vLLM server is running (start with `trl vllm-serve`). `'colocate'`: vLLM will run in the same " + "process and share the training GPUs. This avoids the need for a separate server but may cause resource " + "contention with training.", + }, + ) + vllm_server_base_url: str | None = field( + default=None, + metadata={ + "help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " + "and `vllm_server_port` are ignored.", + }, + ) + vllm_server_host: str = field( + default="0.0.0.0", + metadata={"help": "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."}, + ) + vllm_server_port: int = field( + default=8000, + metadata={"help": "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."}, + ) + vllm_server_timeout: float = field( + default=240.0, + metadata={ + "help": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up " + "after the timeout, a `ConnectionError` is raised.", + }, + ) + vllm_group_port: int = field( + default=51216, + metadata={ + "help": "Port number for the weight update group. This is used to communicate with the vLLM server. " + "Unless the port is occupied, there is no need to change it.", + }, + ) + vllm_tensor_parallel_size: int = field( + default=1, + metadata={ + "help": "Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set " + "to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when " + "launching the vLLM server via the `--vllm_tensor_parallel_size` flag.", + }, + ) + vllm_enable_sleep_mode: bool = field( + default=False, + metadata={ + "help": "Enable vLLM sleep mode to offload weights/cache during the optimizer step. Keeps GPU memory " + "usage low, but waking the engine adds hostโ€“device transfer latency." + }, + ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option " + "is not compatible with vLLM generation." + }, + ) + model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " + "from a string." + }, + ) + reward_weights: list[float] | None = field( + default=None, + metadata={ + "help": "Weights for combining multiple reward functions. Must match the number of reward functions. " + "If None, all reward functions are equally weighted." + }, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + super().__post_init__() + + if hasattr(self.beta, "__len__") and len(self.beta) == 1: + self.beta = self.beta[0] + + if self.max_new_tokens >= self.max_length: + warnings.warn( + f"The configuration has `max_new_tokens` ({self.max_new_tokens}) >= `max_length` ({self.max_length}). " + "This will cause prompts to be truncated or completely removed in the forward pass. " + "To preserve prompts, ensure e.g. `max_length > max_new_tokens + 512`. ", + stacklevel=3, + ) diff --git a/ICL/RL/trl_source/trl/experimental/online_dpo/online_dpo_trainer.py b/ICL/RL/trl_source/trl/experimental/online_dpo/online_dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..af22577cc116eb767985189763776b21f10f0562 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/online_dpo/online_dpo_trainer.py @@ -0,0 +1,1499 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import textwrap +from collections.abc import Callable +from contextlib import nullcontext +from functools import wraps +from pathlib import Path +from typing import Any + +import jinja2 +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data +import transformers +from accelerate import logging +from accelerate.utils import broadcast_object_list, gather_object, is_peft_model +from datasets import Dataset +from packaging.version import Version +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import DataLoader, IterableDataset +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollator, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + TrainerCallback, + is_bitsandbytes_available, +) +from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES +from transformers.trainer_utils import EvalPrediction, seed_worker +from transformers.training_args import OptimizerNames +from transformers.utils import is_flash_attn_2_available, is_peft_available, is_sagemaker_mp_enabled + +from ...data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template +from ...extras.profiling import profiling_context +from ...generation.vllm_client import VLLMClient +from ...import_utils import is_vllm_available +from ...models.utils import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation +from ...trainer.base_trainer import BaseTrainer +from ...trainer.utils import disable_dropout_in_model, empty_cache, ensure_master_addr_port, get_config_model_id, pad +from ..judges import BasePairwiseJudge +from ..utils import SIMPLE_CHAT_TEMPLATE, DPODataCollatorWithPadding, prepare_peft_model, truncate_right +from .online_dpo_config import OnlineDPOConfig + + +if is_peft_available(): + from peft import PeftConfig, PeftModel + + +if is_sagemaker_mp_enabled(): + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = Version(SMP_VERSION) >= Version("1.10") + +else: + IS_SAGEMAKER_MP_POST_1_10 = False + + +if Version(transformers.__version__) >= Version("5.2.0.dev0"): + from transformers.trainer_pt_utils import nested_gather + + +if is_vllm_available(): + from vllm import LLM, SamplingParams + from vllm.sampling_params import StructuredOutputsParams + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + +logger = logging.get_logger(__name__) + +# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of +# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. +RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]] + + +class OnlineDPOTrainer(BaseTrainer): + r""" + Initialize OnlineDPOTrainer. + + Args: + model (`str | nn.Module | PreTrainedModel`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in + `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + ref_model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `None`): + The reference model to use for training. If None is specified, the reference model will be created from the + model. + judge ([`experimental.judges.BasePairwiseJudge`]): + The judge to use for pairwise comparison of model completions. + reward_funcs (`RewardFunc | list[RewardFunc]`, *optional*): + Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward + functions with the prompts and completions and sum the rewards. Can be either: + + - A single reward function: Can be a string (path to model), a [`~transformers.PreTrainedModel`], or a + custom callable function. + - A list of reward functions: Must all be of compatible types. + + Note: Only one of `judge`, or `reward_funcs` should be provided. + args ([`experimental.online_dpo.OnlineDPOConfig`]): + The online DPO config arguments to use for training. + data_collator ([`~transformers.DataCollator`]): + The data collator to use for training. If None is specified, the default data collator + ([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the + maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + The dataset to use for training. + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): + The dataset to use for evaluation. + processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*): + Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: + + - A single processing class: Used when `reward_funcs` contains only one reward function. + - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. + + If set to `None`, the tokenizer for each model-based reward function is automatically loaded using + [`~transformers.AutoTokenizer.from_pretrained`]. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to + metric values. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + """ + + _tag_names = ["trl", "online-dpo"] + _name = "Online DPO" + _paper = { + "title": "Direct Language Model Alignment from Online AI Feedback", + "id": "2402.04792", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @article{guo2024direct, + title = {{Direct Language Model Alignment from Online AI Feedback}}, + author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel}, + year = 2024, + eprint = {arXiv:2402.04792} + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module | str, + ref_model: PreTrainedModel | nn.Module | None = None, + reward_funcs: RewardFunc | list[RewardFunc] | None = None, + judge: BasePairwiseJudge | None = None, + args: OnlineDPOConfig | None = None, + data_collator: DataCollator | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, + peft_config: "PeftConfig | None" = None, + compute_metrics: Callable[[EvalPrediction], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + ) -> None: + if ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, either omit the `ref_model` argument or pass `None`." + ) + + self.ref_model = ref_model + + # Validate reward configuration - must have exactly one of: judge, or reward_funcs + reward_configs = sum(x is not None for x in [judge, reward_funcs]) + if reward_configs == 0: + raise ValueError("One of `judge` or `reward_funcs` must be provided.") + elif reward_configs > 1: + if judge is not None: + logger.warning( + "Both `judge` and `reward_funcs` are provided. Using `judge` and ignoring `reward_funcs`.", + UserWarning, + ) + reward_funcs = None + self.judge = judge + + # Handle reward_funcs + if reward_funcs is not None: + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + self.reward_func_names = [] + + # Process reward functions (convert strings to models, collect names) + model_init_kwargs = args.model_init_kwargs or {} + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + # Load model from string path + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, num_labels=1, **model_init_kwargs + ) + if isinstance(reward_funcs[i], nn.Module): + self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) + else: + self.reward_func_names.append(reward_funcs[i].__name__) + self.reward_funcs = reward_funcs + + # Handle reward processing classes for reward_funcs + if reward_processing_classes is None: + reward_processing_classes = [None] * len(reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + else: + if len(reward_processing_classes) != len(reward_funcs): + raise ValueError( + "The number of reward processing classes must match the number of reward functions." + ) + + self.reward_processing_classes = [] + for reward_processing_class_i, reward_func in zip(reward_processing_classes, reward_funcs, strict=True): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class_i is None: + reward_processing_class_i = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) + if reward_processing_class_i.pad_token_id is None: + reward_processing_class_i.pad_token = reward_processing_class_i.eos_token + # Set pad token ID on reward model config + reward_func.config.pad_token_id = reward_processing_class_i.pad_token_id + self.reward_processing_classes.append(reward_processing_class_i) + else: + self.reward_funcs = None + self.reward_func_names = [] + self.reward_processing_classes = [] + + # Handle reward_weights + if reward_funcs is not None: + if args.reward_weights is not None: + if len(args.reward_weights) != len(self.reward_funcs): + raise ValueError( + f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " + f"functions ({len(self.reward_funcs)})" + ) + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(self.reward_funcs), dtype=torch.float32) + else: + self.reward_weights = None + + if args.missing_eos_penalty is not None and reward_funcs is None and judge is None: + raise ValueError("`missing_eos_penalty` is only supported when `reward_funcs` is provided.") + + if args is None: + raise ValueError("`args` must be provided.") + + # Check that the processing_class is provided + if processing_class is None: + raise ValueError("`processing_class` must be provided.") + + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model, str): + model_id = model + + # Handle dtype in model_init_kwargs + dtype = model_init_kwargs.get("dtype", "auto") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: + pass + elif isinstance(dtype, str): + dtype = getattr(torch, dtype) + model_init_kwargs["dtype"] = dtype + else: + raise ValueError( + "Invalid `dtype` passed to `OnlineDPOConfig`. Expected either 'auto' or a string " + f"representing a `torch.dtype` (e.g., 'float32'), but got {dtype}." + ) + model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto") + + model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) + else: + if args.model_init_kwargs is not None: + raise ValueError( + "You passed `model_init_kwargs` to the `OnlineDPOConfig`, but your model is already instantiated. " + "This argument can only be used when the `model` argument is a string." + ) + self.is_encoder_decoder = model.config.is_encoder_decoder + self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys() + + if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): + model = prepare_peft_model(model, peft_config, args) + + # Enable gradient checkpointing if requested + if args.gradient_checkpointing: + model = self._enable_gradient_checkpointing(model, args) + + # Disable dropout in the model and reference model + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + # Handle the ref_model + # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to + # get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create + # the ref model from the model by copying it and disable the gradients and set it in evaluation mode. + if ref_model is None: # No ref model provided, the most common case + if peft_config is None: + self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode + else: + self.ref_model = None # we don't need a ref model here, we can just disable the adapter. + else: # rare case, the user provided a ref model + self.ref_model = ref_model + self.ref_model.eval() + + # Disable the gradient and set the reward model in eval mode + if reward_funcs is not None: + for reward_func in reward_funcs: + if isinstance(reward_func, PreTrainedModel): + reward_func.eval() + + self.max_length = args.max_length + + self.stats = { + "objective/kl": [], + "objective/entropy": [], + "objective/non_score_reward": [], + "rewards/chosen": [], + "rewards/rejected": [], + "rewards/accuracies": [], + "rewards/margins": [], + "logps/chosen": [], + "logps/rejected": [], + "val/contain_eos_token": [], + "beta": [], + } + if self.reward_funcs is not None: + self.stats["objective/rlhf_reward"] = [] + self.stats["objective/scores_margin"] = [] + self.stats["objective/scores"] = [] + + # Store generation parameters for later use + self.use_vllm = args.use_vllm + self.num_generations = 2 # Generate 2 completions per prompt for Online DPO + self.temperature = args.temperature + self.top_p = args.top_p + self.top_k = args.top_k + self.min_p = args.min_p + self.repetition_penalty = args.repetition_penalty + self.use_transformers_paged = args.use_transformers_paged + self.vllm_mode = args.vllm_mode if args.use_vllm else None + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size + self.vllm_model_impl = args.vllm_model_impl + + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + self.pad_token = tokenizer.pad_token + self.pad_token_id = tokenizer.pad_token_id + self.eos_token_id = tokenizer.eos_token_id + + # Vision tokens for VLM support + self.image_token_id = getattr(processing_class, "image_token_id", None) + self.vision_start_token_id = getattr(processing_class, "vision_start_token_id", None) + self.vision_end_token_id = getattr(processing_class, "vision_end_token_id", None) + # Get the image token string for token collapsing + self.image_token = None + if self.image_token_id is not None: + self.image_token = tokenizer.decode([self.image_token_id]) + + # Define the collator if not provided + if data_collator is None: + data_collator = DPODataCollatorWithPadding(pad_token_id=self.pad_token_id) + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + self._beta = args.beta + + # Set up generation configuration and vLLM after super().__init__ + if self.use_vllm: + if not is_vllm_available(): + raise ImportError( + "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " + "`pip install trl[vllm]` to use it." + ) + + if self.vllm_mode == "server": + if self.accelerator.is_main_process: + if args.vllm_server_base_url is not None: + base_url = args.vllm_server_base_url + else: + base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" + self.vllm_client = VLLMClient( + base_url=base_url, group_port=args.vllm_group_port, connection_timeout=args.vllm_server_timeout + ) + + # Determine device type (supports cuda, xpu, etc.) + accelerator_type = torch.accelerator.current_accelerator().type + current_device = getattr(torch, accelerator_type).current_device() + self.vllm_client.init_communicator(device=current_device) + else: + self.vllm_client = None + elif self.vllm_mode == "colocate": + # vLLM dynamically adjusts the size of the key-value cache based on available GPU memory at instantiation. + # A larger cache size improves speed, so we would expect gpu_memory_utilization=1. + # However, at this stage, the optimizer's weights are not yet loaded onto the GPU; they will be loaded + # after the first optimizer step and remain in GPU memory throughout training. So we must reserve enough + # space for them. + # Configure vLLM parameters + vllm_quantization = None + if is_bitsandbytes_available(): + for _, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + vllm_quantization = "bitsandbytes" + break + elif isinstance(module, bnb.nn.Linear8bitLt): + raise ValueError("vLLM does not support in-flight 8-bit quantization.") + vllm_kwargs = { + "model": model.name_or_path, + "tensor_parallel_size": self.vllm_tensor_parallel_size, + "gpu_memory_utilization": self.vllm_gpu_memory_utilization, + "model_impl": self.vllm_model_impl, + "max_num_seqs": self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size, + "max_model_len": args.max_length + args.max_new_tokens, # max_length includes prompt + completion + "distributed_executor_backend": "external_launcher", + # Feed identical seed for tp groups to ensure sampling results are the same across workers + "seed": self.accelerator.process_index // self.vllm_tensor_parallel_size, + # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) + "max_num_batched_tokens": 4096, + "enable_sleep_mode": self.args.vllm_enable_sleep_mode, + "quantization": vllm_quantization, + } + + # vLLM requires the environment variables to be set for distributed training. + os.environ["RANK"] = str(self.accelerator.process_index) + os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) + os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) + # Ensure distributed rendezvous variables are set without colliding across concurrent runs + ensure_master_addr_port() + + self.llm = LLM(**vllm_kwargs) + if self.args.vllm_enable_sleep_mode: + self.llm.sleep(level=2) + else: + raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.vllm_mode}'.") + # vLLM specific sampling arguments + self.structured_outputs_regex = args.vllm_structured_outputs_regex + self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation + + # Set up vLLM generation config + generation_params = { + "n": 2, # 2 generations per prompt for Online DPO + "repetition_penalty": self.repetition_penalty, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "min_p": 0.0 if self.min_p is None else self.min_p, + "max_tokens": args.max_new_tokens, + "detokenize": False, # to avoid vllm to decode (we don't need it) + } + if args.generation_kwargs is not None: + generation_params.update(args.generation_kwargs) + if self.structured_outputs_regex: + generation_params["structured_outputs"] = StructuredOutputsParams(regex=self.structured_outputs_regex) + self.generation_config = SamplingParams(**generation_params) + + # When using vLLM, the main process is responsible for loading the model weights. This can cause process + # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we + # synchronize all processes after vLLM has been fully initialized. + self.accelerator.wait_for_everyone() + else: + # Set up transformers generation config + generation_kwargs = { + "max_new_tokens": args.max_new_tokens, + "do_sample": True, + "pad_token_id": self.pad_token_id, + "bos_token_id": tokenizer.bos_token_id, + "eos_token_id": self.eos_token_id, + "temperature": self.temperature, + "top_k": self.top_k, + "top_p": self.top_p, + "repetition_penalty": self.repetition_penalty, + "use_cache": True if not self.args.gradient_checkpointing else False, + } + # Add min_p if supported + if self.min_p is not None: + generation_kwargs["min_p"] = self.min_p + if args.generation_kwargs is not None: + generation_kwargs.update(args.generation_kwargs) + # Remove None values + generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} + self.generation_config = GenerationConfig(**generation_kwargs) + # Keep training-specific generation kwargs to overwrite model's original generation config + self.generation_kwargs = generation_kwargs + + if self.ref_model is not None: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + elif self.is_fsdp_enabled: + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + if self.reward_funcs is not None: + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + else: + # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp + self.reward_funcs[i] = self.accelerator.prepare_model( + reward_func, evaluation_mode=True, device_placement=True + ) + + @property + def beta(self): + if isinstance(self._beta, list): + epoch = self.state.epoch + return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1] + else: + return self._beta + + @staticmethod + def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]: + """Tokenize a single row from a DPO specific dataset.""" + if not is_encoder_decoder: + batch = tokenizer(feature["prompt"], add_special_tokens=False) + # Add BOS token to head of prompt. Avoid adding if it's already there + if tokenizer.bos_token_id is not None: + prompt_len_input_ids = len(batch["input_ids"]) + if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]: + batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"] + batch["attention_mask"] = [1] + batch["attention_mask"] + else: + batch = tokenizer(feature["prompt"], add_special_tokens=True) + batch = {f"prompt_{key}": value for key, value in batch.items()} + return batch + + # Same as Trainer.get_train_dataloader but skip the "remove_unused_columns". + @wraps(Trainer.get_train_dataloader) + def get_train_dataloader(self) -> DataLoader: + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + # Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns". + @wraps(Trainer.get_eval_dataloader) + def get_eval_dataloader(self, eval_dataset: str | Dataset | None = None) -> DataLoader: + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + + # If we have persistent workers, don't do a fork bomb especially as eval datasets + # don't change during training + dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval" + if ( + hasattr(self, "_eval_dataloaders") + and dataloader_key in self._eval_dataloaders + and self.args.dataloader_persistent_workers + ): + return self.accelerator.prepare(self._eval_dataloaders[dataloader_key]) + + eval_dataset = ( + self.eval_dataset[eval_dataset] + if isinstance(eval_dataset, str) + else eval_dataset + if eval_dataset is not None + else self.eval_dataset + ) + data_collator = self.data_collator + + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(eval_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + # accelerator.free_memory() will destroy the references, so + # we need to store the non-prepared version + eval_dataloader = DataLoader(eval_dataset, **dataloader_params) + if self.args.dataloader_persistent_workers: + if hasattr(self, "_eval_dataloaders"): + self._eval_dataloaders[dataloader_key] = eval_dataloader + else: + self._eval_dataloaders = {dataloader_key: eval_dataloader} + + return self.accelerator.prepare(eval_dataloader) + + def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: OnlineDPOConfig) -> PreTrainedModel: + """Enables gradient checkpointing for the model.""" + # Ensure use_cache is disabled + model.config.use_cache = False + + # Enable gradient checkpointing on the base model for PEFT + if is_peft_model(model): + model.base_model.gradient_checkpointing_enable() + # Enable gradient checkpointing for non-PEFT models + else: + model.gradient_checkpointing_enable() + + model.enable_input_require_grads() + return model + + def _generate_vllm(self, prompts, images=None): + eos_token_id = self.eos_token_id + pad_token_id = self.pad_token_id + + # Generate completion_ids and prompt_ids based on mode + if self.vllm_mode == "server": + completion_ids, prompt_ids = self._generate_vllm_server(prompts, images) + elif self.vllm_mode == "colocate": + completion_ids, prompt_ids = self._generate_vllm_colocate(prompts, images) + + # Shared padding, masking, and tensor conversion logic + max_prompt_length = max(len(ids) for ids in prompt_ids) + prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids] + prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids] + max_tokens = self.generation_config.max_tokens + completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids] + completion_ids = [ + ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids + for ids in completion_ids + ] + completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids] + + # Convert to tensors + prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device) + prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device) + completion_ids = torch.tensor(completion_ids, device=self.accelerator.device) + completion_mask = torch.tensor(completion_mask, device=self.accelerator.device) + + return prompt_ids, prompt_mask, completion_ids, completion_mask + + def _generate_vllm_server(self, prompts, images=None): + """Generate completions using vLLM server mode""" + has_images = images is not None + + # Update vLLM server weights if needed + if hasattr(self, "_last_loaded_step") and self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step + elif not hasattr(self, "_last_loaded_step"): + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step + + # Apply chat template if conversational + if is_conversational({"prompt": prompts[0]}): + prompts_text = [apply_chat_template({"prompt": p}, self.processing_class)["prompt"] for p in prompts] + else: + prompts_text = prompts + # Gather all prompts to main process + all_prompts = gather_object(prompts_text) + if has_images: + all_images = gather_object(images) + + if self.accelerator.is_main_process: + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually. + ordered_set_of_prompts = all_prompts[:: self.num_generations] + if has_images: + ordered_set_of_images = all_images[:: self.num_generations] + else: + ordered_set_of_images = None + completion_ids = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + images=ordered_set_of_images, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.generation_config.max_tokens, + structured_outputs_regex=self.structured_outputs_regex + if hasattr(self, "structured_outputs_regex") + else None, + generation_kwargs=self.args.generation_kwargs, + )["completion_ids"] + # Flatten: each prompt generates 2 completions + completion_ids = [[comp_id] for prompt_completions in completion_ids for comp_id in prompt_completions] + else: + completion_ids = [None] * (len(all_prompts) * 2) + + # Broadcast completions to all processes + completion_ids = broadcast_object_list(completion_ids, from_process=0) + + # Each process takes its slice + process_slice = slice( + self.accelerator.process_index * len(prompts) * 2, + (self.accelerator.process_index + 1) * len(prompts) * 2, + ) + completion_ids = completion_ids[process_slice] + + # Create prompt_ids by tokenizing locally + prompt_inputs = self.processing_class( + text=prompts_text, + return_tensors="pt", + padding=True, + padding_side="left", + add_special_tokens=False, + ) + prompt_ids = [] + for prompt_tokens in prompt_inputs["input_ids"]: + prompt_ids.extend([prompt_tokens.tolist(), prompt_tokens.tolist()]) # 2 copies for 2 completions + return completion_ids, prompt_ids + + def _generate_vllm_colocate(self, prompts, images=None): + """Generate completions using vLLM colocate mode""" + if self.args.vllm_enable_sleep_mode: + # wake up colocated vLLM instances if needed + torch.cuda.empty_cache() # required to avoid OOM in some cases + self.llm.wake_up(tags=["weights"]) + + # Update model weights if needed - only after gradient accumulation completes + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step + + # Apply chat template if conversational + if is_conversational({"prompt": prompts[0]}): + prompts_text = [apply_chat_template({"prompt": p}, self.processing_class)["prompt"] for p in prompts] + else: + prompts_text = prompts + + # Prepare vLLM inputs with images if available + if images is not None: + vllm_inputs = [] + for prompt, image in zip(prompts_text, images, strict=True): + if image is not None: + vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) + else: + vllm_inputs.append(prompt) + else: + vllm_inputs = prompts_text + + if self.args.vllm_enable_sleep_mode: + self.llm.wake_up(tags=["kv_cache"]) + + outputs = self.llm.generate(vllm_inputs, self.generation_config, use_tqdm=False) + + completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs] + prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs] + if self.args.vllm_enable_sleep_mode: + self.llm.sleep(level=2) + + return completion_ids, prompt_ids + + def _sync_fsdp2_params_to_vllm(self, module: nn.Module): + # For FSDP2, module.state_dict() already covers all parameters, so no need for recursion + for name, param in module.state_dict().items(): + # When using PEFT, we need to recover the original parameter name + name = name.removeprefix("base_model.model.").replace(".base_layer", "") + # Skip PEFT layers: they donโ€™t exist in vLLM, and they are merged already. + if is_peft_model(module) and module.prefix in name: + continue + # When module to save, remove its prefix and discard the original module + if "original_module" in name: + continue + name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) + + if param.is_cpu: + param = param.to(torch.device("cuda")) + param = param.full_tensor() + + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param)]) + + def _move_model_to_vllm(self): + # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 + if zero_stage_3: + import deepspeed + + gather_if_zero3 = deepspeed.zero.GatheredParameters + else: + gather_if_zero3 = nullcontext + + if is_peft_model(self.model): + # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as + # merging adapters in a sharded manner is not supported. + # TODO: does this work with FSDP? + with gather_if_zero3(list(self.model.parameters())): + self.model.merge_adapter() + + # Update vLLM weights while parameters are gathered + if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext + # Update vLLM weights while parameters are gathered + # For PEFT with FSDP we need to use the memory efficient post-order traversal + fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) + fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 + if fsdp_version == 1: + self._sync_fsdp1_params_to_vllm( + self.model + ) # use memory-efficient post-order traversal for FSDP + elif fsdp_version == 2: + self._sync_fsdp2_params_to_vllm(self.model) + else: + # DeepSpeed ZeRO-3 with PEFT + for name, param in self.model.named_parameters(): + # When using PEFT, we need to recover the original parameter name + name = name.removeprefix("base_model.model.").replace(".base_layer", "") + # Skip PEFT layers: they donโ€™t exist in vLLM, and they are merged already. + if self.model.prefix in name: + continue + # When module to save, remove its prefix and discard the original module + if "original_module" in name: + continue + name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) + + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + # Unmerge adapters while parameters are still gathered + self.model.unmerge_adapter() + # Parameters will automatically be repartitioned when exiting the context + else: + # For non-PEFT models, simply gather (if needed) and update each parameter individually. + if self.is_fsdp_enabled: + fsdp_plugin = getattr(self.accelerator.state, "fsdp_plugin", None) + fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 + if fsdp_version == 1: + self._sync_fsdp1_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP + elif fsdp_version == 2: + self._sync_fsdp2_params_to_vllm(self.model) + else: + for name, param in self.model.named_parameters(): + name = self._fix_param_name_to_vllm(name) + with gather_if_zero3([param]): + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + + # Reset cache on vLLM + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.reset_prefix_cache() + elif self.vllm_mode == "colocate": + self.llm.reset_prefix_cache() + + def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): + """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" + # For FSDP1, we need to recurse into children and also use summon_full_params + if visited is None: + visited = set() + for child_name, child_module in module.named_children(): + child_prefix = f"{prefix}.{child_name}" if prefix else child_name + self._sync_fsdp1_params_to_vllm( + child_module, prefix=child_prefix, visited=visited + ) # recurse into the child + + if isinstance(module, FSDP): + with FSDP.summon_full_params(module, recurse=False, writeback=False): + for param_name, param in module.named_parameters(): + full_name = f"{prefix}.{param_name}" if prefix else param_name + full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."]) + + if full_name in visited: + continue # skip FSDP subtrees already traversed + visited.add(full_name) + + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(full_name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(full_name, param.data)]) + + def _fix_param_name_to_vllm(self, name, extra_prefixes: list[str] | None = None): + """Clean parameter names for vLLM compatibility""" + extra_prefixes = extra_prefixes or [] + prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes + for prefix in prefixes: + name = name.replace(prefix, "") + return name + + def process_vision_row( + self, features: dict[str, list | torch.Tensor], processing_class=None + ) -> dict[str, list[int]]: + """ + Process a vision row for VLM models (adapted from DPO trainer) + """ + processor = processing_class or self.processing_class + processed_features = processor(images=[features["image"]], text=features["prompt"], add_special_tokens=False) + + prompt_input_ids = processed_features["input_ids"][0] + + # Create the output dict with required fields + output = { + "prompt_input_ids": prompt_input_ids, + "prompt_attention_mask": processed_features["attention_mask"][0], + } + + # Add vision-specific fields + if "pixel_values" in processed_features: + output["pixel_values"] = processed_features["pixel_values"][0] + if "pixel_attention_mask" in processed_features: + output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0] + if "image_sizes" in processed_features: + output["image_sizes"] = processed_features["image_sizes"][0] + + return output + + def _generate(self, model, prompts, images=None): + """Generate completions using the model""" + device = next(model.parameters()).device + eos_token_id = self.eos_token_id + pad_token_id = self.pad_token_id + + # Apply chat template and tokenize the input + inputs = [{"prompt": prompt} for prompt in prompts] + + # Add images if provided (VLM support) + if images is not None: + for i, image in enumerate(images): + inputs[i]["image"] = image + + # Apply chat template to get text prompts + prompts_text = [maybe_apply_chat_template(x, self.processing_class)["prompt"] for x in inputs] + + # Handle image token collapsing/removal + # The chat template sometimes inserts a single image token into the prompt text. However, when this text is + # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the + # image size. We need to handle this properly. + if self.image_token is not None and images is not None: + escaped_img_token = re.escape(self.image_token) + # Search for the image token in the chat template + if hasattr(self.processing_class, "chat_template") and self.processing_class.chat_template: + if re.search(escaped_img_token, self.processing_class.chat_template): + # Collapse repeated image tokens back into a single token + prompts_text = [ + re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text + ] + else: + # If the chat template doesn't use the image token, remove all instances + if self.vision_end_token_id is not None: + escaped_eoi_token = re.escape( + self.processing_class.tokenizer.decode([self.vision_end_token_id]) + ) + prompts_text = [ + re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text + ] + else: + # If vision_end_token_id is None, just remove the image tokens + prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text] + + # Prepare kwargs for processing class + kwargs = {} + if images is not None: + kwargs = {"images": [[img] for img in images]} + + # Process inputs using the processing class (handles both VLM and LLM) + prompt_inputs = self.processing_class( + text=prompts_text, + return_tensors="pt", + padding=True, + padding_side="left", + add_special_tokens=False, + **kwargs, + ) + + prompt_inputs = {k: v.to(device) for k, v in prompt_inputs.items()} + # Convert vision inputs to model's dtype for proper computation + if "pixel_values" in prompt_inputs: + # Handle DataParallel wrapped models + model_dtype = getattr(model, "dtype", None) + if model_dtype is None and hasattr(model, "module"): + model_dtype = model.module.dtype + if model_dtype is not None: + prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].to(model_dtype) + + # Sample 2 completions per prompt of size `max_new_tokens` from the model + prompt_ids = prompt_inputs["input_ids"].repeat(2, 1) + prompt_mask = prompt_inputs["attention_mask"].repeat(2, 1) + + # Prepare vision inputs if available + vision_generation_kwargs = {} + if self.is_vision_model and images is not None: + if "pixel_values" in prompt_inputs: + vision_generation_kwargs["pixel_values"] = prompt_inputs["pixel_values"].repeat(2, 1, 1, 1) + if "pixel_attention_mask" in prompt_inputs: + vision_generation_kwargs["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"].repeat(2, 1) + if "image_sizes" in prompt_inputs: + vision_generation_kwargs["image_sizes"] = prompt_inputs["image_sizes"].repeat(2, 1) + if "image_grid_thw" in prompt_inputs: + vision_generation_kwargs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(2, 1) + + if self.use_transformers_paged: + previous_attn = self.model_wrapped.config._attn_implementation + + if Version(transformers.__version__).release >= Version("5.0.0").release: + new_attn = "paged|flash_attention_2" if is_flash_attn_2_available() else "paged|sdpa" + else: + new_attn = "paged_attention" if is_flash_attn_2_available() else "sdpa_paged" + self.model_wrapped.config._attn_implementation = new_attn + with ( + profiling_context(self, "transformers.generate_batch"), + unwrap_model_for_generation( + model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + # Cast to the appropriate dtype based on training configuration + if self.args.bf16: + unwrapped_model.to(torch.bfloat16) + elif self.args.fp16: + unwrapped_model.to(torch.float16) + with torch.inference_mode(): + all_outputs = unwrapped_model.generate_batch( + prompt_ids.tolist(), + generation_config=self.generation_config, + progress_bar=False, + ) + unwrapped_model.train() # restore training mode, as generate_batch forces eval mode + completion_ids = [output.generated_tokens for output in all_outputs.values()] + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + # Restore the original attention implementation, training mode + self.model_wrapped.config._attn_implementation = previous_attn + + # Extract completion_ids and create completion_mask + prompt_length = prompt_ids.size(1) + completion_ids = prompt_completion_ids[:, prompt_length:] + completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id) + + return prompt_ids, prompt_mask, completion_ids, completion_mask + else: + # Regular generation path + with ( + profiling_context(self, "transformers.generate"), + unwrap_model_for_generation( + model, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + # Setup cache implementation if specified + if self.args.cache_implementation is not None: + unwrapped_model.generation_config.cache_implementation = self.args.cache_implementation + + # Standard generation + output = unwrapped_model.generate( + input_ids=prompt_ids, + attention_mask=prompt_mask, + generation_config=self.generation_config, + **vision_generation_kwargs, + ) + + completion_ids = output[:, prompt_ids.size(1) :] + completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id) + + return prompt_ids, prompt_mask, completion_ids, completion_mask + + def _calculate_rewards_from_functions(self, prompts, completions, completion_ids_list, **reward_kwargs): + """ + Calculate rewards using reward functions + """ + device = self.accelerator.device + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + + # Add trainer state to reward kwargs for dynamic reward shaping + reward_kwargs["trainer_state"] = self.state + + for i, (reward_func, reward_processing_class) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes, strict=True) + ): + if isinstance(reward_func, nn.Module): # Model-based reward function + # Handle conversational vs text input + if is_conversational({"prompt": prompts[0]}): + messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] + texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] + else: + texts = [p + c for p, c in zip(prompts, completions, strict=True)] + + # Tokenize and get reward scores + reward_inputs = reward_processing_class( + text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + reward_inputs = {k: v.to(device) for k, v in reward_inputs.items()} + + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + else: + # Custom reward function + output_reward_func = reward_func( + prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs + ) + # Convert None values to NaN + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # Weight and sum across all reward functions + if self.reward_weights is not None: + total_rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + else: + total_rewards = rewards_per_func.nansum(dim=1) + + return total_rewards + + def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs=None): + # Get the number of tokens to truncate from prompt + num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0) + + # Truncate left to avoid oom + prompt_ids = prompt_ids[:, num_tokens_to_truncate:] + prompt_mask = prompt_mask[:, num_tokens_to_truncate:] + + # Concat the prompt and completion + prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1) + prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1) + + # Prepare model kwargs with vision inputs if available + model_kwargs = {"attention_mask": prompt_completion_mask} + if vision_inputs is not None: + if "pixel_values" in vision_inputs: + model_kwargs["pixel_values"] = vision_inputs["pixel_values"] + if "pixel_attention_mask" in vision_inputs: + model_kwargs["pixel_attention_mask"] = vision_inputs["pixel_attention_mask"] + if "image_sizes" in vision_inputs: + model_kwargs["image_sizes"] = vision_inputs["image_sizes"] + if "image_grid_thw" in vision_inputs: + model_kwargs["image_grid_thw"] = vision_inputs["image_grid_thw"] + + # Get the logprobs of the completions from the model + output = model(prompt_completion_ids, **model_kwargs) + + # There is 1 offset, because the model predicts the next token + prompt_len = prompt_ids.size(1) + start_idx = prompt_len - 1 if prompt_len > 0 else 0 + # Only slice off the last logit when we have a prompt, otherwise we need all logits + end_idx = -1 if prompt_len > 0 else None + logits = output.logits[:, start_idx:end_idx] + + # Take the completion tokens logprob + logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1) + return logprobs + + def training_step( + self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None + ) -> torch.Tensor: + model.train() + + prompts = inputs["prompt"] + batch_size = len(prompts) + + # Handle images for VLM support + has_images = "image" in inputs + images = None + if has_images: + images = inputs["image"] + # Convert conversational prompts to include image tokens + for prompt in prompts: + if isinstance(prompt, list): + for message in prompt: + if not isinstance(message, dict): + continue + content = message.get("content") + role = message.get("role") + if isinstance(content, str): + if role == "user": + message["content"] = [{"type": "image"}, {"type": "text", "text": content}] + elif role == "system": + message["content"] = [{"type": "text", "text": content}] + + if self.args.use_vllm: + prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(prompts, images) + else: + prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts, images) + + contain_eos_token = torch.any(completion_ids == self.eos_token_id, dim=-1) + + # Extract vision inputs if available for VLM support + vision_inputs = None + if has_images and self.is_vision_model and not self.args.use_vllm: + # For vision models with transformers generation, we need to prepare vision inputs + # Process the images to get vision inputs that can be passed through the forward pass + vision_inputs = {} + kwargs = {"images": [[img] for img in images]} + processed = self.processing_class( + text=[""] * len(images), # Dummy text for vision processing + return_tensors="pt", + **kwargs, + ) + # Handle DataParallel wrapped models + model_device = getattr(model, "device", None) + model_dtype = getattr(model, "dtype", None) + if model_device is None and hasattr(model, "module"): + model_device = model.module.device + model_dtype = model.module.dtype + # Move vision tensors to device and convert to model dtype + # Need to duplicate for 2 completions per prompt + if "pixel_values" in processed: + vision_inputs["pixel_values"] = ( + processed["pixel_values"].to(model_device, dtype=model_dtype).repeat(2, 1, 1, 1) + ) + if "pixel_attention_mask" in processed: + vision_inputs["pixel_attention_mask"] = processed["pixel_attention_mask"].to(model_device).repeat(2, 1) + if "image_sizes" in processed: + vision_inputs["image_sizes"] = processed["image_sizes"].to(model_device).repeat(2, 1) + if "image_grid_thw" in processed: + vision_inputs["image_grid_thw"] = processed["image_grid_thw"].to(model_device).repeat(2, 1) + + logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs) + with torch.no_grad(): + if self.ref_model is not None: + ref_logprobs = self._forward( + self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs + ) + else: # peft case: we just need to disable the adapter + with self.model.disable_adapter(): + ref_logprobs = self._forward( + self.model, prompt_ids, prompt_mask, completion_ids, completion_mask, vision_inputs + ) + + # Decode the completions, and format them if the input is conversational + device = logprobs.device + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + if is_conversational({"prompt": prompts[0]}): + completions = [[{"role": "assistant", "content": completion}] for completion in completions] + + # Get the reward from reward functions or judge + if self.reward_funcs is not None: + # First create completion_ids_list for custom reward functions + completion_ids_list = [completion_ids[i].tolist() for i in range(completion_ids.shape[0])] + + # Extract additional fields from inputs for reward functions + reward_kwargs = {} + keys = [key for key in inputs if key not in ["prompt"]] + for key in keys: + if isinstance(inputs[key], (list, tuple)): + # Repeat input fields to match number of completions (2 per prompt) + reward_kwargs[key] = inputs[key] * 2 + else: + reward_kwargs[key] = inputs[key] + + # Calculate rewards using reward functions + rewards = self._calculate_rewards_from_functions( + prompts=2 * prompts, completions=completions, completion_ids_list=completion_ids_list, **reward_kwargs + ) + + # Apply missing EOS penalty if configured + if self.args.missing_eos_penalty is not None: + rewards[~contain_eos_token] -= self.args.missing_eos_penalty + + # Split rewards into chosen/rejected pairs + first_half, second_half = rewards.split(batch_size) + mask = first_half >= second_half + elif self.judge is not None: + # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not + # directly understandable by the judge and could alter its judgment. To avoid this and make the judge + # independent of the model's chat template, we use the raw conversation data, and apply our own chat + # template to it. + if is_conversational({"prompt": prompts[0]}): + environment = jinja2.Environment() + template = environment.from_string(SIMPLE_CHAT_TEMPLATE) + prompts = [template.render(messages=prompt) for prompt in prompts] + completions = [template.render(messages=completion) for completion in completions] + + ranks_of_first_completion = self.judge.judge( + prompts, list(zip(completions[:batch_size], completions[batch_size:], strict=True)) + ) + + # convert ranks to a True/False mask: + # when rank == 0, it means the first completion is the best + # when rank == 1, it means the second completion is the best + mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device) + + batch_range = torch.arange(batch_size, device=device) + chosen_indices = batch_range + (~mask * batch_size) + rejected_indices = batch_range + (mask * batch_size) + + # Build tensor so that the first half is the chosen examples and the second half the rejected examples + cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected + cr_logprobs = logprobs[cr_indices] + cr_ref_logprobs = ref_logprobs[cr_indices] + + # mask out the padding tokens + padding_mask = ~completion_mask.bool() + cr_padding_mask = padding_mask[cr_indices] + + cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1) + cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1) + + # Split the chosen and rejected examples + chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size) + chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size) + pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum + ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum + + logits = pi_logratios - ref_logratios + + if self.args.loss_type == "sigmoid": + losses = -F.logsigmoid(self.beta * logits) + elif self.args.loss_type == "ipo": + losses = (logits - 1 / (2 * self.beta)) ** 2 + else: + raise NotImplementedError(f"invalid loss type {self.loss_type}") + + loss = losses.mean() + + # Log everything + if self.reward_funcs is not None: + # When using reward_funcs, we have rewards instead of scores + scores_margin = rewards[chosen_indices] - rewards[rejected_indices] + self.stats["objective/scores_margin"].append( + self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item() + ) + self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(rewards.mean()).mean().item()) + self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item()) + self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item()) + self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item()) + + kl = logprobs - ref_logprobs + mean_kl = kl.sum(1).mean() + self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) + non_score_reward = (-self.beta * kl).sum(1) + mean_non_score_reward = non_score_reward.mean() + self.stats["objective/non_score_reward"].append( + self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() + ) + if self.reward_funcs is not None: + # Calculate RLHF reward by combining rewards with non_score_reward + rlhf_reward = rewards + non_score_reward + self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item()) + + mean_entropy = -logprobs.sum(1).mean() + self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item()) + chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) + gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards) + self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item()) + rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) + gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards) + self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item()) + margin = gathered_chosen_rewards - gathered_rejected_rewards + self.stats["rewards/margins"].append(margin.mean().item()) + accuracy = margin > 0 + self.stats["rewards/accuracies"].append(accuracy.float().mean().item()) + self.stats["beta"].append(self.beta) + + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + empty_cache() + + kwargs = {} + + # For LOMO optimizers you need to explicitly use the learning rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + self.accelerator.backward(loss, **kwargs) + + return loss.detach() / self.args.gradient_accumulation_steps + + # Same as Trainer._maybe_log_save_evaluate but log our metrics + def _maybe_log_save_evaluate( + self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None + ): + if self.control.should_log and self.state.global_step > self._globalstep_last_logged: + logs: dict[str, float] = {} + + # all_gather + mean() to get average loss over all processes + if Version(transformers.__version__) >= Version("5.2.0.dev0"): + tr_loss_scalar = nested_gather(tr_loss, self.args.parallel_mode).mean().item() + else: + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + if grad_norm is not None: + logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm + if learning_rate is not None: + logs["learning_rate"] = learning_rate + else: + logs["learning_rate"] = self._get_learning_rate() + + # Add our metrics + for key, val in self.stats.items(): + logs[key] = sum(val) / len(val) + self.stats = {key: [] for key in self.stats} # reset stats + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + self.log(logs, start_time) + + metrics = None + if self.control.should_evaluate: + metrics = self._evaluate(trial, ignore_keys_for_eval) + is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) + + if self.args.save_strategy == "best": + self.control.should_save = is_new_best_metric + + if self.control.should_save: + self._save_checkpoint(model, trial) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/ICL/RL/trl_source/trl/experimental/openenv/__init__.py b/ICL/RL/trl_source/trl/experimental/openenv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4325e17f284102bd02848b432cd9d9ffedd32f58 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/openenv/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .utils import generate_rollout_completions + + +__all__ = ["generate_rollout_completions"] diff --git a/ICL/RL/trl_source/trl/experimental/openenv/utils.py b/ICL/RL/trl_source/trl/experimental/openenv/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5c4c710132bc763b36b07581c4105d495b6fc1fb --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/openenv/utils.py @@ -0,0 +1,209 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...data_utils import is_conversational +from ...extras.profiling import profiling_context +from ...import_utils import is_vllm_available + + +if is_vllm_available(): + from vllm import SamplingParams + from vllm.sampling_params import StructuredOutputsParams + + +def _build_base_generation_kwargs( + trainer, + overrides: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Build base generation kwargs common to both colocate and server modes.""" + generation_kwargs: dict[str, Any] = { + "n": 1, + "temperature": trainer.temperature, + "top_k": trainer.top_k, + "min_p": 0.0 if trainer.min_p is None else trainer.min_p, + "max_tokens": trainer.max_completion_length, + } + if trainer.repetition_penalty is not None: + generation_kwargs["repetition_penalty"] = trainer.repetition_penalty + if trainer.top_p is not None: + generation_kwargs["top_p"] = trainer.top_p + + if trainer.args.generation_kwargs is not None: + generation_kwargs.update(trainer.args.generation_kwargs) + + if overrides is not None: + generation_kwargs.update(overrides) + + generation_kwargs = {key: value for key, value in generation_kwargs.items() if value is not None} + + if generation_kwargs.get("n", 1) != 1: + raise ValueError("generate_rollout_completions expects n=1.") + + return generation_kwargs + + +def _build_colocate_sampling_params( + trainer, + overrides: dict[str, Any] | None = None, + *, + logprobs: bool = True, +) -> "SamplingParams": + """Build SamplingParams for colocate mode.""" + generation_kwargs = _build_base_generation_kwargs(trainer, overrides) + + # Add colocate-specific parameters + if trainer.vllm_generation.structured_outputs_regex: + generation_kwargs["structured_outputs"] = StructuredOutputsParams( + regex=trainer.vllm_generation.structured_outputs_regex + ) + if logprobs: + generation_kwargs["logprobs"] = 0 + + return SamplingParams(**generation_kwargs) + + +def _build_server_generation_kwargs( + trainer, + overrides: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Build generation kwargs for server mode.""" + return _build_base_generation_kwargs(trainer, overrides) + + +def generate_rollout_completions( + trainer, + prompts: list[str], + *, + generation_overrides: dict[str, Any] | None = None, + as_chat: bool | None = None, +) -> list[dict[str, Any]]: + """ + Generate completions for custom rollouts when vLLM is running in colocate or server mode. + + Returns one result per prompt, containing prompt and completion token ids along with per-token log probabilities + and the generated text. + """ + + if not prompts: + return [] + + if not trainer.use_vllm: + raise RuntimeError("Custom rollouts require vLLM to call generate_rollout_completions.") + + if trainer.vllm_mode == "server": + return _generate_rollout_completions_server(trainer, prompts, generation_overrides, as_chat) + elif trainer.vllm_mode == "colocate": + return _generate_rollout_completions_colocate(trainer, prompts, generation_overrides, as_chat) + else: + raise ValueError(f"vllm_mode must be 'server' or 'colocate', got '{trainer.vllm_mode}'") + + +def _generate_rollout_completions_server( + trainer, + prompts: list[str], + generation_overrides: dict[str, Any] | None = None, + as_chat: bool | None = None, +) -> list[dict[str, Any]]: + """Generate completions using vLLM server mode.""" + generation_kwargs = _build_server_generation_kwargs(trainer, generation_overrides) + + if as_chat is None: + as_chat = prompts and is_conversational({"prompt": prompts[0]}) + + with profiling_context(trainer, "vLLM.generate_rollout_server"): + if as_chat: + # For chat mode, we need to pass messages format + # Since prompts are already formatted strings, we use generate instead + output = trainer.vllm_generation.vllm_client.generate(prompts=prompts, **generation_kwargs) + else: + output = trainer.vllm_generation.vllm_client.generate(prompts=prompts, **generation_kwargs) + + # Format results to match colocate output format + results: list[dict[str, Any]] = [] + for i in range(len(prompts)): + results.append( + { + "prompt_ids": output["prompt_ids"][i], + "completion_ids": list(output["completion_ids"][i]), + "logprobs": list(output["logprobs"][i]), + "text": trainer.processing_class.decode(output["completion_ids"][i], skip_special_tokens=True), + } + ) + + return results + + +def _generate_rollout_completions_colocate( + trainer, + prompts: list[str], + generation_overrides: dict[str, Any] | None = None, + as_chat: bool | None = None, +) -> list[dict[str, Any]]: + """Generate completions using vLLM colocate mode.""" + sampling_params = _build_colocate_sampling_params(trainer, generation_overrides) + prompts_for_generation = prompts + original_size = len(prompts) + + if trainer.vllm_tensor_parallel_size > 1: + gathered_prompts = [None for _ in range(trainer.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_prompts, prompts, group=trainer.vllm_generation.tp_group) + prompts_for_generation = [prompt for group_prompts in gathered_prompts for prompt in group_prompts] + + if as_chat is None: + as_chat = prompts_for_generation and is_conversational({"prompt": prompts_for_generation[0]}) + + if trainer.args.vllm_enable_sleep_mode: + trainer.vllm_generation.llm.wake_up(tags=["kv_cache"]) + # Work around for https://github.com/vllm-project/vllm/issues/29341 + trainer.vllm_generation.llm.collective_rpc("reload_weights") + + with profiling_context(trainer, "vLLM.generate_rollout"): + if as_chat: + vllm_outputs = trainer.vllm_generation.llm.chat( + prompts_for_generation, sampling_params=sampling_params, use_tqdm=False + ) + else: + vllm_outputs = trainer.vllm_generation.llm.generate( + prompts_for_generation, sampling_params=sampling_params, use_tqdm=False + ) + + results: list[dict[str, Any]] = [] + for request in vllm_outputs: + if not request.outputs: + results.append({"prompt_ids": request.prompt_token_ids, "completion_ids": [], "logprobs": [], "text": ""}) + continue + sequence = request.outputs[0] + logprobs = [next(iter(token_logprob.values())).logprob for token_logprob in sequence.logprobs] + results.append( + { + "prompt_ids": request.prompt_token_ids, + "completion_ids": sequence.token_ids, + "logprobs": logprobs, + "text": sequence.text, + } + ) + + if trainer.vllm_tensor_parallel_size > 1: + local_rank_in_group = torch.distributed.get_rank(group=trainer.vllm_generation.tp_group) + tp_slice = slice(local_rank_in_group * original_size, (local_rank_in_group + 1) * original_size) + results = results[tp_slice] + + if trainer.args.vllm_enable_sleep_mode: + trainer.vllm_generation.llm.sleep(level=2) + + return results diff --git a/ICL/RL/trl_source/trl/experimental/orpo/__init__.py b/ICL/RL/trl_source/trl/experimental/orpo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..88abd9e1826eae8a0d201b84a048614059b9e264 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/orpo/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .orpo_config import ORPOConfig +from .orpo_trainer import ORPOTrainer + + +__all__ = ["ORPOConfig", "ORPOTrainer"] diff --git a/ICL/RL/trl_source/trl/experimental/orpo/orpo_config.py b/ICL/RL/trl_source/trl/experimental/orpo/orpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d7a4c2996cd538ba5c0a07001e31da51f43f62 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/orpo/orpo_config.py @@ -0,0 +1,162 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from transformers import TrainingArguments + + +@dataclass +class ORPOConfig(TrainingArguments): + r""" + Configuration class for the [`experimental.orpo.ORPOTrainer`]. + + This class includes only the parameters that are specific to ORPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want + to use the default data collator. + max_completion_length (`int`, *optional*): + Maximum length of the completion. This argument is required if you want to use the default data collator + and your model is an encoder-decoder. + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the relative ratio loss weight in the ORPO loss. In the + [paper](https://huggingface.co/papers/2403.07691), it is denoted by ฮป. In the + [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + padding_value (`int`, *optional*): + Padding value to use. If `None`, the padding value of the tokenizer is used. + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. + This argument is required if you want to use the default data collator. + generate_during_eval (`bool`, *optional*, defaults to `False`): + If `True`, generates and logs completions from the model to W&B or Comet during evaluation. + is_encoder_decoder (`bool`, *optional*): + When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, + you need to specify if the model returned by the callable is an encoder-decoder model. + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a + string. + dataset_num_proc (`int`, *optional*): + Number of processes to use for processing the dataset. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue + # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary + # workaround here, which can be removed once we drop support for versions older than 4.57.5. + lr_scheduler_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " + "restarts." + }, + ) + + max_length: int | None = field( + default=1024, + metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, + ) + max_completion_length: int | None = field( + default=None, + metadata={ + "help": "Maximum length of the completion. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the relative ratio loss weight in the ORPO loss. In the paper, it is " + "denoted by ฮป." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) + padding_value: int | None = field( + default=None, + metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the prompt is too long.", + "choices": ["keep_end", "keep_start"], + }, + ) + generate_during_eval: bool = field( + default=False, + metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."}, + ) + is_encoder_decoder: bool | None = field( + default=None, + metadata={ + "help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` " + "argument, you need to specify if the model returned by the callable is an encoder-decoder model." + }, + ) + model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " + "from a string." + }, + ) + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + super().__post_init__() diff --git a/ICL/RL/trl_source/trl/experimental/orpo/orpo_trainer.py b/ICL/RL/trl_source/trl/experimental/orpo/orpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3a39e6d344097d94fce4952ba202b44f047608d5 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/orpo/orpo_trainer.py @@ -0,0 +1,1028 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import random +import textwrap +from collections import defaultdict +from collections.abc import Callable +from contextlib import nullcontext +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformers +from accelerate import PartialState, logging +from datasets import Dataset +from packaging.version import Version +from torch import autocast +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, + is_comet_available, + is_torch_xla_available, + is_wandb_available, +) +from transformers.trainer_utils import EvalLoopOutput +from transformers.utils import is_peft_available, is_torch_fx_proxy + +from ...data_utils import maybe_apply_chat_template, maybe_extract_prompt +from ...models.utils import peft_module_casting_to_bf16 +from ...trainer.base_trainer import BaseTrainer +from ...trainer.utils import ( + disable_dropout_in_model, + log_table_to_comet_experiment, + pad_to_length, + selective_log_softmax, +) +from ..utils import DPODataCollatorWithPadding, add_bos_token_if_needed, add_eos_token_if_needed +from .orpo_config import ORPOConfig + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +if is_wandb_available(): + import wandb + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + +logger = logging.get_logger(__name__) + + +def log1mexp(x: torch.FloatTensor) -> torch.FloatTensor: + """Numerically stable computation of log(1-exp(x)).""" + # branch at -ln 2 ~ -0.693 to avoid cancellation + t = -0.6931471805599453 + return torch.where(x < t, torch.log1p(-torch.exp(x)), torch.log(-torch.expm1(x))) + + +class ORPOTrainer(BaseTrainer): + r""" + Initialize ORPOTrainer. + + Args: + model ([`~transformers.PreTrainedModel`]): + The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`]. + args ([`experimental.orpo.ORPOConfig`]): + The ORPO config arguments to use for training. + data_collator ([`~transformers.DataCollator`]): + The data collator to use for training. If None is specified, the default data collator + ([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the + maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset ([`~datasets.Dataset`]): + The dataset to use for training. + eval_dataset ([`~datasets.Dataset`]): + The dataset to use for evaluation. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be + used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in + a PEFT model. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to + metric values. + """ + + _tag_names = ["trl", "orpo"] + _name = "ORPO" + _paper = { + "title": "ORPO: Monolithic Preference Optimization without Reference Model", + "id": "2403.07691", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @article{hong2024orpo, + title = {{ORPO: Monolithic Preference Optimization without Reference Model}}, + author = {Jiwoo Hong and Noah Lee and James Thorne}, + year = 2024, + eprint = {arXiv:2403.07691} + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module | str | None = None, + args: ORPOConfig | None = None, + data_collator: DataCollator | None = None, + train_dataset: Dataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + model_init: Callable[[], PreTrainedModel] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: dict | None = None, + compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, + ): + if args.model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.") + else: + model_init_kwargs = args.model_init_kwargs + dtype = model_init_kwargs.get("dtype", "auto") + if dtype is not None: + # Convert to `torch.dtype` if an str is passed + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(torch, dtype) + if dtype != "auto" and not isinstance(dtype, torch.dtype): + raise ValueError( + f"Invalid `dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." + ) + model_init_kwargs["dtype"] = dtype + model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto") + + if isinstance(model, str): + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + if isinstance(model, PeftModel): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " + "merge and unload the existing adapter, save the resulting base model, and then pass that base " + "model along with the new `peft_config` to the trainer." + ) + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True + + # For models that use gradient_checkpointing, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif args.is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = args.is_encoder_decoder + + if self.is_encoder_decoder: + self.decoder_start_token_id = model.config.decoder_start_token_id + self.pad_token_id = model.config.pad_token_id + + if processing_class is None: + raise ValueError("processing_class must be specified to tokenize a ORPO dataset.") + if args.max_length is None: + logger.warning( + "`max_length` is not set in the ORPOConfig's init" + " it will default to `512` by default, but you should do it yourself in the future.", + ) + max_length = 512 + else: + max_length = args.max_length + + if args.max_completion_length is None and self.is_encoder_decoder: + logger.warning( + "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init" + " it will default to `128` by default, but you should do it yourself in the future.", + ) + self.max_completion_length = 128 + else: + self.max_completion_length = args.max_completion_length + + if data_collator is None: + data_collator = DPODataCollatorWithPadding( + pad_token_id=processing_class.pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + logger.warning( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" + " we have set it for you, but you should do it yourself in the future.", + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + # Disable dropout in the model and reference model + if args.disable_dropout: + disable_dropout_in_model(model) + + self.max_length = max_length + self.generate_during_eval = args.generate_during_eval + self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id + self.truncation_mode = args.truncation_mode + self.processing_class = processing_class + + self.beta = args.beta + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) + if self.aux_loss_enabled and self.aux_loss_coef == 0.0: + logger.warning( + "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " + "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " + "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " + "loss.", + ) + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().main_process_first(): + # Extract the prompt if needed, and apply the chat template if needed + train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) + train_dataset = train_dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc + ) + train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) + eval_dataset = eval_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + ) + eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + def build_tokenized_answer(self, prompt, answer): + """ + Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a + + b)[len(enc(a)):]`. Reference: + https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + """ + + full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False) + prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"] + + answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] + answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] + + # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` + full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) + + # Prepare input tokens for token by token comparison + full_input_ids = np.array(full_tokenized["input_ids"]) + + if len(full_input_ids) != len(full_concat_input_ids): + raise ValueError("Prompt input ids and answer input ids should have the same length.") + + # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens + # can be merged together when tokenizing prompt+answer. This could result + # on the last token from the prompt being different when tokenized on its own + # vs when done as prompt+answer. + response_token_ids_start_idx = len(prompt_input_ids) + + # If tokenized prompt is different than both prompt+answer, then it means the + # last token has changed due to merging. + if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: + response_token_ids_start_idx -= 1 + + prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] + prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] + + if len(prompt_input_ids) != len(prompt_attention_mask): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] + answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] + + return dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + input_ids=answer_input_ids, + attention_mask=answer_attention_mask, + ) + + def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None) -> dict: + """Tokenize a single row from a ORPO specific dataset. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + + chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long, + we truncate the chosen/rejected. + + We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length + of the prompt and the chosen/rejected response, with `-100` for the prompt tokens. + """ + batch = {} + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + if not self.is_encoder_decoder: + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + prompt_tokens = self.processing_class(prompt, add_special_tokens=False) + prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} + + if not isinstance(chosen, str): + raise ValueError(f"chosen should be an str but got {type(chosen)}") + chosen_tokens = self.build_tokenized_answer(prompt, chosen) + + if not isinstance(rejected, str): + raise ValueError(f"rejected should be an str but got {type(rejected)}") + rejected_tokens = self.build_tokenized_answer(prompt, rejected) + + # Last prompt token might get merged by tokenizer and + # it should not be included for generation if that happens + prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) + + chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) + rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) + prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) + + for k, v in prompt_tokens.items(): + prompt_tokens[k] = v[:prompt_len_input_ids] + + # Make sure prompts only have one different token at most an + # and length only differs by 1 at most + num_diff_tokens = sum( + a != b + for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=True) + ) + num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) + if num_diff_tokens > 1 or num_diff_len > 1: + raise ValueError( + "Chosen and rejected prompt_input_ids might only differ on the " + "last token due to tokenizer merge ops." + ) + + # add BOS token to head of prompt. Avoid adding if it's already there + prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed( + self.processing_class.bos_token_id, + prompt_len_input_ids, + prompt_tokens, + chosen_prompt_len_input_ids, + chosen_tokens, + rejected_prompt_len_input_ids, + rejected_tokens, + ) + + # add EOS token to end of answer. Avoid adding if it's already there + chosen_tokens, rejected_tokens = add_eos_token_if_needed( + self.processing_class.eos_token_id, chosen_tokens, rejected_tokens + ) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + # if combined sequence is too long, truncate the response + for answer_tokens in [chosen_tokens, rejected_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + for k in ["input_ids", "attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_length - longer_response_length] + + # Create labels + chosen_sequence_tokens = { + k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] + } + rejected_sequence_tokens = { + k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [-100] * len( + chosen_tokens["prompt_input_ids"] + ) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [-100] * len( + rejected_tokens["prompt_input_ids"] + ) + + for k, toks in { + "chosen_": chosen_sequence_tokens, + "rejected_": rejected_sequence_tokens, + "": prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}{type_key}"] = tokens + + else: + chosen_tokens = self.processing_class( + chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + rejected_tokens = self.processing_class( + rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + prompt_tokens = self.processing_class(prompt, add_special_tokens=True) + + batch["chosen_labels"] = chosen_tokens["input_ids"] + batch["rejected_labels"] = rejected_tokens["input_ids"] + batch["prompt_input_ids"] = prompt_tokens["input_ids"] + batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + + if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): + batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["rejected_labels"]) + ) + batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["chosen_labels"]) + ) + + if is_torch_xla_available(): + # Pad the sequences to global max_length to avoid TorchXLA recompilation + for k in batch: + if "labels" in k or self.is_encoder_decoder: + pad_value = -100 + elif k.endswith("_input_ids"): + pad_value = self.padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k])) + return batch + + @staticmethod + def concatenated_inputs( + batch: dict[str, list | torch.LongTensor], + is_encoder_decoder: bool = False, + padding_value: int = 0, + device: torch.device | None = None, + ) -> dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: + A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors + of shape (batch_size, sequence_length). + is_encoder_decoder: + Whether the model is an encoder-decoder model. + padding_value: + The padding value to use for the concatenated inputs_ids. + device: + The device for the concatenated inputs. + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + concatenated_batch = {} + + if is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = -100 + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = -100 + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(device=device) + + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1).to(device=device) + ) + + return concatenated_batch + + def odds_ratio_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: + Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: + Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the ORPO + loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for + the chosen and rejected responses, respectively. The log odds ratio of the chosen responses over the + rejected responses ratio for logging purposes. The `log(sigmoid(log_odds_chosen))` for logging purposes. + """ + + # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) + policy_chosen_logps = policy_chosen_logps.float() + policy_rejected_logps = policy_rejected_logps.float() + log_odds = (policy_chosen_logps - policy_rejected_logps) - ( + log1mexp(policy_chosen_logps) - log1mexp(policy_rejected_logps) + ) + ratio = F.logsigmoid(log_odds) + losses = self.beta * ratio + + chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach() + rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() + + return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds) + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + is_encoder_decoder: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: + Labels for which to compute the log probabilities. Label tokens with a value of `-100` are ignored. + Shape: (batch_size, sequence_length) + average_log_prob: + If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the + log probabilities of the (non-masked) tokens. + is_encoder_decoder: Whether the model is an encoder-decoder model. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the + given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != -100 + + # dummy token; we'll ignore the losses on these tokens later + labels = torch.where(labels == -100, 0, labels) + + per_token_logps = selective_log_softmax(logits, labels) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def concatenated_forward( + self, model: nn.Module, batch: dict[str, list | torch.LongTensor] + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + padding_value=self.padding_value, + device=self.accelerator.device, + ) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), + } + if self.is_encoder_decoder + else {} + ) + + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + outputs = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) + all_logits = outputs.logits + + def cross_entropy_loss(logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + if self.is_encoder_decoder: + labels = concatenated_batch["concatenated_labels"].clone() + else: + labels = concatenated_batch["concatenated_input_ids"].clone() + attention_mask = concatenated_batch["concatenated_attention_mask"] + labels = torch.where(attention_mask == 1, labels, -100) + # orpo chosen nll loss is computed over the full prompt and response + chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) + + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=True, + is_encoder_decoder=self.is_encoder_decoder, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + if not self.is_encoder_decoder: + chosen_logits = all_logits[:len_chosen, :-1, :] + rejected_logits = all_logits[len_chosen:, :-1, :] + else: + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss) + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) + + def get_batch_loss_metrics( + self, + model, + batch: dict[str, list | torch.LongTensor], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + forward_output = self.concatenated_forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( + policy_chosen_logps, policy_rejected_logps + ) + # full ORPO loss + loss = policy_nll_loss - losses.mean() + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean() + metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics( + chosen_rewards - rejected_rewards + ).mean() + metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean() + metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean() + metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics( + policy_rejected_logits.detach().mean() + ).mean() + metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics( + policy_chosen_logits.detach().mean() + ).mean() + metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean() + metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean() + metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean() + if is_torch_xla_available(): + xm.mark_step() # needed because .item() calls + for k, v in metrics.items(): + metrics[k] = v.item() + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss + + return loss, metrics + + def compute_loss( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + return_outputs=False, + num_items_in_batch=None, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: + compute_loss_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with compute_loss_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + + # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: + loss = loss.to(self.args.device) + + # force log the metrics + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with generate_context_manager: + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + return policy_output_decoded + + def prediction_step( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + prediction_loss_only: bool, + ignore_keys: list[str] | None = None, + ): + if not self.use_dpo_data_collator: + logger.warning( + "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with torch.no_grad(), prediction_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") + + # force log the metrics + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + logits = [v for k, v in logits_dict.items() if k not in ignore_keys] + logits = torch.tensor(logits, device=self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: bool | None = None, + ignore_keys: list[str] | None = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by + `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded = self.generate_from_model(self.model, random_batch) + + table = pd.DataFrame( + columns=["Prompt", "Policy"], + data=[ + [prompt, pol[len(prompt) :]] + for prompt, pol in zip(random_batch["prompt"], policy_output_decoded, strict=True) + ], + ) + if "wandb" in self.args.report_to: + wandb.log({"game_log": wandb.Table(data=table)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="game_log.csv", + table=table, + ) + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`dict[str, float]`): + The values to log. + start_time (`float`, *optional*): + Start time of the training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + + def _shift_right(self, input_ids): + if self.decoder_start_token_id is None: + raise ValueError( + "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = self.decoder_start_token_id + + if self.pad_token_id is None: + raise ValueError("model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id) + + return shifted_input_ids + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/ICL/RL/trl_source/trl/experimental/papo/__init__.py b/ICL/RL/trl_source/trl/experimental/papo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62fb8105cb446c197570023e388299e1bd96bbc4 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/papo/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .papo_config import PAPOConfig +from .papo_trainer import PAPOTrainer diff --git a/ICL/RL/trl_source/trl/experimental/papo/papo_config.py b/ICL/RL/trl_source/trl/experimental/papo/papo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e142a94a9cdc3128456b6127436c3befd6bc2a93 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/papo/papo_config.py @@ -0,0 +1,73 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Literal + +from ...trainer.grpo_config import GRPOConfig + + +@dataclass +class PAPOConfig(GRPOConfig): + """ + Configuration class for PAPOTrainer. + + PAPO (Perception-Aware Policy Optimization) extends GRPO/DAPO for multimodal reasoning by adding an implicit + perception loss and double entropy regularization. + + Args: + perception_loss_weight (`float`, *optional*, defaults to `0.1`): + gamma Weight coefficient for the perception loss term. This encourages the model to be sensitive to visual + changes. + + mask_ratio (`float`, *optional*, defaults to `0.3`): + Ratio of the image to mask when computing perception loss. + + mask_type (`Literal["random", "patch", "grid"]`, *optional*, defaults to `"random"`): + Type of masking strategy to use. + + der_loss_weight1 (`float`, *optional*, defaults to `0.03`): + eta1 Weight coefficient for the Double Entropy Regularization (DER) term. This term encourages confident + predictions with original images (low entropy) and uncertain predictions with masked images (high entropy). + + der_loss_weight2 (`float`, *optional*, defaults to `0.03`): + eta2 Weight coefficient for the Double Entropy Regularization (DER) term. This term encourages confident + predictions with original images (low entropy) and uncertain predictions with masked images (high entropy). + + loss_type (`Literal["grpo", "dapo"]`, inherited from GRPOConfig): + Base loss type to use. Set to "grpo" for PAPO-G or "dapo" for PAPO-D. + """ + + perception_loss_weight: float = 0.1 + mask_ratio: float = 0.3 + mask_type: Literal["random", "patch", "grid"] = "random" + + # Added for Double Entropy Regularization + der_loss_weight1: float = 0.03 + der_loss_weight2: float = 0.03 + + def __post_init__(self): + super().__post_init__() + + # Validation + if not 0.0 <= self.mask_ratio <= 1.0: + raise ValueError(f"mask_ratio must be between 0 and 1, got {self.mask_ratio}") + + if self.der_loss_weight1 < 0 or self.der_loss_weight2 < 0: + raise ValueError( + f"der_loss_weight1 and der_loss_weight2 must be non-negative, got {self.der_loss_weight1} and {self.der_loss_weight2}" + ) + + if self.mask_type not in ["random", "patch", "grid"]: + raise ValueError(f"mask_type must be one of ['random', 'patch', 'grid'], got {self.mask_type}") diff --git a/ICL/RL/trl_source/trl/experimental/papo/papo_trainer.py b/ICL/RL/trl_source/trl/experimental/papo/papo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..931fee3efe993e95b6564dab28987e125e8e96be --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/papo/papo_trainer.py @@ -0,0 +1,354 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import textwrap + +import torch +from datasets import Dataset, IterableDataset +from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin + +from ...trainer.grpo_trainer import GRPOTrainer, RewardFunc +from ...trainer.utils import nanmax, nanmin +from .papo_config import PAPOConfig + + +class PAPOTrainer(GRPOTrainer): + """ + Trainer for Perception-Aware Policy Optimization (PAPO). + + PAPO extends GRPO/DAPO for multimodal reasoning by adding an implicit perception loss that encourages the model to + better utilize visual information. The key innovation is computing KL divergence between model outputs on original + vs. corrupted (masked) images. + + Two variants are supported: + - PAPO-G: PAPO + GRPO (use loss_type="grpo") + - PAPO-D: PAPO + DAPO (use loss_type="dapo") + + Example: + + ```python + from datasets import load_dataset + from trl.experimental.papo import PAPOTrainer, PAPOConfig + + dataset = load_dataset("your-vlm-dataset", split="train") + + + def reward_func(completions, **kwargs): + # Your reward function for multimodal reasoning + return [compute_reward(c) for c in completions] + + + # PAPO-G + config = PAPOConfig( + loss_type="grpo", # Use GRPO as base + perception_loss_weight=0.1, + mask_ratio=0.3, + ) + + # PAPO-G + config = PAPOConfig( + loss_type="dapo", # Use DAPO as base + perception_loss_weight=0.1, + mask_ratio=0.3, + ) + + trainer = PAPOTrainer( + model="Qwen/Qwen2-VL-2B-Instruct", + reward_funcs=reward_func, + args=config, + train_dataset=dataset, + ) + + trainer.train() + ``` + + Args: + model (`Union[str, PreTrainedModel]`): + Model to be trained (must be a vision-language model). + reward_funcs (`Union[RewardFunc, list[RewardFunc]]`): + Reward functions for computing rewards (same as GRPO). + args ([`PAPOConfig`], *optional*, defaults to `None`): + Configuration for this trainer. If `None`, a default configuration is used. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. Must include "prompt" and "image" columns. + eval_dataset: Same requirements as train_dataset. + processing_class: Processing class (tokenizer/processor) for the model. + reward_processing_classes: Processing classes for reward models. + callbacks: Training callbacks. + optimizers: Optimizer and scheduler tuple. + peft_config: PEFT configuration if using parameter-efficient fine-tuning. + """ + + _tag_names = ["trl", "papo"] + _name = "PAPO" + _paper = { + "title": "Perception-Aware Policy Optimization for Multimodal Reasoning", + "id": "2507.06448", + # docstyle-ignore + "citation": textwrap.dedent( + """\ + @misc{wang2025perceptionawarepolicyoptimizationmultimodal, + title = {{Perception-Aware Policy Optimization for Multimodal Reasoning}}, + author = {Zhenhailong Wang and Xuehang Guo and Sofia Stoica and Haiyang Xu and Hongru Wang and Hyeonjeong Ha and Xiusi Chen and Yangyi Chen and Ming Yan and Fei Huang and Heng Ji}, + year = 2025, + url = {https://arxiv.org/abs/2507.06448}, + archivePrefix= {arXiv}, + eprint = {2507.06448}, + primaryClass = {cs.CL} + }""" + ), + } + + def __init__( + self, + model: str | PreTrainedModel, + reward_funcs: RewardFunc | list[RewardFunc], + args: PAPOConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, + callbacks=None, + optimizers=(None, None), + peft_config=None, + ): + # Initialize with default PAPO config if not provided + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = PAPOConfig(f"{model_name}-PAPO") + + # Store PAPO-specific parameters + self.perception_loss_weight = args.perception_loss_weight + self.mask_ratio = args.mask_ratio + self.mask_type = args.mask_type + self.der_loss_weight1 = args.der_loss_weight1 + self.der_loss_weight2 = args.der_loss_weight2 + + # Initialize parent GRPO trainer + super().__init__( + model=model, + reward_funcs=reward_funcs, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + reward_processing_classes=reward_processing_classes, + callbacks=callbacks, + optimizers=optimizers, + peft_config=peft_config, + ) + + def _mask_image(self, pixel_values: torch.Tensor, mask_ratio: float = None) -> torch.Tensor: + """ + Apply masking to image pixel values. + + Args: + pixel_values: Image tensor of shape (B, C, H, W) or (B, N, C, H, W) for multi-image + mask_ratio: Ratio of image to mask (defaults to self.mask_ratio) + + Returns: + Masked pixel values tensor + """ + if mask_ratio is None: + mask_ratio = self.mask_ratio + + masked_pixel_values = pixel_values.clone() + + if self.mask_type == "random": + # Random pixel masking + mask = torch.rand_like(pixel_values) > mask_ratio + masked_pixel_values = masked_pixel_values * mask + + elif self.mask_type == "patch": + # Patch-based masking (mask contiguous regions) + B = pixel_values.shape[0] + if pixel_values.ndim == 4: # (B, C, H, W) + C, H, W = pixel_values.shape[1:] + for i in range(B): + # Calculate patch size to mask + patch_h = int(H * mask_ratio**0.5) + patch_w = int(W * mask_ratio**0.5) + # Random starting position + start_h = random.randint(0, max(0, H - patch_h)) + start_w = random.randint(0, max(0, W - patch_w)) + # Apply mask + masked_pixel_values[i, :, start_h : start_h + patch_h, start_w : start_w + patch_w] = 0 + + elif pixel_values.ndim == 5: # (B, N, C, H, W) for multi-image + N, C, H, W = pixel_values.shape[1:] + for i in range(B): + for n in range(N): + patch_h = int(H * mask_ratio**0.5) + patch_w = int(W * mask_ratio**0.5) + start_h = random.randint(0, max(0, H - patch_h)) + start_w = random.randint(0, max(0, W - patch_w)) + masked_pixel_values[i, n, :, start_h : start_h + patch_h, start_w : start_w + patch_w] = 0 + + elif self.mask_type == "grid": + # Grid-based masking (mask regular grid cells) + if pixel_values.ndim == 4: + C, H, W = pixel_values.shape[1:] + grid_size = int((1 / mask_ratio) ** 0.5) + cell_h, cell_w = H // grid_size, W // grid_size + + for i in range(grid_size): + for j in range(grid_size): + if random.random() < mask_ratio: + masked_pixel_values[:, :, i * cell_h : (i + 1) * cell_h, j * cell_w : (j + 1) * cell_w] = 0 + + return masked_pixel_values + + def _compute_loss(self, model, inputs): + # >>> 1. GRPO loss + # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + # Compute the per_token_logps and the entropy at each position in the completion + per_token_logps, entropies = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=True, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), + pixel_attention_mask=inputs.get("pixel_attention_mask"), + image_sizes=inputs.get("image_sizes"), + ) + + if self.top_entropy_quantile < 1.0: + entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile) + else: + entropy_mask = None + + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) + + # Compute the loss + advantages = inputs["advantages"] + # When using num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps + # old_per_token_logps == per_token_logps, so we can skip it's computation + # (see _generate_and_score_completions) and use per_token_logps.detach() instead. + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + + log_ratio = per_token_logps - old_per_token_logps + if self.importance_sampling_level == "token": + log_importance_weights = log_ratio + elif self.importance_sampling_level == "sequence": + log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + log_importance_weights = log_importance_weights.unsqueeze(-1) + else: + raise ValueError( + f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " + "and 'sequence'." + ) + # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on + # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) + + coef_1 = torch.exp(log_importance_weights) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + + # Two-sided clipping + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) + + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if entropy_mask is not None: + per_token_loss = per_token_loss * entropy_mask + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + if self.loss_type == "grpo": + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + loss = loss / self.current_gradient_accumulation_steps + elif self.loss_type == "dapo": + normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes + loss = (per_token_loss * completion_mask).sum() / normalizer + else: + raise ValueError(f"Unknown loss type: {self.loss_type}") + # >>> 2. Implicit Perception Loss + inputs["pixel_values"] = self._mask_image(inputs["pixel_values"], self.mask_ratio) + mask_img_per_token_logps, mask_img_entropies = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=True, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), + pixel_attention_mask=inputs.get("pixel_attention_mask"), + image_sizes=inputs.get("image_sizes"), + ) + perception_kl = ( + torch.exp(mask_img_per_token_logps - per_token_logps) - (mask_img_per_token_logps - per_token_logps) - 1 + ) + perception_kl = torch.clamp(perception_kl, min=0.0, max=0.2) + perception_loss = self.perception_loss_weight * perception_kl + + # >>> 3. Double Entropy Loss + der_loss = self.der_loss_weight1 * entropies + self.der_loss_weight2 * mask_img_entropies + + # PAPO Loss + loss = (loss - perception_loss + der_loss).mean() + # Log the metrics + mode = "train" if self.model.training else "eval" + + completion_token_count = completion_mask.sum().clamp(min=1.0) + + def masked_batch_mean(x): + if x.shape[1] == 1: # when importance_sampling_level == "sequence" + return x.mean() + else: + return (x * completion_mask).sum() / completion_token_count + + if self.beta != 0.0: + mean_kl = masked_batch_mean(per_token_kl) + self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) + + mean_entropy = masked_batch_mean(entropies) + self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) + + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = masked_batch_mean(is_low_clipped.float()) + high_clip = masked_batch_mean(is_high_clipped.float()) + clip_ratio = masked_batch_mean(is_region_clipped.float()) + + gathered_low_clip = self.accelerator.gather(low_clip) + self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) + gathered_high_clip = self.accelerator.gather(high_clip) + self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) + gathered_clip_ratio = self.accelerator.gather(clip_ratio) + self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + return loss diff --git a/ICL/RL/trl_source/trl/experimental/ppo/__init__.py b/ICL/RL/trl_source/trl/experimental/ppo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..09451e973bea4ae1432fe16276907644cd67466c --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/ppo/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling_value_head import ( + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, + PreTrainedModelWrapper, +) +from .ppo_config import PPOConfig +from .ppo_trainer import PPOTrainer + + +__all__ = [ + "AutoModelForCausalLMWithValueHead", + "AutoModelForSeq2SeqLMWithValueHead", + "PreTrainedModelWrapper", + "PPOConfig", + "PPOTrainer", +] diff --git a/ICL/RL/trl_source/trl/experimental/ppo/modeling_value_head.py b/ICL/RL/trl_source/trl/experimental/ppo/modeling_value_head.py new file mode 100644 index 0000000000000000000000000000000000000000..b2307b7562fbd06134958676f43687caf585cc7b --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/ppo/modeling_value_head.py @@ -0,0 +1,1007 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os + +import torch +import torch.nn as nn +from accelerate import PartialState +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import ( + EntryNotFoundError, + HFValidationError, + LocalEntryNotFoundError, + RepositoryNotFoundError, +) +from safetensors.torch import load_file as safe_load_file +from transformers import ( + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + PreTrainedModel, + is_torch_npu_available, + is_torch_xpu_available, +) +from transformers.utils import is_peft_available + + +if is_peft_available(): + from peft import ( + PeftConfig, + PeftModel, + PeftModelForCausalLM, + PeftModelForSeq2SeqLM, + PromptLearningConfig, + get_peft_model, + prepare_model_for_kbit_training, + ) + + +class PreTrainedModelWrapper(nn.Module): + """ + Wrapper for a [`~transformers.PreTrainedModel`] implemented as a standard PyTorch [`torch.nn.Module`]. + + This class provides a compatibility layer that preserves the key attributes and methods of the original + [`~transformers.PreTrainedModel`], while exposing a uniform interface consistent with PyTorch modules. It enables + seamless integration of pretrained Transformer models into custom training, evaluation, or inference workflows. + + Attributes: + pretrained_model ([`~transformers.PreTrainedModel`]): + The model to be wrapped. + parent_class ([`~transformers.PreTrainedModel`]): + The parent class of the model to be wrapped. + supported_args (`list`): + The list of arguments that are supported by the wrapper class. + """ + + transformers_parent_class = None + supported_args = None + supported_modules = ("v_head",) + supported_rm_modules = ("score",) + supported_pretrained_model_architectures = ( + (PreTrainedModel) + if not is_peft_available() + else (PreTrainedModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM) + ) + + def __init__( + self, pretrained_model=None, score_module=None, supports_rm_adapter=False, rm_adapter_name=None, **kwargs + ): + super().__init__() + self.pretrained_model = pretrained_model + + self.config = pretrained_model.config + self.prepare_inputs_for_generation = pretrained_model.prepare_inputs_for_generation + self.is_loaded_in_8bit = getattr(pretrained_model, "is_loaded_in_8bit", False) + self.is_loaded_in_4bit = getattr(pretrained_model, "is_loaded_in_4bit", False) + self.is_sequential_parallel = False + + if hasattr(pretrained_model, "gradient_checkpointing_disable"): + self.gradient_checkpointing_disable = pretrained_model.gradient_checkpointing_disable + + if hasattr(pretrained_model, "gradient_checkpointing_enable"): + self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable + + if hasattr(pretrained_model, "enable_input_require_grads"): + self.enable_input_require_grads = pretrained_model.enable_input_require_grads + + self.supports_rm_adapter = supports_rm_adapter + self.rm_adapter_name = rm_adapter_name + self.policy_adapter_name = "default" + if score_module is not None: + self.score = score_module + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Instantiates a new model from a pretrained model from `transformers`. The pretrained model is loaded using the + `from_pretrained` method of the [`~transformers.PreTrainedModel`] class. The arguments that are specific to the + [`~transformers.PreTrainedModel`] class are passed along this method and filtered out from the `kwargs` + argument. + + Args: + pretrained_model_name_or_path (`str` or [`~transformers.PreTrainedModel`]): + The path to the pretrained model or its name. + *model_args (`list`, *optional*): + Additional positional arguments passed along to the underlying model's `from_pretrained` method. + **kwargs (`dict`, *optional*): + Additional keyword arguments passed along to the underlying model's `from_pretrained` method. We also + pre-process the kwargs to extract the arguments that are specific to the + [`~transformers.PreTrainedModel`] class and the arguments that are specific to trl models. The kwargs + also support `prepare_model_for_kbit_training` arguments from `peft` library. + """ + if kwargs is not None: + peft_config = kwargs.pop("peft_config", None) + reward_adapter = kwargs.pop("reward_adapter", None) + reward_adapter_name = kwargs.pop("reward_adapter_name", "reward_adapter") + is_trainable = kwargs.pop("is_trainable", False) + trl_model_args, pretrained_kwargs, peft_quantization_kwargs = cls._split_kwargs(kwargs) + token = pretrained_kwargs.get("token", None) + else: + peft_config = None + is_trainable = False + trl_model_args = {} + pretrained_kwargs = {} + peft_quantization_kwargs = {} + token = None + + if reward_adapter is not None and not isinstance(reward_adapter, str): + raise ValueError( + "The `reward_adapter` argument should be a string representing the name of local path or the Hub id to the Reward Modeling adapter." + ) + + is_peft_model = False + + current_device = cls._get_current_device() + if isinstance(pretrained_model_name_or_path, str): + quantization_config = pretrained_kwargs.get("quantization_config", None) + if quantization_config is not None: + is_loaded_in_8bit = getattr(quantization_config, "load_in_8bit", False) + is_loaded_in_4bit = getattr(quantization_config, "load_in_4bit", False) + else: + is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False + is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False + else: + is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False) + is_loaded_in_4bit = getattr(pretrained_model_name_or_path, "is_loaded_in_4bit", False) + + if (is_loaded_in_8bit or is_loaded_in_4bit) and "device_map" not in pretrained_kwargs: + # warn users + logging.warning( + "The `device_map` argument is not provided. We will override the device_map argument." + " to set the entire" + " model on the current device. If you want to set the model on multiple devices, please provide" + " a custom `device_map` argument." + ) + pretrained_kwargs["device_map"] = {"": current_device} + + if is_peft_available() and peft_config is not None and not isinstance(peft_config, PeftConfig): + raise ValueError("The `peft_config` argument should be an instance of `peft.PeftConfig` class.") + + # First, load the pre-trained model using the parent-class + # either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM` + if isinstance(pretrained_model_name_or_path, str): + if is_peft_available(): + try: + # If there is a trained peft adapter in the hub, load its config. + remote_adapter_config = hf_hub_download( + pretrained_model_name_or_path, + "adapter_config.json", + token=token, + ) + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + remote_adapter_config = None + else: + remote_adapter_config = None + + local_adapter_present = os.path.exists(os.path.join(pretrained_model_name_or_path, "adapter_config.json")) + + if (local_adapter_present or remote_adapter_config is not None) and is_peft_available(): + if peft_config is not None: + logging.warning( + "`peft_config` argument ignored since a peft config file was found in " + f"{pretrained_model_name_or_path}" + ) + + # Load the trained peft adapter config + if local_adapter_present: + trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path) + else: + remote_adapter_dir = os.path.dirname(remote_adapter_config) + trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_dir) + + # Load the pretrained base model + pretrained_model = cls.transformers_parent_class.from_pretrained( + trained_adapter_config.base_model_name_or_path, *model_args, **pretrained_kwargs + ) + + # Wrap the pretrained model with the trained peft adapter + pretrained_model = PeftModel.from_pretrained( + pretrained_model, pretrained_model_name_or_path, is_trainable=is_trainable, token=token + ) + logging.info("Trained peft adapter loaded") + else: + pretrained_model = cls.transformers_parent_class.from_pretrained( + pretrained_model_name_or_path, *model_args, **pretrained_kwargs + ) + + if peft_config is not None: + # Initialize a new peft adapter with the given config + if is_loaded_in_8bit or is_loaded_in_4bit: + pretrained_model = prepare_model_for_kbit_training( + pretrained_model, + **peft_quantization_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) + logging.info("peft adapter initialised") + + elif isinstance(pretrained_model_name_or_path, cls.supported_pretrained_model_architectures): + pretrained_model = pretrained_model_name_or_path + + if peft_config is not None and isinstance(pretrained_model, PreTrainedModel): + # Initialize a new peft adapter with the given config + if is_loaded_in_8bit or is_loaded_in_4bit: + pretrained_model = prepare_model_for_kbit_training( + pretrained_model, + **peft_quantization_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) + logging.info("peft adapter initialised") + else: + raise ValueError( + "pretrained_model_name_or_path should be a string or a PreTrainedModel, " + f"but is {type(pretrained_model_name_or_path)}" + ) + + if is_peft_available(): + if isinstance(pretrained_model, PeftModel): + is_peft_model = True + # for backward compatibility + if hasattr(pretrained_model, "active_peft_config") and isinstance( + pretrained_model.active_peft_config, PromptLearningConfig + ): + raise ValueError("PromptLearningConfig is not supported for PPO training.") + + # Add reward modeling adapter if specified + if not is_peft_model and reward_adapter is not None: + raise ValueError("reward_adapter can only be used with a PeftModel. ") + elif is_peft_model and reward_adapter is not None: + score_module = cls.add_and_load_reward_modeling_adapter( + pretrained_model, reward_adapter, reward_adapter_name, token=token + ) + multi_adapter_args = { + "score_module": score_module, + "supports_rm_adapter": True, + "rm_adapter_name": reward_adapter_name, + } + else: + multi_adapter_args = {"supports_rm_adapter": False} + + # Then, create the full model by instantiating the wrapper class + model = cls(pretrained_model, **multi_adapter_args, **trl_model_args) + + # if resume_training, load the state_dict again - this is ok since the + # state_dict is removed from the model after loading it. + is_resuming_training = True + if isinstance(pretrained_model_name_or_path, str): + safe_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors") + filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") + + sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") + safe_sharded_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") + is_sharded = False + use_safe = os.path.exists(safe_filename) + + if not (os.path.exists(filename) or os.path.exists(safe_filename)): + # Try with `pytorch_model.bin` + filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + sharded_index_filename, + token=token, + ) + # Try with safetensors + if filename is None and files_to_download is None: + safe_filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + safe_sharded_index_filename, + token=token, + model_name="model.safetensors", + model_index_name="model.safetensors.index.json", + ) + use_safe = True + else: + use_safe = False + + loading_func = safe_load_file if use_safe else torch.load + load_kwargs = {} if use_safe else {"map_location": "cpu", "weights_only": True} + + if is_resuming_training: + if is_sharded: + # download each file and add it to the state_dict + state_dict = {} + + for shard_file in files_to_download: + filename = hf_hub_download( + pretrained_model_name_or_path, + shard_file, + token=token, + ) + state_dict.update(loading_func(filename, **load_kwargs)) + else: + state_dict = loading_func(filename if not use_safe else safe_filename, **load_kwargs) + + else: + state_dict = pretrained_model_name_or_path.state_dict() + + model.is_peft_model = is_peft_model + model.current_device = current_device + + if is_resuming_training: + model.post_init(state_dict=state_dict) + + return model + + @classmethod + def _get_checkpoint_from_hub( + cls, + pretrained_model, + pretrained_model_name_or_path, + index_filename, + token=None, + model_name="pytorch_model.bin", + model_index_name="pytorch_model.bin.index.json", + ): + files_to_download = None + filename = None + is_resuming_training = True + is_sharded = False + + try: + filename = hf_hub_download( + pretrained_model_name_or_path, + model_name, + token=token, + ) + # sharded + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + if os.path.exists(index_filename): + index_file_name = index_filename + else: + try: + index_file_name = hf_hub_download( + pretrained_model_name_or_path, + model_index_name, + token=token, + ) + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + # not continue training, do not have v_head weight + is_resuming_training = False + logging.warning( + f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', " + f"and no v_head weight is found. This IS expected if you are not resuming PPO training." + ) + # load json + if is_resuming_training: + with open(index_file_name) as f: + index = json.load(f) + # check filename with `v_head` or any known extra module: + files_to_download = set() + for k, v in index["weight_map"].items(): + if any(module in k for module in cls.supported_modules): + files_to_download.add(v) + is_sharded = True + + return filename, files_to_download, is_sharded, is_resuming_training + + @classmethod + def _get_current_device(cls): + r""" + Get the current device. For GPU & XPU, we return the local process index using the `accelerate.PartialState` + object to handle corner cases when running scripts in distributed environments. + + Returns: + current_device (`int | str`): + The current device. + """ + state = PartialState() + if torch.cuda.is_available() or is_torch_xpu_available(): + return state.local_process_index + elif is_torch_npu_available(): + return f"npu:{state.local_process_index}" + else: + return "cpu" + + @classmethod + def _split_kwargs(cls, kwargs): + """ + Separate the kwargs from the arguments that we support inside `supported_args` and the ones that we don't. + """ + check_peft_kwargs = False + + if is_peft_available(): + from peft import prepare_model_for_kbit_training + + check_peft_kwargs = True + + supported_kwargs = {} + unsupported_kwargs = {} + peft_kwargs = {} + + for key, value in kwargs.items(): + if key in cls.supported_args: + supported_kwargs[key] = value + else: + unsupported_kwargs[key] = value + + if check_peft_kwargs: + if key in prepare_model_for_kbit_training.__code__.co_varnames: + peft_kwargs[key] = value + if key in unsupported_kwargs: + unsupported_kwargs.pop(key) + + return supported_kwargs, unsupported_kwargs, peft_kwargs + + @classmethod + def add_and_load_reward_modeling_adapter( + cls, pretrained_model, adapter_model_id, adapter_name="reward_model_adapter", token=None + ): + r""" + Add and load a reward modeling adapter. This method can only be used if the model is a `PeftModel` and if you + have initialized the model with the `reward_modeling_adapter_id` argument, pointing to the id of the reward + modeling adapter. The latest needs also to contain the score head in order to produce the reward. + """ + pretrained_model.load_adapter(adapter_model_id, adapter_name, is_trainable=False) + pretrained_model.train() + + filename = os.path.join(adapter_model_id, "adapter_model.bin") + safe_loading = False + if not os.path.exists(filename): + try: + local_filename = hf_hub_download( + adapter_model_id, + "adapter_model.bin", + token=token, + ) + except Exception: + filename = os.path.join(adapter_model_id, "adapter_model.safetensors") + safe_loading = True + if not os.path.exists(filename): + try: + local_filename = hf_hub_download( + adapter_model_id, + "adapter_model.safetensors", + token=token, + ) + except Exception as exc: + raise ValueError( + "Could not find adapter model in the Hub, make sure you have the correct adapter model id." + ) from exc + else: + local_filename = filename + else: + local_filename = filename + + loading_func = safe_load_file if safe_loading else torch.load + load_kwargs = {} if safe_loading else {"map_location": "cpu", "weights_only": True} + + adapter_state_dict = loading_func(local_filename, **load_kwargs) + + for score_name_candidate in cls.supported_rm_modules: + if any(score_name_candidate in name for name in adapter_state_dict.keys()): + score_name = score_name_candidate + # we have found the correct head name and can break + break + + score_dict = {} + + for name, param in adapter_state_dict.items(): + if score_name in name: + key_name = ".".join(name.split(".")[-1:]) + score_dict[key_name] = param.to(cls._get_current_device()) + + num_labels, hidden_dim = score_dict["weight"].shape + has_bias = any("bias" in name for name in adapter_state_dict.keys()) + + score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to( + device=cls._get_current_device(), + dtype=pretrained_model.dtype, + ) + score.load_state_dict(score_dict) + for param in score.parameters(): + param.requires_grad = False + + return score + + def push_to_hub(self, *args, **kwargs): + r""" + Push the pretrained model to the hub. This method is a wrapper around + [`~transformers.PreTrainedModel.push_to_hub`]. Please refer to the documentation of + [`~transformers.PreTrainedModel.push_to_hub`] for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's `push_to_hub` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's `push_to_hub` method. + """ + raise NotImplementedError + + def save_pretrained(self, *args, **kwargs): + r""" + Save the pretrained model to a directory. This method is a wrapper around + [`~transformers.PreTrainedModel.save_pretrained`]. Please refer to the documentation of + [`~transformers.PreTrainedModel.save_pretrained`] for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's `save_pretrained` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's `save_pretrained` method. + """ + state_dict = kwargs.get("state_dict") + if state_dict is None: + state_dict = self.state_dict() + kwargs["state_dict"] = state_dict + + # if it is a peft model only save the `v_head` state_dict and + # pop the `state_dict` from the kwargs to avoid silent bugs with `peft` + if self.is_peft_model: + save_path = args[0] + save_path = os.path.join(save_path, "pytorch_model.bin") + torch.save(state_dict, save_path) + _ = kwargs.pop("state_dict", None) + + return self.pretrained_model.save_pretrained(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Return the state_dict of the pretrained model. + """ + raise NotImplementedError + + def post_init(self, *args, **kwargs): + r""" + Post initialization method. This method is called after the model is instantiated and loaded from a checkpoint. + It can be used to perform additional operations such as loading the state_dict. + """ + raise NotImplementedError + + def compute_reward_score(self, input_ids, attention_mask=None, **kwargs): + r""" + Computes the reward score for a given input. The method has first to enable the adapter and then compute the + reward score. After that the model disables the reward modeling adapter and enables the default ppo adapter + again. + """ + if not self.supports_rm_adapter: + raise ValueError("This model does not support reward modeling adapter.") + + # enable rm adapter + self.pretrained_model.set_adapter(self.rm_adapter_name) + self.pretrained_model.eval() + + with torch.no_grad(): + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + **kwargs, + ) + + last_hidden_states = base_model_output.hidden_states[-1] + scores = self.score(last_hidden_states) + + self.pretrained_model.set_adapter(self.policy_adapter_name) + self.pretrained_model.eval() + + return scores + + +class ValueHead(nn.Module): + r""" + The ValueHead class implements a head for GPT2 that returns a scalar for each output token. + """ + + def __init__(self, config, **kwargs): + super().__init__() + if not hasattr(config, "summary_dropout_prob"): + summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) + else: + summary_dropout_prob = config.summary_dropout_prob + + self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() + + # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m + if hasattr(config, "hidden_size"): + hidden_size = config.hidden_size + if hasattr(config, "word_embed_proj_dim"): + hidden_size = config.word_embed_proj_dim + elif hasattr(config, "is_encoder_decoder"): + if config.is_encoder_decoder and hasattr(config, "decoder"): + if hasattr(config.decoder, "hidden_size"): + hidden_size = config.decoder.hidden_size + + self.summary = nn.Linear(hidden_size, 1) + + self.flatten = nn.Flatten() + + def forward(self, hidden_states): + output = self.dropout(hidden_states) + + # For now force upcast in fp32 if needed. Let's keep the + # output in fp32 for numerical stability. + if output.dtype != self.summary.weight.dtype: + output = output.to(self.summary.weight.dtype) + + output = self.summary(output) + return output + + +class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): + """ + An autoregressive model with a value head in addition to the language model head. This class inherits from + [`experimental.ppo.PreTrainedModelWrapper`] and wraps a [`~transformers.PreTrainedModel`] class. The wrapper class + supports classic functions such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped + model, simply manipulate the `pretrained_model` attribute of this class. + + Class attributes: + - **transformers_parent_class** ([`~transformers.PreTrainedModel`]) -- The parent class of the wrapped model. + This + should be set to `transformers.AutoModelForCausalLM` for this class. + - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported + by the [`ValueHead`] class. Currently, the supported args are: + - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the + [`ValueHead`] class. + - **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the + [`ValueHead`] if a specific initialization strategy is selected. + - **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the + [`ValueHead`]. Currently, the supported strategies are: + - **`None`** -- Initializes the weights of the [`ValueHead`] with a random distribution. This is the + default strategy. + - **"normal"** -- Initializes the weights of the [`ValueHead`] with a normal distribution. + """ + + transformers_parent_class = AutoModelForCausalLM + supported_args = ( + "summary_dropout_prob", + "v_head_initializer_range", + "v_head_init_strategy", + ) + + def __init__(self, pretrained_model, **kwargs): + """ + Initializes the model. + + Args: + pretrained_model ([`~transformers.PreTrainedModel`]): + The model to wrap. It should be a causal language model such as GPT2. or any model mapped inside the + `AutoModelForCausalLM` class. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the [`ValueHead`] class. + """ + super().__init__(pretrained_model, **kwargs) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) + self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) + self._init_weights(**v_head_kwargs) + + def _init_weights(self, **kwargs): + r""" + Initializes the weights of the value head. The default initialization strategy is random. Users can pass a + different initialization strategy by passing the `v_head_init_strategy` argument when calling + `.from_pretrained`. Supported strategies are: + - `normal`: initializes the weights with a normal distribution. + + Args: + **kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the [`ValueHead`] class. These arguments can contain + the `v_head_init_strategy` argument as well as the `v_head_initializer_range` argument. + """ + initializer_range = kwargs.pop("v_head_initializer_range", 0.2) + # random init by default + init_strategy = kwargs.pop("v_head_init_strategy", None) + if init_strategy is None: + # do nothing + pass + elif init_strategy == "normal": + self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) + self.v_head.summary.bias.data.zero_() + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + return_past_key_values=False, + **kwargs, + ): + r""" + Applies a forward pass to the wrapped model and returns the logits of the value head. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`): + Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `past_key_values` input) to speed up sequential decoding. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + return_past_key_values (bool): A flag indicating if the computed hidden-states should be returned. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the wrapped model. + """ + kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples + kwargs["past_key_values"] = past_key_values + + if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": + kwargs.pop("past_key_values") + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + + last_hidden_state = base_model_output.hidden_states[-1] + lm_logits = base_model_output.logits + loss = base_model_output.loss + + if last_hidden_state.device != self.v_head.summary.weight.device: + last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device) + + value = self.v_head(last_hidden_state).squeeze(-1) + + # force upcast in fp32 if logits are in half-precision + if lm_logits.dtype != torch.float32: + lm_logits = lm_logits.float() + + if return_past_key_values: + return (lm_logits, loss, value, base_model_output.past_key_values) + else: + return (lm_logits, loss, value) + + def generate(self, *args, **kwargs): + r""" + A simple wrapper around the `generate` method of the wrapped model. Please refer to the + [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils) method of the wrapped model + for more information about the supported arguments. + + Args: + *args (`list`, *optional*): + Positional arguments passed to the `generate` method of the wrapped model. + **kwargs (`dict`, *optional*): + Keyword arguments passed to the `generate` method of the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Returns the state dictionary of the model. We add the state dictionary of the value head to the state + dictionary of the wrapped model by prepending the key with `v_head.`. + """ + if not self.is_peft_model: + pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) + else: + # if it is a peft model, only save the v_head + pretrained_model_state_dict = {} + + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + pretrained_model_state_dict[f"v_head.{k}"] = v + return pretrained_model_state_dict + + def push_to_hub(self, *args, **kwargs): + self.pretrained_model.v_head = self.v_head + + return self.pretrained_model.push_to_hub(*args, **kwargs) + + def post_init(self, state_dict): + r""" + We add the state dictionary of the value head to the state dictionary of the wrapped model by prepending the + key with `v_head.`. This function removes the `v_head.` prefix from the keys of the value head state + dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + if hasattr(self.pretrained_model, "hf_device_map"): + if ( + "cpu" in self.pretrained_model.hf_device_map.values() + or "disk" in self.pretrained_model.hf_device_map.values() + ): + raise ValueError( + "The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models." + ) + + first_device = list(set(self.pretrained_model.hf_device_map.values()))[0] + if isinstance(first_device, int): + if is_torch_npu_available(): + first_device = f"npu:{first_device}" + elif is_torch_xpu_available(): + first_device = f"xpu:{first_device}" + else: + first_device = f"cuda:{first_device}" + self.v_head = self.v_head.to(first_device) + + def set_device_hook(module, input, outputs): + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(first_device),) + else: + new_output += (output,) + return new_output + + self.register_forward_hook(set_device_hook) + + self.is_sequential_parallel = True + + +class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): + """ + A seq2seq model with a value head in addition to the language model head. This class inherits from + [`experimental.ppo.PreTrainedModelWrapper`] and wraps a [`~transformers.PreTrainedModel`] class. The wrapper class + supports classic functions such as `from_pretrained` and `push_to_hub` and also provides some additional + functionalities such as `generate`. + + Args: + pretrained_model ([`~transformers.PreTrainedModel`]): + The model to wrap. It should be a causal language model such as GPT2. or any model mapped inside the + [`~transformers.AutoModelForSeq2SeqLM`] class. + kwargs: + Additional keyword arguments passed along to the [`ValueHead`] class. + """ + + transformers_parent_class = AutoModelForSeq2SeqLM + lm_head_namings = ["lm_head", "embed_out", "output_projection"] + supported_args = ( + "summary_dropout_prob", + "v_head_initializer_range", + "v_head_init_strategy", + ) + + def __init__(self, pretrained_model, **kwargs): + super().__init__(pretrained_model, **kwargs) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) + self.is_encoder_decoder = True + + if not self._has_lm_head(): + raise ValueError("The model does not have a language model head, please use a model that has one.") + + self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) + + self._init_weights(**v_head_kwargs) + + def _has_lm_head(self): + # check module names of all modules inside `pretrained_model` to find the language model head + for name, _module in self.pretrained_model.named_modules(): + if any(attribute in name for attribute in self.lm_head_namings): + return True + return False + + def post_init(self, state_dict): + r""" + We add the state dictionary of the value head to the state dictionary of the wrapped model by prepending the + key with `v_head.`. This function removes the `v_head.` prefix from the keys of the value head state + dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + if hasattr(self.pretrained_model, "hf_device_map"): + if ( + "cpu" in self.pretrained_model.hf_device_map.values() + or "disk" in self.pretrained_model.hf_device_map.values() + ): + raise ValueError( + "The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models." + ) + + # get the lm_head device + for name, module in self.pretrained_model.named_modules(): + if any(attribute in name for attribute in self.lm_head_namings): + lm_head_device = module.weight.device + break + + # put v_head on the same device as the lm_head to avoid issues + self.v_head = self.v_head.to(lm_head_device) + + def set_device_hook(module, input, outputs): + r""" + A hook that sets the device of the output of the model to the device of the first parameter of the + model. + + Args: + module (`nn.Module`): + The module to which the hook is attached. + input (`tuple`): + The input to the module. + outputs (`tuple`): + The output of the module. + """ + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(lm_head_device),) + else: + new_output += (output,) + return new_output + + self.register_forward_hook(set_device_hook) + self.is_sequential_parallel = True + + def state_dict(self, *args, **kwargs): + r""" + Returns the state dictionary of the model. We add the state dictionary of the value head to the state + dictionary of the wrapped model by prepending the key with `v_head.`. + """ + if not self.is_peft_model: + pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) + else: + # if it is a peft model, only save the v_head + pretrained_model_state_dict = {} + + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + pretrained_model_state_dict[f"v_head.{k}"] = v + return pretrained_model_state_dict + + def push_to_hub(self, *args, **kwargs): + self.pretrained_model.v_head = self.v_head + + return self.pretrained_model.push_to_hub(*args, **kwargs) + + def _init_weights(self, **kwargs): + r""" + We initialize the weights of the value head. + """ + initializer_range = kwargs.pop("v_head_initializer_range", 0.2) + # random init by default + init_strategy = kwargs.pop("v_head_init_strategy", None) + if init_strategy is None: + # do nothing + pass + elif init_strategy == "normal": + self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) + self.v_head.summary.bias.data.zero_() + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + return_past_key_values=False, + **kwargs, + ): + kwargs["past_key_values"] = past_key_values + if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": + kwargs.pop("past_key_values") + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, # We force the model to output hidden states + **kwargs, + ) + + last_hidden_state = base_model_output.decoder_hidden_states[-1] + lm_logits = base_model_output.logits + loss = base_model_output.loss + + value = self.v_head(last_hidden_state).squeeze(-1) + + # force upcast in fp32 if logits are in half-precision + if lm_logits.dtype != torch.float32: + lm_logits = lm_logits.float() + + if return_past_key_values: + return (lm_logits, loss, value, base_model_output.past_key_values) + else: + return (lm_logits, loss, value) + + def generate(self, *args, **kwargs): + r""" + We call `generate` on the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) diff --git a/ICL/RL/trl_source/trl/experimental/ppo/ppo_config.py b/ICL/RL/trl_source/trl/experimental/ppo/ppo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6128203f581b0510a68740e5c931f2601a51eeb0 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/ppo/ppo_config.py @@ -0,0 +1,304 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Literal + +from transformers import TrainingArguments + + +@dataclass +class PPOConfig(TrainingArguments): + r""" + Configuration class for the [`experimental.ppo.PPOTrainer`]. + + This class includes only the parameters that are specific to PPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + dataset_num_proc (`int`, *optional*): + Number of processes to use for processing the dataset. + num_mini_batches (`int`, *optional*, defaults to `1`): + Number of minibatches to split a batch into. + total_episodes (`int`, *optional*): + Total number of episodes in the dataset. + local_rollout_forward_batch_size (`int`, *optional*, defaults to `64`): + Per rank no grad forward pass in the rollout phase. + num_sample_generations (`int`, *optional*, defaults to `10`): + Number of debugging samples generations (i.e., `generate_completions` calls) throughout training. + response_length (`int`, *optional*, defaults to `53`): + Length of the response. + stop_token (`str`, *optional*): + Specifies the stop token to use for text generation. This parameter is mutually exclusive with + `stop_token_id`. + + - `None`: No stop token is applied, unless `stop_token_id` is specified. + - `'eos'`: Uses the tokenizer's `eos_token`. + + stop_token_id (`int`, *optional*): + Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is applied, + unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`. + temperature (`float`, *optional*, defaults to `0.7`): + Sampling temperature. + missing_eos_penalty (`float`, *optional*): + Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage to + generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive + value. + sft_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): + Path to the SFT model. + world_size (`int`, *optional*): + Number of processes (GPUs) to use for the training. + num_total_batches (`int`, *optional*): + Number of total batches to train. + micro_batch_size (`int`, *optional*): + Micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`). + local_batch_size (`int`, *optional*): + Batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`). + batch_size (`int`, *optional*): + Batch size across devices (HF's `per_device_train_batch_size` * `world_size` * + `gradient_accumulation_steps`). + local_mini_batch_size (`int`, *optional*): + Mini batch size per GPU. + mini_batch_size (`int`, *optional*): + Mini batch size across GPUs. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the model to the Hub after training. + reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): + Path to the reward model. + model_adapter_name (`str`, *optional*): + Name of the train target PEFT adapter, when using LoRA with multiple adapters. + ref_adapter_name (`str`, *optional*): + Name of the reference PEFT adapter, when using LoRA with multiple adapters. + num_ppo_epochs (`int`, *optional*, defaults to `4`): + Number of epochs to train. + whiten_rewards (`bool`, *optional*, defaults to `False`): + Whether to whiten the rewards. + kl_coef (`float`, *optional*, defaults to `0.05`): + KL coefficient. + kl_estimator (`Literal["k1", "k3"]`, *optional*, defaults to `"k1"`): + Which estimator for KL-Divergence to use from [Approximating KL + Divergence](http://joschu.net/blog/kl-approx.html). Defaults to "k1", a straightforward, unbiased + estimator. Can be set to "k3", an unbiased estimator with lower variance which "appears to be a strictly + better estimator". Cannot be set to "k2", as it is used for logging purposes. + cliprange (`float`, *optional*, defaults to `0.2`): + Clip range. + vf_coef (`float`, *optional*, defaults to `0.1`): + Value function coefficient. + cliprange_value (`float`, *optional*, defaults to `0.2`): + Clip range for the value function. + gamma (`float`, *optional*, defaults to `1.0`): + Discount factor. + lam (`float`, *optional*, defaults to `0.95`): + Lambda value for GAE. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. + """ + + # Parameters whose default values are overridden from TrainingArguments + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue + # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary + # workaround here, which can be removed once we drop support for versions older than 4.57.5. + lr_scheduler_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " + "restarts." + }, + ) + + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + num_mini_batches: int = field( + default=1, + metadata={"help": "Number of minibatches to split a batch into."}, + ) + total_episodes: int | None = field( + default=None, + metadata={"help": "Total number of episodes in the dataset."}, + ) + local_rollout_forward_batch_size: int = field( + default=64, + metadata={"help": "Per rank no grad forward pass in the rollout phase."}, + ) + num_sample_generations: int = field( + default=10, + metadata={ + "help": "Number of debugging samples generations (i.e., `generate_completions` calls) throughout training." + }, + ) + response_length: int = field( + default=53, + metadata={"help": "Length of the response."}, + ) + stop_token: Literal["eos"] | None = field( + default=None, + metadata={ + "help": "Specifies the stop token to use for text generation. This parameter is mutually exclusive with " + "`stop_token_id`." + }, + ) + stop_token_id: int | None = field( + default=None, + metadata={ + "help": "Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is " + "applied, unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`." + }, + ) + temperature: float = field( + default=0.7, + metadata={"help": "Sampling temperature."}, + ) + missing_eos_penalty: float | None = field( + default=None, + metadata={ + "help": "Penalty applied to the score when the model fails to generate an EOS token. This is useful to " + "encourage to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be " + "a positive value." + }, + ) + sft_model_path: str = field( + default="EleutherAI/pythia-160m", + metadata={"help": "Path to the SFT model."}, + ) + world_size: int | None = field( + default=None, + metadata={"help": "Number of processes (GPUs) to use for the training."}, + ) + num_total_batches: int | None = field( + default=None, + metadata={"help": "Number of total batches to train."}, + ) + micro_batch_size: int | None = field( + default=None, + metadata={"help": "Micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)."}, + ) + local_batch_size: int | None = field( + default=None, + metadata={"help": "Batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)."}, + ) + batch_size: int | None = field( + default=None, + metadata={ + "help": "Batch size across devices (HF's `per_device_train_batch_size` * `world_size` * " + "`gradient_accumulation_steps`)." + }, + ) + local_mini_batch_size: int | None = field( + default=None, + metadata={"help": "Mini batch size per GPU."}, + ) + mini_batch_size: int | None = field( + default=None, + metadata={"help": "Mini batch size across GPUs."}, + ) + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the model to the Hub after training."}, + ) + reward_model_path: str = field( + default="EleutherAI/pythia-160m", + metadata={"help": "Path to the reward model."}, + ) + model_adapter_name: str | None = field( + default=None, + metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."}, + ) + ref_adapter_name: str | None = field( + default=None, + metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."}, + ) + num_ppo_epochs: int = field( + default=4, + metadata={"help": "Number of epochs to train."}, + ) + whiten_rewards: bool = field( + default=False, + metadata={"help": "Whether to whiten the rewards."}, + ) + kl_coef: float = field( + default=0.05, + metadata={"help": "KL coefficient."}, + ) + kl_estimator: Literal["k1", "k3"] = field( + default="k1", + metadata={ + "help": "Which estimator for KL-Divergence to use from Approximating KL Divergence " + "(http://joschu.net/blog/kl-approx.html). Defaults to 'k1', a straightforward, unbiased estimator. Can be " + "set to 'k3', an unbiased estimator with lower variance which 'appears to be a strictly better " + "estimator'. Cannot be set to 'k2', as it is used for logging purposes." + }, + ) + cliprange: float = field( + default=0.2, + metadata={"help": "Clip range."}, + ) + vf_coef: float = field( + default=0.1, + metadata={"help": "Value function coefficient."}, + ) + cliprange_value: float = field( + default=0.2, + metadata={"help": "Clip range for the value function."}, + ) + gamma: float = field( + default=1.0, + metadata={"help": "Discount factor."}, + ) + lam: float = field( + default=0.95, + metadata={"help": "Lambda value for GAE."}, + ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." + }, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + super().__post_init__() diff --git a/ICL/RL/trl_source/trl/experimental/ppo/ppo_trainer.py b/ICL/RL/trl_source/trl/experimental/ppo/ppo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..df2e193537ffdf676caa695781193d3ad33abc62 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/ppo/ppo_trainer.py @@ -0,0 +1,1019 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import math +import os +import textwrap +import time +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import transformers +from accelerate import Accelerator, logging +from accelerate.utils import gather_object +from datasets import Dataset +from packaging.version import Version +from torch.utils.data import DataLoader +from transformers import ( + BaseImageProcessor, + DataCollatorWithPadding, + FeatureExtractionMixin, + GenerationConfig, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, + TrainerControl, + TrainerState, +) +from transformers.integrations import get_reporting_integration_callbacks +from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK +from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback +from transformers.utils import ModelOutput, is_peft_available, is_rich_available + +from ...models.utils import create_reference_model, peft_module_casting_to_bf16, unwrap_model_for_generation +from ...trainer.base_trainer import BaseTrainer +from ...trainer.utils import ( + disable_dropout_in_model, + empty_cache, + log_table_to_comet_experiment, + pad, + prepare_deepspeed, + selective_log_softmax, +) +from ..utils import first_true_indices, get_reward +from .ppo_config import PPOConfig + + +if is_rich_available(): + from rich.console import Console + from rich.table import Table + + +logger = logging.get_logger(__name__) + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model + + +INVALID_LOGPROB = 1.0 + + +def generate( + lm_backbone: torch.nn.Module, queries: torch.Tensor, pad_token_id: int, generation_config: GenerationConfig +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Generates sequences from the language model backbone in a way that does not affect padding tokens. + + Args: + lm_backbone (`torch.nn.Module`): + The language model backbone used for generation. + queries (`torch.Tensor`): + The tensor containing the input queries. + pad_token_id (`int`): + The token ID representing the pad token. + generation_config ([`~transformers.GenerationConfig`]): + The configuration for the generation process. + + Returns: + tuple: + - `generated_sequences` (`torch.Tensor`): + The concatenated tensor of input queries and generated sequences. + - `logits` (`torch.Tensor`): + The logits output from the generation process. + """ + context_length = queries.shape[1] + attention_mask = queries != pad_token_id + input_ids = torch.masked_fill(queries, ~attention_mask, 0) + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # not needed: already adjusted in generations + # https://github.com/huggingface/transformers/blob/ac33aeeeee2a7a89b89c93c2962e6feb90daef0a/src/transformers/models/gpt2/modeling_gpt2.py#L1227-L1250 + generation_config=generation_config, + return_dict_in_generate=True, + output_scores=True, + ) + logits = torch.stack(output.scores, 1) + return torch.cat((queries, output.sequences[:, context_length:]), dim=1), logits + + +@torch.no_grad() +def batch_generation( + model: torch.nn.Module, + queries: torch.Tensor, + local_rollout_forward_batch_size: int, + pad_token_id: int, + generation_config: GenerationConfig, +): + query_responses = [] + logitss = [] + batch_size = queries.shape[0] + for i in range(0, batch_size, local_rollout_forward_batch_size): + query = queries[i : i + local_rollout_forward_batch_size] + query_response, logits = generate( + model, + query, + pad_token_id, + generation_config, + ) + query_responses.append(query_response) + logitss.append(logits) + + # padding tensors + padded_query_responses = pad(query_responses, padding_value=pad_token_id, padding_side="right") + padded_logitss = pad(logitss, padding_value=0, padding_side="right") + + # reshaping + padded_query_responses = padded_query_responses.view(-1, padded_query_responses.shape[-1])[:batch_size] + padded_logitss = padded_logitss.view(-1, *padded_logitss.shape[2:])[:batch_size] + + return padded_query_responses, padded_logitss + + +def exact_div(a, b, custom_error_message=""): + q = a // b + if a != q * b: + raise ValueError(f"{custom_error_message}, inexact division: {a} / {b} = {a / b}") + return q + + +def print_rich_table(df: pd.DataFrame) -> None: + if not is_rich_available(): + raise ImportError( + "The function `print_rich_table` requires the `rich` library. Please install it with `pip install rich`." + ) + console = Console() + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.print(table) + + +def truncate_response(stop_token_id: int, pad_token_id: int, responses: torch.Tensor) -> torch.Tensor: + """ + Truncates the responses at the first occurrence of the stop token, filling the rest with pad tokens. + + Args: + stop_token_id (`int`): + The token ID representing the stop token where truncation occurs. + pad_token_id (`int`): + The token ID representing the pad token used to fill the truncated responses. + responses (`torch.Tensor`): + The tensor containing the responses to be truncated. + + Returns: + `torch.Tensor`: + The truncated responses tensor with pad tokens filled after the stop token. + """ + trunc_idxs = first_true_indices(responses == stop_token_id).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [responses.shape[1]] + idxs = torch.arange(responses.shape[1], device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, pad_token_id) + return postprocessed_responses + + +def forward( + model: torch.nn.Module, + query_responses: torch.Tensor, + pad_token_id: int, +) -> ModelOutput: + """ + Performs a forward pass through the model with the given query responses and pad token ID. + + Args: + model (`torch.nn.Module`): + The model to perform the forward pass. + query_responses (`torch.Tensor`): + The tensor containing the query responses. + pad_token_id (`int`): + The token ID representing the pad token. + + Returns: + `ModelOutput`: + The output of the model, including hidden states. + """ + attention_mask = query_responses != pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +@dataclass +class OnlineTrainerState(TrainerState): + """ + Training state for online/on-policy trainers. + + Extends [`~transformers.TrainerState`] with an `episode` counter to track the current rollout/episode. + + Args: + episode (`int`, defaults to 0): Zero-based episode index. + """ + + episode: int = 0 + + +def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool | None = None) -> torch.Tensor: + """Compute mean of tensor with a masked values.""" + if axis is not None: + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + else: + return (values * mask).sum() / mask.sum() + + +def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" + "try increase the `mini_batch_size` or `gradient_accumulation_steps`" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, value_model) -> None: + super().__init__() + self.policy = policy + self.value_model = value_model + self.critic_backbone = getattr(value_model, value_model.base_model_prefix) + self.is_gradient_checkpointing = policy.is_gradient_checkpointing + + def forward(self, **kwargs): + output = self.critic_backbone(**kwargs) + logits = self.value_model.score(output.hidden_states[-1]) + return self.policy(**kwargs), logits + + +class PPOTrainer(BaseTrainer): + """Trainer for Proximal Policy Optimization (PPO). + + For details on PPO, see the paper: [Proximal Policy Optimization + Algorithms](https://huggingface.co/papers/1707.06347). + + Args: + args ([`experimental.ppo.PPOConfig`]): + Training arguments. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`]): + Class to process the data. + model (`torch.nn.Module`): + Model to be trained. This is the policy model. + ref_model (`torch.nn.Module`, *optional*): + Reference model used to compute the KL divergence. If `None`, a copy of the policy model is created. + reward_model (`torch.nn.Module`): + Reward model used to compute the rewards. + train_dataset ([`~datasets.Dataset`]): + Dataset for training. + value_model (`torch.nn.Module`): + Value model used to predict the value of a state. + data_collator ([`~transformers.DataCollatorWithPadding`], *optional*): + Data collator to batch and pad samples from the dataset. If `None`, a default data collator is created + using the `processing_class`. + eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*): + Dataset for evaluation. + optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`): + Tuple containing the optimizer and the learning rate scheduler to use for training. If `None`, the + optimizer and the learning rate scheduler are created using the + [`~transformers.Trainer.create_optimizer_and_scheduler`] method. + callbacks (`list` of [`~transformers.TrainerCallback`], *optional*): + Callbacks to use during training. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the policy `model` + will be wrapped with the specified PEFT adapter. + """ + + _tag_names = ["trl", "ppo"] + _name = "PPO" + _paper = { + "title": "Fine-Tuning Language Models from Human Preferences", + "id": "1909.08593", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @article{mziegler2019fine-tuning, + title = {{Fine-Tuning Language Models from Human Preferences}}, + author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving}, + year = 2019, + eprint = {arXiv:1909.08593} + }"""), + } + + def __init__( + self, + args: PPOConfig, + processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin, + model: nn.Module, + ref_model: nn.Module | None, + reward_model: nn.Module, + train_dataset: Dataset, + value_model: nn.Module, + data_collator: DataCollatorWithPadding | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + # less commonly used + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + callbacks: list[TrainerCallback] | None = None, + peft_config: "PeftConfig | None" = None, + ) -> None: + if ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, you must make a copy of it, or `None` if you use peft." + ) + + self.args = args + self.processing_class = processing_class + self.policy_model = model + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + # Define the collator if not provided + if data_collator is None: + data_collator = DataCollatorWithPadding(self.processing_class) + + # Handle stop token settings: update policy model's generation_config to use provided stop token + if args.stop_token and args.stop_token_id: + raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") + elif args.stop_token: + if args.stop_token == "eos": + self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id + else: + raise ValueError( + f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)." + ) + else: + self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int + + # Check that the kl estimator is valid + if self.args.kl_estimator not in {"k1", "k3"}: + raise ValueError( + "kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, " + "appears to be a strictly better estimator). See " + "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details." + ) + + # peft support + if not is_peft_available() and peft_config is not None: + raise ImportError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + if isinstance(self.policy_model, PeftModel): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " + "merge and unload the existing adapter, save the resulting base model, and then pass that base " + "model along with the new `peft_config` to the trainer." + ) + + # get peft model with the given config + self.policy_model = get_peft_model(self.policy_model, peft_config) + if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(self.policy_model) + + self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel) + self.model_adapter_name = args.model_adapter_name + self.ref_adapter_name = args.ref_adapter_name + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model: + self.ref_model = None + else: + self.ref_model = create_reference_model(self.policy_model) + + self.reward_model = reward_model + self.train_dataset = train_dataset + self.train_dataset_len = len(train_dataset) + self.value_model = value_model + self.data_collator = data_collator + self.eval_dataset = eval_dataset + self.optimizer, self.lr_scheduler = optimizers + self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47 + + ######### + # calculate various batch sizes + ######### + if args.total_episodes is None: # allow the users to define episodes in terms of epochs. + args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + self.accelerator = accelerator + args.world_size = accelerator.num_processes + args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + args.mini_batch_size = exact_div( + args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" + ) + args.local_mini_batch_size = exact_div( + args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" + ) + if args.whiten_rewards: + assert args.local_mini_batch_size >= 8, ( + f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + ) + # `per_rank_rollout_batch_size` is our `args.local_batch_size` + # `per_rank_minibatch_size` is our `args.local_mini_batch_size` + args.num_total_batches = math.ceil( + args.total_episodes / args.batch_size + ) # we may train for more than `total_episodes` + self.local_seed = args.seed + accelerator.process_index * 100003 # Prime + if args.num_sample_generations > 0: + self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) + self.local_dataloader_batch_size = args.local_batch_size + + ######### + # setup model, optimizer, and others + ######### + for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]: + if module is not None: + disable_dropout_in_model(module) + self.model = PolicyAndValueWrapper(self.policy_model, self.value_model) + self.model.config = self.policy_model.config # needed for pushing to hub + self.create_optimizer_and_scheduler( + num_training_steps=args.num_total_batches + ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level + + ######### + # trainer specifics + ######### + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + self.control = TrainerControl() + self.state = OnlineTrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ], + ) + self.current_flos = 0 + self.hp_search_backend = None + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + # Create distant repo and output directory if needed + self.hub_model_id = None + if self.args.push_to_hub: + self.init_hf_repo() + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + ######### + # setup dataloader + ######### + self.dataloader = DataLoader( + self.train_dataset, + batch_size=self.local_dataloader_batch_size, + shuffle=True, + collate_fn=self.data_collator, + drop_last=True, # needed; otherwise the last batch will be of ragged shape + ) + # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` + # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c + torch.manual_seed(args.seed) + self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) + torch.manual_seed(self.local_seed) # reset the local seed again + + self.eval_dataloader = DataLoader( + self.eval_dataset, + batch_size=args.per_device_eval_batch_size, + collate_fn=self.data_collator, + drop_last=True, + ) # no need to shuffle eval dataset + self.eval_dataloader = accelerator.prepare(self.eval_dataloader) + + if self.is_deepspeed_enabled: + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + + if self.ref_model is None: + if not self.is_peft_model: + raise ValueError("No reference model and model is not a Peft model.") + else: + self.ref_model = prepare_deepspeed( + self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + else: + if self.ref_model is None: + if not self.is_peft_model: + raise ValueError("No reference model and model is not a Peft model.") + else: + self.ref_model = self.ref_model.to(self.accelerator.device) + self.reward_model = self.reward_model.to(self.accelerator.device) + + def get_train_dataloader(self) -> DataLoader: + return self.dataloader + + def get_eval_dataloader(self) -> DataLoader: + return self.eval_dataloader + + @contextmanager + def null_ref_context(self): + """Context manager for handling null reference model (that is, peft adapter manipulation).""" + with ( + self.accelerator.unwrap_model(self.model.policy).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): + if self.ref_adapter_name: + self.model.policy.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.policy.set_adapter(self.model_adapter_name or "default") + + def save_model(self, output_dir: str | None = None, _internal_call: bool = False): + backup_model = self.model + self.model = self.model.policy # save only the policy + + if self.is_deepspeed_enabled: + backup_deepspeed = self.deepspeed + self.deepspeed = self.model + + super().save_model(output_dir, _internal_call) + + self.model = backup_model + + if self.is_deepspeed_enabled: + self.deepspeed = backup_deepspeed + + def train(self): + args = self.args + accelerator = self.accelerator + optimizer = self.optimizer + model = self.model + ref_policy = self.ref_model + reward_model = self.reward_model + processing_class = self.processing_class + dataloader = self.dataloader + device = accelerator.device + + def repeat_generator(): + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + generation_kwargs = { + "max_new_tokens": args.response_length, + "temperature": (args.temperature + 1e-7), + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + } + generation_config = GenerationConfig(**generation_kwargs) + + accelerator.print("===training policy===") + start_time = time.time() + stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + + # trainer state initialization + self.state.global_step = 0 + self.state.episode = 0 + self.state.max_steps = args.num_total_batches + self.state.num_train_epochs = args.total_episodes / self.train_dataset_len + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model + self.model_wrapped = self.model + + for update in range(1, args.num_total_batches + 1): + self.state.episode += 1 * args.batch_size + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["input_ids"].to(device) + context_length = queries.shape[1] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + scores = [] + sequence_lengths = [] + values = [] + with ( + unwrap_model_for_generation( + self.model, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + generation_kwargs=generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 + ) as unwrapped_model + ): + query_responses, logitss = batch_generation( + unwrapped_model.policy, + queries, + args.local_rollout_forward_batch_size, + processing_class.pad_token_id, + generation_config, + ) + + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + logits = logitss[i : i + args.local_rollout_forward_batch_size] + logprob = selective_log_softmax(logits, response) + del logits + empty_cache() + + if ref_policy is None: + with self.null_ref_context(): + ref_output = forward(model.policy, query_response, processing_class.pad_token_id) + else: + ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_logprob = selective_log_softmax(ref_logits, response) + del ref_output, ref_logits + empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + self.stop_token_id, processing_class.pad_token_id, response + ) + + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 + unwrapped_value_model = accelerator.unwrap_model(model).value_model + full_value, _, _ = get_reward( + unwrapped_value_model, query_response, processing_class.pad_token_id, context_length + ) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + _, score, _ = get_reward( + reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + values.append(value) + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + values = torch.cat(values, 0) + del (logprob, ref_logprob, full_value, value, score, unwrapped_model) + empty_cache() + gc.collect() + + # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id + # Completions not passing that filter will receive a lower score. + contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1) + if self.args.missing_eos_penalty is not None: + scores[~contain_eos_token] -= self.args.missing_eos_penalty + # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + sequence_lengths_p1 = sequence_lengths + 1 + padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) + values = torch.masked_fill(values, padding_mask_p1, 0) + + # 4. compute rewards + # Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators + logr = ref_logprobs - logprobs + kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3 + non_score_reward = -args.kl_coef * kl + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) + rewards[actual_start, actual_end] += scores + + # 5. whiten rewards + if args.whiten_rewards: + rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) + rewards = torch.masked_fill(rewards, padding_mask_p1, 0) + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = responses.shape[1] + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.gamma * args.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = masked_whiten(advantages, ~padding_mask) + advantages = torch.masked_fill(advantages, padding_mask, 0) + empty_cache() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.num_ppo_epochs): + b_inds = np.random.permutation(args.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + with accelerator.accumulate(model): + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + mb_return = returns[micro_batch_inds] + mb_values = values[micro_batch_inds] + + output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.temperature + 1e-7 + new_logprobs = selective_log_softmax(logits, mb_responses) + new_logprobs = torch.masked_fill( + new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB + ) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.cliprange_value, + mb_values + args.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss_max = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds]) + vf_clipfrac = masked_mean( + (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds] + ) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) + loss = pg_loss + args.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + with torch.no_grad(): + pg_clipfrac = masked_mean( + (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] + ) + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + pg_clipfrac + ) + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + vf_clipfrac + ) + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + # del everything and empty cache + # fmt: off + del ( + output, vpred_temp, logits, new_logprobs, vpred, vpredclipped, + vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, + pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, + mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, + ) + # fmt: on + empty_cache() + with torch.no_grad(): + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + rlhf_reward = mean_non_score_reward + scores.mean() + eps = int(self.state.episode / (time.time() - start_time)) + metrics = {} + metrics["eps"] = eps + metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item() + metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item() + metrics["objective/non_score_reward"] = ( + self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() + ) + metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item() + metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item() + metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item() + metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item() + metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item() + metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item() + metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item() + metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item() + metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item() + metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item() + metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() + metrics["lr"] = self.lr_scheduler.get_last_lr()[0] + metrics["episode"] = self.state.episode + self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log + self.state.global_step += 1 + self.log(metrics) + + self.lr_scheduler.step() + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward + empty_cache() + gc.collect() + + if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: + self.generate_completions(sampling=True) + empty_cache() + del ( + query_responses, + responses, + postprocessed_responses, + logprobs, + ref_logprobs, + values, + sequence_lengths, + contain_eos_token, + sequence_lengths_p1, + response_idxs, + padding_mask, + padding_mask_p1, + rewards, + actual_start, + actual_end, + advantages, + returns, + ) + empty_cache() + + # HF trainer specifics + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def generate_completions(self, sampling: bool = False): + args = self.args + processing_class = self.processing_class + generation_kwargs = { + "max_new_tokens": args.response_length, + "temperature": (0.01 + 1e-7), + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + } + generation_config = GenerationConfig(**generation_kwargs) + + table = defaultdict(list) + with ( + unwrap_model_for_generation( + self.model, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + generation_kwargs=generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 + ) as unwrapped_model + ): + for batch in self.eval_dataloader: + query = batch["input_ids"] + with torch.no_grad(): + context_length = query.shape[1] + query_response, _ = batch_generation( + unwrapped_model.policy, + query, + query.shape[0], + processing_class.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + postprocessed_response = response + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + self.stop_token_id, processing_class.pad_token_id, response + ) + table["query"].extend( + gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) + ) + table["model response"].extend( + gather_object(processing_class.batch_decode(postprocessed_response)) + ) + + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy()) + + if sampling: + break + df = pd.DataFrame(table) + + if self.accelerator.is_main_process: + if is_rich_available(): + print_rich_table(df.iloc[0 : 0 + 5]) + if "wandb" in args.report_to: + import wandb + + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) + + if "comet_ml" in args.report_to: + log_table_to_comet_experiment( + name="completions.csv", + table=df, + ) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/ICL/RL/trl_source/trl/experimental/prm/__init__.py b/ICL/RL/trl_source/trl/experimental/prm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..210c474aa459261c71906d4e1b11994a08dddd4c --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/prm/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .prm_config import PRMConfig +from .prm_trainer import PRMTrainer + + +__all__ = ["PRMConfig", "PRMTrainer"] diff --git a/ICL/RL/trl_source/trl/experimental/prm/prm_config.py b/ICL/RL/trl_source/trl/experimental/prm/prm_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6c21c219669376b64fa7632e4ae581a9c03e200b --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/prm/prm_config.py @@ -0,0 +1,106 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from transformers import TrainingArguments + + +@dataclass +class PRMConfig(TrainingArguments): + r""" + Configuration class for the [`experimental.prm.PRMTrainer`]. + + This class includes only the parameters that are specific to PRM training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the sequences (prompt + completion) used for truncation. + max_completion_length (`int`, *optional*): + Maximum length of the completion used for truncation. The completion is the concatenation of the steps. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + step_separator (`str`, *optional*, defaults to `"\n"`): + Separator used to separate each step of the reasoning process. + train_on_last_step_only (`bool`, *optional*, defaults to `False`): + Whether to train only on the last step. + dataset_num_proc (`int`, *optional*): + Number of processes to use for processing the dataset. + """ + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-5, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + + max_length: int | None = field( + default=1024, + metadata={"help": "Maximum length of the sequences (prompt + completion) used for truncation."}, + ) + max_completion_length: int | None = field( + default=None, + metadata={ + "help": "Maximum length of the completion used for truncation. The completion is the concatenation of the " + "steps." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model and reference model."}, + ) + step_separator: str = field( + default="\n", + metadata={"help": "Separator used to separate each step of the reasoning process."}, + ) + train_on_last_step_only: bool = field( + default=False, + metadata={"help": "Whether to train only on the last step."}, + ) + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + super().__post_init__() diff --git a/ICL/RL/trl_source/trl/experimental/utils.py b/ICL/RL/trl_source/trl/experimental/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd0a0e2f0e3de3325087fcebf824583f9682fc9 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/utils.py @@ -0,0 +1,509 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file contains utility classes and functions that are used across more than one experimental trainer or feature. + +import inspect +from dataclasses import dataclass +from typing import Any + +import torch +from accelerate.utils import is_peft_model +from packaging.version import Version +from torch.nn.utils.rnn import pad_sequence +from transformers import PreTrainedModel, PreTrainedTokenizerBase, TrainingArguments +from transformers.utils import is_peft_available + +from ..models.utils import peft_module_casting_to_bf16 +from ..trainer.utils import pad + + +if is_peft_available(): + import peft + from peft import PeftConfig, PeftModel, get_peft_model + + +@dataclass +class DPODataCollatorWithPadding: + r""" + DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch. + + Args: + pad_token_id (`int` defaults to 0): + The tokenizer's pad_token_id. + is_encoder_decoder (`bool` or `None`, `optional`, defaults to `None`): + Whether you model has an encoder_decoder architecture. + """ + + pad_token_id: int = 0 + is_encoder_decoder: bool | None = False + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: + # first, pad everything to the same length + padded_batch = {} + for k in features[0].keys(): + if k.endswith(("_input_ids", "_attention_mask", "_labels", "_pixel_values")): + if self.is_encoder_decoder: + to_pad = [torch.LongTensor(ex[k]) for ex in features] + + if (k.startswith("prompt")) and (k.endswith("input_ids")): + if self.pad_token_id is None: + raise ValueError( + "Padding is enabled, but the tokenizer is not configured with a padding token." + " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" + " before calling the trainer." + ) + padding_value = self.pad_token_id + elif k.endswith("_attention_mask"): + padding_value = 0 + elif k.startswith(("chosen", "rejected", "completion")) or ("decoder" in k): + padding_value = -100 + else: + raise ValueError(f"Unexpected key in batch '{k}'") + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) + else: + # Set padding value based on the key + if k.endswith("_input_ids"): + if self.pad_token_id is None: + raise ValueError( + "Padding is enabled, but the tokenizer is not configured with a padding token." + " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" + " before calling the trainer." + ) + padding_value = self.pad_token_id + elif k.endswith("_labels"): + padding_value = -100 + elif k.endswith("_attention_mask"): + padding_value = 0 + elif k.endswith("_pixel_values"): + padding_value = 0 # TODO: check if this is correct + else: + raise ValueError(f"Unexpected key in batch '{k}'") + + # Set padding side based on the key + if k in ["prompt_input_ids", "prompt_attention_mask"]: + padding_side = "left" + else: + padding_side = "right" + + # Set the dtype + if k.endswith("_pixel_values"): + dtype = torch.float32 # will be downcasted if necessary by the Trainer + else: + dtype = torch.int64 + + # Convert to tensor and pad + to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features] + padded_batch[k] = pad(to_pad, padding_value=padding_value, padding_side=padding_side) + elif k.endswith("_logps"): + # the cached reference model logprobs + padded_batch[k] = torch.tensor([ex[k] for ex in features]) + else: + padded_batch[k] = [ex[k] for ex in features] + + return padded_batch + + +@dataclass +class DataCollatorForChatML: + """ + Data collator for ChatML format datasets. + """ + + tokenizer: PreTrainedTokenizerBase + ignore_index: int = -100 + max_length: int = None + prompt_key: str = "prompt" + messages_key: str = "messages" + + def __post_init__(self): + if self.tokenizer.pad_token_id is None: + raise ValueError("The tokenizer does not have a pad token. Please set `pad_token_id` in the tokenizer.") + if self.max_length is None: + # set a sensible default + self.max_length = min(self.tokenizer.model_max_length, 1024) + + def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: + input_ids = [] + attention_mask = [] + prompts_input_ids = [] + prompt_attention_mask = [] + labels = [] + + for example in examples: + formatted_prompt = example.get(self.prompt_key, None) + if formatted_prompt is None: + prompt = example[self.messages_key][:-1] + formatted_prompt = self.tokenizer.apply_chat_template( + prompt, add_generation_prompt=True, tokenize=False + ) + + if "input_ids" not in example: + message = example[self.messages_key] + formatted_message = self.tokenizer.apply_chat_template( + message, add_generation_prompt=False, tokenize=False + ) + + tokenized_message = self.tokenizer( + formatted_message, + truncation=False, + padding=False, + return_tensors=None, + add_special_tokens=False, + return_offsets_mapping=True, + ) + message_input_ids_full = tokenized_message["input_ids"] + offsets = tokenized_message.get("offset_mapping") + + if offsets is not None: + prompt_char_len = len(formatted_prompt) + completion_start_idx_full = next( + (idx for idx, (start, _) in enumerate(offsets) if start >= prompt_char_len), + len(message_input_ids_full), + ) + else: + tokenized_prompt_full = self.tokenizer( + formatted_prompt, + truncation=False, + padding=False, + return_tensors=None, + add_special_tokens=False, + ) + completion_start_idx_full = len(tokenized_prompt_full["input_ids"]) + + prompt_tokens_full = message_input_ids_full[:completion_start_idx_full] + completion_input_ids_full = message_input_ids_full[completion_start_idx_full:] + + if self.max_length is not None and len(message_input_ids_full) > self.max_length: + completion_ids = completion_input_ids_full + if len(completion_ids) >= self.max_length: + completion_ids = completion_ids[-self.max_length :] + prompt_ids = [] + else: + max_prompt_tokens = self.max_length - len(completion_ids) + prompt_ids = prompt_tokens_full[-max_prompt_tokens:] if max_prompt_tokens > 0 else [] + message_input_ids = prompt_ids + completion_ids + else: + message_input_ids = message_input_ids_full + prompt_ids = prompt_tokens_full + + input_ids.append(message_input_ids) + attention_mask.append([1] * len(message_input_ids)) + current_prompt_ids = prompt_ids + else: + message_input_ids = example["input_ids"] + input_ids.append(message_input_ids) + if "attention_mask" in example: + attention_mask.append(example["attention_mask"]) + else: + attention_mask.append([1] * len(message_input_ids)) + + tokenized_prompt = self.tokenizer( + formatted_prompt, + truncation=True, + max_length=len(message_input_ids), + padding=False, + return_tensors=None, + add_special_tokens=False, + ) + current_prompt_ids = tokenized_prompt["input_ids"] + + prompts_input_ids.append(current_prompt_ids) + prompt_attention_mask.append([1] * len(current_prompt_ids)) + + label = [self.ignore_index] * len(input_ids[-1]) + completion_start_idx = len(current_prompt_ids) + label[completion_start_idx:] = input_ids[-1][completion_start_idx:] + labels.append(label) + + # convert to list of tensors and pad + input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids] + attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask] + labels = [torch.tensor(label, dtype=torch.long) for label in labels] + input_ids = pad(input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) + attention_mask = pad(attention_mask, padding_side="left", padding_value=0) + labels = pad(labels, padding_side="left", padding_value=self.ignore_index) + + prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in prompts_input_ids] + prompt_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in prompt_attention_mask] + prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) + prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + "prompts": prompts_input_ids, + "prompt_attention_mask": prompt_attention_mask, + } + + +def truncate_right( + input_ids: torch.Tensor, stop_token_id: int, pad_token_id: int +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Truncates the input tensor from the right side after the first occurrence of the stop token. + + Args: + input_ids (`torch.Tensor`): + The tensor containing the responses to be truncated + stop_token_id (`int`): + The token ID representing the stop token where truncation occurs + pad_token_id (`int`): + The token ID representing the pad token used to fill the truncated responses + + Returns: + tuple: + - `output_ids` (`torch.Tensor`): + The truncated responses tensor with pad tokens filled after the stop token + - `mask` (`torch.Tensor`): + The mask tensor to indicate the padding tokens + """ + trunc_idxs = first_true_indices(input_ids == stop_token_id).unsqueeze(-1) + new_size = [1] * (len(input_ids.size()) - 1) + [input_ids.shape[1]] + idxs = torch.arange(input_ids.shape[1], device=input_ids.device).view(*new_size) + output_ids = torch.masked_fill(input_ids, idxs > trunc_idxs, pad_token_id) + mask = torch.masked_fill(torch.ones_like(input_ids), idxs > trunc_idxs, 0) + return output_ids, mask + + +SIMPLE_CHAT_TEMPLATE = "{% for message in messages %}{{message['role'].capitalize() + ': ' + message['content'] + '\n\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" + + +def add_bos_token_if_needed( + bos_token_id: int | None, + prompt_len_input_ids: int, + prompt_tokens: dict[str, list[int]], + chosen_prompt_len_input_ids: int, + chosen_tokens: dict[str, list[int]], + rejected_prompt_len_input_ids: int, + rejected_tokens: dict[str, list[int]], +): + if bos_token_id is not None: + if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: + prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"] + prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] + if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]: + chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"] + chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] + if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]: + rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"] + rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + return prompt_tokens, chosen_tokens, rejected_tokens + + +def add_eos_token_if_needed( + eos_token_id: int, chosen_tokens: dict[str, list[int]], rejected_tokens: dict[str, list[int]] +): + if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]: + chosen_tokens["input_ids"].append(eos_token_id) + chosen_tokens["attention_mask"].append(1) + if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]: + rejected_tokens["input_ids"].append(eos_token_id) + rejected_tokens["attention_mask"].append(1) + return chosen_tokens, rejected_tokens + + +def first_true_indices(bools: torch.Tensor, dtype=torch.long) -> torch.Tensor: + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving the position of the + first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + + Args: + bools (`torch.Tensor`): + An N-dimensional boolean tensor. + dtype (`torch.dtype`, optional): + The desired data type of the output tensor. Defaults to `torch.long`. + + Returns: + `torch.Tensor`: + An (N-1)-dimensional tensor of integers indicating the position of the first True in each row. If no True + value is found in a row, returns the length of the row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def get_reward( + model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes the reward logits and the rewards for a given model and query responses. + + Args: + model (`torch.nn.Module`): + The model used to compute the reward logits. + query_responses (`torch.Tensor`): + The tensor containing the query responses. + pad_token_id (`int`): + The token ID representing the pad token. + context_length (`int`): + The length of the context in the query responses. + + Returns: + tuple: + - `reward_logits` (`torch.Tensor`): + The logits for the reward model. + - `final_rewards` (`torch.Tensor`): + The final rewards for each query response. + - `sequence_lengths` (`torch.Tensor`): + The lengths of the sequences in the query responses. + """ + attention_mask = query_responses != pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + lm_backbone = getattr(model, model.base_model_prefix) + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + output = lm_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + use_cache=False, # otherwise mistral-based RM would error out + ) + reward_logits = model.score(output.hidden_states[-1]) + sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return ( + reward_logits, + reward_logits[ + torch.arange(reward_logits.size(0), device=reward_logits.device), + sequence_lengths, + ].squeeze(-1), + sequence_lengths, + ) + + +def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None): + r""" + Prepare a k-bit quantized transformers model for training (PEFT/QLoRA). + """ + loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) + quant_methods = ["gptq", "aqlm", "eetq", "torchao", "hqq"] + is_quantized = getattr(model, "quantization_method", None) in quant_methods or getattr( + model, "hqq_quantized", False + ) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + + for _, param in model.named_parameters(): + # freeze all parameters + param.requires_grad = False + + # Enable gradient checkpointing if needed + if (loaded_in_kbit or is_quantized) and use_gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + # backward-compatible hook + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( + inspect.signature(model.gradient_checkpointing_enable).parameters + ) + gc_kwargs = {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} if supports_gc_kwargs else {} + model.gradient_checkpointing_enable(**gc_kwargs) + + return model + + +def enable_gradient_checkpointing( + model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None +) -> PreTrainedModel: + """Enables gradient checkpointing for the model.""" + # Enable gradient checkpointing on the base model for PEFT + if is_peft_model(model): + model.base_model.gradient_checkpointing_enable() + # Enable gradient checkpointing for non-PEFT models + else: + model.gradient_checkpointing_enable() + + gradient_checkpointing_kwargs = gradient_checkpointing_kwargs or {} + use_reentrant = ( + "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] + ) + + if use_reentrant: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + return model + + +def prepare_peft_model( + model: PreTrainedModel, peft_config: "PeftConfig | None", args: TrainingArguments +) -> PreTrainedModel: + """Prepares a model for PEFT training.""" + if not is_peft_available(): + raise ImportError("PEFT is required to use a peft model. Run `pip install peft`.") + + if isinstance(model, PeftModel) and peft_config is not None: + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge and " + "unload the existing adapter, save the resulting base model, and then pass that base model along with the " + "new `peft_config` to the trainer." + ) + + # Handle quantized models (QLoRA) + is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False) + + is_sharded_qlora = False + if getattr(model, "is_loaded_in_4bit", False): + # Check if model is sharded (FSDP/DS-Zero3) + for _, param in model.named_parameters(): + if param.__class__.__name__ == "Params4bit": + is_sharded_qlora = param.data.device.type in {"cpu", "meta"} + break + + # Prepare model for kbit training if needed + if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel): + model = prepare_model_for_kbit_training( + model, + use_gradient_checkpointing=args.gradient_checkpointing, + gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs or {}, + ) + # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training + args.gradient_checkpointing = False + elif args.gradient_checkpointing: + model = enable_gradient_checkpointing(model, args.gradient_checkpointing_kwargs) + + # Create PEFT model + if peft_config is not None: + if ( + Version(peft.__version__) >= Version("0.12") # autocast_adapter_dtype introduced in 0.12 + and getattr(model, "is_loaded_in_4bit", False) + and is_sharded_qlora + ): + model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) + else: + model = get_peft_model(model, peft_config) + + # Handle bf16 casting for 4-bit models + if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora: + peft_module_casting_to_bf16(model) + + return model diff --git a/ICL/RL/trl_source/trl/experimental/winrate_callback.py b/ICL/RL/trl_source/trl/experimental/winrate_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd0c44a05517fb07d3d2c4aef7b87ea378f4c27 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/winrate_callback.py @@ -0,0 +1,285 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import pandas as pd +from accelerate import Accelerator +from accelerate.utils import gather_object, is_wandb_available +from transformers import ( + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + +from ..models.utils import unwrap_model_for_generation +from ..trainer.utils import log_table_to_comet_experiment + + +if is_wandb_available(): + import wandb + +# Logger for module-level logging +logger = logging.getLogger(__name__) + + +def _generate_completions( + prompts: list[str], + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + accelerator: Accelerator, + generation_config: GenerationConfig | None, + batch_size: int = 1, +) -> list[str]: + """ + Generates completions for a list of pre-formatted prompts from the given model. + + Args: + prompts (list[str]): A list of input prompts for which completions are to be generated. + model (PreTrainedModel): The pre-trained model to be used for generation. + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for encoding and decoding. + accelerator (Accelerator): The accelerator to be used for model execution. + generation_config (GenerationConfig): Configuration for text generation. + batch_size (int, optional): The number of prompts to process in each batch. Default is 1. + + Returns: + list[str]: A list of generated text completions corresponding to the input prompts. + """ + completions = [] + # TODO: Override model.generation_config with generation_kwargs + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + for idx in range(0, len(prompts), batch_size): + batch = prompts[idx : idx + batch_size] + tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device) + generations = unwrapped_model.generate( + **tokenized_batch, + generation_config=generation_config, + ) + for prompt, generation in zip(tokenized_batch.input_ids, generations, strict=True): + # Remove prompt from generation + generation = generation[len(prompt) :] + completion = tokenizer.decode(generation, skip_special_tokens=True) + completions.append(completion) + return completions + + +def _win_rate_completions_df( + state: TrainerState, prompts: list[str], completions: list[str], winner_indices: list[str] +) -> pd.DataFrame: + global_step = [str(state.global_step)] * len(prompts) + data = list(zip(global_step, prompts, completions, winner_indices, strict=True)) + # Split completions from reference model and policy + split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data] + return pd.DataFrame(split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"]) + + +class WinRateCallback(TrainerCallback): + """ + A [`~transformers.TrainerCallback`] that computes the win rate of a model based on a reference. + + It generates completions using prompts from the evaluation dataset and compares the trained model's outputs against + a reference. The reference is either the initial version of the model (before training) or the reference model, if + available in the trainer. During each evaluation step, a judge determines how often the trained model's completions + win against the reference using a judge. The win rate is then logged in the trainer's logs under the key + `"eval_win_rate"`. + + Usage: + ```python + from trl import DPOTrainer + from trl.experimental.judges import PairRMJudge + from trl.experimental.winrate_callback import WinRateCallback + + trainer = DPOTrainer(...) + judge = PairRMJudge() + win_rate_callback = WinRateCallback(judge=judge, trainer=trainer) + trainer.add_callback(win_rate_callback) + ``` + + Args: + judge ([`experimental.judges.BasePairwiseJudge`]): + The judge to use for comparing completions. + trainer (`Trainer`): + Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` + column containing the prompts for generating completions. If the `Trainer` has a reference model (via the + `ref_model` attribute), it will use this reference model for generating the reference completions; + otherwise, it defaults to using the initial model. + generation_config ([`~transformers.GenerationConfig`], *optional*): + The generation config to use for generating completions. + num_prompts (`int`, *optional*): + The number of prompts to generate completions for. If not provided, defaults to the number of examples in + the evaluation dataset. + shuffle_order (`bool`, *optional*, defaults to `True`): + Whether to shuffle the order of the completions before judging. + use_soft_judge (`bool`, *optional*, defaults to `False`): + Whether to use a soft judge that returns a win probability between 0 and 1 for the first completion vs the + second. + """ + + def __init__( + self, + judge, + trainer: Trainer, + generation_config: GenerationConfig | None = None, + num_prompts: int | None = None, + shuffle_order: bool = True, + use_soft_judge: bool = False, + ): + self.judge = judge + self.trainer = trainer + self.shuffle_order = shuffle_order + self.generation_config = generation_config + self.ref_completions = [] + self.use_soft_judge = use_soft_judge + + if self.trainer.eval_dataset is None: + raise ValueError("Trainer must have an evaluation dataset to use the WinRateCallback.") + else: + self.eval_dataset = self.trainer.eval_dataset + + if num_prompts is not None: + self.eval_dataset = self.eval_dataset.select(range(num_prompts)) + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + # When the trainer is initialized, we generate completions for the reference model. + tokenizer = kwargs["processing_class"] + tokenizer.padding_side = "left" + accelerator = self.trainer.accelerator + # Use the reference model if available, otherwise use the initial model + model = getattr(self.trainer, "ref_model", None) + # At this point, there are two cases where `ref_model` is None: + # 1. The method doesn't require a reference model. + # 2. The method uses a reference model, but `ref_model` is set to None. + # This occurs when using PEFT, where the reference model can be obtained by simply disabling the model's adapter. + # In theory, we should disable the adapter here, but since it's zero-initialized at the start of training, + # the model behaves identically with or without the adapter. + # Therefore, there's no need to explicitly disable it at this point. + if model is None: + model = self.trainer.model_wrapped + with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts: + self.ref_completions = _generate_completions( + prompts, + model=model, + tokenizer=tokenizer, + accelerator=accelerator, + generation_config=self.generation_config, + batch_size=args.per_device_eval_batch_size, + ) + # Compute initial win rate as a reference point + completions = list(zip(self.ref_completions, self.ref_completions, strict=True)) + if self.use_soft_judge: + ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True) + winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs] + ref_win_probs = gather_object(ref_win_probs) + else: + winner_indices = self.judge.judge(prompts, completions, self.shuffle_order) + prompts = gather_object(prompts) + completions = gather_object(completions) + winner_indices = gather_object(winner_indices) + + # Logging + if self.trainer.accelerator.is_main_process: + win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices) + if self.use_soft_judge: + avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs) + self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate}) + else: + self.trainer.log({"eval_win_rate": win_rate}) + + if "wandb" in args.report_to: + if wandb.run is not None: + df = _win_rate_completions_df( + state=state, + prompts=prompts, + completions=completions, + winner_indices=winner_indices, + ) + wandb.log({"win_rate_completions": wandb.Table(dataframe=df)}) + + if "comet_ml" in args.report_to: + df = _win_rate_completions_df( + state=state, + prompts=prompts, + completions=completions, + winner_indices=winner_indices, + ) + log_table_to_comet_experiment( + name="win_rate_completions.csv", + table=df, + ) + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + # At every evaluation step, we generate completions for the model and compare them with the reference + # completions that have been generated at the beginning of training. We then compute the win rate and log it to + # the trainer. + tokenizer = kwargs["processing_class"] + tokenizer.padding_side = "left" + accelerator = self.trainer.accelerator + model = self.trainer.model_wrapped + with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts: + completions = _generate_completions( + prompts, + model=model, + tokenizer=tokenizer, + accelerator=accelerator, + generation_config=self.generation_config, + batch_size=args.per_device_eval_batch_size, + ) + + completions = list(zip(self.ref_completions, completions, strict=True)) + + if self.use_soft_judge: + ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True) + winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs] + ref_win_probs = gather_object(ref_win_probs) + else: + winner_indices = self.judge.judge(prompts, completions, self.shuffle_order) + prompts = gather_object(prompts) + completions = gather_object(completions) + winner_indices = gather_object(winner_indices) + + # Logging + if self.trainer.accelerator.is_main_process: + win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices) + if self.use_soft_judge: + avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs) + self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate}) + else: + self.trainer.log({"eval_win_rate": win_rate}) + + if "wandb" in args.report_to: + if wandb.run is not None: + df = _win_rate_completions_df( + state=state, + prompts=prompts, + completions=completions, + winner_indices=winner_indices, + ) + wandb.log({"win_rate_completions": wandb.Table(dataframe=df)}) + + if "comet_ml" in args.report_to: + df = _win_rate_completions_df( + state=state, + prompts=prompts, + completions=completions, + winner_indices=winner_indices, + ) + log_table_to_comet_experiment( + name="win_rate_completions.csv", + table=df, + ) diff --git a/ICL/RL/trl_source/trl/experimental/xpo/__init__.py b/ICL/RL/trl_source/trl/experimental/xpo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc35384cd4246001ac14d1bd126430cce00b122 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/xpo/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .xpo_config import XPOConfig +from .xpo_trainer import XPOTrainer + + +__all__ = ["XPOConfig", "XPOTrainer"] diff --git a/ICL/RL/trl_source/trl/experimental/xpo/xpo_config.py b/ICL/RL/trl_source/trl/experimental/xpo/xpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..243545084617bdc8f384289cb76883d22fcfaf54 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/xpo/xpo_config.py @@ -0,0 +1,44 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ..online_dpo import OnlineDPOConfig + + +@dataclass +class XPOConfig(OnlineDPOConfig): + r""" + Configuration class for the [`experimental.xpo.XPOTrainer`]. + + Subclass of [`experimental.online_dpo.OnlineDPOConfig`] we can use all its arguments and add the following: + + Parameters: + alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`): + Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch + and the last alpha is used for the rest of the epochs. + """ + + alpha: list[float] = field( + default_factory=lambda: [1e-5], + metadata={ + "help": "Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each " + "new epoch and the last alpha is used for the rest of the epochs." + }, + ) + + def __post_init__(self): + super().__post_init__() + if hasattr(self.alpha, "__len__") and len(self.alpha) == 1: + self.alpha = self.alpha[0] diff --git a/ICL/RL/trl_source/trl/experimental/xpo/xpo_trainer.py b/ICL/RL/trl_source/trl/experimental/xpo/xpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..04e184b804b9df29b8560d69fa16a172580f637a --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/xpo/xpo_trainer.py @@ -0,0 +1,545 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap +from collections.abc import Callable +from typing import Any + +import jinja2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import Dataset, IterableDataset +from transformers import ( + BaseImageProcessor, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.trainer_utils import EvalPrediction +from transformers.training_args import OptimizerNames +from transformers.utils import is_peft_available + +from ...data_utils import is_conversational, maybe_apply_chat_template +from ...models.utils import unwrap_model_for_generation +from ...trainer.utils import empty_cache, selective_log_softmax +from ..judges import BasePairwiseJudge +from ..online_dpo import OnlineDPOTrainer +from ..utils import SIMPLE_CHAT_TEMPLATE, get_reward, truncate_right +from .xpo_config import XPOConfig + + +if is_peft_available(): + from peft import PeftModel + + +class XPOTrainer(OnlineDPOTrainer): + """ + Trainer for Exploratory Preference Optimization (XPO). + + It is implemented as a subclass of [`experimental.online_dpo.OnlineDPOTrainer`]. + + Args: + model ([`~transformers.PreTrainedModel`]): + The model to train, preferably an `AutoModelForCausalLM`. + ref_model ([`~transformers.PreTrainedModel`]): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation + and loss. If no reference model is provided, the trainer will create a reference model with the same + architecture as the model to be optimized. + reward_funcs ([`~transformers.PreTrainedModel`]): + The reward model to score completions with, preferably an + [`~transformers.AutoModelForSequenceClassification`]. + judge ([`experimental.judges.BasePairwiseJudge`]): + The judge to use for pairwise comparison of model completions. + args ([`experimental.xpo.XPOConfig`]): + The XPO config arguments to use for training. + data_collator ([`~transformers.DataCollator`]): + The data collator to use for training. If None is specified, the default data collator + ([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the + maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset ([`~datasets.Dataset`]): + The dataset to use for training. + eval_dataset ([`~datasets.Dataset`]): + The dataset to use for evaluation. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + peft_config (`dict`): + The peft config to use for training. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to + metric values. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + """ + + _tag_names = ["trl", "xpo"] + _name = "XPO" + _paper = { + "title": "Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF", + "id": "2405.21046", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @article{jung2024binary, + title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}}, + author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin}, + year = 2024, + eprint = {arXiv:2405.21046} + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module = None, + ref_model: PreTrainedModel | nn.Module = None, + reward_funcs: nn.Module | None = None, + judge: BasePairwiseJudge | None = None, + args: XPOConfig | None = None, + data_collator: Callable | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, + peft_config: dict | None = None, + compute_metrics: Callable[[EvalPrediction], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + ) -> None: + super().__init__( + model=model, + ref_model=ref_model, + judge=judge, + reward_funcs=reward_funcs, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + reward_processing_classes=reward_processing_classes, + peft_config=peft_config, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + self._alpha = self.args.alpha + + # Overwrite the stats dictionary to include XPO specific statistics + self.stats = { + # Remove "non_score_reward", "rlhf_reward", "scores" + # Add "loss/dpo", "loss/xpo" + "loss/dpo": [], + "loss/xpo": [], + "objective/kl": [], + "objective/entropy": [], + "rewards/chosen": [], + "rewards/rejected": [], + "rewards/accuracies": [], + "rewards/margins": [], + "logps/chosen": [], + "logps/rejected": [], + # Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token" + "val/model_contain_eos_token": [], + "val/ref_contain_eos_token": [], + "alpha": [], + "beta": [], + } + if self.reward_funcs is not None: + if len(self.reward_funcs) != 1: + raise ValueError("XPOTrainer only supports one reward function/model.") + self.reward_funcs = self.reward_funcs[0] + self.stats["objective/model_scores"] = [] + self.stats["objective/ref_scores"] = [] + self.stats["objective/scores_margin"] = [] + + @property + def alpha(self): + if isinstance(self._alpha, list): + epoch = self.state.epoch + return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1] + else: + return self._alpha + + def _generate_completions(self, prompts, model): + with ( + unwrap_model_for_generation( + model, + self.accelerator, + generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 + ) as unwrapped_policy_model_for_gen, + ): + model_output = unwrapped_policy_model_for_gen.generate( + input_ids=prompts["input_ids"], + attention_mask=prompts["attention_mask"], + generation_config=self.generation_config, + ) + + actual_model_for_ref_generation: torch.nn.Module + if self.ref_model is None: + unwrapped_main_model_for_ref_logic = self.accelerator.unwrap_model(model) + + if is_peft_available() and isinstance(unwrapped_main_model_for_ref_logic, PeftModel): + actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic.get_base_model() + else: + actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic + else: + actual_model_for_ref_generation = self.accelerator.unwrap_model(self.ref_model) + + with ( + unwrap_model_for_generation( + actual_model_for_ref_generation, + self.accelerator, + generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 + ) as final_ref_model_for_gen, + ): + ref_output = final_ref_model_for_gen.generate( + input_ids=prompts["input_ids"], + attention_mask=prompts["attention_mask"], + generation_config=self.generation_config, + ) + + return model_output, ref_output + + def _process_completions(self, model_output, ref_output, prompts): + context_length = prompts["input_ids"].shape[1] + + # Process model completions + model_completion_ids = model_output[:, context_length:] + model_completion_ids, model_completion_mask = truncate_right( + model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id + ) + model_data = { + "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1), + "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1), + "raw": prompts["raw"], + } + + # Process reference model completions + ref_completion_ids = ref_output[:, context_length:] + ref_completion_ids, ref_completion_mask = truncate_right( + ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id + ) + ref_data = { + "input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1), + "attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1), + "raw": prompts["raw"], + } + + return model_data, ref_data + + def _compute_rewards(self, model_data, ref_data, context_length): + with torch.no_grad(): + _, model_scores, _ = get_reward( + self.reward_funcs, model_data["input_ids"], self.processing_class.pad_token_id, context_length + ) + _, ref_scores, _ = get_reward( + self.reward_funcs, ref_data["input_ids"], self.processing_class.pad_token_id, context_length + ) + + # Apply EOS penalty if needed + if self.args.missing_eos_penalty is not None: + model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) + ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) + model_scores[~model_contain_eos] -= self.args.missing_eos_penalty + ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty + + return model_scores, ref_scores + + def _compute_judge(self, model_data, ref_data, context_length): + prompts = model_data["raw"] + model_data_completions = self.processing_class.batch_decode( + model_data["input_ids"][:, context_length:], skip_special_tokens=True + ) + model_data_completions = [completion.strip() for completion in model_data_completions] + + ref_data_completions = self.processing_class.batch_decode( + ref_data["input_ids"][:, context_length:], skip_special_tokens=True + ) + ref_data_completions = [completion.strip() for completion in ref_data_completions] + + if is_conversational({"prompt": prompts[0]}): + model_data_completions = [ + [{"role": "assistant", "content": completion}] for completion in model_data_completions + ] + environment = jinja2.Environment() + template = environment.from_string(SIMPLE_CHAT_TEMPLATE) + prompts = [template.render(messages=message) for message in prompts] + model_data_completions = [template.render(messages=completion) for completion in model_data_completions] + + ref_data_completions = [ + [{"role": "assistant", "content": completion}] for completion in ref_data_completions + ] + ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions] + + ranks_of_first_completion = self.judge.judge( + prompts, + list(zip(model_data_completions, ref_data_completions, strict=True)), + ) + # convert ranks to a True/False mask: + # when rank == 0, it means the first completion is the best + # when rank == 1, it means the second completion is the best + return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device) + + def _compute_logprobs(self, model, model_data, ref_data, context_length): + def compute_logprobs_for_data(m, data): + output = m(data["input_ids"], attention_mask=data["attention_mask"]) + logits = output.logits[:, context_length - 1 : -1] + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) + return token_logprobs + + # Compute logprobs for model completions + model_logprobs_model_data = compute_logprobs_for_data(model, model_data) + # Compute logprobs for model on reference completions (for XPO loss) + model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data) + + # Compute logprobs for reference model completions + with torch.no_grad(): + if self.ref_model is None: + with model.disable_adapter(): + ref_logprobs_model_data = compute_logprobs_for_data(model, model_data) + ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data) + else: + ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data) + ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data) + + # Mask padding tokens + model_padding_mask = model_data["attention_mask"][:, context_length:] == 0 + ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0 + model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0) + model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0) + ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0) + ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0) + + return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data + + def _compute_losses( + self, + model_logprobs_model_data, + model_logprobs_ref_data, + ref_logprobs_ref_data, + ref_logprobs_model_data, + chosen_mask, + ): + # Compute log probs + model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) + model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1) + ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1) + ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) + + chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum) + chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum) + chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs + + rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum) + rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum) + rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs + + # Compute logits as the difference between chosen and rejected log ratios + logits = chosen_log_ratios - rejected_log_ratios + + if self.args.loss_type == "sigmoid": + dpo_losses = -F.logsigmoid(self.beta * logits) + elif self.args.loss_type == "ipo": + dpo_losses = (logits - 1 / (2 * self.beta)) ** 2 + else: + raise NotImplementedError(f"invalid loss type {self.args.loss_type}") + + # Compute XPO specific loss + xpo_losses = self.alpha * model_logprobs_ref_data_sum + + # Total loss + loss = (dpo_losses + xpo_losses).mean() + + return loss, dpo_losses, xpo_losses + + def _log_statistics( + self, + model_data, + ref_data, + model_logprobs_model_data, + model_logprobs_ref_data, + ref_logprobs_ref_data, + ref_logprobs_model_data, + chosen_mask, + dpo_losses, + xpo_losses, + context_length, + model_scores=None, + ref_scores=None, + ): + # Helper function to gather and compute mean + def gather_mean(tensor): + return self.accelerator.gather_for_metrics(tensor).mean().item() + + # Log losses + self.stats["loss/dpo"].append(gather_mean(dpo_losses)) + self.stats["loss/xpo"].append(gather_mean(xpo_losses)) + + # Log scores + if self.reward_funcs is not None: + self.stats["objective/model_scores"].append(gather_mean(model_scores)) + self.stats["objective/ref_scores"].append(gather_mean(ref_scores)) + self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores)) + + # Log logprobs + model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) + model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1) + ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1) + ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) + + chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum) + chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum) + chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs + + rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum) + rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum) + rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs + + self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean())) + self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean())) + + # Log rewards + # Compute various statistics + chosen_rewards = chosen_log_ratios * self.beta + rejected_rewards = rejected_log_ratios * self.beta + self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean())) + self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean())) + + # Calculate KL divergence for model and ref data + kl_model_data = model_logprobs_model_data - ref_logprobs_model_data + kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data + mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2 + self.stats["objective/kl"].append(gather_mean(mean_kl)) + + # Calculate entropy for model and ref data + entropy_model_data = -model_logprobs_model_data.sum(1) + entropy_ref_data = -model_logprobs_ref_data.sum(1) + mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2 + self.stats["objective/entropy"].append(gather_mean(mean_entropy)) + + # Calculate margins + margin = chosen_rewards - rejected_rewards + self.stats["rewards/margins"].append(gather_mean(margin.mean())) + + # Calculate accuracy + accuracy = (margin > 0).float() + self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean())) + + # Log EOS token statistics + model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) + ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) + self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float())) + self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float())) + + # Log alpha and beta + self.stats["alpha"].append(self.alpha) + self.stats["beta"].append(self.beta) + + def training_step( + self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None + ) -> torch.Tensor: + model.train() + + # Apply chat template and tokenize the input + batch_size = len(next(iter(inputs.values()))) + prompts = inputs["prompt"] + inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)] + inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs] + inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs] + inputs = self.data_collator(inputs) + + # need the prompt_ only + inputs = self._prepare_inputs(inputs) + context_length = inputs["prompt_input_ids"].shape[1] + prompts = { + "input_ids": inputs["prompt_input_ids"], + "attention_mask": inputs["prompt_attention_mask"], + "raw": prompts, + } + del inputs + + # Sample completions from both the model and the reference model + model_output, ref_output = self._generate_completions(prompts, model) + + # Process model completions + model_data, ref_data = self._process_completions(model_output, ref_output, prompts) + + # Compute rewards + if self.reward_funcs is not None: + model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length) + chosen_mask = model_scores >= ref_scores + else: + model_scores, ref_scores = None, None + chosen_mask = self._compute_judge(model_data, ref_data, context_length) + + # Compute logprobs + model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = ( + self._compute_logprobs(model, model_data, ref_data, context_length) + ) + + # Compute loss + loss, dpo_losses, xpo_losses = self._compute_losses( + model_logprobs_model_data, + model_logprobs_ref_data, + ref_logprobs_ref_data, + ref_logprobs_model_data, + chosen_mask, + ) + + # Log everything + self._log_statistics( + model_data, + ref_data, + model_logprobs_model_data.detach(), + model_logprobs_ref_data.detach(), + ref_logprobs_ref_data, + ref_logprobs_model_data, + chosen_mask, + dpo_losses.detach(), + xpo_losses.detach(), + context_length, + model_scores, + ref_scores, + ) + + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + empty_cache() + + kwargs = {} + # For LOMO optimizers you need to explicitly use the learning rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + self.accelerator.backward(loss, **kwargs) + + return loss.detach() / self.args.gradient_accumulation_steps diff --git a/ICL/RL/trl_source/trl/extras/__init__.py b/ICL/RL/trl_source/trl/extras/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2777dd0eb21a0ea67dce8775337a27ea8499da2 --- /dev/null +++ b/ICL/RL/trl_source/trl/extras/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ICL/RL/trl_source/trl/extras/__pycache__/__init__.cpython-313.pyc b/ICL/RL/trl_source/trl/extras/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..deaa9ab9d9f36f3a5cae6cdd99b033af05b4daaf Binary files /dev/null and b/ICL/RL/trl_source/trl/extras/__pycache__/__init__.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/extras/dataset_formatting.py b/ICL/RL/trl_source/trl/extras/dataset_formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..6dd89493a33263bda1506f67f661947ac69a3690 --- /dev/null +++ b/ICL/RL/trl_source/trl/extras/dataset_formatting.py @@ -0,0 +1,32 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datasets +from datasets import Value +from packaging.version import Version + + +if Version(datasets.__version__) >= Version("4.0.0"): + from datasets import List + + FORMAT_MAPPING = { + "chatml": List({"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}), + "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, + } +else: + FORMAT_MAPPING = { + "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], + "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, + } diff --git a/ICL/RL/trl_source/trl/extras/profiling.py b/ICL/RL/trl_source/trl/extras/profiling.py new file mode 100644 index 0000000000000000000000000000000000000000..e051935c86da8babde4d3ed15c827ecc80665996 --- /dev/null +++ b/ICL/RL/trl_source/trl/extras/profiling.py @@ -0,0 +1,217 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import time +from collections.abc import Callable + +from transformers import Trainer +from transformers.integrations import is_mlflow_available, is_wandb_available + + +if is_wandb_available(): + import wandb + +if is_mlflow_available(): + import mlflow + + +class ProfilingContext: + """ + Context manager for profiling code blocks with configurable logging. + + This class handles timing of code execution and logging metrics to various backends (Weights & Biases, MLflow) + without being coupled to the Trainer class. + + Args: + name (`str`): + Name of the profiling context. Used in the metric name. + report_to (`list` of `str`): + List of integrations to report metrics to (e.g., ["wandb", "mlflow"]). + is_main_process (`bool`, *optional*, defaults to `True`): + Whether this is the main process in distributed training. Metrics are only logged from the main process. + step (`int` or `None`, *optional*): + Training step to associate with the logged metrics. + metric_prefix (`str`, *optional*, defaults to `"profiling/Time taken"`): + Prefix for the metric name in logs. + + Example: + ```python + # Direct usage + from trl.extras.profiling import ProfilingContext + + with ProfilingContext( + name="MyClass.expensive_operation", + report_to=["wandb"], + is_main_process=True, + step=100, + ): + # Code to profile + result = expensive_computation() + + # With Trainer (backwards compatible via profiling_context function) + from transformers import Trainer + from trl.extras.profiling import profiling_context + + + class MyTrainer(Trainer): + def some_method(self): + with profiling_context(self, "matrix_multiplication"): + result = matrix_multiply() + ``` + """ + + def __init__( + self, + name: str, + report_to: list[str], + is_main_process: bool = True, + step: int | None = None, + metric_prefix: str = "profiling/Time taken", + ): + self.name = name + self.report_to = report_to + self.is_main_process = is_main_process + self.step = step + self.metric_prefix = metric_prefix + self._start_time = None + + def __enter__(self): + """Start timing when entering the context.""" + self._start_time = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop timing and log metrics when exiting the context.""" + if self._start_time is not None: + duration = time.perf_counter() - self._start_time + self._log_metrics(duration) + return False + + def _log_metrics(self, duration: float) -> None: + """ + Log profiling metrics to configured backends. + + Args: + duration (`float`): + Execution time in seconds. + """ + if not self.is_main_process: + return + + metric_name = f"{self.metric_prefix}: {self.name}" + metrics = {metric_name: duration} + + # Log to Weights & Biases if configured + if "wandb" in self.report_to and is_wandb_available() and wandb.run is not None: + wandb.log(metrics) + + # Log to MLflow if configured + if "mlflow" in self.report_to and is_mlflow_available() and mlflow.active_run() is not None: + mlflow.log_metrics(metrics, step=self.step) + + +def profiling_context(trainer: Trainer, name: str) -> ProfilingContext: + """ + Factory function to create a ProfilingContext from a Trainer instance. + + This function maintains backwards compatibility with existing code while using the decoupled ProfilingContext class + internally. + + Args: + trainer (`~transformers.Trainer`): + Trainer object containing configuration for logging. + name (`str`): + Name of the block to be profiled. Will be prefixed with the trainer class name. + + Returns: + `ProfilingContext`: A configured profiling context manager. + + Example: + ```python + from transformers import Trainer + from trl.extras.profiling import profiling_context + + + class MyTrainer(Trainer): + def some_method(self): + A = np.random.rand(1000, 1000) + B = np.random.rand(1000, 1000) + with profiling_context(self, "matrix_multiplication"): + # Code to profile: simulate a computationally expensive operation + result = A @ B # Matrix multiplication + ``` + """ + context_name = f"{trainer.__class__.__name__}.{name}" + step = trainer.state.global_step + + return ProfilingContext( + name=context_name, + report_to=trainer.args.report_to, + is_main_process=trainer.accelerator.is_main_process, + step=step, + ) + + +def profiling_decorator(func: Callable) -> Callable: + """ + Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`]. + + This decorator works with methods that have access to a trainer instance (typically as `self`). For non-Trainer + objects that have an `accelerator` attribute, it will use that for logging configuration. + + Args: + func (`Callable`): + Function to be profiled. + + Returns: + `Callable`: Wrapped function that profiles execution time. + + Example: + ```python + from transformers import Trainer + from trl.extras.profiling import profiling_decorator + + + class MyTrainer(Trainer): + @profiling_decorator + def some_method(self): + A = np.random.rand(1000, 1000) + B = np.random.rand(1000, 1000) + # Code to profile: simulate a computationally expensive operation + result = A @ B + ``` + """ + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + # Check if self is a Trainer-like object with required attributes + if hasattr(self, "state") and hasattr(self, "args"): + with profiling_context(self, func.__name__): + return func(self, *args, **kwargs) + # For non-Trainer objects (e.g., VLLMGeneration), use ProfilingContext directly + elif hasattr(self, "accelerator"): + context_name = f"{self.__class__.__name__}.{func.__name__}" + with ProfilingContext( + name=context_name, + report_to=[], # No reporting for non-Trainer objects without args + is_main_process=self.accelerator.is_main_process, + step=None, + ): + return func(self, *args, **kwargs) + else: + # No profiling available, just run the function + return func(self, *args, **kwargs) + + return wrapper diff --git a/ICL/RL/trl_source/trl/generation/__init__.py b/ICL/RL/trl_source/trl/generation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22e7cf6d88463c5b6774725e2b66ee1496bddd22 --- /dev/null +++ b/ICL/RL/trl_source/trl/generation/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generation backends for TRL trainers.""" + +from ..import_utils import is_vllm_available + + +__all__ = [] + +if is_vllm_available(): + from .vllm_generation import VLLMGeneration + + __all__.append("VLLMGeneration") diff --git a/ICL/RL/trl_source/trl/generation/vllm_client.py b/ICL/RL/trl_source/trl/generation/vllm_client.py new file mode 100644 index 0000000000000000000000000000000000000000..1919345eaf7e151af3f802326a79d7403a136100 --- /dev/null +++ b/ICL/RL/trl_source/trl/generation/vllm_client.py @@ -0,0 +1,628 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import base64 +import copy +import logging +import socket +import time +from io import BytesIO +from urllib.parse import urlparse + +import torch +import torch.distributed.distributed_c10d as c10d +from requests.adapters import HTTPAdapter +from torch import nn +from transformers import is_torch_xpu_available +from urllib3.util.retry import Retry + +from ..import_utils import is_requests_available, is_vllm_ascend_available, is_vllm_available + + +if is_requests_available(): + import requests + from requests import ConnectionError + + +if is_vllm_available(): + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + if is_vllm_ascend_available(): + from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator + + +logger = logging.getLogger(__name__) + + +def pil_to_base64(image): + buffer = BytesIO() + image.save(buffer, format="PNG") + img_bytes = buffer.getvalue() + return base64.b64encode(img_bytes).decode("utf-8") + + +class VLLMClient: + """ + A client class to interact with a vLLM server. + + This class provides methods to generate completions, initialize and manage weight update groups, and update model + weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`. + + Args: + base_url (`str`, *optional*): + Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `host` and `server_port` are + ignored. + host (`str`, *optional*, defaults to `"0.0.0.0"`): + IP address of the vLLM server. Ignored if `base_url` is provided. + server_port (`int`, *optional*, defaults to `8000`): + Port number of the vLLM server. Ignored if `base_url` is provided. + group_port (`int`, *optional*, defaults to `51216`): + Port number for the weight update group. + connection_timeout (`float`, *optional*, defaults to `0.0`): + Total timeout duration in seconds to wait for the server to be up. If the server is not up after the + timeout, a `ConnectionError` is raised. + + Examples: + Run the vLLM server with the model `Qwen/Qwen2.5-7B`: + + ``` + $ trl vllm-serve --model Qwen/Qwen2.5-7B + ... + INFO: Application startup complete. + INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) + ``` + + Use the client to generate completions and update model weights: + + ```python + >>> from trl.generation.vllm_client import VLLMClient + + >>> client = VLLMClient() + >>> client.generate(["Hello, AI!", "Tell me a joke"]) + {'prompt_ids': [[9707, 11, 15235, 0], + [40451, 752, 264, 21646]], + 'completion_ids': [[11479, 752, 5046, 279, 1465, 304, 419, 23670, 2038, 358, 2776, 4378, 369, 847, 15549, 6733], + [911, 19654, 382, 3838, 1558, 279, 16158, 1977, 979, 498, 2299, 4460, 311, 10542, 432, 518]], + 'logprobs': [[-5.193126201629639, -0.05592319369316101, -4.861808776855469, -1.673396110534668, -2.6316866874694824, -0.2861405313014984, -0.35006725788116455, -5.23351526260376, -0.1447441577911377, -5.21489953994751, -1.6022650003433228, -1.9649192094802856, -2.1338791847229004, -1.2775304317474365, -10.004860877990723, -4.171003818511963], + [-0.012896230444312096, -5.747106552124023, -1.5248860120773315, -1.9286258220672607, -2.8512537479400635, -2.8055880069732666, -3.019822835922241, -0.37132859230041504, -0.6311739087104797, -2.562908411026001, -3.1664533615112305, -2.685293436050415, -0.007259538397192955, -7.339841842651367, -1.188662052154541, -3.54781436920166]]} + + >>> from transformers import AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda") + >>> client.init_communicator(device="cuda") + >>> client.update_model_params(model) + ``` + + There are several ways to initialize the client: + + ```python + VLLMClient(base_url="http://localhost:8000") + VLLMClient(base_url="http://192.168.1.100:8000") + VLLMClient(host="localhost", server_port=8000) + VLLMClient(host="192.168.1.100", server_port=8000) + ``` + """ + + def __init__( + self, + base_url: str | None = None, + host: str = "0.0.0.0", + server_port: int = 8000, + group_port: int = 51216, + connection_timeout: float = 0.0, + ): + if not is_requests_available(): + raise ImportError("requests is not installed. Please install it with `pip install requests`.") + if not is_vllm_available(): + raise ImportError("vLLM is not installed. Please install it with `pip install trl[vllm]`.") + + self.session = requests.Session() + + # Configure retries for HTTP requests made through this session. + # This is not strictly required for correctness, but it helps make training more robust to rare, transient + # failures (network hiccups, temporary 5xx errors, overloaded servers). Without this, such failures could cause + # an otherwise healthy training run to fail. + retry_strategy = Retry( + total=5, # global cap on the total number of retries across all failure types + connect=5, # retry connection-level failures (DNS issues, refused connections, etc) + read=5, # retry failures while reading the response after the connection was successfully established + status=3, # retry a limited number of times when we receive certain HTTP error responses from the server + status_forcelist=[500, 502, 503], # only retry on server-side errors that are usually temporary + backoff_factor=2, # exponential backoff between retries (2s, 4s, 8s, ...) + allowed_methods=["POST", "GET"], # allow POST as well, even though we're not sure it's safe here + ) + + adapter = HTTPAdapter(max_retries=retry_strategy) + self.session.mount("http://", adapter) + self.session.mount("https://", adapter) + + if base_url is not None: + # Parse the base_url to extract host and port + parsed_url = urlparse(base_url) + self.host = socket.gethostbyname(parsed_url.hostname) + scheme = parsed_url.scheme or "http" + self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}" + else: + self.host = host + self.server_port = server_port + self.base_url = f"http://{self.host}:{self.server_port}" + self.group_port = group_port + self.check_server(connection_timeout) # check server and fail after timeout + + def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): + """ + Check server availability with retries on failure, within a total timeout duration. If the server is not up + after the total timeout duration, raise a `ConnectionError`. + + Args: + retry_interval (`float`, *optional*, defaults to `2.0`): + Interval in seconds between retries. + total_timeout (`float`, *optional*, defaults to `0.0`): + Total timeout duration in seconds. + """ + url = f"{self.base_url}/health/" + start_time = time.time() # Record the start time + + while True: + try: + response = requests.get(url) + except requests.exceptions.RequestException as exc: + # Check if the total timeout duration has passed + elapsed_time = time.time() - start_time + if elapsed_time >= total_timeout: + raise ConnectionError( + f"The vLLM server can't be reached at {self.base_url} after {total_timeout} seconds. Make " + "sure the server is running by running `trl vllm-serve`." + ) from exc + else: + if response.status_code == 200: + if "X-Forwarded-For" in response.headers: + self.host = response.headers["X-Forwarded-For"] + logger.info("Server is up!") + return None + + # Retry logic: wait before trying again + logger.info(f"Server is not up yet. Retrying in {retry_interval} seconds...") + time.sleep(retry_interval) + + def generate( + self, + prompts: list[str], + images: list | None = None, + n: int = 1, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 0, + min_p: float = 0.0, + max_tokens: int = 16, + truncate_prompt_tokens: int | None = None, + structured_outputs_regex: str | None = None, + generation_kwargs: dict | None = None, + ) -> dict[str, list[list[int]]]: + """ + Generates model completions for the provided prompts. + + Args: + prompts (`list[str]`): + List of text prompts for which the model will generate completions. + images (`list[PIL.Image]`, *optional*): + List of PIL Images to send along with the prompts. + n (`int`, *optional*, defaults to `1`): + Number of completions to generate for each prompt. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Parameter for repetition penalty. 1.0 means no penalty. + temperature (`float`, *optional*, defaults to `1.0`): + Temperature parameter for sampling. Higher values increase diversity. + top_p (`float`, *optional*, defaults to `1.0`): + Top-p sampling parameter.`1.0` means no truncation. + top_k (`int`, *optional*, defaults to `0`): + Top-k sampling parameter. `0` means no truncation. + min_p (`float`, *optional*, defaults to `0.0`): + Minimum probability for sampling. + max_tokens (`int`, *optional*, defaults to `16`): + Maximum number of tokens to generate for each prompt. + truncate_prompt_tokens (`int`, *optional*): + If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use + only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is + disabled. + structured_outputs_regex (`str`, *optional*): + Regular expression to guide the decoding process. + generation_kwargs (`dict`, *optional*): + Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like + `seed`, `frequency_penalty`, etc. If it contains keys that conflict with the other parameters, they + will override them. + + Returns: + `dict` with keys: + - `prompt_ids` (`list[list[int]]`): + List of lists of token IDs representing the tokenized input prompts. + - `completion_ids` (`list[list[int]]`): + List of lists of token IDs representing the model-generated completions for each prompt. + - `logprobs` (`list[list[float]]`): + List of lists of log probabilities for each generated token. + """ + url = f"{self.base_url}/generate/" + + # Convert PIL images to base64 strings + images = [pil_to_base64(img) for img in images] if images else None + + response = self.session.post( + url, + json={ + "prompts": prompts, + "images": images, + "n": n, + "repetition_penalty": repetition_penalty, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + "max_tokens": max_tokens, + "truncate_prompt_tokens": truncate_prompt_tokens, + "structured_outputs_regex": structured_outputs_regex, + "generation_kwargs": generation_kwargs or {}, + }, + ) + if response.status_code == 200: + json_response = response.json() + return { + "prompt_ids": json_response["prompt_ids"], + "completion_ids": json_response["completion_ids"], + "logprobs": json_response["logprobs"], + } + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + def chat( + self, + messages: list[list[dict]], + n: int = 1, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 0, + min_p: float = 0.0, + max_tokens: int = 16, + truncate_prompt_tokens: int | None = None, + structured_outputs_regex: str | None = None, + generation_kwargs: dict | None = None, + chat_template_kwargs: dict | None = None, + tools: list | None = None, + chat_template: str | None = None, + ) -> dict[str, list[list[int]]]: + """ + Generates model completions for the provided chat messages. + + Args: + messages (`list[list[dict]]`): + List of message lists for which the model will generate completions. Each message is a dictionary with + keys like "role" and "content". + n (`int`, *optional*, defaults to `1`): + Number of completions to generate for each message list. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Parameter for repetition penalty. 1.0 means no penalty. + temperature (`float`, *optional*, defaults to `1.0`): + Temperature parameter for sampling. Higher values increase diversity. + top_p (`float`, *optional*, defaults to `1.0`): + Top-p sampling parameter.`1.0` means no truncation. + top_k (`int`, *optional*, defaults to `0`): + Top-k sampling parameter. `0` means no truncation. + min_p (`float`, *optional*, defaults to `0.0`): + Minimum probability for sampling. + max_tokens (`int`, *optional*, defaults to `16`): + Maximum number of tokens to generate for each message list. + truncate_prompt_tokens (`int`, *optional*): + If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use + only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is + disabled. + structured_outputs_regex (`str`, *optional*): + Regular expression to guide the decoding process. + generation_kwargs (`dict`, *optional*): + Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like + `seed`, `frequency_penalty`, etc. If it contains keys that conflict with the other parameters, they + will override them. + chat_template_kwargs (`dict`, *optional*): + Additional keyword arguments to customize the chat template used by the model. + tools (`list`, *optional*): + List of tool functions available for tool calling during chat generation. + chat_template (`str`, *optional*): + Template to use for structuring the chat. If not provided, the model's default chat template will be + used. + + Returns: + `dict` with keys: + - `prompt_ids` (`list[list[int]]`): + List of lists of token IDs representing the tokenized input messages. + - `completion_ids` (`list[list[int]]`): + List of lists of token IDs representing the model-generated completions for each message list. + - `logprobs` (`list[list[float]]`): + List of lists of log probabilities for each generated token. + """ + if tools is not None: + raise NotImplementedError("Tool calling is not yet implemented in VLLMClient.chat().") + if chat_template is not None: + raise NotImplementedError("Custom chat templates are not yet implemented in VLLMClient.chat().") + + url = f"{self.base_url}/chat/" + + # Convert PIL images to base64 strings + messages = copy.deepcopy(messages) # avoid modifying the original messages + for message_list in messages: + for message in message_list: + if isinstance(message["content"], list): + for part in message["content"]: + if part["type"] == "image_pil": + part["image_pil"] = pil_to_base64(part["image_pil"]) + + response = self.session.post( + url, + json={ + "messages": messages, + "n": n, + "repetition_penalty": repetition_penalty, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + "max_tokens": max_tokens, + "truncate_prompt_tokens": truncate_prompt_tokens, + "structured_outputs_regex": structured_outputs_regex, + "generation_kwargs": generation_kwargs or {}, + "chat_template_kwargs": chat_template_kwargs or {}, + }, + ) + if response.status_code == 200: + json_response = response.json() + return { + "prompt_ids": json_response["prompt_ids"], + "completion_ids": json_response["completion_ids"], + "logprobs": json_response["logprobs"], + } + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + def init_communicator(self, device: torch.device | str | int = 0): + """ + Initializes the weight update group in a distributed setup for model synchronization. + + Args: + device (`torch.device`, `str`, or `int`, *optional*, defaults to `0`): + Device of trainer main process. It's the device that will be used for the weights synchronization. Can + be a `torch.device` object, a string like `'cuda:0'`, or an integer device index. + """ + # Get the world size from the server + url = f"{self.base_url}/get_world_size/" + response = requests.get(url) + if response.status_code == 200: + vllm_world_size = response.json()["world_size"] + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + world_size = vllm_world_size + 1 # add the client to the world + self.rank = vllm_world_size # the client's rank is the last process + + # Initialize weight update group + url = f"{self.base_url}/init_communicator/" + # Will simplify it after torch xpu 2.9 support get uuid. + if is_torch_xpu_available(): + if hasattr(torch.xpu.get_device_properties(device), "uuid"): + client_device_uuid = str(torch.xpu.get_device_properties(device).uuid) + else: + client_device_uuid = "42" + else: + client_device_uuid = str(torch.cuda.get_device_properties(device).uuid) + + # Set the weight update group's host to "0.0.0.0" so that + # clients from different IPs can send updated weights + response = self.session.post( + url, + json={ + "host": "0.0.0.0", + "port": self.group_port, + "world_size": world_size, + "client_device_uuid": client_device_uuid, + }, + ) + if response.status_code != 200: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + # Brief delay to allow server initialization. While not strictly required (client socket will retry on + # connection failure), this prevents log warnings like: + # [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3 + time.sleep(0.1) + + # Set up the communication group for weight broadcasting + if is_torch_xpu_available(): + store = torch.distributed.TCPStore( + host_name=self.host, port=self.group_port, world_size=world_size, is_master=(self.rank == 0) + ) + prefixed_store = c10d.PrefixStore("client2server", store) + xccl_options = c10d.ProcessGroupXCCL.Options() + pg = c10d.ProcessGroupXCCL( + store=prefixed_store, + rank=self.rank, + size=world_size, + options=xccl_options, + ) + self.communicator = pg + else: + pg = StatelessProcessGroup.create( + host=self.host, port=self.group_port, rank=self.rank, world_size=world_size + ) + self.communicator = PyNcclCommunicator(pg, device=device) + + # When the client object is deleted, close the weight update group + atexit.register(self.close_communicator) + + def update_named_param(self, name: str, weights: torch.Tensor): + """ + Updates a specific named parameter in the model and broadcasts it to other processes. + + Args: + name (`str`): + Name of the layer whose weights are being updated. + weights (`torch.Tensor`): + Tensor containing the updated weights. + """ + dtype, shape = str(weights.dtype), tuple(weights.shape) + url = f"{self.base_url}/update_named_param/" + response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape}) + if response.status_code != 200: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + if is_torch_xpu_available(): + # Use XCCL to broadcast the updated weights from the client (src) to all workers. + self.communicator.broadcast(weights, root=self.rank) + self.communicator.barrier() + else: + # Use NCCL to broadcast the updated weights from the client (src) to all workers. + self.communicator.broadcast(weights, src=self.rank) + self.communicator.group.barrier() + + def update_model_params(self, model: nn.Module): + """ + Updates all parameters of the given model by calling `update_named_param` for each parameter in the model. + + Args: + model (`nn.Module`): + Model whose parameters (weights/biases) are to be updated. + """ + for name, param in model.named_parameters(): + # Update each parameter individually + self.update_named_param(name, param.data) + + def reset_prefix_cache(self): + """ + Resets the prefix cache for the model. + """ + url = f"{self.base_url}/reset_prefix_cache/" + response = self.session.post(url) + if response.status_code != 200: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + def chat_completions( + self, + messages: list[dict], + model: str | None = None, + temperature: float = 1.0, + top_p: float = 1.0, + max_tokens: int | None = None, + n: int = 1, + tools: list[dict] | None = None, + **kwargs, + ) -> dict: + """ + OpenAI-compatible chat completions endpoint. + + Args: + messages (`list[dict]`): + List of messages in OpenAI format with "role" and "content" keys. + model (`str`, *optional*): + Model name to use. + temperature (`float`, *optional*, defaults to `1.0`): + Temperature for sampling. + top_p (`float`, *optional*, defaults to `1.0`): + Top-p sampling parameter. + max_tokens (`int`, *optional*): + Maximum number of tokens to generate. + n (`int`, *optional*, defaults to `1`): + Number of completions to generate. + tools (`list[dict]`, *optional*): + List of tool definitions for tool calling. + **kwargs: + Additional parameters to pass to the endpoint. + + Returns: + `dict`: + OpenAI-compatible response with "choices", "usage", etc. + """ + url = f"{self.base_url}/v1/chat/completions" + response = self.session.post( + url, + json={ + "messages": messages, + "model": model, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "n": n, + "tools": tools, + **kwargs, + }, + ) + if response.status_code == 200: + return response.json() + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + def tokenize(self, messages: list[dict], tools: list[dict] | None = None) -> dict: + """ + Tokenize messages to get token IDs. + + Args: + messages (`list[dict]`): + List of messages to tokenize. + tools (`list[dict]`, *optional*): + List of tool definitions. + + Returns: + `dict`: + Dictionary with "tokens" (list of token IDs) and "model" keys. + """ + url = f"{self.base_url}/tokenize" + response = self.session.post(url, json={"messages": messages, "tools": tools}) + if response.status_code == 200: + return response.json() + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + def close_communicator(self): + """ + Closes the weight update group and cleans up the communication group. + """ + url = f"{self.base_url}/close_communicator/" + + try: + response = self.session.post(url) + except ConnectionError: + # The server might be already down, so we don't need to close the communicator + pass + else: + if response.status_code != 200: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + if self.communicator is not None: + self.communicator = None + + +# Example usage +if __name__ == "__main__": + from vllm import SamplingParams + + device = "xpu" if is_torch_xpu_available() else "cuda" + client = VLLMClient() + client.init_communicator(device=device) + + # Generate completions + responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32, sampling_params=SamplingParams()) + print("Responses:", responses) # noqa + + # Update model weights + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to(device) + client.update_model_params(model) diff --git a/ICL/RL/trl_source/trl/generation/vllm_generation.py b/ICL/RL/trl_source/trl/generation/vllm_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..3956a70aebfdf14a187e3268a5bee90cb34234ed --- /dev/null +++ b/ICL/RL/trl_source/trl/generation/vllm_generation.py @@ -0,0 +1,706 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""vLLM-based generation backend for TRL trainers.""" + +import json +import logging +import math +import os +from collections.abc import Callable +from contextlib import nullcontext +from typing import TYPE_CHECKING + +import torch +from accelerate.utils import broadcast_object_list, gather_object, is_peft_model +from packaging.version import Version +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, is_bitsandbytes_available + +from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages_vllm +from ..extras.profiling import ProfilingContext +from ..import_utils import is_vllm_available +from ..trainer.utils import ensure_master_addr_port +from .vllm_client import VLLMClient + + +logger = logging.getLogger(__name__) + + +def sanitize_logprob(logprob): + value = logprob.logprob + if math.isnan(value): + logger.warning(f"Generated NaN logprob, token logprob '{logprob}' will be ignored") + return None + + return value + + +if TYPE_CHECKING: + from accelerate import Accelerator + from peft import PeftModel + + +if is_vllm_available(): + import vllm + from vllm import LLM, SamplingParams + + if Version(vllm.__version__) <= Version("0.10.2"): + from vllm.sampling_params import GuidedDecodingParams + else: + from vllm.sampling_params import StructuredOutputsParams + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + + +class VLLMGeneration: + """Handles vLLM-based generation for trainers. + + Extracts all vLLM-specific logic (initialization, generation, weight sync) from trainers into a separate, testable + class. + + Args: + model ([`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): + Model to use for generation. + accelerator ([`~accelerate.Accelerator`]): + Accelerator for distributed training. + is_fsdp_enabled (`bool`): + Whether FSDP is enabled. + processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`]): + Tokenizer or processor for the model. + + > Parameters for vLLM: + + mode (`str`, *optional*, defaults to `"server"`): vLLM mode. Must be one of `"server"` or + `"colocate"`. + + - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM + server is running (start with `trl vllm-serve`). + - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a + separate server but may cause resource contention with training. + structured_outputs_regex (`str`, *optional*): + Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled. + + > Parameters for "server" vLLM mode: + + server_base_url (`str`, *optional*): + Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `server_host` and + `server_port` are ignored. + server_host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host of the vLLM server to connect to. Ignored if `server_base_url` is provided. + server_port (`int`, *optional*, defaults to `8000`): + Port of the vLLM server to connect to. Ignored if `server_base_url` is provided. + server_timeout (`float`, *optional*, defaults to `240.0`): + Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the + timeout, a `ConnectionError` is raised. + group_port (`int`, *optional*, defaults to `51216`): + Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port + is occupied, there is no need to change it. + + > Parameters for "colocate" vLLM mode: + + tensor_parallel_size (`int`, *optional*, defaults to `1`): + The number of GPUs to use for distributed execution with tensor parallelism. This setting only applies when + `mode` is set to `"colocate"`. If you are using `mode="server"`, this parameter must be passed separately + when launching the vLLM server via the `--vllm_tensor_parallel_size` flag. + gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): + Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's throughput. However, if the value is + too high, it may cause out-of- memory (OOM) errors. This setting only applies when `mode` is set to + `"colocate"`. If you are using `mode="server"`, this parameter must be passed separately when launching the + vLLM server via the `--vllm_gpu_memory_utilization` flag. + max_model_length (`int`, *optional*): + Model context length (prompt and completion). Set it to at least the maximum prompt length in the dataset + plus `max_completion_length`; if omitted, it is inferred from the model config. + max_num_seqs (`int`, *optional*): + Maximum number of sequences to process in parallel, effectively capping the batch size. + enable_sleep_mode (`bool`, *optional*, defaults to `False`): + Whether to enable sleep mode for the engine to offload weights/cache during the optimizer step. Keeps GPU + memory usage low, but waking the engine adds hostโ€“device transfer latency. + model_impl (`str`, *optional*, defaults to `"auto"`): + Model implementation to use for vLLM. + - "auto" will try to use the vLLM implementation, if it exists, and fall back to the Transformers + implementation if no vLLM implementation is available. + - "vllm" will use the vLLM model implementation. + - "transformers" will use the Transformers model implementation. + - "terratorch" will use the TerraTorch model implementation. + + > Parameters for generation: + + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Parameter for repetition penalty. It penalizes new tokens based on whether they appear in the prompt and + the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the + model to repeat tokens. Default `1.0` means no penalty. + temperature(`float`, *optional*, defaults to `1.0`): + Sampling temperature. It controls the randomness of the sampling. Lower values make the model more + deterministic, while higher values make the model more random and increase diversity. + top_p: (`float`, *optional*, defaults to `1.0`): + Top-p sampling parameter. It controls the cumulative probability of the top tokens to consider. Defaults to + `1.0` to consider all tokens. + top_k (`int`, *optional*, defaults to `0`): + Top-k sampling parameter. It controls the number of top tokens to consider. Defaults to `0` to consider all + tokens. + min_p (`float`, *optional*, defaults to `0.0`): + Min-p sampling parameter. It represents the minimum probability for a token to be considered, relative to + the probability of the most likely token. Default `0.0` means min-p is disabled. + max_completion_length (`int`, *optional*, defaults to `16`): + Maximum number of tokens to generate for each prompt. + generation_kwargs (`dict`, *optional*): + Additional generation parameters to pass to the vLLM `SamplingParams`. This can include parameters like + `seed`, `frequency_penalty`, etc. If it contains keys that conflict with the other parameters, they will + override them. + + > Parameters for chat/tools: + + chat_template (`str`, *optional*): + Template to use for structuring the chat. If not provided, the model's default chat template will be used. + chat_template_kwargs (`dict`, *optional*): + Additional keyword arguments to customize the chat template used by the model. + tools (`list`, *optional*): + Tools available for tool calling during chat generation. + rollout_func (`Callable`, *optional*): Optional custom rollout function that accepts prompts and returns + a dict with 'prompt_ids', 'completion_ids', 'logprobs', and optional extra fields. Should be a + single-argument callable: rollout_func(prompts) -> dict. To pass additional context (e.g., trainer), use a + closure or functools.partial: + rollout_func = lambda prompts: my_custom_rollout(prompts, trainer) + The closure will hold a reference to trainer and see its state updates. + """ + + def __init__( + self, + model: "PreTrainedModel | PeftModel", + accelerator: "Accelerator", + is_fsdp_enabled: bool, + processing_class: PreTrainedTokenizerBase | ProcessorMixin, + # vLLM configuration + mode: str = "server", + structured_outputs_regex: str | None = None, + # Server mode configuration + server_base_url: str | None = None, + server_host: str = "0.0.0.0", + server_port: int = 8000, + server_timeout: float = 240.0, + group_port: int = 51216, + # Colocate mode configuration + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.9, + max_model_length: int | None = None, + max_num_seqs: int | None = None, + enable_sleep_mode: bool = False, + model_impl: str = "auto", + # Generation configuration + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 0, + min_p: float = 0.0, + max_completion_length: int = 16, + generation_kwargs: dict | None = None, + # Chat/tool configuration + chat_template: str | None = None, + chat_template_kwargs: dict | None = None, + tools: list | None = None, + rollout_func: Callable | None = None, + ): + self.model = model + self.accelerator = accelerator + self.is_fsdp_enabled = is_fsdp_enabled + self.processing_class = processing_class + + # vLLM configuration + self.mode = mode + self.structured_outputs_regex = structured_outputs_regex + + # Server mode configuration + self.server_base_url = server_base_url + self.server_host = server_host + self.server_port = server_port + self.group_port = group_port + self.server_timeout = server_timeout + + # Colocate mode configuration + self.tensor_parallel_size = tensor_parallel_size + self.gpu_memory_utilization = gpu_memory_utilization + self.max_model_length = max_model_length + self.max_num_seqs = max_num_seqs + self.enable_sleep_mode = enable_sleep_mode + self.model_impl = model_impl + + # Generation configuration + self.repetition_penalty = repetition_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.min_p = min_p + self.max_completion_length = max_completion_length + self.generation_kwargs = generation_kwargs or {} + + # Chat/tool configuration + self.chat_template = chat_template + self.chat_template_kwargs = chat_template_kwargs or {} + self.tools = tools + self.rollout_func = rollout_func + + self._init_vllm() + + def _init_vllm(self): + """Initialize vLLM in server or colocate mode.""" + model = self.model + accelerator = self.accelerator + + if not is_vllm_available(): + raise ImportError( + "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " + "`pip install trl[vllm]` to use it." + ) + + if self.mode == "server": + if accelerator.is_main_process: + if self.server_base_url is not None: + base_url = self.server_base_url + else: + base_url = f"http://{self.server_host}:{self.server_port}" + self.vllm_client = VLLMClient( + base_url=base_url, group_port=self.group_port, connection_timeout=self.server_timeout + ) + self.vllm_client.init_communicator(device=torch.cuda.current_device()) + + elif self.mode == "colocate": + # Make sure tensor_parallel_size group size evenly divides the world size - each group should have + # the same number of ranks + if not accelerator.num_processes % self.tensor_parallel_size == 0: + raise ValueError( + f"tensor_parallel_size ({self.tensor_parallel_size}) must divide world size " + f"({accelerator.num_processes}) evenly." + ) + + if self.tensor_parallel_size > 1: + # Create subgroups of ranks for TP, each group with `tensor_parallel_size` ranks. + # For example, if world_size=8 and tensor_parallel_size=2 โ†’ groups: [0,1], [2,3], [4,5], [6,7] + self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration( + [ + list(range(i * self.tensor_parallel_size, (i + 1) * self.tensor_parallel_size)) + for i in range(accelerator.num_processes // self.tensor_parallel_size) + ] + ) + + # vLLM requires the environment variables to be set for distributed training. + os.environ["RANK"] = str(accelerator.process_index) + os.environ["LOCAL_RANK"] = str(accelerator.local_process_index) + os.environ["WORLD_SIZE"] = str(accelerator.num_processes) + # Ensure distributed rendezvous variables are set without colliding across concurrent runs + ensure_master_addr_port() + + quantization = None + if is_bitsandbytes_available(): + for _, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + quantization = "bitsandbytes" + break + elif isinstance(module, bnb.nn.Linear8bitLt): + raise ValueError("vLLM does not support in-flight 8-bit quantization.") + + # Build LLM initialization kwargs + self.llm = LLM( + model=model.name_or_path, + tensor_parallel_size=self.tensor_parallel_size, + gpu_memory_utilization=self.gpu_memory_utilization, + max_model_len=self.max_model_length, + max_num_seqs=self.max_num_seqs, + enable_sleep_mode=self.enable_sleep_mode, + model_impl=self.model_impl, + distributed_executor_backend="external_launcher", + # Feed identical seed for tp groups to ensure sampling results are the same across workers + seed=accelerator.process_index // self.tensor_parallel_size, + # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory + max_num_batched_tokens=4096, + # Important so temperature scaling/logit tweaking affects the TIS log probs + logprobs_mode="processed_logprobs", + quantization=quantization, + ) + if self.enable_sleep_mode: + self.llm.sleep(level=2) + else: + raise ValueError(f"vllm_mode must be either 'server' or 'colocate', got '{self.mode}'.") + + # When using vLLM, the main process is responsible for loading the model weights. This can cause process + # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we + # synchronize all processes after vLLM has been fully initialized. + accelerator.wait_for_everyone() + + def _fix_param_name_to_vllm(self, name: str, extra_prefixes: list[str] | None = None) -> str: + """Fix parameter name for vLLM compatibility.""" + extra_prefixes = extra_prefixes or [] + prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes + for prefix in prefixes: + name = name.replace(prefix, "") + return name + + def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited: set[str] | None = None): + """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" + # For FSDP1, we need to recurse into children and also use summon_full_params + accelerator = self.accelerator + + if visited is None: + visited = set() + for child_name, child_module in module.named_children(): + child_prefix = f"{prefix}.{child_name}" if prefix else child_name + self._sync_fsdp1_params_to_vllm( + child_module, prefix=child_prefix, visited=visited + ) # recurse into the child + + if isinstance(module, FSDP): + with FSDP.summon_full_params(module, recurse=False, writeback=False): + for param_name, param in module.named_parameters(): + full_name = f"{prefix}.{param_name}" if prefix else param_name + full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."]) + + if full_name in visited: + continue # skip FSDP subtrees already traversed + visited.add(full_name) + + if self.mode == "server" and accelerator.is_main_process: + self.vllm_client.update_named_param(full_name, param.data) + elif self.mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(full_name, param.data)]) + + def _sync_fsdp2_params_to_vllm(self, module: nn.Module): + """FSDP2-specific parameter synchronization.""" + accelerator = self.accelerator + + # For FSDP2, module.state_dict() already covers all parameters, so no need for recursion + for name, param in module.state_dict().items(): + # When using PEFT, we need to recover the original parameter name + name = name.removeprefix("base_model.model.").replace(".base_layer", "") + # Skip PEFT layers: they don't exist in vLLM, and they are merged already. + if is_peft_model(module) and module.prefix in name: + continue + # When module to save, remove its prefix and discard the original module + if "original_module" in name: + continue + name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) + + if param.is_cpu: + param = param.to(torch.device("cuda")) + param = param.full_tensor() + + if self.mode == "server" and accelerator.is_main_process: + self.vllm_client.update_named_param(name, param) + elif self.mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param)]) + + def sync_weights(self): + """Synchronize model weights to vLLM. + + Handles FSDP, DeepSpeed, PEFT weight synchronization. + """ + model = self.model + accelerator = self.accelerator + is_fsdp_enabled = self.is_fsdp_enabled + + # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations + deepspeed_plugin = accelerator.state.deepspeed_plugin + zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 + if zero_stage_3: + import deepspeed + + gather_if_zero3 = deepspeed.zero.GatheredParameters + else: + gather_if_zero3 = nullcontext + + if is_peft_model(model): + # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as + # merging adapters in a sharded manner is not supported. + # TODO: does this work with FSDP? + with gather_if_zero3(list(model.parameters())): + model.merge_adapter() + + # Update vLLM weights while parameters are gathered + if is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext + # Update vLLM weights while parameters are gathered + # For PEFT with FSDP we need to use the memory efficient post-order traversal + fsdp_plugin = getattr(accelerator.state, "fsdp_plugin", None) + fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 + if fsdp_version == 1: + self._sync_fsdp1_params_to_vllm(model) # use memory-efficient post-order traversal for FSDP + elif fsdp_version == 2: + self._sync_fsdp2_params_to_vllm(model) + else: + # DeepSpeed ZeRO-3 with PEFT + for name, param in model.named_parameters(): + # When using PEFT, we need to recover the original parameter name + name = name.removeprefix("base_model.model.").replace(".base_layer", "") + # Skip PEFT layers: they don't exist in vLLM, and they are merged already. + if model.prefix in name: + continue + # When module to save, remove its prefix and discard the original module + if "original_module" in name: + continue + name = self._fix_param_name_to_vllm(name, extra_prefixes=["modules_to_save.default."]) + + if self.mode == "server" and accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + # Unmerge adapters while parameters are still gathered + model.unmerge_adapter() + # Parameters will automatically be repartitioned when exiting the context + else: + # For non-PEFT models, simply gather (if needed) and update each parameter individually. + if is_fsdp_enabled: + fsdp_plugin = getattr(accelerator.state, "fsdp_plugin", None) + fsdp_version = getattr(fsdp_plugin, "fsdp_version", 1) if fsdp_plugin else 1 + if fsdp_version == 1: + self._sync_fsdp1_params_to_vllm(model) # use memory-efficient post-order traversal for FSDP + elif fsdp_version == 2: + self._sync_fsdp2_params_to_vllm(model) + else: + for name, param in model.named_parameters(): + name = self._fix_param_name_to_vllm(name) + with gather_if_zero3([param]): + if self.mode == "server" and accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + + # Reset cache on vLLM + if self.mode == "server" and accelerator.is_main_process: + self.vllm_client.reset_prefix_cache() + elif self.mode == "colocate": + self.llm.reset_prefix_cache() + + def generate(self, prompts: list, num_generations: int, profiler: ProfilingContext | None = None) -> tuple: + """Generate completions using vLLM. + + Args: + prompts: List of prompts (strings or chat conversations) + num_generations: Number of generations per prompt + profiler: Optional profiler for performance tracking + + Returns: + Tuple of (prompt_ids, completion_ids, logprobs, extra_fields) + """ + profiler = profiler or nullcontext() + accelerator = self.accelerator + rollout_func = self.rollout_func + temperature = self.temperature + top_p = self.top_p + top_k = self.top_k + min_p = self.min_p + repetition_penalty = self.repetition_penalty + max_completion_length = self.max_completion_length + processing_class = self.processing_class + chat_template_kwargs = self.chat_template_kwargs + tools = self.tools + chat_template = self.chat_template + + # Wake up colocated vLLM instances if needed + if self.mode == "colocate" and self.enable_sleep_mode: + torch.cuda.empty_cache() # required to avoid OOM in some cases + self.llm.wake_up(tags=["weights"]) + # Work around for https://github.com/vllm-project/vllm/issues/29341 + self.llm.collective_rpc("reload_weights") + + if is_conversational({"prompt": prompts[0]}): + prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts] + + # In vLLM, tool call arguments must be JSON strings. See https://github.com/vllm-project/vllm/pull/28820 + for prompt in prompts: # iterate over each conversation + if is_conversational({"prompt": prompt}): + for message in prompt: # iterate over each message + if "tool_calls" in message: # check if message has tool calls + for call in message["tool_calls"]: + args_value = call["function"]["arguments"] + if isinstance(args_value, dict): # only convert dict โ†’ JSON string + call["function"]["arguments"] = json.dumps(args_value) + + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + if self.mode == "server": + all_prompts = gather_object(prompts) + + if accelerator.is_main_process: + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually. + ordered_set_of_prompts = all_prompts[::num_generations] + + sampling_params = { + "n": num_generations, + "repetition_penalty": repetition_penalty, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": 0.0 if min_p is None else min_p, + "max_tokens": max_completion_length, + "structured_outputs_regex": self.structured_outputs_regex, + "generation_kwargs": self.generation_kwargs, + } + with profiler: # TODO: profiling_context(trainer, "vLLM.generate"): + if rollout_func is not None: + # Pass all prompts (with duplicates) to rollout_func for consistency with colocate mode + rollout_prompts = all_prompts + if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}): + rollout_prompts = [ + apply_chat_template({"prompt": p}, processing_class, **chat_template_kwargs)["prompt"] + for p in rollout_prompts + ] + output = rollout_func(rollout_prompts) + else: + if is_conversational({"prompt": ordered_set_of_prompts[0]}): + output = self.vllm_client.chat( + messages=ordered_set_of_prompts, + **sampling_params, + chat_template_kwargs=chat_template_kwargs, + tools=tools, + chat_template=chat_template, + ) + else: + output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) + # Extract required fields and collect any extra fields for reward functions + required_keys = {"prompt_ids", "completion_ids", "logprobs"} + extra_fields = {k: v for k, v in output.items() if k not in required_keys} + payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"], extra_fields) + else: + payload = None + + # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. + obj_list = [payload] + broadcast_object_list(obj_list, from_process=0) + all_prompt_ids, all_completion_ids, all_logprobs, all_extra_fields = obj_list[0] + + # When using rollout_func, it handles its own generation logic and returns one result per prompt. + # When NOT using rollout_func, vllm_client.generate(n=num_generations) returns num_generations + # completions per prompt, so we need to duplicate prompt_ids to match. + if self.rollout_func is None: + # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times + all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(num_generations)] + + process_slice = slice( + accelerator.process_index * len(prompts), + (accelerator.process_index + 1) * len(prompts), + ) + prompt_ids = all_prompt_ids[process_slice] + completion_ids = all_completion_ids[process_slice] + logprobs = all_logprobs[process_slice] + + # Slice extra fields dict-of-lists per process (extra fields are per-completion, like completion_ids) + extra_fields = {} + for key, values in all_extra_fields.items(): + if isinstance(values, list): + extra_fields[key] = values[process_slice] + else: + extra_fields[key] = values + + # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts + elif self.mode == "colocate": + if rollout_func is not None: + rollout_prompts = prompts + if rollout_prompts and is_conversational({"prompt": rollout_prompts[0]}): + rollout_prompts = [ + apply_chat_template({"prompt": prompt}, processing_class, **chat_template_kwargs)["prompt"] + for prompt in rollout_prompts + ] + output = rollout_func(rollout_prompts) + required_keys = {"prompt_ids", "completion_ids", "logprobs"} + extra_fields = {k: v for k, v in output.items() if k not in required_keys} + prompt_ids = output["prompt_ids"] + completion_ids = output["completion_ids"] + logprobs = output["logprobs"] + else: + if Version(vllm.__version__) <= Version("0.10.2"): + structured_outputs_key = "guided_decoding" + if self.structured_outputs_regex: + structured_outputs = GuidedDecodingParams(regex=self.structured_outputs_regex) + else: + structured_outputs = None + else: + structured_outputs_key = "structured_outputs" + if self.structured_outputs_regex: + structured_outputs = StructuredOutputsParams(regex=self.structured_outputs_regex) + else: + structured_outputs = None + + generation_kwargs = { + "n": 1, # vLLM on each GPU generates only 1 in colocate mode + "repetition_penalty": repetition_penalty, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": 0.0 if min_p is None else min_p, + "max_tokens": max_completion_length, + "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only + } + generation_kwargs[structured_outputs_key] = structured_outputs + generation_kwargs.update(self.generation_kwargs) + sampling_params = SamplingParams(**generation_kwargs) + + if self.tensor_parallel_size > 1: + # Gather prompts from all ranks in the TP group and flatten. + # Each rank starts with its own prompts; after gathering, all ranks see the full group set. + orig_size = len(prompts) + gathered_prompts = [None for _ in range(self.tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) + all_prompts = [p for sublist in gathered_prompts for p in sublist] + else: + all_prompts = prompts + + if self.enable_sleep_mode: + self.llm.wake_up(tags=["kv_cache"]) + + with profiler: # TODO: profiling_context(trainer, "vLLM.generate"): + if is_conversational({"prompt": prompts[0]}): + all_outputs = self.llm.chat( + all_prompts, + sampling_params=sampling_params, + use_tqdm=False, + chat_template_kwargs=chat_template_kwargs, + tools=tools, + chat_template=chat_template, + ) + else: + all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) + + all_prompt_ids = [output.prompt_token_ids for output in all_outputs] + all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + all_logprobs = [ + [sanitize_logprob(next(iter(lp.values()))) for lp in output.logprobs] + for outputs in all_outputs + for output in outputs.outputs + ] + + if self.tensor_parallel_size > 1: + # Slice completions for this rank within its TP group. + # Each rank generates all outputs โ€” we keep only our share. + local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + prompt_ids = all_prompt_ids[tp_slice] + completion_ids = all_completion_ids[tp_slice] + logprobs = all_logprobs[tp_slice] + else: + prompt_ids = all_prompt_ids + completion_ids = all_completion_ids + logprobs = all_logprobs + + extra_fields = {} # No extra fields for colocate mode + + if self.enable_sleep_mode: + self.llm.sleep(level=2) + + return prompt_ids, completion_ids, logprobs, extra_fields diff --git a/ICL/RL/trl_source/trl/models/__init__.py b/ICL/RL/trl_source/trl/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a2018bd0d9760c94f87467beebf361f56104593 --- /dev/null +++ b/ICL/RL/trl_source/trl/models/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ..import_utils import _LazyModule + + +_import_structure = { + "activation_offloading": ["get_act_offloading_ctx_manager"], + "utils": ["create_reference_model", "prepare_deepspeed", "prepare_fsdp", "unwrap_model_for_generation"], +} + + +if TYPE_CHECKING: + from .activation_offloading import get_act_offloading_ctx_manager + from .utils import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/ICL/RL/trl_source/trl/models/__pycache__/__init__.cpython-313.pyc b/ICL/RL/trl_source/trl/models/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6027664391d48ff726193f2439eaa0f4e97e456 Binary files /dev/null and b/ICL/RL/trl_source/trl/models/__pycache__/__init__.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/models/utils.py b/ICL/RL/trl_source/trl/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1fbe0ad0d4b9d4131c0f5ac064dd1c433500edb5 --- /dev/null +++ b/ICL/RL/trl_source/trl/models/utils.py @@ -0,0 +1,446 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging +from collections.abc import Callable +from contextlib import contextmanager +from copy import deepcopy +from typing import TYPE_CHECKING, Any + +import torch +import torch.nn as nn +import transformers +from packaging.version import Version +from transformers import GenerationConfig, PreTrainedModel +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled + + +if TYPE_CHECKING: + from accelerate import Accelerator + from deepspeed.runtime.engine import DeepSpeedEngine + from torch.nn import Module + from torch.nn.parallel.distributed import DistributedDataParallel + + +def remove_hooks(model: "DeepSpeedEngine") -> None: + """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model.""" + if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer + return + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + else: + raise RuntimeError("The model optimizer is None, which is not yet supported.") + + for param in iter_params(optimizer_offload.module, recurse=True): + param.ds_active_sub_modules.clear() + + for hook in optimizer_offload.forward_hooks: + hook.remove() + for hook in optimizer_offload.backward_hooks: + hook.remove() + + optimizer_offload.forward_hooks = [] + optimizer_offload.backward_hooks = [] + + +def get_all_parameters(sub_module, recurse=False): + return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) + + +def iter_params(module, recurse=False): + return [param for _, param in get_all_parameters(module, recurse)] + + +def add_hooks(model: "DeepSpeedEngine") -> None: + """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model.""" + import deepspeed + + if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer + return + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + else: + raise RuntimeError("The model optimizer is None, which is not yet supported.") + if Version(deepspeed.__version__) >= Version("0.16.4"): + # Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847 + optimizer_offload._register_deepspeed_module(optimizer_offload.module) + else: + optimizer_offload._register_hooks_recursively(optimizer_offload.module) + + +@contextmanager +def _unwrap_model_for_generation( + model: "DistributedDataParallel | DeepSpeedEngine", + accelerator: "Accelerator", + gather_deepspeed3_params: bool = True, +): + """ + Context manager to unwrap distributed or accelerated models for generation tasks. + + Args: + model (`DistributedDataParallel | DeepSpeedEngine`): + Model to be unwrapped. + accelerator ([`~accelerate.Accelerator`]): + Accelerator instance managing the model. + gather_deepspeed3_params (`bool`, *optional*, defaults to `True`): + Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which + can be more memory-efficient but may lead to slower generation times. + + Yields: + Unwrapped model. + + Example: + ```python + with _unwrap_model_for_generation(model, accelerator) as unwrapped_model: + generated_outputs = unwrapped_model.generate(input_ids) + ``` + """ + unwrapped_model = accelerator.unwrap_model(model) + is_gradient_checkpointing = unwrapped_model.is_gradient_checkpointing + if is_gradient_checkpointing: + unwrapped_model.gradient_checkpointing_disable() + if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3: + if not gather_deepspeed3_params: + yield accelerator.unwrap_model(model) + else: + import deepspeed + + with deepspeed.zero.GatheredParameters(model.parameters()): + remove_hooks(model) + yield accelerator.unwrap_model(model) + add_hooks(model) + else: + yield unwrapped_model + if is_gradient_checkpointing: + unwrapped_model.gradient_checkpointing_enable() + + +@contextmanager +def _override_model_generation_config(model, generation_kwargs=None): + """ + Context manager to temporarily override a model's generation_config with training config. + + This works around transformers' config merging logic that would otherwise overwrite values matching global defaults + with model-specific values (see upstream issue transformers#42762; fixed in transformers v5 by PR + `transformers#42702`). + + By temporarily setting the model's generation_config to match the passed generation_config, we avoid the conflict. + + The model's original generation_config is preserved outside this context, ensuring that saved/pushed models retain + their intended inference behavior. + + Args: + model: The model (typically unwrapped_model) whose generation_config to temporarily override. + generation_kwargs (dict): Generation kwargs to be used to override model's generation config. + """ + if ( + # Issue fixed in transformers v5 by PR transformers#42702 + Version(transformers.__version__) >= Version("5.0.0") + or generation_kwargs is None + or not hasattr(model, "generation_config") + ): + yield model + return + # If it is a PEFT model, override the underlying base model + if hasattr(model, "get_base_model"): + model = model.get_base_model() + # Keep original model generation_config + original_config = model.generation_config + # Create training-specific generation config from the model's original generation config + # Then overwrite it with the training-specific generation kwargs + generation_config = GenerationConfig.from_dict(model.generation_config.to_dict()) + generation_config.update(**generation_kwargs) + model.generation_config = generation_config + try: + yield + finally: + model.generation_config = original_config + + +@contextmanager +def unwrap_model_for_generation( + model: "DistributedDataParallel | DeepSpeedEngine", + accelerator: "Accelerator", + gather_deepspeed3_params: bool = True, + generation_kwargs: dict | None = None, +): + """ + Context manager to unwrap distributed or accelerated models for generation tasks. + + This function unwraps distributed models (FSDP, DeepSpeed) and optionally overrides the model's generation_config + temporarily during generation. This is useful for applying training-specific generation parameters without + permanently modifying the model's original generation_config. + + Args: + model (`DistributedDataParallel | DeepSpeedEngine`): + Model to be unwrapped. + accelerator ([`~accelerate.Accelerator`]): + Accelerator instance managing the model. + gather_deepspeed3_params (`bool`, *optional*, defaults to `True`): + Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which + can be more memory-efficient but may lead to slower generation times. + generation_kwargs (dict, *optional*): + If provided, temporarily overrides the model's generation_config during generation. The original config is + automatically restored when exiting the context. This is useful for using different generation parameters + during training vs. inference. + + Yields: + Unwrapped model with optionally overridden generation_config. + """ + with ( + _unwrap_model_for_generation( + model, accelerator, gather_deepspeed3_params=gather_deepspeed3_params + ) as unwrapped_model, + _override_model_generation_config(unwrapped_model, generation_kwargs=generation_kwargs), + ): + yield unwrapped_model + + +def prepare_deepspeed(model: "Module", accelerator: "Accelerator"): + """Prepares the model for DeepSpeed inference or evaluation by initializing it with the appropriate configuration. + + Adapted from accelerate: + https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + """ + import deepspeed # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm): https://github.com/deepspeedai/DeepSpeed/issues/7252 + + deepspeed_plugin = accelerator.state.deepspeed_plugin + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + stage = config_kwargs["zero_optimization"]["stage"] + + if model is not None: + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and stage == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache + # @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO + # disabled (stage 0) + if stage != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + + +def prepare_fsdp(model, accelerator): + # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1421 + from torch.distributed.fsdp import FSDPModule + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + + # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, + # don't wrap it again + if not (isinstance(model, FSDP) or isinstance(model, FSDPModule)): + accelerator.state.fsdp_plugin.set_auto_wrap_policy(model) + fsdp_plugin = accelerator.state.fsdp_plugin + kwargs = { + "sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward, + "cpu_offload": fsdp_plugin.cpu_offload, + "auto_wrap_policy": fsdp_plugin.auto_wrap_policy, + "mixed_precision": fsdp_plugin.mixed_precision_policy, + "sync_module_states": fsdp_plugin.sync_module_states, + "backward_prefetch": fsdp_plugin.backward_prefetch, + "forward_prefetch": fsdp_plugin.forward_prefetch, + "use_orig_params": fsdp_plugin.use_orig_params, + "param_init_fn": fsdp_plugin.param_init_fn, + "ignored_modules": fsdp_plugin.ignored_modules, + "limit_all_gathers": fsdp_plugin.limit_all_gathers, + "device_id": accelerator.device, + } + model = FSDP(model, **kwargs) + model.eval() + return model + + +class _ForwardRedirection: + """Implements the `forward-redirection`. + + Taken from Pytorch-lightning: + https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602 + + A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead. + + """ + + def __call__( + self, wrapper_module: nn.Module, original_module: nn.Module, method: Callable, *args: Any, **kwargs: Any + ): + """Reroutes a method call through the `wrapper_module`'s `forward` method. + + Args: + wrapper_module: The module that has `original_module` wrapped. + original_module: The module that was wrapped inside `wrapper_module`. + method: The method that should be called on the `original_module` after inputs get + redirected through the `wrapper_module`'s `forward` method. + *args: The positional arguments to the `method`. They will get passed to a patched + `forward` method instead. + **kwargs: The keyword arguments to the `method`. They will get passed to a patched + `forward` method instead. + + """ + original_forward = original_module.forward + + def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any: + # Unpatch ourselves immediately before calling the method `method_name` + # because itself may want to call the real `forward` + original_module.forward = original_forward # type: ignore[method-assign] + # Call the actual method e.g. `.training_step(...)` + out = method(*_args, **_kwargs) + self.on_after_inner_forward(wrapper_module, original_module) + return out + + # Patch the original_module's forward so we can redirect the arguments back to the real method + original_module.forward = wrapped_forward # type: ignore[method-assign] + + wrapper_output = wrapper_module(*args, **kwargs) + self.on_after_outer_forward(wrapper_module, original_module) + return wrapper_output + + def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None: + pass + + def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None: + pass + + +def peft_module_casting_to_bf16(model): + for name, module in model.named_modules(): + if isinstance(module, torch.nn.LayerNorm) or "norm" in name: + module = module.to(torch.float32) + elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): + if hasattr(module, "weight"): + if module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + +@contextmanager +def disable_gradient_checkpointing(model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None = None): + """ + Temporarily disable gradient checkpointing, restoring the previous state afterward. + + Args: + model (`PreTrainedModel`): + Model for which to temporarily disable gradient checkpointing. + gradient_checkpointing_kwargs (`dict` or `None`, *optional*): + Additional kwargs for gradient checkpointing enabling. + """ + was_enabled = model.is_gradient_checkpointing + if was_enabled: + model.gradient_checkpointing_disable() + try: + yield + finally: + if was_enabled: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs) + + +LAYER_PATTERNS = [ + "transformer.h.{layer}", + "model.decoder.layers.{layer}", + "gpt_neox.layers.{layer}", + "model.layers.{layer}", +] + + +def create_reference_model( + model: nn.Module, num_shared_layers: int | None = None, pattern: str | None = None +) -> nn.Module: + """ + Creates a static reference copy of a model. Note that model will be in `.eval()` mode. + + Args: + model ([`nn.Module`]): The model to be copied. + num_shared_layers (`int`, *optional*): + The number of initial layers that are shared between both models and kept frozen. + pattern (`str`, *optional*): The shared layers are selected with a string pattern + (e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here. + + Returns: + [`nn.Module`] + """ + if is_deepspeed_zero3_enabled(): + raise ValueError( + "DeepSpeed ZeRO-3 is enabled and is not compatible with `create_reference_model()`. Please instantiate your reference model directly with `AutoModelForCausalLM.from_pretrained()`." + ) + + parameter_names = [n for n, _ in model.named_parameters()] + ref_model = deepcopy(model) + + # if no layers are shared, return copy of model + if num_shared_layers is None: + for param_name in parameter_names: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + return ref_model.eval() + + # identify layer name pattern + if pattern is not None: + pattern = pattern.format(layer=num_shared_layers) + else: + for pattern_candidate in LAYER_PATTERNS: + pattern_candidate = pattern_candidate.format(layer=num_shared_layers) + if any(pattern_candidate in name for name in parameter_names): + pattern = pattern_candidate + break + + if pattern is None: + raise ValueError("Layer pattern could not be matched.") + + # divide parameters in shared and unshared parameter lists + shared_param_list = [] + unshared_param_list = [] + + shared_parameter = True + for name, _param in model.named_parameters(): + if pattern in name: + shared_parameter = False + if shared_parameter: + shared_param_list.append(name) + else: + unshared_param_list.append(name) + + # create reference of the original parameter if they are shared + for param_name in shared_param_list: + param = model.get_parameter(param_name) + param.requires_grad = False + + _ref_param = ref_model.get_parameter(param_name) + + # for all other parameters just make sure they don't use gradients + for param_name in unshared_param_list: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + + if pattern is not None and len(unshared_param_list) == 0: + logging.warning("Pattern passed or found, but no layers matched in the model. Check for a typo.") + + return ref_model.eval() diff --git a/ICL/RL/trl_source/trl/rewards/__init__.py b/ICL/RL/trl_source/trl/rewards/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0651d6141b4bdc8764a7b7a7945514a4f3d5fb --- /dev/null +++ b/ICL/RL/trl_source/trl/rewards/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import TYPE_CHECKING + +from ..import_utils import _LazyModule + + +_import_structure = { + "accuracy_rewards": ["accuracy_reward", "reasoning_accuracy_reward"], + "format_rewards": ["think_format_reward"], + "other_rewards": ["get_soft_overlong_punishment"], +} + + +if TYPE_CHECKING: + from .accuracy_rewards import accuracy_reward, reasoning_accuracy_reward + from .format_rewards import think_format_reward + from .other_rewards import get_soft_overlong_punishment + + +else: + sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__) diff --git a/ICL/RL/trl_source/trl/rewards/format_rewards.py b/ICL/RL/trl_source/trl/rewards/format_rewards.py new file mode 100644 index 0000000000000000000000000000000000000000..c737d0036b165ebf4a15e2c2558f82187e87038d --- /dev/null +++ b/ICL/RL/trl_source/trl/rewards/format_rewards.py @@ -0,0 +1,50 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + + +def think_format_reward(completions: list[list[dict[str, str]]], **kwargs) -> list[float]: + r""" + Reward function that checks if the reasoning process is enclosed within `""` and `""` tags. The + function returns a reward of 1.0 if the format is correct, otherwise 0.0. + + Args: + completions (`list[list[dict[str, str]]]`): + List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary + containing the key `"content"` with the value being the text of the completion. + **kwargs: + Additional keyword arguments. This function does not use them, but they are required in the function + signature to ensure compatibility with trainers like [`GRPOTrainer`]. + + Returns: + `list[float]`: + A list of rewards, where each reward is 1.0 if the completion matches the expected format, otherwise 0.0. + + Example: + ```python + >>> from trl.rewards import think_format_reward + + >>> completions = [ + ... [{"content": "\nThis is my reasoning.\n\nThis is my answer."}], + ... [{"content": "\nThis is my reasoning.\nThis is my answer."}], + ... ] + >>> think_format_reward(completions) + [1.0, 0.0] + ``` + """ + pattern = r"^(?!.*)(.*?).*$" + completion_contents = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents] + return [1.0 if match else 0.0 for match in matches] diff --git a/ICL/RL/trl_source/trl/rewards/other_rewards.py b/ICL/RL/trl_source/trl/rewards/other_rewards.py new file mode 100644 index 0000000000000000000000000000000000000000..4d26ec2d52045e88fbfe4250636eb0de32d9afe3 --- /dev/null +++ b/ICL/RL/trl_source/trl/rewards/other_rewards.py @@ -0,0 +1,62 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable + + +def get_soft_overlong_punishment(max_completion_len: int, soft_punish_cache: int) -> Callable: + # docstyle-ignore + r""" + Reward function that penalizes overlong completions. It is used to penalize overlong completions, but not to reward + shorter completions. Reference: Eq. (13) from the DAPO paper (https://huggingface.co/papers/2503.14476) + + $$ + R_{\text{length}}(y) = \begin{cases} + 0, & |y| \le L_{\max} - L_{\text{cache}} \\ + \dfrac{(L_{\max} - L_{\text{cache}}) - |y|}{L_{\text{cache}}}, & L_{\max} - L_{\text{cache}} < |y| \le L_{\max} \\ + -1, & L_{\max} < |y| + \end{cases} + $$ + + Args: + max_completion_len (`int`): + Maximum length of the completion, \( L_{\max} \). + soft_punish_cache (`int`): + Minimum length of the completion, \( L_{\text{cache}} \). If set to `0`, no minimum length is applied. + + Example: + ```python + from trl.rewards import get_soft_overlong_punishment + + soft_overlong_punishment = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) + completion_ids = [[1] * 90] # simulating a completion with 90 tokens. 90 is between 80 and 100. + rewards = soft_overlong_punishment(completion_ids) + print(rewards) # [-0.5] + ``` + """ + + def soft_overlong_punishment_reward(completion_ids: list[list[int]], **kwargs) -> list[float]: + """Reward function that penalizes overlong completions.""" + rewards = [] + for ids in completion_ids: + completion_length = len(ids) + if completion_length <= max_completion_len - soft_punish_cache: + rewards.append(0.0) + elif max_completion_len - soft_punish_cache < completion_length <= max_completion_len: + rewards.append((max_completion_len - soft_punish_cache - completion_length) / soft_punish_cache) + else: + rewards.append(-1.0) + return rewards + + return soft_overlong_punishment_reward diff --git a/ICL/RL/trl_source/trl/scripts/__init__.py b/ICL/RL/trl_source/trl/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..010ee40482d4ff9b3a82499b2f7eeb2bd4595f47 --- /dev/null +++ b/ICL/RL/trl_source/trl/scripts/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ..import_utils import _LazyModule + + +_import_structure = { + "utils": ["DatasetMixtureConfig", "ScriptArguments", "TrlParser", "get_dataset", "init_zero_verbose"], +} + +if TYPE_CHECKING: + from .utils import DatasetMixtureConfig, ScriptArguments, TrlParser, get_dataset, init_zero_verbose +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/ICL/RL/trl_source/trl/scripts/dpo.py b/ICL/RL/trl_source/trl/scripts/dpo.py new file mode 100644 index 0000000000000000000000000000000000000000..1e97122f1a44c791eec1add93e644db566e8b069 --- /dev/null +++ b/ICL/RL/trl_source/trl/scripts/dpo.py @@ -0,0 +1,184 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +""" +# Full training +```bash +python trl/scripts/dpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --learning_rate 5.0e-7 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --max_steps 1000 \ + --gradient_accumulation_steps 8 \ + --eval_strategy steps \ + --eval_steps 50 \ + --output_dir Qwen2-0.5B-DPO \ + --no_remove_unused_columns +``` + +# LoRA: +```bash +python trl/scripts/dpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --learning_rate 5.0e-6 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --max_steps 1000 \ + --gradient_accumulation_steps 8 \ + --eval_strategy steps \ + --eval_steps 50 \ + --output_dir Qwen2-0.5B-DPO \ + --no_remove_unused_columns \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 +``` +""" + +import argparse +import os + +import torch +from accelerate import logging +from datasets import load_dataset +from transformers import AutoModelForCausalLM + +from trl import ( + DatasetMixtureConfig, + DPOConfig, + DPOTrainer, + ModelConfig, + ScriptArguments, + TrlParser, + get_dataset, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +logger = logging.get_logger(__name__) + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +def main(script_args, training_args, model_args, dataset_args): + ################ + # Model + ################### + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + peft_config = get_peft_config(model_args) + if peft_config is None: + ref_model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + else: + ref_model = None + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + # Load the dataset + if dataset_args.datasets and script_args.dataset_name: + logger.warning( + "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " + "dataset and `dataset_name` will be ignored." + ) + dataset = get_dataset(dataset_args) + elif dataset_args.datasets and not script_args.dataset_name: + dataset = get_dataset(dataset_args) + elif not dataset_args.datasets and script_args.dataset_name: + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming + ) + else: + raise ValueError("Either `datasets` or `dataset_name` must be provided.") + + # Initialize the DPO trainer + trainer = DPOTrainer( + model, + ref_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + peft_config=peft_config, + ) + + # Train the model + trainer.train() + + # Log training complete + trainer.accelerator.print("โœ… Training completed.") + + if training_args.eval_strategy != "no": + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Save and push to Hub + trainer.save_model(training_args.output_dir) + trainer.accelerator.print(f"๐Ÿ’พ Model saved to {training_args.output_dir}.") + + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + trainer.accelerator.print(f"๐Ÿค— Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.") + + +def make_parser(subparsers: argparse._SubParsersAction | None = None): + dataclass_types = (ScriptArguments, DPOConfig, ModelConfig, DatasetMixtureConfig) + if subparsers is not None: + parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config( + return_remaining_strings=True + ) + main(script_args, training_args, model_args, dataset_args) diff --git a/ICL/RL/trl_source/trl/scripts/env.py b/ICL/RL/trl_source/trl/scripts/env.py new file mode 100644 index 0000000000000000000000000000000000000000..87e050f93eac7376ec5e9c11a6cefe5cbf3baac7 --- /dev/null +++ b/ICL/RL/trl_source/trl/scripts/env.py @@ -0,0 +1,88 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# ] +# /// + +import os +import platform +from importlib.metadata import version + +import torch +from accelerate.commands.config import default_config_file, load_config_from_file +from transformers import is_bitsandbytes_available +from transformers.utils import is_openai_available, is_peft_available + +from trl import __version__ +from trl.import_utils import ( + is_deepspeed_available, + is_liger_kernel_available, + is_llm_blender_available, + is_vllm_available, +) +from trl.scripts.utils import get_git_commit_hash + + +def print_env(): + devices = None + if torch.cuda.is_available(): + devices = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] + elif torch.backends.mps.is_available(): + devices = ["MPS"] + elif torch.xpu.is_available(): + devices = [torch.xpu.get_device_name(i) for i in range(torch.xpu.device_count())] + + accelerate_config = accelerate_config_str = "not found" + + # Get the default from the config file. + if os.path.isfile(default_config_file): + accelerate_config = load_config_from_file(default_config_file).to_dict() + + accelerate_config_str = ( + "\n" + "\n".join([f" - {prop}: {val}" for prop, val in accelerate_config.items()]) + if isinstance(accelerate_config, dict) + else accelerate_config + ) + + commit_hash = get_git_commit_hash("trl") + + info = { + "Platform": platform.platform(), + "Python version": platform.python_version(), + "TRL version": f"{__version__}+{commit_hash[:7]}" if commit_hash else __version__, + "PyTorch version": version("torch"), + "accelerator(s)": ", ".join(devices) if devices is not None else "cpu", + "Transformers version": version("transformers"), + "Accelerate version": version("accelerate"), + "Accelerate config": accelerate_config_str, + "Datasets version": version("datasets"), + "HF Hub version": version("huggingface_hub"), + "bitsandbytes version": version("bitsandbytes") if is_bitsandbytes_available() else "not installed", + "DeepSpeed version": version("deepspeed") if is_deepspeed_available() else "not installed", + "Liger-Kernel version": version("liger_kernel") if is_liger_kernel_available() else "not installed", + "LLM-Blender version": version("llm_blender") if is_llm_blender_available() else "not installed", + "OpenAI version": version("openai") if is_openai_available() else "not installed", + "PEFT version": version("peft") if is_peft_available() else "not installed", + "vLLM version": version("vllm") if is_vllm_available() else "not installed", + } + + info_str = "\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + print(f"\nCopy-paste the following information when reporting an issue:\n\n{info_str}\n") # noqa + + +if __name__ == "__main__": + print_env() diff --git a/ICL/RL/trl_source/trl/scripts/grpo.py b/ICL/RL/trl_source/trl/scripts/grpo.py new file mode 100644 index 0000000000000000000000000000000000000000..71e867d8c37e0909ff8c0e6103d89d4164920130 --- /dev/null +++ b/ICL/RL/trl_source/trl/scripts/grpo.py @@ -0,0 +1,193 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +import argparse +import importlib +import os +import sys +from dataclasses import dataclass, field + +import torch +from accelerate import logging +from datasets import load_dataset + +from trl import ( + DatasetMixtureConfig, + GRPOConfig, + GRPOTrainer, + ModelConfig, + ScriptArguments, + TrlParser, + get_dataset, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.rewards import accuracy_reward, get_soft_overlong_punishment, reasoning_accuracy_reward, think_format_reward + + +logger = logging.get_logger(__name__) + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +reward_funcs_registry = { + "accuracy_reward": accuracy_reward, + "reasoning_accuracy_reward": reasoning_accuracy_reward, + "think_format_reward": think_format_reward, + "get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256), +} + + +@dataclass +class GRPOScriptArguments(ScriptArguments): + """ + Script arguments for the GRPO training script. + + Args: + reward_model_name_or_path (`str`, *optional*): + Reward model id of a pretrained model hosted inside a model repo on huggingface.co or local path to a + directory containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`]. + reward_funcs (`list[str]`, *optional*): + Reward functions to use. Supported values are: + - `"accuracy_reward"` + - `"reasoning_accuracy_reward"` + - `"think_format_reward"` + - `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`) + - any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`). + """ + + reward_model_name_or_path: str | None = field( + default=None, + metadata={ + "help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or " + "local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`." + }, + ) + reward_funcs: list[str] | None = field( + default=None, + metadata={ + "help": "Reward functions to use. Supported values are: `accuracy_reward`, `reasoning_accuracy_reward`, `think_format_reward`, " + "`get_soft_overlong_punishment` (used values are `max_completion_len=1280`, `soft_punish_cache=256`), or " + "any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)." + }, + ) + + +def main(script_args, training_args, model_args, dataset_args): + # Get the reward models and functions + reward_funcs = [] + if script_args.reward_model_name_or_path: + reward_funcs.append(script_args.reward_model_name_or_path) + + if script_args.reward_funcs: + for func_name in script_args.reward_funcs: + if func_name in reward_funcs_registry: + reward_funcs.append(reward_funcs_registry[func_name]) + elif "." in func_name: + module_path, func_name = func_name.rsplit(".", 1) + sys.path.insert(0, os.getcwd()) + module = importlib.import_module(module_path) + reward_func = getattr(module, func_name) + reward_funcs.append(reward_func) + else: + raise ValueError( + f"Could not load reward function '{func_name}'. Expected one of " + f"{list(reward_funcs_registry.keys())} or a valid import path." + ) + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) + quantization_config = get_quantization_config(model_args) + + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + training_args.model_init_kwargs = model_kwargs + + # Load the dataset + if dataset_args.datasets and script_args.dataset_name: + logger.warning( + "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " + "dataset and `dataset_name` will be ignored." + ) + dataset = get_dataset(dataset_args) + elif dataset_args.datasets and not script_args.dataset_name: + dataset = get_dataset(dataset_args) + elif not dataset_args.datasets and script_args.dataset_name: + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming + ) + else: + raise ValueError("Either `datasets` or `dataset_name` must be provided.") + + # Initialize the GRPO trainer + trainer = GRPOTrainer( + model=model_args.model_name_or_path, + reward_funcs=reward_funcs, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + peft_config=get_peft_config(model_args), + ) + + # Train the model + trainer.train() + + # Log training complete + trainer.accelerator.print("โœ… Training completed.") + + # Save and push to Hub + trainer.save_model(training_args.output_dir) + trainer.accelerator.print(f"๐Ÿ’พ Model saved to {training_args.output_dir}.") + + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + trainer.accelerator.print(f"๐Ÿค— Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.") + + +def make_parser(subparsers: argparse._SubParsersAction | None = None): + dataclass_types = (GRPOScriptArguments, GRPOConfig, ModelConfig, DatasetMixtureConfig) + if subparsers is not None: + parser = subparsers.add_parser("grpo", help="Run the GRPO training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config( + return_remaining_strings=True + ) + main(script_args, training_args, model_args, dataset_args) diff --git a/ICL/RL/trl_source/trl/scripts/kto.py b/ICL/RL/trl_source/trl/scripts/kto.py new file mode 100644 index 0000000000000000000000000000000000000000..9acdfd49fbdde2e0facd4fbc0c917c5dd8a9445f --- /dev/null +++ b/ICL/RL/trl_source/trl/scripts/kto.py @@ -0,0 +1,165 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +""" +Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to +that of DPO. + +# Full training: +```bash +python trl/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ + --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ + --per_device_train_batch_size 16 \ + --num_train_epochs 1 \ + --learning_rate 5e-7 \ + --lr_scheduler_type=cosine \ + --gradient_accumulation_steps 1 \ + --eval_steps 500 \ + --output_dir=kto-aligned-model \ + --warmup_steps 0.1 \ + --logging_first_step +``` + +# QLoRA: +```bash +# QLoRA: +python trl/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ + --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ + --per_device_train_batch_size 8 \ + --num_train_epochs 1 \ + --learning_rate 5e-7 \ + --lr_scheduler_type=cosine \ + --gradient_accumulation_steps 1 \ + --eval_steps 500 \ + --output_dir=kto-aligned-model-lora \ + --warmup_steps 0.1 \ + --logging_first_step \ + --use_peft \ + --load_in_4bit \ + --lora_target_modules=all-linear \ + --lora_r=16 \ + --lora_alpha=16 +``` +""" + +import argparse +import os + +from accelerate import logging +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl import ( + DatasetMixtureConfig, + ModelConfig, + ScriptArguments, + TrlParser, + get_dataset, + get_peft_config, +) +from trl.experimental.kto import KTOConfig, KTOTrainer + + +logger = logging.get_logger(__name__) + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +def main(script_args, training_args, model_args, dataset_args): + # Load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + ref_model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Load the dataset + if dataset_args.datasets and script_args.dataset_name: + logger.warning( + "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " + "dataset and `dataset_name` will be ignored." + ) + dataset = get_dataset(dataset_args) + elif dataset_args.datasets and not script_args.dataset_name: + dataset = get_dataset(dataset_args) + elif not dataset_args.datasets and script_args.dataset_name: + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming + ) + else: + raise ValueError("Either `datasets` or `dataset_name` must be provided.") + + # Initialize the KTO trainer + trainer = KTOTrainer( + model, + ref_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + # Train the model + trainer.train() + + # Log training complete + trainer.accelerator.print("โœ… Training completed.") + + # Save and push to Hub + trainer.save_model(training_args.output_dir) + trainer.accelerator.print(f"๐Ÿ’พ Model saved to {training_args.output_dir}.") + + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + trainer.accelerator.print(f"๐Ÿค— Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.") + + +def make_parser(subparsers: argparse._SubParsersAction | None = None): + dataclass_types = (ScriptArguments, KTOConfig, ModelConfig, DatasetMixtureConfig) + if subparsers is not None: + parser = subparsers.add_parser("kto", help="Run the KTO training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config( + return_remaining_strings=True + ) + main(script_args, training_args, model_args, dataset_args) diff --git a/ICL/RL/trl_source/trl/scripts/reward.py b/ICL/RL/trl_source/trl/scripts/reward.py new file mode 100644 index 0000000000000000000000000000000000000000..75b6556572c70bce2713477b4e27044c7a58e611 --- /dev/null +++ b/ICL/RL/trl_source/trl/scripts/reward.py @@ -0,0 +1,108 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +import argparse +import os + +from accelerate import logging +from datasets import load_dataset + +from trl import ( + DatasetMixtureConfig, + ModelConfig, + RewardConfig, + RewardTrainer, + ScriptArguments, + TrlParser, + get_dataset, + get_peft_config, +) + + +logger = logging.get_logger(__name__) + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +def main(script_args, training_args, model_args, dataset_args): + # Load the dataset + if dataset_args.datasets and script_args.dataset_name: + logger.warning( + "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " + "dataset and `dataset_name` will be ignored." + ) + dataset = get_dataset(dataset_args) + elif dataset_args.datasets and not script_args.dataset_name: + dataset = get_dataset(dataset_args) + elif not dataset_args.datasets and script_args.dataset_name: + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming + ) + else: + raise ValueError("Either `datasets` or `dataset_name` must be provided.") + + # Initialize the RewardTrainer + trainer = RewardTrainer( + model=model_args.model_name_or_path, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + peft_config=get_peft_config(model_args), + ) + + # Train the model + trainer.train() + + # Log training complete + trainer.accelerator.print("โœ… Training completed.") + + # Save and push to Hub + trainer.save_model(training_args.output_dir) + trainer.accelerator.print(f"๐Ÿ’พ Model saved to {training_args.output_dir}.") + + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + trainer.accelerator.print(f"๐Ÿค— Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.") + + +def make_parser(subparsers: argparse._SubParsersAction | None = None): + dataclass_types = (ScriptArguments, RewardConfig, ModelConfig, DatasetMixtureConfig) + if subparsers is not None: + parser = subparsers.add_parser( + "reward", help="Run the reward training script", dataclass_types=dataclass_types + ) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config( + return_remaining_strings=True + ) + main(script_args, training_args, model_args, dataset_args) diff --git a/ICL/RL/trl_source/trl/scripts/rloo.py b/ICL/RL/trl_source/trl/scripts/rloo.py new file mode 100644 index 0000000000000000000000000000000000000000..060c04be92e7d842e3dfba8a7866d801e8667fd8 --- /dev/null +++ b/ICL/RL/trl_source/trl/scripts/rloo.py @@ -0,0 +1,175 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +import argparse +import importlib +import os +import sys +from dataclasses import dataclass, field + +from accelerate import logging +from datasets import load_dataset + +from trl import ( + DatasetMixtureConfig, + ModelConfig, + RLOOConfig, + RLOOTrainer, + ScriptArguments, + TrlParser, + get_dataset, + get_peft_config, +) +from trl.rewards import accuracy_reward, get_soft_overlong_punishment, reasoning_accuracy_reward, think_format_reward + + +logger = logging.get_logger(__name__) + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +reward_funcs_registry = { + "accuracy_reward": accuracy_reward, + "reasoning_accuracy_reward": reasoning_accuracy_reward, + "think_format_reward": think_format_reward, + "get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256), +} + + +@dataclass +class RLOOScriptArguments(ScriptArguments): + """ + Script arguments for the RLOO training script. + + Args: + reward_model_name_or_path (`str`, *optional*): + Reward model id of a pretrained model hosted inside a model repo on huggingface.co or local path to a + directory containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`]. + reward_funcs (`list[str]`, *optional*): + Reward functions to use. Supported values are: + - `"accuracy_reward"` + - `"reasoning_accuracy_reward"` + - `"think_format_reward"` + - `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`) + - any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`). + """ + + reward_model_name_or_path: str | None = field( + default=None, + metadata={ + "help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or " + "local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`." + }, + ) + reward_funcs: list[str] | None = field( + default=None, + metadata={ + "help": "Reward functions to use. Supported values are: `accuracy_reward`, `reasoning_accuracy_reward`, `think_format_reward`, " + "`get_soft_overlong_punishment` (used values are `max_completion_len=1280`, `soft_punish_cache=256`), or " + "any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)." + }, + ) + + +def main(script_args, training_args, model_args, dataset_args): + # Get the reward models and functions + reward_funcs = [] + if script_args.reward_model_name_or_path: + reward_funcs.append(script_args.reward_model_name_or_path) + + if script_args.reward_funcs: + for func_name in script_args.reward_funcs: + if func_name in reward_funcs_registry: + reward_funcs.append(reward_funcs_registry[func_name]) + elif "." in func_name: + module_path, func_name = func_name.rsplit(".", 1) + sys.path.insert(0, os.getcwd()) + module = importlib.import_module(module_path) + reward_func = getattr(module, func_name) + reward_funcs.append(reward_func) + else: + raise ValueError( + f"Could not load reward function '{func_name}'. Expected one of " + f"{list(reward_funcs_registry.keys())} or a valid import path." + ) + + # Load the dataset + if dataset_args.datasets and script_args.dataset_name: + logger.warning( + "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " + "dataset and `dataset_name` will be ignored." + ) + dataset = get_dataset(dataset_args) + elif dataset_args.datasets and not script_args.dataset_name: + dataset = get_dataset(dataset_args) + elif not dataset_args.datasets and script_args.dataset_name: + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming + ) + else: + raise ValueError("Either `datasets` or `dataset_name` must be provided.") + + # Initialize the RLOO trainer + trainer = RLOOTrainer( + model=model_args.model_name_or_path, + reward_funcs=reward_funcs, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + peft_config=get_peft_config(model_args), + ) + + # Train the model + trainer.train() + + # Log training complete + trainer.accelerator.print("โœ… Training completed.") + + # Save and push to Hub + trainer.save_model(training_args.output_dir) + trainer.accelerator.print(f"๐Ÿ’พ Model saved to {training_args.output_dir}.") + + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + trainer.accelerator.print(f"๐Ÿค— Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.") + + +def make_parser(subparsers: argparse._SubParsersAction | None = None): + dataclass_types = (RLOOScriptArguments, RLOOConfig, ModelConfig, DatasetMixtureConfig) + if subparsers is not None: + parser = subparsers.add_parser("rloo", help="Run the RLOO training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config( + return_remaining_strings=True + ) + main(script_args, training_args, model_args, dataset_args) diff --git a/ICL/RL/trl_source/trl/scripts/sft.py b/ICL/RL/trl_source/trl/scripts/sft.py new file mode 100644 index 0000000000000000000000000000000000000000..a900a3e63cbf854c92da758cf00e4ad271327a1e --- /dev/null +++ b/ICL/RL/trl_source/trl/scripts/sft.py @@ -0,0 +1,175 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +""" +# Full training +``` +python trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --learning_rate 2.0e-5 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --eos_token '<|im_end|>' \ + --eval_strategy steps \ + --eval_steps 100 \ + --output_dir Qwen2-0.5B-SFT \ + --push_to_hub +``` + +# LoRA +``` +python trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --learning_rate 2.0e-4 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --eos_token '<|im_end|>' \ + --eval_strategy steps \ + --eval_steps 100 \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ + --output_dir Qwen2-0.5B-SFT \ + --push_to_hub +``` +""" + +import argparse +import os + +from accelerate import logging +from datasets import load_dataset +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES + +from trl import ( + DatasetMixtureConfig, + ModelConfig, + ScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_dataset, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +logger = logging.get_logger(__name__) + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + + +def main(script_args, training_args, model_args, dataset_args): + ################ + # Model init kwargs + ################ + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + dtype=model_args.dtype, + ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + # Passing None would not be treated the same as omitting the argument, so we include it only when valid. + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + # Create model + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + valid_image_text_architectures = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values() + + if config.architectures and any(arch in valid_image_text_architectures for arch in config.architectures): + from transformers import AutoModelForImageTextToText + + model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs) + else: + model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) + + # Load the dataset + if dataset_args.datasets and script_args.dataset_name: + logger.warning( + "Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the " + "dataset and `dataset_name` will be ignored." + ) + dataset = get_dataset(dataset_args) + elif dataset_args.datasets and not script_args.dataset_name: + dataset = get_dataset(dataset_args) + elif not dataset_args.datasets and script_args.dataset_name: + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming + ) + else: + raise ValueError("Either `datasets` or `dataset_name` must be provided.") + + # Initialize the SFT trainer + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + peft_config=get_peft_config(model_args), + ) + + # Train the model + trainer.train() + + # Log training complete + trainer.accelerator.print("โœ… Training completed.") + + # Save and push to Hub + trainer.save_model(training_args.output_dir) + trainer.accelerator.print(f"๐Ÿ’พ Model saved to {training_args.output_dir}.") + + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + trainer.accelerator.print(f"๐Ÿค— Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.") + + +def make_parser(subparsers: argparse._SubParsersAction | None = None): + dataclass_types = (ScriptArguments, SFTConfig, ModelConfig, DatasetMixtureConfig) + if subparsers is not None: + parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config( + return_remaining_strings=True + ) + main(script_args, training_args, model_args, dataset_args) diff --git a/ICL/RL/trl_source/trl/scripts/utils.py b/ICL/RL/trl_source/trl/scripts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6271b6ea662430e18d15b720612b2b081841187 --- /dev/null +++ b/ICL/RL/trl_source/trl/scripts/utils.py @@ -0,0 +1,474 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import inspect +import logging +import os +import subprocess +import sys +from collections.abc import Iterable +from dataclasses import dataclass, field + +import datasets +import yaml +from datasets import DatasetDict, concatenate_datasets +from transformers import HfArgumentParser +from transformers.hf_argparser import DataClass, DataClassType +from transformers.utils import is_rich_available + + +def _ensure_transformers_parallelism_config() -> None: + """ + Ensure that ``transformers.training_args`` always defines the symbol `ParallelismConfig` so that Python's + `typing.get_type_hints` can resolve annotations on `transformers.TrainingArguments` without raising a `NameError`. + + This is needed when running with ``accelerate<1.10.1``, where the module ``accelerate.parallelism_config`` did not + exist and therefore the type alias is not imported by Transformers. + + See upstream fix PR in transformers#40818. + """ + from typing import Any + + import transformers.training_args + + if not hasattr(transformers.training_args, "ParallelismConfig"): + transformers.training_args.ParallelismConfig = Any + + +_ensure_transformers_parallelism_config() # before creating HfArgumentParser + +logger = logging.getLogger(__name__) + + +@dataclass +class DatasetConfig: + """ + Configuration for a dataset. + + This class matches the signature of [`~datasets.load_dataset`] and the arguments are used directly in the + [`~datasets.load_dataset`] function. You can refer to the [`~datasets.load_dataset`] documentation for more + details. + + Parameters: + path (`str`): + Path or name of the dataset. + name (`str`, *optional*): + Defining the name of the dataset configuration. + data_dir (`str`, *optional*): + Defining the `data_dir` of the dataset configuration. If specified for the generic builders(csv, text etc.) + or the Hub datasets and `data_files` is `None`, the behavior is equal to passing `os.path.join(data_dir, + **)` as `data_files` to reference all the files in a directory. + data_files (`str` or `Sequence` or `Mapping`, *optional*): + Path(s) to source data file(s). + split (`str`, *optional*, defaults to `"train"`): + Which split of the data to load. + columns (`list[str]`, *optional*): + List of column names to select from the dataset. If `None`, all columns are selected. + """ + + path: str + name: str | None = None + data_dir: str | None = None + data_files: str | list[str] | dict[str, str] | None = None + split: str = "train" + columns: list[str] | None = None + + +@dataclass +class DatasetMixtureConfig: + """ + Configuration class for a mixture of datasets. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + datasets (`list[DatasetConfig]`): + List of dataset configurations to include in the mixture. + streaming (`bool`, *optional*, defaults to `False`): + Whether to stream the datasets. If `True`, the datasets will be loaded in streaming mode. + test_split_size (`float`, *optional*): + Size of the test split. Refer to the `test_size` parameter in the [`~datasets.train_test_split`] function + for more details. If `None`, the dataset will not be split into train and test sets. + + Usage: + When using the CLI, you can add the following section to your YAML config file: + + ```yaml + datasets: + - path: ... + name: ... + data_dir: ... + data_files: ... + split: ... + columns: ... + - path: ... + name: ... + data_dir: ... + data_files: ... + split: ... + columns: ... + streaming: ... + test_split_size: ... + ``` + """ + + datasets: list[DatasetConfig] = field( + default_factory=list, + metadata={"help": "List of dataset configurations to include in the mixture."}, + ) + streaming: bool = field( + default=False, + metadata={"help": "Whether to stream the datasets. If True, the datasets will be loaded in streaming mode."}, + ) + test_split_size: float | None = field( + default=None, + metadata={ + "help": "Size of the test split. Refer to the `test_size` parameter in the `datasets.train_test_split` " + "function for more details. If None, the dataset will not be split into train and test sets." + }, + ) + + def __post_init__(self): + # Convert any dataset dicts (from CLI/config parsing) into DatasetConfig objects + for idx, dataset in enumerate(self.datasets): + if isinstance(dataset, dict): + # If it's a dict, convert it to DatasetConfig + self.datasets[idx] = DatasetConfig(**dataset) + + +@dataclass +class ScriptArguments: + """ + Arguments common to all scripts. + + Args: + dataset_name (`str`,, *optional*): + Path or name of the dataset to load. If `datasets` is provided, this will be ignored. + dataset_config (`str`, *optional*): + Dataset configuration name. Corresponds to the `name` argument of the [`~datasets.load_dataset`] function. + If `datasets` is provided, this will be ignored. + dataset_train_split (`str`, *optional*, defaults to `"train"`): + Dataset split to use for training. If `datasets` is provided, this will be ignored. + dataset_test_split (`str`, *optional*, defaults to `"test"`): + Dataset split to use for evaluation. If `datasets` is provided, this will be ignored. + dataset_streaming (`bool`, *optional*, defaults to `False`): + Whether to stream the dataset. If True, the dataset will be loaded in streaming mode. If `datasets` is + provided, this will be ignored. + ignore_bias_buffers (`bool`, *optional*, defaults to `False`): + Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar + type, inplace operation. See + https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. + """ + + dataset_name: str | None = field( + default=None, + metadata={"help": "Path or name of the dataset to load. If `datasets` is provided, this will be ignored."}, + ) + dataset_config: str | None = field( + default=None, + metadata={ + "help": "Dataset configuration name. Corresponds to the `name` argument of the `datasets.load_dataset` " + "function. If `datasets` is provided, this will be ignored." + }, + ) + dataset_train_split: str = field( + default="train", + metadata={"help": "Dataset split to use for training. If `datasets` is provided, this will be ignored."}, + ) + dataset_test_split: str = field( + default="test", + metadata={"help": "Dataset split to use for evaluation. If `datasets` is provided, this will be ignored."}, + ) + dataset_streaming: bool = field( + default=False, + metadata={ + "help": "Whether to stream the dataset. If True, the dataset will be loaded in streaming mode. If " + "`datasets` is provided, this will be ignored." + }, + ) + ignore_bias_buffers: bool = field( + default=False, + metadata={ + "help": "Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid " + "scalar type, inplace operation. See " + "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992." + }, + ) + + +def init_zero_verbose(): + """ + Perform zero verbose init - use this method on top of the CLI modules to make logging and warning output cleaner. + Uses Rich if available, falls back otherwise. + """ + import logging + import warnings + + FORMAT = "%(message)s" + + if is_rich_available(): + from rich.logging import RichHandler + + handler = RichHandler() + else: + handler = logging.StreamHandler() + + logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[handler], level=logging.ERROR) + + # Custom warning handler to redirect warnings to the logging system + def warning_handler(message, category, filename, lineno, file=None, line=None): + logging.warning(f"{filename}:{lineno}: {category.__name__}: {message}") + + # Add the custom warning handler - we need to do that before importing anything to make sure the loggers work well + warnings.showwarning = warning_handler + + +class TrlParser(HfArgumentParser): + """ + A subclass of [`transformers.HfArgumentParser`] designed for parsing command-line arguments with dataclass-backed + configurations, while also supporting configuration file loading and environment variable management. + + Args: + dataclass_types (`DataClassType | Iterable[DataClassType]`, *optional*): + Dataclass types to use for argument parsing. + **kwargs: + Additional keyword arguments passed to the [`transformers.HfArgumentParser`] constructor. + + Examples: + + ```yaml + # config.yaml + env: + VAR1: value1 + arg1: 23 + ``` + + ```python + # main.py + import os + from dataclasses import dataclass + from trl import TrlParser + + + @dataclass + class MyArguments: + arg1: int + arg2: str = "alpha" + + + parser = TrlParser(dataclass_types=[MyArguments]) + training_args = parser.parse_args_and_config() + + print(training_args, os.environ.get("VAR1")) + ``` + + ```bash + $ python main.py --config config.yaml + (MyArguments(arg1=23, arg2='alpha'),) value1 + + $ python main.py --arg1 5 --arg2 beta + (MyArguments(arg1=5, arg2='beta'),) None + ``` + """ + + def __init__( + self, + dataclass_types: DataClassType | Iterable[DataClassType] | None = None, + **kwargs, + ): + # Make sure dataclass_types is an iterable + if dataclass_types is None: + dataclass_types = [] + elif not isinstance(dataclass_types, Iterable): + dataclass_types = [dataclass_types] + + # Check that none of the dataclasses have the "config" field + for dataclass_type in dataclass_types: + if "config" in dataclass_type.__dataclass_fields__: + raise ValueError( + f"Dataclass {dataclass_type.__name__} has a field named 'config'. This field is reserved for the " + f"config file path and should not be used in the dataclass." + ) + + super().__init__(dataclass_types=dataclass_types, **kwargs) + + def parse_args_and_config( + self, + args: Iterable[str] | None = None, + return_remaining_strings: bool = False, + fail_with_unknown_args: bool = True, + ) -> tuple[DataClass, ...]: + """ + Parse command-line args and config file into instances of the specified dataclass types. + + This method wraps [`transformers.HfArgumentParser.parse_args_into_dataclasses`] and also parses the config file + specified with the `--config` flag. The config file (in YAML format) provides argument values that replace the + default values in the dataclasses. Command line arguments can override values set by the config file. The + method also sets any environment variables specified in the `env` field of the config file. + """ + args = list(args) if args is not None else sys.argv[1:] + if "--config" in args: + # Get the config file path from + config_index = args.index("--config") + args.pop(config_index) # remove the --config flag + config_path = args.pop(config_index) # get the path to the config file + with open(config_path) as yaml_file: + config = yaml.safe_load(yaml_file) + + # Set the environment variables specified in the config file + if "env" in config: + env_vars = config.pop("env", {}) + if not isinstance(env_vars, dict): + raise ValueError("`env` field should be a dict in the YAML file.") + for key, value in env_vars.items(): + os.environ[key] = str(value) + + # Set the defaults from the config values + config_remaining_strings = self.set_defaults_with_config(**config) + else: + config_remaining_strings = [] + + # Parse the arguments from the command line + output = self.parse_args_into_dataclasses(args=args, return_remaining_strings=return_remaining_strings) + + # Merge remaining strings from the config file with the remaining strings from the command line + if return_remaining_strings: + args_remaining_strings = output[-1] + return output[:-1] + (config_remaining_strings + args_remaining_strings,) + elif fail_with_unknown_args and config_remaining_strings: + raise ValueError( + f"Unknown arguments from config file: {config_remaining_strings}. Please remove them, add them to the " + "dataclass, or set `fail_with_unknown_args=False`." + ) + else: + return output + + def set_defaults_with_config(self, **kwargs) -> list[str]: + """ + Overrides the parser's default values with those provided via keyword arguments, including for subparsers. + + Any argument with an updated default will also be marked as not required if it was previously required. + + Returns a list of strings that were not consumed by the parser. + """ + + def apply_defaults(parser, kw): + used_keys = set() + for action in parser._actions: + # Handle subparsers recursively + if isinstance(action, argparse._SubParsersAction): + for subparser in action.choices.values(): + used_keys.update(apply_defaults(subparser, kw)) + elif action.dest in kw: + action.default = kw[action.dest] + action.required = False + used_keys.add(action.dest) + return used_keys + + used_keys = apply_defaults(self, kwargs) + # Remaining args not consumed by the parser + remaining = [ + item for key, value in kwargs.items() if key not in used_keys for item in (f"--{key}", str(value)) + ] + return remaining + + +def get_git_commit_hash(package_name): + try: + # Import the package to locate its path + package = importlib.import_module(package_name) + # Get the path to the package using inspect + package_path = os.path.dirname(inspect.getfile(package)) + + # Navigate up to the Git repository root if the package is inside a subdirectory + git_repo_path = os.path.abspath(os.path.join(package_path, "..")) + git_dir = os.path.join(git_repo_path, ".git") + + if os.path.isdir(git_dir): + # Run the git command to get the current commit hash + commit_hash = ( + subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=git_repo_path).strip().decode("utf-8") + ) + return commit_hash + else: + return None + except Exception as e: + return f"Error: {str(e)}" + + +def get_dataset(mixture_config: DatasetMixtureConfig) -> DatasetDict: + """ + Load a mixture of datasets based on the configuration. + + Args: + mixture_config ([`DatasetMixtureConfig`]): + Script arguments containing dataset configuration. + + Returns: + [`~datasets.DatasetDict`]: + Combined dataset(s) from the mixture configuration, with optional train/test split if `test_split_size` is + set. + + Example: + ```python + from trl import DatasetMixtureConfig, get_dataset + from trl.scripts.utils import DatasetConfig + + mixture_config = DatasetMixtureConfig(datasets=[DatasetConfig(path="trl-lib/tldr")]) + dataset = get_dataset(mixture_config) + print(dataset) + ``` + + ``` + DatasetDict({ + train: Dataset({ + features: ['prompt', 'completion'], + num_rows: 116722 + }) + }) + ``` + """ + logger.info(f"Creating dataset mixture with {len(mixture_config.datasets)} datasets") + datasets_list = [] + for dataset_config in mixture_config.datasets: + logger.info(f"Loading dataset for mixture: {dataset_config.path} (config name: {dataset_config.name})") + dataset = datasets.load_dataset( + path=dataset_config.path, + name=dataset_config.name, + data_dir=dataset_config.data_dir, + data_files=dataset_config.data_files, + split=dataset_config.split, + streaming=mixture_config.streaming, + ) + if dataset_config.columns is not None: + dataset = dataset.select_columns(dataset_config.columns) + datasets_list.append(dataset) + + if datasets_list: + combined_dataset = concatenate_datasets(datasets_list) + if isinstance(combined_dataset, datasets.Dataset): # IterableDataset does not have a length + logger.info(f"Created dataset mixture with {len(combined_dataset)} examples") + + if mixture_config.test_split_size is not None: + logger.info(f"Splitting dataset into train and test sets with test size: {mixture_config.test_split_size}") + combined_dataset = combined_dataset.train_test_split(test_size=mixture_config.test_split_size) + return combined_dataset + else: + return DatasetDict({"train": combined_dataset}) + else: + raise ValueError("No datasets were loaded from the mixture configuration") diff --git a/ICL/RL/trl_source/trl/scripts/vllm_serve.py b/ICL/RL/trl_source/trl/scripts/vllm_serve.py new file mode 100644 index 0000000000000000000000000000000000000000..dc874df437056fd55a281146c83f4952aadb1ffa --- /dev/null +++ b/ICL/RL/trl_source/trl/scripts/vllm_serve.py @@ -0,0 +1,1285 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import base64 +import json +import logging +import os +import re +import time +import uuid +from collections.abc import Sequence +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from io import BytesIO +from itertools import chain +from multiprocessing import Pipe, Process +from multiprocessing.connection import Connection + +import torch +import torch.distributed.distributed_c10d as c10d +from packaging.version import Version +from transformers import AutoTokenizer, is_torch_xpu_available, is_vision_available + +from trl import TrlParser +from trl.generation.vllm_generation import sanitize_logprob +from trl.import_utils import ( + is_fastapi_available, + is_pydantic_available, + is_uvicorn_available, + is_vllm_ascend_available, + is_vllm_available, +) + + +if is_fastapi_available(): + from fastapi import FastAPI + + +if is_pydantic_available(): + from pydantic import BaseModel + + +if is_uvicorn_available(): + import uvicorn + + +if is_vision_available(): + from PIL import Image + + +if is_vllm_available(): + import vllm + from vllm import LLM, SamplingParams + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.parallel_state import get_world_group + from vllm.distributed.utils import StatelessProcessGroup + + if Version(vllm.__version__) <= Version("0.11.0"): + from vllm.utils import get_open_port + else: + from vllm.utils.network_utils import get_open_port + + if Version(vllm.__version__) <= Version("0.10.2"): + from vllm.sampling_params import GuidedDecodingParams + else: + from vllm.sampling_params import StructuredOutputsParams + + if is_vllm_ascend_available(): + from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator + + +logger = logging.getLogger(__name__) + +# We use CUDA with multiprocessing, so we must use the 'spawn' start method. Otherwise, we will get the following +# error: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use +# the 'spawn' start method +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +class WeightSyncWorkerExtension: + """ + A vLLM worker extension that enables weight synchronization between a client and multiple server workers. + + This worker uses a `StatelessProcessGroup` to establish communication and a `PyNcclCommunicator` or + `ProcessGroupXCCL` to handle efficient GPU-based communication using NCCL. The primary purpose of this class is to + receive updated model weights from a client process and distribute them to all worker processes participating in + model inference. + """ + + # The following attributes are initialized when `init_communicator` method is called. + communicator = None # Communicator for weight updates + client_rank = None # Source rank for broadcasting updated weights + + def init_communicator(self, host: str, port: int, world_size: int, client_device_uuid: str) -> None: + """ + Initializes the weight update communicator using a stateless process group. + + This method creates a `StatelessProcessGroup` that allows external training processes to communicate with vLLM + workers without interfering with the global torch distributed group. + + Args: + host (`str`): + Hostname or IP address of the master node. + port (`int`): + Port number to be used for communication. + world_size (`int`): + Total number of participating processes in the update group. + client_device_uuid (`str`): + UUID of the device of client main process. Used to assert that devices are different from vllm workers + devices. + """ + if self.communicator is not None: + raise RuntimeError("Weight update group already initialized. Call close_communicator first.") + + # TODO: will remove after torch xpu 2.9 support uuid in get_device_properties + if torch.cuda.is_available() or ( + is_torch_xpu_available() and hasattr(torch.xpu.get_device_properties(self.device), "uuid") + ): + accelerator_module = torch.xpu if is_torch_xpu_available() else torch.cuda + if client_device_uuid == str(accelerator_module.get_device_properties(self.device).uuid): + raise RuntimeError( + f"Attempting to use the same CUDA device (UUID: {client_device_uuid}) for multiple distinct " + "roles/ranks within the same communicator. This setup is unsupported and will likely lead to program " + "hangs or incorrect behavior. Ensure that trainer is using different devices than vLLM server." + ) + # Get the rank of the current worker in the global world group. + rank = get_world_group().rank + + if is_torch_xpu_available(): + store = torch.distributed.TCPStore(host_name=host, port=port, world_size=world_size, is_master=(rank == 0)) + prefixed_store = c10d.PrefixStore("client2server", store) + xccl_options = c10d.ProcessGroupXCCL.Options() + pg = c10d.ProcessGroupXCCL( + store=prefixed_store, + rank=rank, + size=world_size, + options=xccl_options, + ) + self.communicator = pg + else: + # Create a stateless process group to manage communication between training processes and vLLM workers. + # Initialize the NCCL-based communicator for weight synchronization. + pg = StatelessProcessGroup.create(host=host, port=port, rank=rank, world_size=world_size) + self.communicator = PyNcclCommunicator(pg, device=self.device) + + # The client process that sends updated weights has the highest rank (world_size - 1). + self.client_rank = world_size - 1 + + def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> None: + """ + Receives updated weights from the client process and updates the named parameter in the model. + + Args: + name (`str`): + Name of the weight tensor being updated. + dtype (`str`): + Data type of the weight tensor as a string (e.g., `"torch.float32"`). + shape (`Sequence[int]`): + Shape of the weight tensor. + """ + if self.communicator is None: + raise RuntimeError("Communicator not initialized. Call `init_communicator` first.") + + dtype = getattr(torch, dtype.split(".")[-1]) + # Allocate memory for the incoming weight tensor on the correct device. + weight = torch.empty(shape, dtype=dtype, device=self.device) + + if is_torch_xpu_available(): + # Use XCCL to broadcast the updated weights from the client (src) to all workers. + self.communicator.broadcast(weight, root=self.client_rank) + self.communicator.barrier() + else: + # Use NCCL to broadcast the updated weights from the client (src) to all workers. + self.communicator.broadcast(weight, src=self.client_rank) + self.communicator.group.barrier() + + # Load the received weights into the model. + self.model_runner.model.load_weights(weights=[(name, weight)]) + + def close_communicator(self) -> None: + """ + Closes the communicator when weight synchronization is no longer needed. + + This method deletes the NCCL communicator to release associated resources. + """ + + if self.communicator is not None: + del self.communicator + self.communicator = None # Ensure attribute is reset to None + self.client_rank = None # Ensure attribute is reset to None + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + model (`str`): + Model name or path to load the model from. + revision (`str`, *optional*): + Revision to use for the model. If not specified, the default branch will be used. + tensor_parallel_size (`int`, *optional*, defaults to `1`): + Number of tensor parallel workers to use. + data_parallel_size (`int`, *optional*, defaults to `1`): + Number of data parallel workers to use. + host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host address to run the server on. + port (`int`, *optional*, defaults to `8000`): + Port to run the server on. + gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): + Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the + device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus + improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors + during initialization. + dtype (`str`, *optional*, defaults to `"auto"`): + Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined + based on the model configuration. Find the supported values in the vLLM documentation. + max_model_len (`int`, *optional*): + If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced + `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model + context size, which might be much larger than the KV cache, leading to inefficiencies. + enable_prefix_caching (`bool`, *optional*): + Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support + this feature. + enforce_eager (`bool`, *optional*, defaults to `False`): + Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the + model in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid. + vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): + Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use + the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model + implementation. + kv_cache_dtype (`str`, *optional*, defaults to `"auto"`): + Data type to use for KV cache. If set to `"auto"`, the dtype will default to the model data type. + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether to trust remote code when loading models. Set to `True` to allow executing code from model + repositories. This is required for some custom models but introduces security risks. + log_level (`str`, *optional*, defaults to `"info"`): + Log level for uvicorn. Possible choices: `"critical"`, `"error"`, `"warning"`, `"info"`, `"debug"`, + `"trace"`. + """ + + model: str = field( + metadata={"help": "Model name or path to load the model from."}, + ) + revision: str | None = field( + default=None, + metadata={"help": "Revision to use for the model. If not specified, the default branch will be used."}, + ) + tensor_parallel_size: int = field( + default=1, + metadata={"help": "Number of tensor parallel workers to use."}, + ) + data_parallel_size: int = field( + default=1, + metadata={"help": "Number of data parallel workers to use."}, + ) + host: str = field( + default="0.0.0.0", + metadata={"help": "Host address to run the server on."}, + ) + port: int = field( + default=8000, + metadata={"help": "Port to run the server on."}, + ) + gpu_memory_utilization: float = field( + default=0.9, + metadata={ + "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV " + "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache " + "size and thus improve the model's throughput. However, if the value is too high, it may cause " + "out-of-memory (OOM) errors during initialization." + }, + ) + dtype: str = field( + default="auto", + metadata={ + "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically " + "determined based on the model configuration. Find the supported values in the vLLM documentation." + }, + ) + max_model_len: int | None = field( + default=None, + metadata={ + "help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced " + "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model " + "context size, which might be much larger than the KV cache, leading to inefficiencies." + }, + ) + enable_prefix_caching: bool | None = field( + default=None, + metadata={ + "help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the " + "hardware support this feature." + }, + ) + enforce_eager: bool | None = field( + default=False, + metadata={ + "help": "Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always " + "execute the model in eager mode. If `False` (default behavior), we will use CUDA graph and eager " + "execution in hybrid." + }, + ) + kv_cache_dtype: str = field( + default="auto", + metadata={ + "help": "Data type to use for KV cache. If set to 'auto', the dtype will default to the model data type." + }, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": "Whether to trust remote code when loading models. Set to True to allow executing code from model " + "repositories. This is required for some custom models but introduces security risks." + }, + ) + log_level: str = field( + default="info", + metadata={ + "help": "Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', " + "'trace'." + }, + ) + vllm_model_impl: str = field( + default="vllm", + metadata={ + "help": "Model implementation to use for vLLM. Must be one of `transformers` or `vllm`. `transformers`: " + "Use the `transformers` backend for model implementation. `vllm`: Use the `vllm` library for " + "model implementation." + }, + ) + + +def llm_worker( + script_args: ScriptArguments, data_parallel_rank: int, master_port: int, connection: Connection +) -> None: + # Set required environment variables for DP to work with vLLM + os.environ["VLLM_DP_RANK"] = str(data_parallel_rank) + os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank) + os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size) + os.environ["VLLM_DP_MASTER_PORT"] = str(master_port) + + llm = LLM( + model=script_args.model, + revision=script_args.revision, + tensor_parallel_size=script_args.tensor_parallel_size, + gpu_memory_utilization=script_args.gpu_memory_utilization, + enforce_eager=script_args.enforce_eager, + dtype=script_args.dtype, + # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can + # directly reuse the KV cache if it shares the same prefix with one of the existing queries. + # This is particularly useful here because we generate completions from the same prompts. + enable_prefix_caching=script_args.enable_prefix_caching, + kv_cache_dtype=script_args.kv_cache_dtype, + max_model_len=script_args.max_model_len, + worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension", + trust_remote_code=script_args.trust_remote_code, + model_impl=script_args.vllm_model_impl, + # Important so temperature scaling/logit tweaking affects the TIS log probs + logprobs_mode="processed_logprobs", + ) + + # Send ready signal to parent process + connection.send({"status": "ready"}) + + while True: + # Wait for commands from the parent process + try: + command = connection.recv() + except KeyboardInterrupt: + llm.collective_rpc(method="close_communicator") + break + + # Handle commands + if command["type"] in ["call", "fire_and_forget"]: + method_name = command["method"] + args, kwargs = command.get("args", ()), command.get("kwargs", {}) + method = getattr(llm, method_name) + + try: + result = method(*args, **kwargs) + except ValueError as e: + error_msg = str(e) + if "longer than the maximum model length" in error_msg or "context length" in error_msg: + logger.error(f"[Worker] Context length exceeded: {error_msg}") + if method_name in ["generate", "chat"]: + result = [] + else: + raise + else: + raise + except Exception as e: + logger.error(f"[Worker] Unexpected error in {method_name}: {e}") + raise + + if command["type"] == "call": + connection.send(result) + elif command["type"] == "shutdown": + break + + +def chunk_list(lst: list, n: int) -> list[list]: + """ + Split list `lst` into `n` evenly distributed sublists. + + Example: + ```python + >>> chunk_list([1, 2, 3, 4, 5, 6], 2) + [[1, 2, 3], [4, 5, 6]] + + >>> chunk_list([1, 2, 3, 4, 5, 6], 4) + [[1, 2], [3, 4], [5], [6]] + + >>> chunk_list([1, 2, 3, 4, 5, 6], 8) + [[1], [2], [3], [4], [5], [6], [], []] + ``` + """ + k, r = divmod(len(lst), n) + return [lst[i * k + min(i, r) : (i + 1) * k + min(i + 1, r)] for i in range(n)] + + +def _replace_prefix_tokens( + tokenizer, + model_prefix_token_ids: list[int], + template_prefix_token_ids: list[int], + template_token_ids: list[int], +) -> list[int]: + """ + This function is for fixing up the chat template-tokenized messages history to match the model output tokenization + up to the last assistant turn, in order to preserve the monotonic tokens property for optimized multi-turn + training. + + RL training frameworks train models on token IDs, but the OpenAI compatible server communicates in what is + basically de-tokenized text. When multiple model calls are made to the OpenAI compatible server in a single + trajectory, model generations in previous model calls may be re-tokenized to something that is different than what + was generated. This is not too big of an issue (that we know of) at inference time, but the log probs the model + produces are different enough for the differently re-tokenized generation result that it causes the training to be + off policy. Off policy isn't necessarily a bad thing in isolation, but this source of off-policyness may cause + unexpected issues if not properly accounted for. It also mis-aligns the token ID sequences across model calls, + which is strange during training. + + There are real cases where the model output string _does not match_ the chat template tokenization of the parsed + model output. A concrete example is inconsistent whitespace tokens around tool call special tokens. + + Based on NeMo RL's _replace_prefix_tokens: + https://github.com/NVIDIA-NeMo/RL/blob/748b9caff4e6d672b8a98a10b6e612d028cfc96b/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 + """ + if not model_prefix_token_ids: + return template_token_ids + + eos_token_id = tokenizer.eos_token_id + if eos_token_id is None: + logger.warning("Tokenizer has no EOS token ID, cannot apply _replace_prefix_tokens") + return template_token_ids + + model_cut_end = len(model_prefix_token_ids) + if model_prefix_token_ids and model_prefix_token_ids[-1] == eos_token_id: + model_cut_end -= 1 + + # We take everything starting with the EOS token ID. + template_cut_start = -1 + for pos in reversed(range(len(template_prefix_token_ids))): + if template_token_ids[pos] == eos_token_id: + template_cut_start = pos + break + + # This should never be the case, but + if template_cut_start < 0: + logger.warning("No EOS token found in template prefix, cannot apply _replace_prefix_tokens") + return template_token_ids + + result = model_prefix_token_ids[:model_cut_end] + template_token_ids[template_cut_start:] + + return result + + +def main(script_args: ScriptArguments): + if not is_fastapi_available(): + raise ImportError( + "FastAPI is required to run the vLLM serve script. Please install it using `pip install fastapi`." + ) + + if not is_pydantic_available(): + raise ImportError( + "Pydantic is required to run the vLLM serve script. Please install it using `pip install pydantic`." + ) + + if not is_uvicorn_available(): + raise ImportError( + "Uvicorn is required to run the vLLM serve script. Please install it using `pip install uvicorn`." + ) + + if not is_vllm_available(): + raise ImportError("vLLM is required to run the vLLM serve script. Please install it using `pip install vllm`.") + + # Spawn dp workers, and setup pipes for communication + master_port = get_open_port() + connections = [] + processes = [] + for data_parallel_rank in range(script_args.data_parallel_size): + parent_connection, child_connection = Pipe() + process = Process(target=llm_worker, args=(script_args, data_parallel_rank, master_port, child_connection)) + process.start() + connections.append(parent_connection) + processes.append(process) + + @asynccontextmanager + async def lifespan(app: FastAPI): + logger.info(f"Loading tokenizer for {script_args.model}...") + app.state.tokenizer = AutoTokenizer.from_pretrained( + script_args.model, trust_remote_code=script_args.trust_remote_code + ) + + # Wait for all workers to send "ready" + ready_connections = set() + while len(ready_connections) < script_args.data_parallel_size: + for connection in connections: + msg = connection.recv() + if isinstance(msg, dict) and msg.get("status") == "ready": + ready_connections.add(connection) + + yield + + # Wait for processes to terminate + for process in processes: + process.join(timeout=10) # Wait for 10 seconds for the process to terminate + if process.is_alive(): + logger.warning(f"Process {process} is still alive after 10 seconds, attempting to terminate...") + process.terminate() + process.join() # ensure process termination after calling terminate() + + app = FastAPI(lifespan=lifespan) + + # Define the endpoints for the model server + @app.get("/health/") + async def health(): + """ + Health check endpoint to verify that the server is running. + """ + return {"status": "ok"} + + @app.get("/get_world_size/") + async def get_world_size(): + """ + Retrieves the world size of the LLM engine, which is `tensor_parallel_size * data_parallel_size`. + + Returns: + `dict`: + A dictionary containing the world size. + + Example response: + ```json + {"world_size": 8} + ``` + """ + return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size} + + class GenerateRequest(BaseModel): + prompts: list[str] + images: list[str] | None = None + n: int = 1 + repetition_penalty: float = 1.0 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + min_p: float = 0.0 + max_tokens: int = 16 + truncate_prompt_tokens: int | None = None + structured_outputs_regex: str | None = None + generation_kwargs: dict = field(default_factory=dict) + + class GenerateResponse(BaseModel): + prompt_ids: list[list[int]] + completion_ids: list[list[int]] + logprobs: list[list[float]] + + @app.post("/generate/", response_model=GenerateResponse) + async def generate(request: GenerateRequest): + """ + Generates completions for the provided prompts. + + Args: + request (`GenerateRequest`): + - `prompts` (list of `str`): A list of prompts (text strings) for the model to generate completions. + - `images` (list of `str`, *optional*, default to `None`): A list of base64 encoded images to process + along with prompts. + - `n` (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt. + - `repetition_penalty` (`float`, *optional*, defaults to `1.0`): Repetition penalty to apply during + generation. + - `temperature` (`float`, *optional*, defaults to `1.0`): Temperature for sampling. Higher values lead + to more random outputs. + - `top_p` (`float`, *optional*, defaults to `1.0`): Top-p (nucleus) sampling parameter. It controls the + diversity of the generated text. + - `top_k` (`int`, *optional*, defaults to `-1`): Top-k sampling parameter. If set to `-1`, it disables + top-k sampling. + - `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling. + - `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each + completion. + - `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported + by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left + truncation). If set to `None`, truncation is disabled. + - `structured_outputs_regex` (`str`, *optional*): A regex pattern for structured outputs. If provided, + the model will only generate tokens that match this regex pattern. + - `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM + `SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains + keys that conflict with the other parameters, they will override them. + + Returns: + `GenerateResponse`: + - `prompt_ids` (list of list of `int`): A list of lists of token IDs for each input prompt. + - `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion. + - `logprobs` (list of list of `float`): A list of lists of log probabilities for each token in the + generated completions. + + Example request: + ```json + {"prompts": ["Hello world", "What is AI?"]} + ``` + + Example response: + ```json + { + "prompt_ids": [[101, 102], [201, 202]], + "completion_ids": [[103, 104, 105], [203, 204, 205]], + "logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]] + } + ``` + """ + request.images = request.images or [None] * len(request.prompts) + + prompts = [] + for prompt, image in zip(request.prompts, request.images, strict=True): + row = {"prompt": prompt} + if image is not None: + row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))} + prompts.append(row) + + # Structured outputs, if enabled + if Version(vllm.__version__) <= Version("0.10.2"): + structured_outputs_key = "guided_decoding" + if request.structured_outputs_regex is not None: + structured_outputs = GuidedDecodingParams(regex=request.structured_outputs_regex) + else: + structured_outputs = None + else: + structured_outputs_key = "structured_outputs" + if request.structured_outputs_regex is not None: + structured_outputs = StructuredOutputsParams(regex=request.structured_outputs_regex) + else: + structured_outputs = None + + generation_kwargs = { + "n": request.n, + "repetition_penalty": request.repetition_penalty, + "temperature": request.temperature, + "top_p": request.top_p, + "top_k": request.top_k, + "min_p": request.min_p, + "max_tokens": request.max_tokens, + "truncate_prompt_tokens": request.truncate_prompt_tokens, + "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only + } + generation_kwargs[structured_outputs_key] = structured_outputs + generation_kwargs.update(request.generation_kwargs) + sampling_params = SamplingParams(**generation_kwargs) + + # Evenly distribute prompts across DP ranks + chunked_prompts = chunk_list(prompts, script_args.data_parallel_size) + + # Send the prompts to each worker + for connection, prompts in zip(connections, chunked_prompts, strict=True): + # When the number of prompts is less than data_parallel_size, some workers will receive empty prompts. + # However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply + # with vLLM's requirement, and we later ignore the result. + if not prompts: + prompts = [""] + kwargs = {"prompts": prompts, "sampling_params": sampling_params} + connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) + + # Receive results + all_outputs = [connection.recv() for connection in connections] + + # Handle empty prompts (see above) + all_outputs = [output for output, prompts in zip(all_outputs, chunked_prompts, strict=True) if prompts] + + # Flatten and combine all results + all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list + prompt_ids = [output.prompt_token_ids for output in all_outputs] + completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs] + logprobs: list[list[float]] = [ + [sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs] + for outputs in all_outputs + for output in outputs.outputs + ] + return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs} + + class ChatRequest(BaseModel): + messages: list[list[dict]] + n: int = 1 + repetition_penalty: float = 1.0 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + min_p: float = 0.0 + max_tokens: int = 16 + truncate_prompt_tokens: int | None = None + structured_outputs_regex: str | None = None + generation_kwargs: dict = field(default_factory=dict) + chat_template_kwargs: dict = field(default_factory=dict) + tools: list[dict] | None = None + + class ChatResponse(BaseModel): + prompt_ids: list[list[int]] + completion_ids: list[list[int]] + logprobs: list[list[float]] + + @app.post("/chat/", response_model=ChatResponse) + async def chat(request: ChatRequest): + """ + Generates completions for the provided chat messages. + + Args: + request (`ChatRequest`): + - `messages` (list of `dict`): A list of messages (dicts with "role" and "content" keys) for the model + to generate completions. + - `n` (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt. + - `repetition_penalty` (`float`, *optional*, defaults to `1.0`): Repetition penalty to apply during + generation. + - `temperature` (`float`, *optional*, defaults to `1.0`): Temperature for sampling. Higher values lead + to more random outputs. + - `top_p` (`float`, *optional*, defaults to `1.0`): Top-p (nucleus) sampling parameter. It controls the + diversity of the generated text. + - `top_k` (`int`, *optional*, defaults to `-1`): Top-k sampling parameter. If set to `-1`, it disables + top-k sampling. + - `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling. + - `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each + completion. + - `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported + by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left + truncation). If set to `None`, truncation is disabled. + - `structured_outputs_regex` (`str`, *optional*): A regex pattern for structured outputs. If provided, + the model will only generate tokens that match this regex pattern. + - `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM + `SamplingParams`. This can include parameters like `seed`, `frequency_penalty`, etc. If it contains + keys that conflict with the other parameters, they will override them. + - `chat_template_kwargs` (`dict`, *optional*): Additional keyword arguments to pass to the chat + template. + + Returns: + `ChatResponse`: + - `prompt_ids` (list of list of `int`): A list of lists of token IDs for each input prompt. + - `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion. + - `logprobs` (list of list of `float`): A list of lists of log probabilities for each token in the + generated completions. + + Example request: + ```bash + curl -X POST 'http://0.0.0.0:8000/chat/' \ + -H 'Content-Type: application/json' \ + -d '{"messages": [[{ "role": "user", "content": "Hello!" }]]}' + ``` + + Example response: + ```json + { + "prompt_ids": [[151644, 872, 198, 9707, 0, 151645, 198, 151644, 77091, 198]], + "completion_ids":[[151667, 198, 32313, 11, 279, 1196, 1101, 1053, 330, 9707, 8958, 773, 358, 1184, 311, 5889]], + "logprobs": [[-0.00029404606902971864, -3.576278118089249e-07, -0.09024181962013245, -6.389413465512916e-05, -0.038671817630529404, -0.00013314791431184858, -0.5868351459503174, -0.09682723134756088, -0.06609706580638885, -0.00023803261865396053, -0.02242819033563137, -0.8185162544250488, -0.04954879730939865, -0.3169460594654083, -4.887569048150908e-06, -0.006023705471307039]] + } + ``` + """ + # Convert PIL images to base64 strings + for message_list in request.messages: + for message in message_list: + if isinstance(message["content"], list): + for part in message["content"]: + if part["type"] == "image_pil": + part["image_pil"] = Image.open(BytesIO(base64.b64decode(part["image_pil"]))) + + # Structured outputs, if enabled + if Version(vllm.__version__) <= Version("0.10.2"): + structured_outputs_key = "guided_decoding" + if request.structured_outputs_regex is not None: + structured_outputs = GuidedDecodingParams(regex=request.structured_outputs_regex) + else: + structured_outputs = None + else: + structured_outputs_key = "structured_outputs" + if request.structured_outputs_regex is not None: + structured_outputs = StructuredOutputsParams(regex=request.structured_outputs_regex) + else: + structured_outputs = None + + generation_kwargs = { + "n": request.n, + "repetition_penalty": request.repetition_penalty, + "temperature": request.temperature, + "top_p": request.top_p, + "top_k": request.top_k, + "min_p": request.min_p, + "max_tokens": request.max_tokens, + "truncate_prompt_tokens": request.truncate_prompt_tokens, + "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only + } + generation_kwargs[structured_outputs_key] = structured_outputs + generation_kwargs.update(request.generation_kwargs) + sampling_params = SamplingParams(**generation_kwargs) + + # Evenly distribute prompts across DP ranks + chunked_messages = chunk_list(request.messages, script_args.data_parallel_size) + + # Send the messages to each worker + for connection, messages in zip(connections, chunked_messages, strict=True): + # When the number of messages is less than data_parallel_size, some workers will receive empty messages. + # However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply + # with vLLM's requirement, and we later ignore the result. + if not messages: + messages = [[{"role": "user", "content": ""}]] + kwargs = { + "messages": messages, + "sampling_params": sampling_params, + "chat_template_kwargs": request.chat_template_kwargs, + "tools": request.tools if request.tools else None, + } + + connection.send({"type": "call", "method": "chat", "kwargs": kwargs}) + + # Receive results + all_outputs = [connection.recv() for connection in connections] + + # Handle empty prompts (see above) + all_outputs = [output for output, prompts in zip(all_outputs, chunked_messages, strict=True) if prompts] + + # Flatten and combine all results + all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list + prompt_ids = [output.prompt_token_ids for output in all_outputs] + completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs] + logprobs: list[list[float]] = [ + [sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs] + for outputs in all_outputs + for output in outputs.outputs + ] + return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs} + + class InitCommunicatorRequest(BaseModel): + host: str + port: int + world_size: int + client_device_uuid: str + + @app.post("/init_communicator/") + async def init_communicator(request: InitCommunicatorRequest): + """ + Initializes the communicator for synchronizing model weights between a client and multiple server workers. + + Args: + request (`InitCommunicatorRequest`): + - `host` (`str`): Hostname or IP address of the master node. + - `port` (`int`): Port number to be used for communication. + - `world_size` (`int`): Total number of participating processes in the group. + - `client_device_uuid` (`str`): UUID of the device of client main process. Used to assert that devices + are different from vLLM workers devices. + """ + world_size = script_args.tensor_parallel_size * script_args.data_parallel_size + 1 + + # The function init_communicator is called this way: init_communicator(host, port, world_size) + # So with collective_rpc we need to call it this way: + # llm.collective_rpc(method="init_communicator", args=(host, port, world_size)) + kwargs = { + "method": "init_communicator", + "args": (request.host, request.port, world_size, request.client_device_uuid), + } + for connection in connections: + connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) + + return {"message": "Request received, initializing communicator"} + + class UpdateWeightsRequest(BaseModel): + name: str + dtype: str + shape: list[int] + + @app.post("/update_named_param/") + async def update_named_param(request: UpdateWeightsRequest): + """ + Updates the model weights with the provided tensor. + + Once this endpoint is called, the client process should broadcast the updated weights to all server workers. + + Args: + request (`UpdateWeightsRequest`): + - `name` (`str`): Name of the weight tensor being updated. + - `dtype` (`str`): Data type of the weight tensor (e.g., `"torch.float32"`). + - `shape` (list of `int`): Shape of the weight + + """ + # The function update_named_param is called this way: update_named_param("name", "torch.float32", (10, 10)) + # So with collective_rpc we need to call it this way: + # llm.collective_rpc("update_named_param", args=("name", "torch.float32", (10, 10))) + kwargs = {"method": "update_named_param", "args": (request.name, request.dtype, tuple(request.shape))} + for connection in connections: + connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) + + return {"message": "Request received, updating named parameter"} + + @app.post("/reset_prefix_cache/") + async def reset_prefix_cache(): + """ + Resets the prefix cache for the model. + """ + for connection in connections: + connection.send({"type": "call", "method": "reset_prefix_cache"}) + # Wait for and collect all results + all_outputs = [connection.recv() for connection in connections] + success = all(output for output in all_outputs) + return {"message": "Request received, resetting prefix cache status: " + str(success)} + + @app.post("/close_communicator/") + async def close_communicator(): + """ + Closes the weight update group and cleans up associated resources. + """ + kwargs = {"method": "close_communicator"} + for connection in connections: + connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) + return {"message": "Request received, closing communicator"} + + class ChatCompletionRequest(BaseModel): + messages: list[dict] + model: str | None = None + temperature: float = 1.0 + top_p: float = 1.0 + max_completion_tokens: int | None = None + max_tokens: int | None = None + n: int = 1 + stop: str | list[str] | None = None + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + logprobs: bool = False + top_logprobs: int | None = None + tools: list[dict] | None = None + tool_choice: str | dict = "auto" + parallel_tool_calls: bool = True + + @app.post("/v1/chat/completions") + async def chat_completions(request: ChatCompletionRequest): + completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created_at = int(time.time()) + + messages = [] + for msg in request.messages: + role = msg.get("role", "") + if role not in ["system", "user", "assistant", "tool"]: + logger.warning(f"Unknown message role: {role}") + messages.append(msg) + + max_tokens = request.max_completion_tokens or request.max_tokens or 512 + + sampling_kwargs = { + "n": request.n, + "temperature": request.temperature, + "top_p": request.top_p, + "max_tokens": max_tokens, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "stop": request.stop, + } + + if request.logprobs or request.top_logprobs: + sampling_kwargs["logprobs"] = request.top_logprobs if request.top_logprobs else 1 + + sampling_params = SamplingParams(**sampling_kwargs) + + chat_template_kwargs = {} + if request.tool_choice and request.tool_choice != "auto": + chat_template_kwargs["tool_choice"] = request.tool_choice + + has_prefix_token_ids = any(msg.get("role") == "assistant" and "prompt_token_ids" in msg for msg in messages) + + if has_prefix_token_ids: + # do on policy token id correction and call generate instead of chat + # see https://docs.nvidia.com/nemo/gym/latest/contribute/rl-framework-integration/openai-compatible-http-server-on-policy-correction.html + # and https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 + tokenizer = app.state.tokenizer + + # preprocess full conversation + connections[0].send( + { + "type": "call", + "method": "preprocess_chat", + "kwargs": { + "messages": [messages], + "chat_template_kwargs": chat_template_kwargs, + "tools": request.tools, + "add_generation_prompt": True, + }, + } + ) + template_prompts = connections[0].recv() + template_prompt = template_prompts[0] + + # extract model prefix tokens from last assistant message + model_prefix_tokens = None + last_assistant_idx = None + for i in reversed(range(len(messages))): + if messages[i].get("role") == "assistant": + last_assistant_idx = i + if "prompt_token_ids" in messages[i]: + model_prefix_tokens = messages[i]["prompt_token_ids"] + messages[i].get( + "generation_token_ids", [] + ) + break + + if model_prefix_tokens and last_assistant_idx is not None: + messages_to_last_assistant = messages[: last_assistant_idx + 1] + connections[0].send( + { + "type": "call", + "method": "preprocess_chat", + "kwargs": { + "messages": [messages_to_last_assistant], + "chat_template_kwargs": chat_template_kwargs, + "tools": request.tools, + "add_generation_prompt": False, + }, + } + ) + template_prefix_prompts = connections[0].recv() + template_prefix_token_ids = template_prefix_prompts[0]["prompt_token_ids"] + + corrected_token_ids = _replace_prefix_tokens( + tokenizer, model_prefix_tokens, template_prefix_token_ids, template_prompt["prompt_token_ids"] + ) + + else: + corrected_token_ids = template_prompt["prompt_token_ids"] + + corrected_prompt = {"prompt_token_ids": corrected_token_ids} + chunked_prompts = chunk_list([corrected_prompt], script_args.data_parallel_size) + + for connection, prompts in zip(connections, chunked_prompts, strict=True): + if not prompts: + prompts = [{"prompt_token_ids": [tokenizer.eos_token_id]}] + connection.send( + { + "type": "call", + "method": "generate", + "kwargs": {"prompts": prompts, "sampling_params": sampling_params}, + } + ) + else: + # no prefix token IDs, use chat() + chunked_messages = chunk_list([messages], script_args.data_parallel_size) + + for connection, message_chunk in zip(connections, chunked_messages, strict=True): + if not message_chunk: + message_chunk = [[{"role": "user", "content": ""}]] + kwargs = { + "messages": message_chunk, + "sampling_params": sampling_params, + "tools": request.tools, + "chat_template_kwargs": chat_template_kwargs, + } + connection.send({"type": "call", "method": "chat", "kwargs": kwargs}) + + all_outputs = [connection.recv() for connection in connections] + if has_prefix_token_ids: + all_outputs = [ + output for output, prompt_chunk in zip(all_outputs, chunked_prompts, strict=True) if prompt_chunk + ] + else: + all_outputs = [ + output for output, msg_chunk in zip(all_outputs, chunked_messages, strict=True) if msg_chunk + ] + all_outputs = list(chain.from_iterable(all_outputs)) + + if not all_outputs: + return { + "id": completion_id, + "object": "chat.completion", + "created": created_at, + "model": request.model or script_args.model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": ""}, + "finish_reason": "length", + "logprobs": None, + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + choices = [] + total_input_tokens = 0 + total_output_tokens = 0 + + idx = 0 + for output in all_outputs: + total_input_tokens += len(output.prompt_token_ids) + + for gen_output in output.outputs: + total_output_tokens += len(gen_output.token_ids) + text = gen_output.text if hasattr(gen_output, "text") else "" + + tool_calls = None + finish_reason = gen_output.finish_reason if hasattr(gen_output, "finish_reason") else "stop" + + # Manual XML-json tool call parsing + if request.tools and text: + pattern = r"(.*?)" + matches = re.findall(pattern, text, re.DOTALL) + if matches: + tool_calls = [] + for match in matches: + try: + data = json.loads(match.strip()) + tool_calls.append( + { + "id": f"call_{uuid.uuid4().hex[:24]}", + "type": "function", + "function": { + "name": data.get("name", ""), + "arguments": json.dumps(data.get("arguments", {})), + }, + } + ) + except json.JSONDecodeError: + continue + if tool_calls: + finish_reason = "tool_calls" + text = re.sub(pattern, "", text, flags=re.DOTALL).strip() + + if not request.parallel_tool_calls and tool_calls and len(tool_calls) > 1: + tool_calls = [tool_calls[0]] + + logprobs_data = None + if request.logprobs and hasattr(gen_output, "logprobs") and gen_output.logprobs: + logprobs_data = { + "content": [ + { + "token": str(token_id), + "logprob": float(list(logprob_dict.values())[0].logprob) if logprob_dict else 0.0, + "bytes": None, + "top_logprobs": [], + } + for token_id, logprob_dict in zip(gen_output.token_ids, gen_output.logprobs, strict=False) + ] + } + + choices.append( + { + "index": idx, + "message": { + "role": "assistant", + "content": text if not tool_calls else None, + "tool_calls": tool_calls, + }, + "logprobs": logprobs_data, + "finish_reason": finish_reason, + } + ) + idx += 1 + + return { + "id": completion_id, + "object": "chat.completion", + "created": created_at, + "model": request.model or script_args.model, + "choices": choices, + "usage": { + "prompt_tokens": total_input_tokens, + "completion_tokens": total_output_tokens, + "total_tokens": total_input_tokens + total_output_tokens, + }, + } + + class TokenizeRequest(BaseModel): + model: str | None = None + messages: list[dict] + tools: list[dict] | None = None + + @app.post("/tokenize") + async def tokenize(request: TokenizeRequest): + messages = request.messages + + has_prefix_token_ids = any(msg.get("role") == "assistant" and "prompt_token_ids" in msg for msg in messages) + + kwargs = { + "messages": [messages], + "tools": request.tools, + "add_generation_prompt": True, + "chat_template_kwargs": {}, + } + + connections[0].send({"type": "call", "method": "preprocess_chat", "kwargs": kwargs}) + preprocessed_prompts = connections[0].recv() + + if preprocessed_prompts and len(preprocessed_prompts) > 1: + logger.warning( + "More than one tokenized message returned from preprocess_chat inside tokenize, double check results!" + ) + + if not preprocessed_prompts or len(preprocessed_prompts) == 0: + return {"tokens": [], "model": request.model or script_args.model} + + template_prompt = preprocessed_prompts[0] + result_tokens = template_prompt["prompt_token_ids"] + + if has_prefix_token_ids: + tokenizer = app.state.tokenizer + + # Extract model prefix tokens from last assistant message + model_prefix_tokens = None + last_assistant_idx = None + for i in reversed(range(len(messages))): + if messages[i].get("role") == "assistant": + last_assistant_idx = i + if "prompt_token_ids" in messages[i]: + model_prefix_tokens = messages[i]["prompt_token_ids"] + messages[i].get( + "generation_token_ids", [] + ) + break + + if model_prefix_tokens and last_assistant_idx is not None: + # Preprocess up to last assistant + messages_to_last_assistant = messages[: last_assistant_idx + 1] + connections[0].send( + { + "type": "call", + "method": "preprocess_chat", + "kwargs": { + "messages": [messages_to_last_assistant], + "tools": request.tools, + "add_generation_prompt": False, + "chat_template_kwargs": {}, + }, + } + ) + template_prefix_prompts = connections[0].recv() + template_prefix_token_ids = template_prefix_prompts[0]["prompt_token_ids"] + + result_tokens = _replace_prefix_tokens( + tokenizer, model_prefix_tokens, template_prefix_token_ids, template_prompt["prompt_token_ids"] + ) + + return {"tokens": result_tokens, "model": request.model or script_args.model} + + # Start the server + uvicorn.run( + app, + host=script_args.host, + port=script_args.port, + log_level=script_args.log_level, + limit_concurrency=256, + backlog=4096, + timeout_keep_alive=600, + ) + + +def make_parser(subparsers: argparse._SubParsersAction | None = None): + if subparsers is not None: + parser = subparsers.add_parser("vllm-serve", help="Run the vLLM serve script", dataclass_types=ScriptArguments) + else: + parser = TrlParser(ScriptArguments) + return parser + + +if __name__ == "__main__": + parser = make_parser() + (script_args,) = parser.parse_args_and_config() + main(script_args) diff --git a/ICL/RL/trl_source/trl/templates/completions_dataset_card.md b/ICL/RL/trl_source/trl/templates/completions_dataset_card.md new file mode 100644 index 0000000000000000000000000000000000000000..352fdd78c7e959301ea0a04412924ee5685781e8 --- /dev/null +++ b/ICL/RL/trl_source/trl/templates/completions_dataset_card.md @@ -0,0 +1,40 @@ +--- +{{ card_data }} +--- + +# TRL Completion logs + +This dataset contains the completions generated during training using `trl`. + +{% if hub_model_id %} +Find the trained model at https://huggingface.co/{{ hub_model_id }}. + +{% endif %} +The completions are stored in parquet files, and each file contains the completions for a single step of training (depending on the `logging_steps` argument). + +Each file contains the following columns: + +- `step`: the step of training +- `prompt`: the prompt used to generate the completion +- `completion`: the completion generated by the model +- ``: the reward(s) assigned to the completion by the reward function(s) used during training +- `advantage`: the computed advantage for the completion + +Having this data stored as a simple parquet file makes it easy to load and analyze using the Datasets Viewer, Polars, Pandas, etc. + +You can load the dataset using the `datasets` library: + +```python +import datasets + +dataset = datasets.load_dataset("{{ repo_id }}") +``` + +You can also load the dataset using Polars: + +```python +import polars as pl + +# Login using e.g. `huggingface-cli login` to access this dataset if it's private +df = pl.read_parquet(f"hf://datasets/{{ repo_id }}/*.parquet") +``` diff --git a/ICL/RL/trl_source/trl/templates/lm_model_card.md b/ICL/RL/trl_source/trl/templates/lm_model_card.md new file mode 100644 index 0000000000000000000000000000000000000000..e0a21fa8eac25bde0f9606e636617331dbee4af6 --- /dev/null +++ b/ICL/RL/trl_source/trl/templates/lm_model_card.md @@ -0,0 +1,54 @@ +--- +{{ card_data }} +--- + +# Model Card for {{ model_name }} + +This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}. +It has been trained using [TRL](https://github.com/huggingface/trl). + +## Quick start + +```python +from transformers import pipeline + +question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?" +generator = pipeline("text-generation", model="{{ hub_model_id }}", device="cuda") +output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0] +print(output["generated_text"]) +``` + +## Training procedure + +{% if wandb_url %}[Visualize in Weights & Biases]({{ wandb_url }}){% endif %} +{% if comet_url %}[Visualize in Comet]({{ comet_url }}){% endif %} + +This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}. + +### Framework versions + +- TRL: {{ trl_version }} +- Transformers: {{ transformers_version }} +- Pytorch: {{ pytorch_version }} +- Datasets: {{ datasets_version }} +- Tokenizers: {{ tokenizers_version }} + +## Citations + +{% if trainer_citation %}Cite {{ trainer_name }} as: + +```bibtex +{{ trainer_citation }} +```{% endif %} + +Cite TRL as: + +```bibtex +{% raw %}@software{vonwerra2020trl, + title = {{TRL: Transformers Reinforcement Learning}}, + author = {von Werra, Leandro and Belkada, Younes and Tunstall, Lewis and Beeching, Edward and Thrush, Tristan and Lambert, Nathan and Huang, Shengyi and Rasul, Kashif and Gallouรฉdec, Quentin}, + license = {Apache-2.0}, + url = {https://github.com/huggingface/trl}, + year = {2020} +}{% endraw %} +``` diff --git a/ICL/RL/trl_source/trl/trainer/__init__.py b/ICL/RL/trl_source/trl/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67ccf600ee6b53c69a73859d34f0569c425e0719 --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/__init__.py @@ -0,0 +1,93 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ..import_utils import _LazyModule + + +_import_structure = { + "callbacks": [ + "BEMACallback", + "LogCompletionsCallback", + "RichProgressCallback", + "SyncRefModelCallback", + "WeaveCallback", + ], + "dpo_config": [ + "DPOConfig", + "FDivergenceConstants", # deprecated import + "FDivergenceType", # deprecated import + ], + "dpo_trainer": ["DPOTrainer"], + "grpo_config": ["GRPOConfig"], + "grpo_trainer": ["GRPOTrainer"], + "kto_config": ["KTOConfig"], + "kto_trainer": ["KTOTrainer"], + "model_config": ["ModelConfig"], + "reward_config": ["RewardConfig"], + "reward_trainer": ["RewardTrainer"], + "rloo_config": ["RLOOConfig"], + "rloo_trainer": ["RLOOTrainer"], + "sft_config": ["SFTConfig"], + "sft_trainer": ["SFTTrainer"], + "utils": [ + "RunningMoments", + "disable_dropout_in_model", + "empty_cache", + "ensure_master_addr_port", + "get_kbit_device_map", + "get_peft_config", + "get_quantization_config", + ], +} + +if TYPE_CHECKING: + from .callbacks import ( + BEMACallback, + LogCompletionsCallback, + RichProgressCallback, + SyncRefModelCallback, + WeaveCallback, + ) + from .dpo_config import ( + DPOConfig, + FDivergenceConstants, # deprecated import + FDivergenceType, # deprecated import + ) + from .dpo_trainer import DPOTrainer + from .grpo_config import GRPOConfig + from .grpo_trainer import GRPOTrainer + from .kto_config import KTOConfig + from .kto_trainer import KTOTrainer + from .model_config import ModelConfig + from .reward_config import RewardConfig + from .reward_trainer import RewardTrainer + from .rloo_config import RLOOConfig + from .rloo_trainer import RLOOTrainer + from .sft_config import SFTConfig + from .sft_trainer import SFTTrainer + from .utils import ( + RunningMoments, + disable_dropout_in_model, + empty_cache, + ensure_master_addr_port, + get_kbit_device_map, + get_peft_config, + get_quantization_config, + ) +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/ICL/RL/trl_source/trl/trainer/__pycache__/__init__.cpython-313.pyc b/ICL/RL/trl_source/trl/trainer/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c113b74df8d38703e961db03eb153a9c598385a7 Binary files /dev/null and b/ICL/RL/trl_source/trl/trainer/__pycache__/__init__.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/trainer/__pycache__/base_trainer.cpython-313.pyc b/ICL/RL/trl_source/trl/trainer/__pycache__/base_trainer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e84fc01c2234dde47af6bfc73ab31e1b7d97f92 Binary files /dev/null and b/ICL/RL/trl_source/trl/trainer/__pycache__/base_trainer.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/trainer/__pycache__/callbacks.cpython-313.pyc b/ICL/RL/trl_source/trl/trainer/__pycache__/callbacks.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b336a8da4b645537f6ff6fbbb496e441db92cbb4 Binary files /dev/null and b/ICL/RL/trl_source/trl/trainer/__pycache__/callbacks.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/trainer/__pycache__/grpo_config.cpython-313.pyc b/ICL/RL/trl_source/trl/trainer/__pycache__/grpo_config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0e466fdf341151f1985f4737d64eb22b25c0d51 Binary files /dev/null and b/ICL/RL/trl_source/trl/trainer/__pycache__/grpo_config.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/trainer/__pycache__/model_config.cpython-313.pyc b/ICL/RL/trl_source/trl/trainer/__pycache__/model_config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..864e2c8995aaab2cd94dcaab30a9270e98bf26f4 Binary files /dev/null and b/ICL/RL/trl_source/trl/trainer/__pycache__/model_config.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/trainer/__pycache__/utils.cpython-313.pyc b/ICL/RL/trl_source/trl/trainer/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bb06d6e66370a7df09777da5622ef53b6b99e19 Binary files /dev/null and b/ICL/RL/trl_source/trl/trainer/__pycache__/utils.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/trainer/base_trainer.py b/ICL/RL/trl_source/trl/trainer/base_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..35180a30f18e7bb0dffdc2388e7aa3e2302ead79 --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/base_trainer.py @@ -0,0 +1,86 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from transformers import Trainer, is_wandb_available + +from .utils import generate_model_card, get_comet_experiment_url, get_config_model_id + + +if is_wandb_available(): + import wandb + + +class BaseTrainer(Trainer): + _tag_names = [] + _name = "Base" + _paper = {} + _template_file = None + + def create_model_card( + self, + model_name: str | None = None, + dataset_name: str | None = None, + tags: str | list[str] | None = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*): + Name of the model. + dataset_name (`str`, *optional*): + Name of the dataset used for training. + tags (`str`, `list[str]`, *optional*): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + model_name_or_path = get_config_model_id(self.model.config) + if model_name_or_path and not os.path.isdir(model_name_or_path): + base_model = model_name_or_path + else: + base_model = None + + # Normalize tags + if tags is None: + tags = set() + elif isinstance(tags, str): + tags = {tags} + else: + tags = set(tags) + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + if "JOB_ID" in os.environ: + tags.add("hf_jobs") + tags.update(self._tag_names) + tags = list(tags) + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name=self._name, + trainer_citation=self._paper.get("citation"), + template_file=self._template_file, + paper_title=self._paper.get("title"), + paper_id=self._paper.get("id"), + ) + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/ICL/RL/trl_source/trl/trainer/callbacks.py b/ICL/RL/trl_source/trl/trainer/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..a530e38b0c4f39f63ca16c7c8a52fa6078316b76 --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/callbacks.py @@ -0,0 +1,758 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import pandas as pd +import torch +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import gather_object, is_wandb_available +from transformers import ( + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) +from transformers.trainer_utils import has_length +from transformers.utils import is_rich_available + +from ..data_utils import maybe_apply_chat_template +from ..import_utils import is_weave_available +from ..models.utils import unwrap_model_for_generation +from .utils import log_table_to_comet_experiment + + +if is_rich_available(): + from rich.columns import Columns + from rich.console import Console, Group + from rich.live import Live + from rich.panel import Panel + from rich.progress import Progress + from rich.table import Table + +if is_wandb_available(): + import wandb + +if is_weave_available(): + import weave + from weave import EvaluationLogger + from weave.trace.context import weave_client_context + + +# Logger for module-level logging +logger = logging.getLogger(__name__) + + +def _generate_completions( + prompts: list[str], + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + accelerator: Accelerator, + generation_config: GenerationConfig | None, + batch_size: int = 1, +) -> list[str]: + """ + Generates completions for a list of pre-formatted prompts from the given model. + + Args: + prompts (list[str]): A list of input prompts for which completions are to be generated. + model (PreTrainedModel): The pre-trained model to be used for generation. + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for encoding and decoding. + accelerator (Accelerator): The accelerator to be used for model execution. + generation_config (GenerationConfig): Configuration for text generation. + batch_size (int, optional): The number of prompts to process in each batch. Default is 1. + + Returns: + list[str]: A list of generated text completions corresponding to the input prompts. + """ + completions = [] + # TODO: Override model.generation_config with generation_kwargs + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + for idx in range(0, len(prompts), batch_size): + batch = prompts[idx : idx + batch_size] + tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device) + generations = unwrapped_model.generate( + **tokenized_batch, + generation_config=generation_config, + ) + for prompt, generation in zip(tokenized_batch.input_ids, generations, strict=True): + # Remove prompt from generation + generation = generation[len(prompt) :] + completion = tokenizer.decode(generation, skip_special_tokens=True) + completions.append(completion) + return completions + + +class SyncRefModelCallback(TrainerCallback): + """ + Callback to synchronize the model with a reference model. + """ + + def __init__( + self, + ref_model: PreTrainedModel | torch.nn.Module, + accelerator: Accelerator | None, + ): + self.accelerator = accelerator + self.ref_model = ref_model + + @staticmethod + def _sync_target_model(model, target_model, alpha): + for target_param, copy_param in zip(target_model.parameters(), model.parameters(), strict=True): + target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha) + + @staticmethod + def sync_target_model(model, target_model, alpha): + deepspeed_plugin = AcceleratorState().deepspeed_plugin + if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3: + import deepspeed + + with deepspeed.zero.GatheredParameters( + list(model.parameters()) + list(target_model.parameters()), modifier_rank=0 + ): + if deepspeed.comm.get_rank() == 0: + SyncRefModelCallback._sync_target_model(model, target_model, alpha) + else: + SyncRefModelCallback._sync_target_model(model, target_model, alpha) + + def on_step_end(self, args, state, control, **kwargs): + model: PreTrainedModel = kwargs["model"] + + if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0: + if self.accelerator: + model = self.accelerator.unwrap_model(model) + self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha) + + +class RichProgressCallback(TrainerCallback): + """ + A [`TrainerCallback`] that displays the progress of training or evaluation using Rich. + """ + + def __init__(self): + if not is_rich_available(): + raise ImportError("RichProgressCallback requires the `rich` extra. To install, run `pip install rich`.") + + self.training_bar = None + self.evaluation_bar = None + self.training_task = None + self.evaluation_task = None + self.rich_group = None + self.rich_console = None + self.training_status = None + self.current_step = None + + def on_train_begin(self, args, state, control, **kwargs): + if not state.is_world_process_zero: + return + + self.training_bar = Progress() + self.evaluation_bar = Progress() + self.rich_console = Console() + self.training_status = self.rich_console.status("Nothing to log yet ...") + self.rich_group = Live(Panel(Group(self.training_bar, self.evaluation_bar, self.training_status))) + self.rich_group.start() + self.training_task = self.training_bar.add_task("[blue]Training ", total=state.max_steps) + self.current_step = 0 + + def on_step_end(self, args, state, control, **kwargs): + if not state.is_world_process_zero: + return + + self.training_bar.update(self.training_task, advance=state.global_step - self.current_step, update=True) + self.current_step = state.global_step + + def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): + if not state.is_world_process_zero: + return + + if has_length(eval_dataloader): + if self.evaluation_task is None: + self.evaluation_task = self.evaluation_bar.add_task("[blue]Evaluation", total=len(eval_dataloader)) + self.evaluation_bar.update(self.evaluation_task, advance=1, update=True) + + def on_evaluate(self, args, state, control, **kwargs): + if not state.is_world_process_zero: + return + + if self.evaluation_task is not None: + self.evaluation_bar.remove_task(self.evaluation_task) + self.evaluation_task = None + + def on_predict(self, args, state, control, **kwargs): + if not state.is_world_process_zero: + return + + if self.evaluation_task is not None: + self.evaluation_bar.remove_task(self.evaluation_task) + self.evaluation_task = None + + def on_log(self, args, state, control, logs=None, **kwargs): + if not (state.is_world_process_zero and self.training_bar): + return + + # Group keys by top-level prefix + grouped_logs = {} + for key, value in logs.items(): + parts = key.split("/") + group = parts[0] if len(parts) > 1 else None + subkey = "/".join(parts[1:]) if len(parts) > 1 else key + grouped_logs.setdefault(group, {})[subkey] = value + + # Create a table per group + tables = [] + for group_name, metrics in grouped_logs.items(): + table = Table( + title=f"[bold blue]{group_name}[/]" if group_name else None, header_style="bold magenta", box=None + ) + table.add_column("Metric", justify="left", no_wrap=True) + table.add_column("Value", justify="right") + + for metric, val in metrics.items(): + formatted = f"{val:.3f}" if isinstance(val, (float, int)) else str(val) + table.add_row(metric, formatted) + + tables.append(Panel(table, border_style="cyan", padding=(0, 1))) + + # Arrange tables in columns using Columns + column_layout = Columns(tables, equal=False, expand=True) + self.training_status.update( + Panel(column_layout, title=f"[bold green]Step {state.global_step}[/bold green]", border_style="green") + ) + + def on_train_end(self, args, state, control, **kwargs): + if not state.is_world_process_zero: + return + + self.rich_group.stop() + self.training_bar = None + self.evaluation_bar = None + self.training_task = None + self.evaluation_task = None + self.rich_group = None + self.rich_console = None + self.training_status = None + self.current_step = None + + +class LogCompletionsCallback(TrainerCallback): + r""" + A [`~transformers.TrainerCallback`] that logs completions to Weights & Biases and/or Comet. + + Usage: + ```python + trainer = DPOTrainer(...) + completions_callback = LogCompletionsCallback(trainer=trainer) + trainer.add_callback(completions_callback) + ``` + + Args: + trainer (`Trainer`): + Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` + column containing the prompts for generating completions. + generation_config ([`~transformers.GenerationConfig`], *optional*): + The generation config to use for generating completions. + num_prompts (`int`, *optional*): + The number of prompts to generate completions for. If not provided, defaults to the number of examples in + the evaluation dataset. + freq (`int`, *optional*): + The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`. + """ + + def __init__( + self, + trainer: Trainer, + generation_config: GenerationConfig | None = None, + num_prompts: int | None = None, + freq: int | None = None, + ): + self.trainer = trainer + self.generation_config = generation_config + self.freq = freq + self.table = [] + self._last_logged_step = -1 + + if self.trainer.eval_dataset is None: + raise ValueError("Trainer must have an evaluation dataset to use the LogCompletionsCallback.") + else: + self.eval_dataset = self.trainer.eval_dataset + + if num_prompts is not None: + self.eval_dataset = self.eval_dataset.select(range(num_prompts)) + + def on_step_end(self, args, state, control, **kwargs): + # Only log once per step (this method may be called multiple times) + if state.global_step == self._last_logged_step: + return + + # Only log every `freq` steps (if no `freq` is provided, log every `eval_steps` steps) + freq = self.freq or state.eval_steps + if state.global_step % freq != 0: + return + + tokenizer = kwargs["processing_class"] + tokenizer.padding_side = "left" + accelerator = self.trainer.accelerator + model = self.trainer.model_wrapped + with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts: + prompts = [maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"] for prompt in prompts] + completions = _generate_completions( + prompts, + model=model, + tokenizer=tokenizer, + accelerator=accelerator, + generation_config=self.generation_config, + batch_size=args.per_device_eval_batch_size, + ) + completions = gather_object(completions) + prompts = gather_object(prompts) + + # Build the data to log + if self.trainer.accelerator.is_main_process: + global_step = [str(state.global_step)] * len(prompts) + data = list(zip(global_step, prompts, completions, strict=True)) + self.table.extend(data) + table = pd.DataFrame(columns=["step", "prompt", "completion"], data=self.table) + + if "wandb" in args.report_to: + wandb.log({"completions": table}) + + if "comet_ml" in args.report_to: + log_table_to_comet_experiment( + name="completions.csv", + table=table, + ) + + # Save the last logged step, so we don't log the same completions multiple times + self._last_logged_step = state.global_step + + +class WeaveCallback(TrainerCallback): + r""" + A [`~transformers.TrainerCallback`] that logs traces and evaluations to W&B Weave. The callback uses + https://weave-docs.wandb.ai/guides/evaluation/evaluation_logger/ to log traces and evaluations at each evaluation + step. + + Supports two modes based on the `scorers` parameter: + - **Tracing Mode** (when scorers=None): Logs predictions for data exploration and analysis + - **Evaluation Mode** (when scorers provided): Logs predictions with scoring and summary metrics + + Both modes use Weave's EvaluationLogger for structured, consistent data logging. + + The callback logs data during evaluation phases (`on_evaluate`) rather than training steps, making it more + efficient and semantically correct. It gracefully handles missing weave installation by logging warnings and + skipping weave-specific functionality. It also checks for existing weave clients before initializing new ones. + + Usage: + ```python + # Tracing mode (just log predictions) + trainer = DPOTrainer(...) + weave_callback = WeaveTraceCallback(trainer=trainer) # project_name optional + trainer.add_callback(weave_callback) + + # Or specify a project name + weave_callback = WeaveTraceCallback(trainer=trainer, project_name="my-llm-training") + trainer.add_callback(weave_callback) + + + # Evaluation mode (log predictions + scores + summary) + def accuracy_scorer(prompt: str, completion: str) -> float: + # Your scoring logic here (metadata available via eval_attributes) + return score + + + weave_callback = WeaveTraceCallback( + trainer=trainer, + project_name="my-llm-training", # optional and needed only if weave client is not initialized + scorers={"accuracy": accuracy_scorer}, + ) + trainer.add_callback(weave_callback) + ``` + + Args: + trainer (`Trainer`): + Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` + column containing the prompts for generating completions. + project_name (`str`, *optional*): + Name of the Weave project where data will be logged. If not provided, will try to use existing weave client + or fall back to the active wandb run's project name. Raises an error if none of these are available. + scorers (`dict[str, Callable]`, *optional*): + Dictionary mapping scorer names to scorer functions. If `None`, operates in tracing mode (predictions + only). If provided, operates in evaluation mode (predictions + scores + summary). Scorer functions should + have signature: `scorer(prompt: str, completion: str) -> float | int` + generation_config ([`~transformers.GenerationConfig`], *optional*): + Generation config to use for generating completions. + num_prompts (`int` or `None`, *optional*): + Number of prompts to generate completions for. If not provided, defaults to the number of examples in the + evaluation dataset. + dataset_name (`str`, *optional*, defaults to `"eval_dataset"`): + Name for the dataset metadata in Weave. + model_name (`str`, *optional*): + Name for the model metadata in Weave. If not provided, attempts to extract from model config. + """ + + def __init__( + self, + trainer: Trainer, + project_name: str | None = None, + scorers: dict[str, callable] | None = None, + generation_config: GenerationConfig | None = None, + num_prompts: int | None = None, + dataset_name: str = "eval_dataset", + model_name: str | None = None, + ): + self.trainer = trainer + self.project_name = project_name + self.scorers = scorers or {} + self.generation_config = generation_config + self.dataset_name = dataset_name + self.model_name = model_name + self._last_logged_step = -1 + self._weave_initialized = False + self._eval_logger = None + + if self.trainer.eval_dataset is None: + raise ValueError("Trainer must have an evaluation dataset to use the WeaveCallback.") + else: + self.eval_dataset = self.trainer.eval_dataset + + if num_prompts is not None: + self.eval_dataset = self.eval_dataset.select(range(num_prompts)) + + def _initialize_weave(self): + """Initialize Weave and EvaluationLogger if not already initialized.""" + if not self._weave_initialized: + if not is_weave_available(): + logger.warning("Weave is not available. Please install weave to enable logging: `pip install weave`") + return + + if wc := weave_client_context.get_weave_client(): + self._weave_client = wc + else: + if self.project_name is None: + if is_wandb_available(): + if wandb.run is not None: + self.project_name = wandb.run.entity + "/" + wandb.run.project + logger.info(f"Using project name from active wandb run: {self.project_name}") + + if self.project_name is None: + raise ValueError( + "No existing Weave client found and no project_name provided. " + "Please either initialize weave with `weave.init('project-name')`, " + "provide a project_name to the `WeaveTraceCallback`, " + "or ensure an active wandb run exists." + ) + + self._weave_client = weave.init(self.project_name) + logger.info(f"Initialized Weave with project: {self.project_name}") + + if self.model_name is None: + self.model_name = getattr(self.trainer.model_wrapped.config, "_name_or_path", "unknown_model") + + self._EvaluationLogger = EvaluationLogger + + self._weave_initialized = True + + @property + def is_evaluation_mode(self) -> bool: + """True if scorers are provided (evaluation mode), False for tracing mode.""" + return bool(self.scorers) + + def on_train_begin(self, args, state, control, **kwargs): + """Initialize Weave when training begins.""" + self._initialize_weave() + + def on_evaluate(self, args, state, control, **kwargs): + if state.global_step == self._last_logged_step: + return + + self._initialize_weave() + + if not self._weave_initialized: + logger.debug("Weave not initialized, skipping logging") + return + + tokenizer = kwargs["processing_class"] + tokenizer.padding_side = "left" + accelerator = self.trainer.accelerator + model = self.trainer.model_wrapped + + with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts: + prompts = [maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"] for prompt in prompts] + + completions = _generate_completions( + prompts=prompts, + model=model, + tokenizer=tokenizer, + accelerator=accelerator, + generation_config=self.generation_config, + batch_size=args.per_device_eval_batch_size, + ) + + all_prompts = gather_object(prompts) + all_completions = gather_object(completions) + + if self.trainer.accelerator.is_main_process: + eval_attributes = { + "training_step": state.global_step, + "model_name": self.model_name, + "generation_config": (self.generation_config.to_dict() if self.generation_config else None), + } + + eval_logger = self._EvaluationLogger( + model=self.model_name, + dataset=self.dataset_name, + eval_attributes=eval_attributes, + ) + + successful_predictions = 0 + total_score_values = {} # For summary statistics + + for prompt, completion in zip(all_prompts, all_completions, strict=True): + try: + pred_logger = eval_logger.log_prediction(inputs={"prompt": prompt}, output=completion) + + if self.is_evaluation_mode: + for scorer_name, scorer_func in self.scorers.items(): + try: + score = scorer_func(prompt, completion) + pred_logger.log_score(scorer=scorer_name, score=score) + + if scorer_name not in total_score_values: + total_score_values[scorer_name] = [] + total_score_values[scorer_name].append(score) + + except Exception as scorer_e: + logger.warning(f"Failed to apply scorer '{scorer_name}': {scorer_e}") + + pred_logger.finish() + successful_predictions += 1 + + except Exception as pred_e: + logger.warning(f"Failed to log prediction for prompt: {pred_e}") + # Continue with other predictions even if one fails + + if self.is_evaluation_mode and total_score_values: + try: + summary_stats = { + "total_predictions": len(all_prompts), + "successful_predictions": successful_predictions, + } + + for scorer_name, scores in total_score_values.items(): + if scores: # Only if we have valid scores + summary_stats[f"avg_{scorer_name}"] = sum(scores) / len(scores) + + eval_logger.log_summary(summary_stats) + + except Exception as summary_e: + logger.warning(f"Failed to log summary: {summary_e}") + else: + try: + eval_logger.finish() + except Exception as finish_e: + logger.warning(f"Failed to finish evaluation logger: {finish_e}") + + self._last_logged_step = state.global_step + + +class BEMACallback(TrainerCallback): + # docstyle-ignore + r""" + A [`~transformers.TrainerCallback`] that implements [BEMA](https://huggingface.co/papers/2508.00180) + (Bias-Corrected Exponential Moving Average) by [Adam Block](https://huggingface.co/abblock) and [Cyril + Zhang](https://huggingface.co/cyrilzhang). Code from https://github.com/abblock/bema under MIT license. + + BEMA computes model weights that scale like: + + $$ + \theta_t' = \alpha_t \cdot (\theta_t - \theta_0) + \text{EMA}_t + $$ + + where \\( \theta_t \\) is the current model weights, \\( \theta_0 \\) is a snapshot of the model weights at the + first `update_after` step, \\( \text{EMA}_t \\) is the exponential moving average of the model weights, and + \\( \alpha_t \\) is a scaling factor that decays with the number of steps \\( t \\) as + + $$ + \alpha_t = (\rho + \gamma \cdot t)^{-\eta}. + $$ + + The EMA is computed as: + + $$ + \text{EMA}_t = (1 - \beta_t) \cdot \text{EMA}_{t-1} + \beta_t \cdot \theta_t + $$ + + where \\( \beta_t \\) is a decay factor that decays with the number of steps \\( t \\) as + + $$ + \beta_t = (\rho + \gamma \cdot t)^{-\kappa}. + $$ + + Args: + update_freq (`int`, *optional*, defaults to `400`): + Update the BEMA weights every X steps. Denoted this as \\( \phi \\) in the paper. + ema_power (`float`, *optional*, defaults to `0.5`): + Power for the EMA decay factor. Denoted \\( \kappa \\) in the paper. To disable EMA, set this to `0.0`. + bias_power (`float`, *optional*, defaults to `0.2`): + Power for the BEMA scaling factor. Denoted \\( \eta \\) in the paper. To disable BEMA, set this to `0.0`. + lag (`int`, *optional*, defaults to `10`): + Initial offset in the weight decay schedule that controls early-stage smoothness by acting as a virtual + starting age for the updates. Denoted as \\( \rho \\) in the paper. + update_after (`int`, *optional*, defaults to `0`): + Burn-in time before starting to update the BEMA weights. Denoted \\( \tau \\) in the paper. + multiplier (`float`, *optional*, defaults to `1.0`): + Initial value for the EMA decay factor. Denoted as \\( \gamma \\) in the paper. + min_ema_multiplier (`float`, *optional*, defaults to `0.0`): + Minimum value for the EMA decay factor. + device (`str`, *optional*, defaults to `"cpu"`): + Device to use for the BEMA buffers, e.g. `"cpu"` or `"cuda"`. Note that in most cases, this device SHOULD + BE DIFFERENT from the device used for training in order to avoid OOM. + + Example: + + ```python + from trl import BEMACallback + + trainer = Trainer(..., callbacks=[BEMACallback()]) + ``` + """ + + def __init__( + self, + update_freq: int = 400, + ema_power: float = 0.5, + bias_power: float = 0.2, + lag: int = 10, + update_after: int = 0, + multiplier: float = 1.0, + min_ema_multiplier: float = 0.0, + device: str = "cpu", + ): + # User-provided hyperparams + self.update_freq = update_freq + self.ema_power = ema_power + self.bias_power = bias_power + self.lag = lag + self.update_after = update_after + self.multiplier = multiplier + self.min_ema_multiplier = min_ema_multiplier + self.device = device + + # Internal state + self.param_names = [] # references to training model param names + self.thetat_params = [] # references to training model params + self.theta0_params = [] # ฮธโ‚€ buffers (on self.device) + self.ema_params = [] # EMA buffers (on self.device) + self.running_model = None # a copy of the model to run BEMA on + + @staticmethod + def _unwrap_model(model): + """ + Helper function to unwrap model from various wrappers including DataParallel, DistributedDataParallel, + DeepSpeed, and FSDP. + """ + # Handle DeepSpeed + if hasattr(model, "module") and hasattr(model, "engine"): + # DeepSpeed engine + return model.module + + # Handle FSDP + if hasattr(model, "_fsdp_wrapped_module"): + # FSDP wrapped model + return model._fsdp_wrapped_module + + # Handle DataParallel/DistributedDataParallel + if hasattr(model, "module"): + return model.module + + return model + + @torch.no_grad() + def on_train_begin( + self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs + ): + model = self._unwrap_model(model) + + # Create a new instance and load state_dict + self.running_model = type(model)(model.config).to(self.device) + self.running_model.load_state_dict(model.state_dict()) + + # Cache trainable parameters once in a fixed order + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + self.param_names.append(name) + self.thetat_params.append(param) + + # Clone ฮธโ‚€ and EMA on the same device as model + theta0 = param.detach().clone().to(self.device) + self.theta0_params.append(theta0) + self.ema_params.append(theta0.clone()) # initialize EMA with ฮธโ‚€ + + def _ema_beta(self, step: int) -> float: + """Compute the EMA decay factor ฮฒโ‚œ = (ฯ + ฮณยทt)โปแตแตƒแต–แต–แตƒ.""" + beta = (self.lag + self.multiplier * step) ** (-self.ema_power) + return max(beta, self.min_ema_multiplier) + + def _bema_alpha(self, step: int) -> float: + """Compute the BEMA scaling factor ฮฑโ‚œ = (ฯ + ฮณยทt)โปแต‰แต—แตƒ.""" + return (self.lag + self.multiplier * step) ** (-self.bias_power) + + def _update_bema_weights(self, step: int): + beta = self._ema_beta(step) + alpha = self._bema_alpha(step) + + # Compute EMA + BEMA in-place and write directly to running_model + for thetat, theta0, ema, run_param in zip( + self.thetat_params, self.theta0_params, self.ema_params, self.running_model.parameters(), strict=True + ): + thetat = thetat.detach().to(self.device) + ema.mul_(1 - beta).add_(thetat, alpha=beta) # EMA update: ema = (1 - beta) * ema + beta * ฮธโ‚œ + run_param.copy_(ema + alpha * (thetat - theta0)) # BEMA update: run_param = ema + alpha * (ฮธโ‚œ - ฮธโ‚€) + + @torch.no_grad() + def on_step_end( + self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs + ): + step = state.global_step + + # If we haven't reached the update_after step, skip the BEMA update + if step < self.update_after: + return + + # Snapshot ฮธโ‚€ and EMA at first update + if step == self.update_after: + for thetat_param, theta0_param, ema_param in zip( + self.thetat_params, self.theta0_params, self.ema_params, strict=True + ): + theta0_param.copy_(thetat_param) + ema_param.copy_(thetat_param) + + # Update BEMA weights every `update_freq` steps + elif (step - self.update_after) % self.update_freq == 0: + self._update_bema_weights(step) + logger.info(f"Updated BEMA weights at step {step}") + + @torch.no_grad() + def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + if state.is_world_process_zero: + save_directory = f"{args.output_dir}/bema" + self.running_model.save_pretrained(save_directory) + logger.info(f"Saved BEMA model to {save_directory}") diff --git a/ICL/RL/trl_source/trl/trainer/dpo_config.py b/ICL/RL/trl_source/trl/trainer/dpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..21a0377cb68873923f54ef9bf678546f5d5a5556 --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/dpo_config.py @@ -0,0 +1,775 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from transformers import TrainingArguments + + +class FDivergenceType(Enum): + """ + Types of f-divergence functions for DPO loss regularization. + + + + Using `FDivergenceType` for `f_divergence_type` in [`DPOConfig`] is deprecated and will be removed in version + 0.29.0. Use a string instead. + + + + Attributes: + REVERSE_KL: Reverse KL divergence. + JS_DIVERGENCE: Jensen-Shannon divergence. + ALPHA_DIVERGENCE: Alpha divergence. + + Examples: + ```python + >>> from trl.trainer.dpo_config import DPOConfig, FDivergenceType + + >>> config = DPOConfig( + ... f_divergence_type=FDivergenceType.ALPHA_DIVERGENCE, + ... f_alpha_divergence_coef=0.5, # used only with ALPHA_DIVERGENCE + ... ) + ``` + """ + + REVERSE_KL = "reverse_kl" + JS_DIVERGENCE = "js_divergence" + ALPHA_DIVERGENCE = "alpha_divergence" + + +class FDivergenceConstants: + """Constants for f-divergence types and their parameters. + + Attributes: + ALPHA_DIVERGENCE_COEF_KEY (`str`): Key for the alpha divergence coefficient. + ALPHA_DIVERGENCE_COEF_DEFAULT (`float`): Default value for the alpha divergence coefficient. + """ + + ALPHA_DIVERGENCE_COEF_KEY = "alpha_divergence_coef" + ALPHA_DIVERGENCE_COEF_DEFAULT = 1.0 + + +@dataclass +class DPOConfig(TrainingArguments): + r""" + Configuration class for the [`DPOTrainer`]. + + This class includes only the parameters that are specific to DPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model and reference model + + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of the + [`DPOTrainer`] is provided as a string. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model and reference model. + + > Parameters that control the data preprocessing + + dataset_num_proc (`int`, *optional*): + Number of processes to use for processing the dataset. + pad_token (`str`, *optional*): + Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, + it falls back to `processing_class.eos_token`. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the full sequence (prompt + completion). + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and + `"keep_start"`. + padding_free (`bool`, *optional*, defaults to `False`): + Whether to perform forward passes without padding by flattening all sequences in the batch into a single + continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only + supported with the `flash_attention_2` attention implementation, which can efficiently handle the flattened + batch structure. + precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): + Whether to precompute the log probabilities from the reference model. Setting this to `True` allows + training without needing the reference model during training, which can help reduce GPU memory usage. If + set to `False` (default), the reference model will be used during training to compute log probabilities + on-the-fly. + precompute_ref_batch_size (`int`, *optional*): + Batch size to use when precomputing reference model log probabilities. This can be set higher than the + training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for + training and `per_device_eval_batch_size` for evaluation. + > Parameters that control the training + + loss_type (`str` or `list[str]`, *optional*, defaults to `"sigmoid"`): + Type of loss to use. Possible values are: + + - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. + - `"hinge"`: hinge loss on the normalized likelihood from the + [SLiC](https://huggingface.co/papers/2305.10425) paper. + - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. + - `"exo_pair"`: pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper. + - `"nca_pair"`: pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper. + - `"robust"`: unbiased estimate of the DPO loss that is robust to preference noise from the [Robust + DPO](https://huggingface.co/papers/2403.00409) paper. + - `"bco_pair"`: pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper. + - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675) + paper. + - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. + - `"aot_unpaired"`: AOT loss for unpaired datasets from the + [AOT](https://huggingface.co/papers/2406.05882) paper. + - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the + [DiscoPOP](https://huggingface.co/papers/2406.08414) paper. + - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. + - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. + - `"sft"`: Negative log-likelihood loss (standard supervised fine-tuning loss). + + Multiple loss types can be combined using comma separation (e.g., `["sigmoid", "bco_pair", "sft"]` for + [MPO](https://huggingface.co/papers/2411.10442)). The `loss_weights` parameter can be used to specify + corresponding weights for each loss type. + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher ฮฒ means less deviation from the + reference model. For the IPO loss (`loss_type="ipo"`), ฮฒ is the regularization parameter denoted by ฯ„ in + the [paper](https://huggingface.co/papers/2310.12036). + f_divergence_type (`str`, *optional*, defaults to `"reverse_kl"`): + Type of f-divergence regularization function to compute divergence between policy and reference model. + Supported values: + - `"reverse_kl"`: Reverse KL divergence. + - `"js_divergence"`: Jensen-Shannon divergence. + - `"alpha_divergence"`: Alpha divergence. + f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`): + ฮฑ coefficient in the ฮฑ-divergence u^-ฮฑ regularization function for DPO loss. + label_smoothing (`float`, *optional*, defaults to `0.0`): + Robust DPO label smoothing parameter from the [cDPO report](https://ericmitchell.ai/cdpo.pdf) and [Robust + DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`. + use_weighting (`bool`, *optional*, defaults to `False`): + Whether to weight the loss as done in the [WPO paper](https://huggingface.co/papers/2406.11827). + ld_alpha (`float`, *optional*): + ฮฑ parameter from the [LD-DPO paper](https://huggingface.co/papers/2409.06411), which controls the weighting + of the verbose token log-probabilities in responses. If `None`, no weighting is applied to the verbose + part, and the loss is equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between + `0.0` and `1.0`. + discopop_tau (`float`, *optional*, defaults to `0.05`): + ฯ„/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls + the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. + loss_weights (`list[float]`, *optional*): + List of loss weights for multi-loss combinations. Used when combining multiple loss types. Example: `[0.8, + 0.2, 1.0]` for [MPO](https://huggingface.co/papers/2411.10442). If not provided, defaults to equal weights + (`1.0`) for all loss types. + sync_ref_model (`bool`, *optional*, defaults to `False`): + Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using + the `ref_model_mixup_alpha` parameter. This synchronization originates from the + [TR-DPO](https://huggingface.co/papers/2404.09656) paper. + ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): + ฮฑ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix + between the current policy and the previous reference policy during updates. The reference policy is + updated according to the equation: `ฯ€_ref = ฮฑ * ฯ€_ฮธ + (1 - ฮฑ) * ฯ€_ref_prev`. To use this parameter, you + must set `sync_ref_model=True`. + ref_model_sync_steps (`int`, *optional*, defaults to `512`): + ฯ„ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how + frequently the current policy is synchronized with the reference policy. To use this parameter, you must + set `sync_ref_model=True`. + + > Deprecated parameters + + base_model_attribute_name (`str`, *optional*, defaults to `"model"`): + Name of the attribute in the model that contains the base model. This is used to get the base model from + the model when the model does not have a `get_decoder` method in the case when `use_liger_kernel` is + `True`. + + + + This parameter is deprecated and will be removed in version 0.29.0. In the future the base model will be + retrieved via `get_decoder`; if your model does not support this, it will no longer be supported by the + [`DPOTrainer`]. + + + ref_model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument of the + [`DPOTrainer`] is provided as a string. + + + + This parameter is deprecated and will be removed in version 0.29.0. If you need different init kwargs for + the reference model, instantiate it yourself and pass it via the `ref_model` argument. + + + model_adapter_name (`str`, *optional*): + Name of the train target PEFT adapter, when using LoRA with multiple adapters. Only the default adapter + will be supported going forward. + + + + This parameter is deprecated and will be removed in version 0.29.0. Only the default adapter will be + supported going forward. + + + ref_adapter_name (`str`, *optional*): + Name of the reference PEFT adapter, when using LoRA with multiple adapters. If you used it to resume + training an adapter, you won't need this argument anymore in the next version and can rely on the trainer. + For now, it is still the only supported way to do this. + + + + This parameter is deprecated and will be removed in version 0.29.0. If you used it to resume training an + adapter, you won't need this argument anymore in the next version and can rely on the trainer. For now, it + is still the only supported way to do this. + + + force_use_ref_model (`bool`, *optional*, defaults to `False`): + If you provide a PEFT model as the active model and wish to use a different model for the `ref_model`, set + this flag to `True`. + + + + This parameter is deprecated and will be removed in version 0.29.0. There is no need to pass this argument + anymore: if you provide a reference model, it will be used automatically. + + + generate_during_eval (`bool`, *optional*, defaults to `False`): + Whether to generate and log completions from both the model and the reference model to W&B or Comet during + evaluation. + + + + This parameter is deprecated and will be removed in version 0.29.0. Please use a callback instead; see + `https://gist.github.com/qgallouedec/a08da3457a3a76c5ca539d4a0b38e482`. + + + label_pad_token_id (`int`, *optional*, defaults to `-100`): + Padding value to use for labels. + + + + This parameter is deprecated and will be removed in version 0.29.0. It will no longer be possible to set + this value. + + + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt. We recommend filtering overlong prompts from your dataset before passing it + to the trainer instead of using this parameter. + + + + This parameter is deprecated and will be removed in version 0.29.0. We recommend filtering overlong prompts + from your dataset before passing it to the trainer instead of using this parameter. + + + max_completion_length (`int`, *optional*): + Maximum length of the completion. + + + + This parameter is deprecated and will be removed in version 0.29.0. We recommend using `max_length` instead + to control the maximum length of samples. + + + reference_free (`bool`, *optional*, defaults to `False`): + Whether to ignore the provided reference model and implicitly use a reference model that assigns equal + probability to all responses. + + + + This parameter is deprecated and will be removed in version 0.29.0. If you want a reference-free objective, + use [`experimental.cpo.CPOTrainer`] instead. + + + rpo_alpha (`float`, *optional*): + ฮฑ parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) (v3), which controls the + weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the + DPO loss. The paper recommends `rpo_alpha=1.0`. + + + + This parameter is deprecated and will be removed in version 0.29.0. This is equivalent to including `"sft"` + in `loss_type`; we recommend adding `"sft"` to `loss_type` and setting its weight in `loss_weights` to + `rpo_alpha`. + + + tools (`list[dict] | None`, *optional*): + List of tools (callable functions) that will be accessible to the model. If the template does not support + function calling, this argument will have no effect. + + + + This parameter is deprecated and will be removed in version 0.29.0. In 0.29 this argument will be ignored; + tools should be provided via the dataset instead. For now, `DPOConfig.tools` remains the only supported way + to pass tools. + + + use_logits_to_keep (`bool`, *optional*, defaults to `False`): + If `True`, only a specified number of logits are computed in the forward pass. This can be useful for + saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios + when working with very long prompts where labels are ignored (-100). + + + + This parameter is deprecated and will be removed in version 0.29.0. The DPO trainer will no longer use this + setting. + + + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "ref_model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue + # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary + # workaround here, which can be removed once we drop support for versions older than 4.57.5. + lr_scheduler_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " + "restarts." + }, + ) + + # Parameters that control the model and reference model + model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `DPOTrainer` is provided as a string." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model and reference model."}, + ) + + # Parameters that control the data preprocessing + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + pad_token: str | None = field( + default=None, + metadata={ + "help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that " + "is also `None`, it falls back to `processing_class.eos_token`." + }, + ) + max_length: int | None = field( + default=1024, + metadata={"help": "Maximum length of the full sequence (prompt + completion)."}, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` " + "and `'keep_start'`.", + "choices": ["keep_end", "keep_start"], + }, + ) + padding_free: bool = field( + default=False, + metadata={ + "help": "Whether to perform forward passes without padding by flattening all sequences in the batch into " + "a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, " + "this is only supported with the `flash_attention_2` attention implementation, which can efficiently " + "handle the flattened batch structure." + }, + ) + precompute_ref_log_probs: bool = field( + default=False, + metadata={ + "help": "Whether to precompute the log probabilities from the reference model. Setting this to `True` " + "allows training without needing the reference model during training, which can help reduce GPU memory " + "usage. If set to `False` (default), the reference model will be used during training to compute log " + "probabilities on-the-fly." + }, + ) + precompute_ref_batch_size: int | None = field( + default=None, + metadata={ + "help": "Batch size to use when precomputing reference model log probabilities. This can be set higher " + "than the training batch size to speed up preprocessing. If `None`, defaults to " + "`per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation." + }, + ) + + # Parameters that control the training + loss_type: list[str] = field( + default_factory=lambda: ["sigmoid"], + metadata={ + "help": "Type of loss to use. Possible values are: `'sigmoid'`, `'hinge'`, `'ipo'`, `'exo_pair'`, " + "`'nca_pair'`, `'robust'`, `'bco_pair'`, `'sppo_hard'`, `'aot'`, `'aot_unpaired'`, `'discopop'`, " + "`'apo_zero'`, `'apo_down'` and `'sft'`. Multiple loss types can be combined using comma separation " + "(e.g., `['sigmoid', 'bco_pair', 'sft']` for MPO). The `loss_weights` parameter can be used to specify " + "corresponding weights for each loss type." + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the deviation from the reference model. " + "Higher ฮฒ means less deviation from the reference model." + }, + ) + f_divergence_type: str = field( + default="reverse_kl", + metadata={ + "help": "Type of f-divergence regularization function to compute divergence between policy and reference " + "model.", + "choices": ["reverse_kl", "js_divergence", "alpha_divergence"], + }, + ) + f_alpha_divergence_coef: float = field( + default=1.0, + metadata={"help": "ฮฑ coefficient in the ฮฑ-divergence u^-ฮฑ regularization function for DPO loss."}, + ) + label_smoothing: float = field( + default=0.0, + metadata={ + "help": "Robust DPO label smoothing parameter from the cDPO report and Robust DPO paper that should " + "be between `0.0` and `0.5`." + }, + ) + use_weighting: bool = field( + default=False, + metadata={"help": "Whether to weight the loss as done in the WPO paper."}, + ) + ld_alpha: float | None = field( + default=None, + metadata={ + "help": "ฮฑ parameter from the LD-DPO paper, which controls the weighting of the verbose token " + "log-probabilities in responses. If `None`, no weighting is applied to the verbose part, and the loss is " + "equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between `0.0` and `1.0`.", + }, + ) + discopop_tau: float = field( + default=0.05, + metadata={ + "help": "ฯ„/temperature parameter from the DiscoPOP paper, which controls the shape of log ratio modulated " + "loss. The paper recommends the default value `discopop_tau=0.05`." + }, + ) + loss_weights: list[float] | None = field( + default=None, + metadata={ + "help": "List of loss weights for multi-loss combinations. Used when combining multiple loss types. " + "Example: `[0.8, 0.2, 1.0]` for MPO. If not provided, defaults to equal weights (`1.0`) for all loss " + "types." + }, + ) + sync_ref_model: bool = field( + default=False, + metadata={ + "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " + "steps, using the `ref_model_mixup_alpha` parameter." + }, + ) + ref_model_mixup_alpha: float = field( + default=0.6, + metadata={ + "help": "ฮฑ parameter from the TR-DPO paper, which controls the mix between the current policy and the " + "previous reference policy during updates. The reference policy is updated according to the equation: " + "`ฯ€_ref = ฮฑ * ฯ€_ฮธ + (1 - ฮฑ) * ฯ€_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + ref_model_sync_steps: int = field( + default=512, + metadata={ + "help": "ฯ„ parameter from the TR-DPO paper, which determines how frequently the current policy is " + "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + + # Deprecated parameters + base_model_attribute_name: str | None = field( + default=None, + metadata={ + "help": "Name of the attribute in the model that contains the base model. This is used to get the base " + "model from the model when the model does not have a `get_decoder` method in the case when " + "`use_liger_kernel` is `True`. Deprecated: the base model will be retrieved via `get_decoder`; models " + "without it won't be supported by the DPO trainer." + }, + ) + force_use_ref_model: bool | None = field( + default=None, + metadata={ + "help": "Deprecated. There is no need to pass this argument anymore: if you provide a reference model, it " + "will be used automatically." + }, + ) + generate_during_eval: bool | None = field( + default=None, + metadata={ + "help": "Deprecated. Please use a callback instead; see " + "`https://gist.github.com/qgallouedec/a08da3457a3a76c5ca539d4a0b38e482`." + }, + ) + label_pad_token_id: int | None = field( + default=None, + metadata={"help": "Deprecated. It will no longer be possible to set this value."}, + ) + max_completion_length: int | None = field( + # This default value is used to determine whether the user has set it or not, since `None` is a valid value for + # this parameter. This is overridden in `__post_init__` to preserve the old default value of `None`. + default=-1, + metadata={"help": "Deprecated. Use `max_length` instead to control the maximum length of samples."}, + ) + max_prompt_length: int | None = field( + # This default value is used to determine whether the user has set it or not, since `None` is a valid value for + # this parameter. This is overridden in `__post_init__` to preserve the old default value of `512`. + default=-1, + metadata={ + "help": "Deprecated. We recommend filtering overlong prompts from your dataset before passing it to the " + "trainer instead of using this parameter." + }, + ) + model_adapter_name: str | None = field( + default=None, + metadata={"help": "Deprecated. Only the default adapter will be supported going forward."}, + ) + ref_adapter_name: str | None = field( + default=None, + metadata={ + "help": "Deprecated. If you used it to resume training an adapter, you won't need this argument anymore " + "in the next version and can rely on the trainer. For now, it is still the only supported way to do " + "this." + }, + ) + ref_model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument " + "of the `DPOTrainer` is provided as a string. Deprecated: if you need different init kwargs for the " + "reference model, instantiate it yourself and pass it via the `ref_model` argument." + }, + ) + reference_free: bool | None = field( + default=None, + metadata={ + "help": "Whether to ignore the provided reference model and implicitly use a reference model that assigns " + "equal probability to all responses. Deprecated: if you want a reference-free objective, use " + "`CPOTrainer` instead." + }, + ) + rpo_alpha: float | None = field( + default=None, + metadata={ + "help": "ฮฑ parameter from the RPO paper (v3), which controls the weighting of the NLL term in the loss. " + "If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends " + "`rpo_alpha=1.0`. Deprecated: this is equivalent to including `'sft'` in `loss_type`; we recommend adding " + "'sft' to `loss_type` and setting its weight in `loss_weights` to `rpo_alpha`." + }, + ) + tools: list[dict] | None = field( + default=None, + metadata={ + "help": "List of tools (callable functions) that will be accessible to the model. If the template does " + "not support function calling, this argument will have no effect. Deprecated: in 0.29 this argument " + "will be ignored; tools should be provided via the dataset instead. For now, `DPOConfig.tools` remains " + "the only supported way to pass tools." + }, + ) + use_logits_to_keep: bool | None = field( + default=None, + metadata={ + "help": "If `True`, only a specified number of logits are computed in the forward pass. This can be " + "useful for saving memory and speeding up training by not computing the logits for all tokens, especially " + "in scenarios when working with very long prompts where labels are ignored (-100). Deprecated: the DPO " + "trainer will no longer use this setting." + }, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + if self.base_model_attribute_name is not None: + warnings.warn( + "`base_model_attribute_name` is deprecated and will be removed in version 0.29.0. The base model " + "will be retrieved via `get_decoder`; if your model does not support this, it will no longer be " + "supported by the DPO trainer.", + FutureWarning, + stacklevel=3, + ) + else: # keep the old default + self.base_model_attribute_name = "model" + + if self.force_use_ref_model is not None: + warnings.warn( + "`force_use_ref_model` is deprecated and will be removed in version 0.29.0. There is no need to pass " + "this argument anymore: if you provide a reference model, it will be used automatically.", + FutureWarning, + stacklevel=3, + ) + + if self.generate_during_eval is not None: + warnings.warn( + "`generate_during_eval` is deprecated and will be removed in version 0.29.0. Please use a callback " + "instead. See the example at `https://gist.github.com/qgallouedec/a08da3457a3a76c5ca539d4a0b38e482`.", + FutureWarning, + stacklevel=3, + ) + else: # keep the old default + self.generate_during_eval = False + + if self.label_pad_token_id is not None: + warnings.warn( + "`label_pad_token_id` is deprecated and will be removed in version 0.29.0. It will no longer be " + "possible to set this value.", + FutureWarning, + stacklevel=3, + ) + else: # keep the old default + self.label_pad_token_id = -100 + + if self.max_completion_length != -1: + warnings.warn( + "`max_completion_length` is deprecated and will be removed in version 0.29.0. We recommend using " + "`max_length` instead to control the maximum length of samples.", + FutureWarning, + stacklevel=3, + ) + else: # keep the old default + self.max_completion_length = None + + if self.max_prompt_length != -1: + warnings.warn( + "`max_prompt_length` is deprecated and will be removed in version 0.29.0. We recommend filtering out " + "overlong prompts from your dataset before passing it to the trainer instead of using this parameter.", + FutureWarning, + stacklevel=3, + ) + else: # keep the old default + self.max_prompt_length = 512 + + if self.model_adapter_name is not None: + warnings.warn( + "`model_adapter_name` is deprecated and will be removed in version 0.29.0. Only the default adapter " + "will be supported going forward.", + FutureWarning, + stacklevel=3, + ) + + if self.ref_adapter_name is not None: + warnings.warn( + "`ref_adapter_name` is deprecated and will be removed in version 0.29.0. If you used it to resume " + "training an adapter, you won't need this argument anymore in the next version and can rely on the " + "trainer. For now, it is still the only supported way to do this.", + FutureWarning, + stacklevel=3, + ) + + if self.ref_model_init_kwargs is not None: + warnings.warn( + "`ref_model_init_kwargs` is deprecated and will be removed in version 0.29.0. If you need different " + "init kwargs for the reference model, instantiate it yourself and pass it via the `ref_model` " + "argument.", + FutureWarning, + stacklevel=3, + ) + + if self.reference_free is not None: + warnings.warn( + "`reference_free` is deprecated and will be removed in version 0.29.0. If you want a reference-free " + "objective, use `CPOTrainer` instead.", + FutureWarning, + stacklevel=3, + ) + else: # keep the old default + self.reference_free = False + + if self.rpo_alpha is not None: + warnings.warn( + "`rpo_alpha` is deprecated and will be removed in version 0.29.0. It is equivalent to including " + "`'sft'` in `loss_type`; we recommend adding `'sft'` to `loss_type` and setting its weight in " + "`loss_weights` to `rpo_alpha`.", + FutureWarning, + stacklevel=3, + ) + + if self.tools is not None: + warnings.warn( + "`tools` is deprecated and will be removed in version 0.29.0. In 0.29 this argument will be ignored; " + "tools should be provided via the dataset instead but for now, `DPOConfig.tools` remains the only " + "supported way to pass tools.", + FutureWarning, + stacklevel=3, + ) + + if self.use_logits_to_keep is not None: + warnings.warn( + "`use_logits_to_keep` is deprecated and will be removed in version 0.29.0. The DPO trainer will no " + "longer use this setting.", + FutureWarning, + stacklevel=3, + ) + else: # keep the old default + self.use_logits_to_keep = False + + if isinstance(self.f_divergence_type, FDivergenceType): + warnings.warn( + "`f_divergence_type` will require a string in 0.29.0; `FDivergenceType` is deprecated. Use one of: " + "`'reverse_kl'`, `'js_divergence'`, `'alpha_divergence'`.", + FutureWarning, + stacklevel=3, + ) + self.f_divergence_type = self.f_divergence_type.value + + # Normalize loss_type to string format for internal use + if hasattr(self.loss_type, "__len__") and len(self.loss_type) == 1: + self.loss_type = self.loss_type[0] + + # Validate loss_type + if self.loss_weights is not None: + loss_types = self.loss_type if isinstance(self.loss_type, list) else [self.loss_type] + if len(self.loss_weights) != len(loss_types): + raise ValueError( + f"Length of loss_weights list ({self.loss_weights}) must match number of loss types " + f"({loss_types})." + ) + + if "aot_pair" in self.loss_type: + warnings.warn( + "The loss type 'aot_pair' has been renamed to 'aot_unpaired' and is deprecated. " + "It will be removed in version 0.29.0. Please use 'aot_unpaired' in `loss_type` instead.", + FutureWarning, + stacklevel=3, + ) + if isinstance(self.loss_type, str): + self.loss_type = "aot_unpaired" + else: + self.loss_type = ["aot_unpaired" if lt == "aot_pair" else lt for lt in self.loss_type] + + super().__post_init__() diff --git a/ICL/RL/trl_source/trl/trainer/dpo_trainer.py b/ICL/RL/trl_source/trl/trainer/dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..95aa3f4e696ab0a19dfaf1ed723620e1b0abf81d --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/dpo_trainer.py @@ -0,0 +1,2046 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import random +import textwrap +import warnings +from collections import defaultdict +from collections.abc import Callable +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformers +from accelerate import PartialState, logging +from accelerate.utils import tqdm +from datasets import Dataset, IterableDataset +from packaging.version import Version +from torch import autocast +from torch.utils.data import DataLoader +from transformers import ( + AutoProcessor, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.data.data_collator import DataCollatorMixin +from transformers.integrations import ( + is_comet_available, + is_mlflow_available, + is_wandb_available, +) +from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES +from transformers.trainer_utils import EvalLoopOutput +from transformers.utils import is_liger_kernel_available, is_peft_available + +from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_extract_prompt +from ..models import create_reference_model, prepare_deepspeed +from ..models.utils import peft_module_casting_to_bf16, prepare_fsdp +from .base_trainer import BaseTrainer +from .callbacks import SyncRefModelCallback +from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType +from .utils import ( + RunningMoments, + cap_exp, + create_model_from_path, + disable_dropout_in_model, + empty_cache, + flush_left, + flush_right, + get_config_model_id, + log_table_to_comet_experiment, + pad, + pad_to_length, + selective_log_softmax, +) + + +if is_peft_available(): + from peft import ( + PeftConfig, + PeftModel, + get_peft_model, + prepare_model_for_kbit_training, + ) + +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss + + +if is_wandb_available(): + import wandb + +if is_mlflow_available(): + import mlflow + + +logger = logging.get_logger(__name__) + + +def shift_tokens_right(input_ids: torch.Tensor, decoder_start_token_id: int) -> torch.Tensor: + """Shift input ids one token to the right, and pad with pad_token_id""" + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + return shifted_input_ids + + +@dataclass +class DataCollatorForPreference(DataCollatorMixin): + """ + Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch if they are + not all of the same length. + + Args: + pad_token_id (`int`): + Token ID to use for padding. + return_tensors (`str`, *optional*, defaults to `"pt"`): + Type of Tensor to return. Only `"pt"` is currently supported. + + Examples: + ```python + >>> from trl import DataCollatorForPreference + + >>> collator = DataCollatorForPreference(pad_token_id=0) + >>> examples = [ + ... {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]}, + ... {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]}, + ... ] + >>> collator(examples) + {'prompt_input_ids': tensor([[1, 2, 3], + [0, 7, 8]]), + 'prompt_attention_mask': tensor([[1, 1, 1], + [0, 1, 1]]), + 'chosen_input_ids': tensor([[ 4, 5], + [ 9, 10]]), + 'chosen_attention_mask': tensor([[1, 1], + [1, 1]]), + 'rejected_input_ids': tensor([[ 6, 0, 0], + [11, 12, 13]]), + 'rejected_attention_mask': tensor([[1, 0, 0], + [1, 1, 1]]) + } + ``` + """ + + pad_token_id: int + return_tensors: str = "pt" + + def torch_call(self, examples: list[list[int] | Any | dict[str, Any]]) -> dict[str, Any]: + # Convert to tensor + prompt_input_ids = [torch.tensor(example["prompt_input_ids"]) for example in examples] + prompt_attention_mask = [torch.ones_like(input_ids) for input_ids in prompt_input_ids] + chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples] + chosen_attention_mask = [torch.ones_like(input_ids) for input_ids in chosen_input_ids] + rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples] + rejected_attention_mask = [torch.ones_like(input_ids) for input_ids in rejected_input_ids] + if "pixel_values" in examples[0]: + pixel_values = [torch.tensor(example["pixel_values"]) for example in examples] + if "pixel_attention_mask" in examples[0]: + pixel_attention_mask = [torch.tensor(example["pixel_attention_mask"]) for example in examples] + if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]: + ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples]) + ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples]) + + # Pad + output = {} + output["prompt_input_ids"] = pad(prompt_input_ids, padding_value=self.pad_token_id, padding_side="left") + output["prompt_attention_mask"] = pad(prompt_attention_mask, padding_value=0, padding_side="left") + output["chosen_input_ids"] = pad(chosen_input_ids, padding_value=self.pad_token_id) + output["chosen_attention_mask"] = pad(chosen_attention_mask, padding_value=0) + output["rejected_input_ids"] = pad(rejected_input_ids, padding_value=self.pad_token_id) + output["rejected_attention_mask"] = pad(rejected_attention_mask, padding_value=0) + if "pixel_values" in examples[0]: + output["pixel_values"] = pad(pixel_values, padding_value=0.0) + if "pixel_attention_mask" in examples[0]: + output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0) + if "image_sizes" in examples[0]: + output["image_sizes"] = torch.tensor([example["image_sizes"] for example in examples]) + if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]: + output["ref_chosen_logps"] = ref_chosen_logps + output["ref_rejected_logps"] = ref_rejected_logps + if "token_type_ids" in examples[0]: + token_type_ids = [torch.tensor(example["token_type_ids"]) for example in examples] + output["token_type_ids"] = pad(token_type_ids, padding_value=0, padding_side="left") + + return output + + +class DPOTrainer(BaseTrainer): + """ + Trainer for Direct Preference Optimization (DPO) method. + + This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods. + + Args: + model (`str | PreTrainedModel`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in + `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + ref_model ([`~transformers.PreTrainedModel`]) + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation + and loss. If no reference model is provided, the trainer will create a reference model with the same + architecture as the model to be optimized. + args ([`DPOConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. + data_collator ([`~transformers.DataCollator`], *optional*): + Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. + Will default to [`DataCollatorForPreference`]. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. DPO supports [preference](#preference) type and. The format of the samples can + be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If `None`, the processing class is loaded from the model's name + with [`~transformers.AutoTokenizer.from_pretrained`]. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return + a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to + `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered + after the last eval batch to signal that the function needs to calculate and return the global summary + statistics rather than accumulating the batch-level statistics. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your + model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*): + A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in + `args`. Incompatible with the `optimizers` argument. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + """ + + _tag_names = ["trl", "dpo"] + _name = "DPO" + _paper = { + "title": "Direct Preference Optimization: Your Language Model is Secretly a Reward Model", + "id": "2305.18290", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @inproceedings{rafailov2023direct, + title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}}, + author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn}, + year = 2023, + booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023}, + url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html}, + editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine}, + }"""), + } + + def __init__( + self, + model: str | nn.Module | PreTrainedModel, + ref_model: PreTrainedModel | nn.Module | None = None, + args: DPOConfig | None = None, + data_collator: DataCollator | None = None, # type: ignore + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None, + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: "PeftConfig | None" = None, + ): + # Args + if args is None: + model_name = model if isinstance(model, str) else get_config_model_id(model.config) + model_name = model_name.split("/")[-1] + args = DPOConfig(f"{model_name}-DPO") + + # IterableDataset requires dispatch_batches=False because Accelerate's dispatch mode may try to concatenate + # batches from multiple processes, leading to mismatch errors. + if isinstance(train_dataset, IterableDataset): + if args.accelerator_config.dispatch_batches is True: + logger.warning( + "You are using an `IterableDataset` for training with `dispatch_batches=True`. `dispatch_batches` " + "is forced to `False` when using an `IterableDataset`. To remove this warning, unset " + "`dispatch_batches` in `DPOConfig` or set it to `False`." + ) + args.accelerator_config.dispatch_batches = False + + # Model and reference model + if isinstance(model, str): + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) + else: + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + model_id = get_config_model_id(model.config) + if isinstance(ref_model, str): + warnings.warn( + "Passing `ref_model` as a string is deprecated and will be removed in version 0.29.0. Usually, you " + "can just omit `ref_model` and we'll initialize it to a copy of `model` for you. If you really need " + "to load the reference model from a different path, you can still do so by passing `ref_model` as a " + "model instance.", + FutureWarning, + stacklevel=2, + ) + model_init_kwargs = args.ref_model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + ref_model = create_model_from_path(ref_model, **model_init_kwargs) + else: + if args.ref_model_init_kwargs is not None: + logger.warning( + "You passed `ref_model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. " + "The `ref_model_init_kwargs` will be ignored." + ) + if ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, you can simply omit the `ref_model` argument and it will be created for you." + ) + + if args.force_use_ref_model is None: + self.force_use_ref_model = ref_model is not None + else: + self.force_use_ref_model = args.force_use_ref_model + + # Processing class + if processing_class is None: + processing_class = AutoProcessor.from_pretrained(model_id) + + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + self._is_vlm = True + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + self._is_vlm = False + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + # Get the pad token: if not provided, use the one from the processing class or the eos token + # if the processing class does not have a pad token. + pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token + self.pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) + if self.pad_token_id is None: + raise ValueError( + f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " + "in the vocabulary before using it as a padding token." + ) + + # PEFT configuration and model wrapping + model = self._prepare_peft_model(model, ref_model, peft_config, args) + + if args.generate_during_eval and not (is_wandb_available() or is_comet_available() or is_mlflow_available()): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed." + " Please install `wandb`, `mlflow` or `comet-ml` to resolve." + ) + + self.is_encoder_decoder = model.config.is_encoder_decoder + self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys() + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + self.model_adapter_name = args.model_adapter_name + self.ref_adapter_name = args.ref_adapter_name + self.reference_free = args.reference_free + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model or args.precompute_ref_log_probs: + # The `model` with adapters turned off will be used as the reference model + self.ref_model = None + else: + self.ref_model = create_reference_model(model) + + # Disable dropout in the model and reference model + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + # Liger kernel + if args.use_liger_kernel: + if not is_liger_kernel_available(): + raise ImportError( + "You set `use_liger_kernel=True` but the liger kernel is not available. " + "Please install liger-kernel first: `pip install liger-kernel`" + ) + if args.loss_type not in ["sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"]: + raise ValueError( + "You set `use_liger_kernel=True` but the loss type is not from `[sigmoid, apo_zero, apo_down, sppo_hard, nca_pair`. " + "Please set `loss_type='[sigmoid | apo_zero | apo_down | sppo_hard | nca_pair]'` to use the liger kernel." + ) + self.dpo_loss_fn = LigerFusedLinearDPOLoss( + ignore_index=args.label_pad_token_id, + beta=args.beta, + use_ref_model=not args.reference_free, + average_log_prob=False, + loss_type=args.loss_type, + ) + + # Data collator + if data_collator is None: + data_collator = DataCollatorForPreference(pad_token_id=self.pad_token_id) + + self.generate_during_eval = args.generate_during_eval + self.label_pad_token_id = args.label_pad_token_id + self.max_prompt_length = args.max_prompt_length + self.max_completion_length = args.max_completion_length + self.max_length = args.max_length + self.truncation_mode = args.truncation_mode + self.precompute_ref_log_probs = args.precompute_ref_log_probs + self.use_logits_to_keep = args.use_logits_to_keep + + if args.padding_free: + if model.config._attn_implementation != "flash_attention_2": + logger.warning( + "Padding-free training is enabled, but the attention implementation is not set to " + "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and " + "'flash_attention_2' is the only known attention mechanism that reliably supports this. Using " + "other implementations may lead to unexpected behavior. To ensure compatibility, set " + "`attn_implementation='flash_attention_2'` in the model configuration, or verify that your " + "attention mechanism can handle flattened sequences." + ) + if args.per_device_train_batch_size == 1: + logger.warning( + "You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size " + "of 1 annihilate the benefits of padding-free training. Please consider increasing the batch size " + "to at least 2." + ) + self.padding_free = args.padding_free + + # Since ref_logs are precomputed on the first call to get_train/eval_dataloader + # keep track of first called to avoid computation of future calls + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + + self.beta = args.beta + self.label_smoothing = args.label_smoothing + self.loss_type = args.loss_type if isinstance(args.loss_type, list) else [args.loss_type] + self.loss_weights = args.loss_weights + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + self.use_weighting = args.use_weighting + self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) + if self.aux_loss_enabled and self.aux_loss_coef == 0.0: + logger.warning( + "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " + "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " + "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " + "loss.", + ) + for loss_type in self.loss_type: + if ( + loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"] + and args.label_smoothing > 0 + ): + logger.warning( + f"You are using the {loss_type} loss type that does not support label smoothing. The " + "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this " + "warning.", + ) + if loss_type == "kto_pair": + raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.") + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + self.f_divergence_type = args.f_divergence_type + self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} + self.dataset_num_proc = args.dataset_num_proc + + # Dataset preparation + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + # Deepspeed Zero-3 does not support precompute_ref_log_probs + if self.is_deepspeed_enabled: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." + ) + + if self.ref_model is None: + if not (self.is_peft_model or self.precompute_ref_log_probs): + raise ValueError( + "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" + ) + if args.sync_ref_model: + raise ValueError( + "You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`." + ) + else: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + elif self.is_fsdp_enabled: + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + if args.sync_ref_model: + if self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`." + ) + + self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) + + if "bco_pair" in self.loss_type: + self.running = RunningMoments(self.accelerator) + + def _prepare_peft_model( + self, model: PreTrainedModel, ref_model: PreTrainedModel, peft_config: Any, args: DPOConfig + ) -> PreTrainedModel: + """Prepares a model for PEFT training.""" + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + if isinstance(model, PeftModel): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " + "merge and unload the existing adapter, save the resulting base model, and then pass that base " + "model along with the new `peft_config` to the trainer." + ) + + if ref_model is not None and not self.force_use_ref_model: + raise ValueError( + "You passed a ref_model and a peft_config with `force_use_ref_model=False`. For training PEFT adapters with DPO there is no need to pass a reference" + " model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with in DPOTrainer's init, and unset force_use_ref_model" + " if you want to use a different ref_model." + ) + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + + else: + model = self._prepare_gradient_checkpointing(model, args) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True + + else: + model = self._prepare_gradient_checkpointing(model, args) + + return model + + def _prepare_gradient_checkpointing(self, model: PreTrainedModel, args: DPOConfig): + """Prepare the gradienting checkpointing for the model.""" + # For models that use gradient_checkpointing, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + if args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + return model + + def _prepare_dataset( + self, + dataset: Dataset | IterableDataset, + processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin, + args: DPOConfig, + dataset_name: str, + ) -> Dataset | IterableDataset: + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc nor writer_batch_size + map_kwargs["num_proc"] = args.dataset_num_proc + map_kwargs["writer_batch_size"] = 10 + + with PartialState().main_process_first(): + # Extract prompt if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" + dataset = dataset.map(maybe_extract_prompt, **map_kwargs) + + is_chat = is_conversational(next(iter(dataset))) + + # Apply the chat template if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" + dataset = dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs + ) + + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + dataset = dataset.map( + self.tokenize_row if not self.is_vision_model else self.process_row, + remove_columns=["chosen", "rejected"], + fn_kwargs={ + "processing_class": processing_class, + "max_prompt_length": args.max_prompt_length, + "max_completion_length": args.max_completion_length, + # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) + "add_special_tokens": False, + "is_chat": is_chat, + }, + **map_kwargs, + ) + + return dataset + + @staticmethod + def tokenize_row( + features: dict[str, str], + processing_class: PreTrainedTokenizerBase, + max_prompt_length: int | None = None, + max_completion_length: int | None = None, + add_special_tokens: bool = True, + is_chat: bool = False, + ) -> dict[str, list[int]]: + """ + Tokenize a row of the dataset. + + Args: + features (`dict[str, str]`): + Row of the dataset, should contain the keys `"prompt"`, `"chosen"`, and `"rejected"`. + processing_class ([`~transformers.PreTrainedTokenizerBase`]): + Processing class used to process the data. + max_prompt_length (`int` or `None`): + Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated. + max_completion_length (`int` or `None`): + Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. + add_special_tokens (`bool`): + Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`, + the prompt sequence will have a bos token prepended and an eos token appended. In any case, the + completion sequences will have an eos token appended. + is_chat (`bool`): + Whether the data is conversational. If `True`, the completion sequences will not have an eos token + appended. + + Returns: + `dict[str, list[int]]`: + Tokenized sequences with the keys `"prompt_input_ids"`, `"chosen_input_ids"`, and + `"rejected_input_ids". + + Example: + ```python + >>> from transformers import GPT2Tokenizer + + >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + >>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} + >>> DPOTrainer.tokenize_row( + ... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False + ... ) + {'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]} + ``` + """ + tokenizer = processing_class # the processing class is a tokenizer + prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] + chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] + rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] + + # Add special tokens (typically for encoder-decoder models) + if add_special_tokens: + if tokenizer.bos_token_id is not None: + prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids + if tokenizer.eos_token_id is not None: + prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] + # For conversational data, the chat template already includes proper EOS tokens + if not is_chat: + chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] + rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] + + # Truncate prompt and completion sequences + if max_prompt_length is not None: + prompt_input_ids = prompt_input_ids[-max_prompt_length:] + if max_completion_length is not None: + chosen_input_ids = chosen_input_ids[:max_completion_length] + rejected_input_ids = rejected_input_ids[:max_completion_length] + + return { + "prompt_input_ids": prompt_input_ids, + "chosen_input_ids": chosen_input_ids, + "rejected_input_ids": rejected_input_ids, + } + + @staticmethod + def process_row( + features: dict[str, str], + processing_class: PreTrainedTokenizerBase, + max_prompt_length: int | None = None, + max_completion_length: int | None = None, + add_special_tokens: bool = True, + is_chat: bool = False, + ) -> dict[str, list[int]]: + """ + Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information. + + Note: Unlike `tokenize_row`, this method does not truncate prompts even if `max_prompt_length` is set. For + vision models, prompts contain image tokens that must exactly match the image features (pixel_values). + Truncating these tokens would cause a mismatch, leading to errors during the forward pass, like "Image features + and image tokens do not match". Users should filter their datasets to ensure prompts are an appropriate length + before training. + """ + if max_prompt_length is not None: + warnings.warn( + "max_prompt_length is not supported for vision models and will be ignored. " + "Truncating prompts would cause image token/feature mismatch errors.", + stacklevel=2, + ) + processor, tokenizer = processing_class, processing_class.tokenizer # the processing class is a processor + processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False) + + prompt_input_ids = processed_features["input_ids"][0] + pixel_values = processed_features["pixel_values"][0] + chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] + rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] + + # Add special tokens (typically for encoder-decoder models) + if add_special_tokens: + if tokenizer.bos_token_id is not None: + prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids + if tokenizer.eos_token_id is not None: + prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] + if not is_chat: + chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] + rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] + + # Truncate completion sequences only. + # Note: We do not truncate prompt_input_ids for vision models because the prompts contain image tokens + # that must exactly match the image features (pixel_values). Truncating would cause errors like + # "Image features and image tokens do not match: tokens: X, features: Y". Users should filter overlong + # prompts from their dataset before training (the recommended approach for the deprecated max_prompt_length). + if max_completion_length is not None: + chosen_input_ids = chosen_input_ids[:max_completion_length] + rejected_input_ids = rejected_input_ids[:max_completion_length] + + output = { + "prompt_input_ids": prompt_input_ids, + "pixel_values": pixel_values, + "chosen_input_ids": chosen_input_ids, + "rejected_input_ids": rejected_input_ids, + } + + if "pixel_attention_mask" in processed_features: + output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0] + if "image_sizes" in processed_features: + output["image_sizes"] = processed_features["image_sizes"][0] + if "token_type_ids" in processed_features: + output["token_type_ids"] = processed_features["token_type_ids"][0] + + return output + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs. + # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work. + # Instead, we set them to the columns expected by `DataCollatorForPreference`, hence the override. + if self._signature_columns is None: + self._signature_columns = [ + "prompt_input_ids", + "chosen_input_ids", + "rejected_input_ids", + "image_sizes", + "token_type_ids", + "ref_chosen_logps", + "ref_rejected_logps", + ] + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. + """ + + if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: + batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size + dataloader_params = { + "batch_size": batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) + + ref_chosen_logps = [] + ref_rejected_logps = [] + for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): + ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) + ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( + (ref_chosen_logp, ref_rejected_logp) + ) + ref_chosen_logps.append(ref_chosen_logp.cpu()) + ref_rejected_logps.append(ref_rejected_logp.cpu()) + + # Unnecessary cache clearing to avoid OOM + empty_cache() + self.accelerator.free_memory() + + all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() + all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() + + self.train_dataset = self.train_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) + self.train_dataset = self.train_dataset.add_column( + name="ref_rejected_logps", column=all_ref_rejected_logps + ) + + self._precomputed_train_ref_log_probs = True + + return super().get_train_dataloader() + + def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: + batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size + dataloader_params = { + "batch_size": batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) + + ref_chosen_logps = [] + ref_rejected_logps = [] + for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): + ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) + ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( + (ref_chosen_logp, ref_rejected_logp) + ) + ref_chosen_logps.append(ref_chosen_logp.cpu()) + ref_rejected_logps.append(ref_rejected_logp.cpu()) + + all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() + all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() + + eval_dataset = eval_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) + eval_dataset = eval_dataset.add_column(name="ref_rejected_logps", column=all_ref_rejected_logps) + + # Save calculated ref_chosen_logps and ref_rejected_logps to the eval_dataset for subsequent runs + if self.eval_dataset is not None: + self.eval_dataset = eval_dataset + self._precomputed_eval_ref_log_probs = True + + return super().get_eval_dataloader(eval_dataset=eval_dataset) + + @contextmanager + def null_ref_context(self): + """Context manager for handling null reference model (that is, peft adapter manipulation).""" + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): + if self.ref_adapter_name: + self.model.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.set_adapter(self.model_adapter_name or "default") + + def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> tuple[torch.Tensor, torch.Tensor]: + """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.""" + compte_ref_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + with torch.no_grad(), compte_ref_context_manager: + if self.ref_model is None: + with self.null_ref_context(): + ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) + else: + ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True) + return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] + + @staticmethod + def concatenated_inputs( + batch: dict[str, list | torch.LongTensor], padding_value: int + ) -> dict[str, torch.LongTensor]: + """ + Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt and + completion sequences. + + Args: + batch (`dict[str, list | torch.LongTensor]`): + A batch of input data. The batch must contain the following keys: + + - `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input + IDs. + - `"chosen_input_ids"`: Tensor of shape `(batch_size, chosen_length)` representing the chosen + completion input IDs. + - `"rejected_input_ids"`: Tensor of shape `(batch_size, rejected_length)` representing the rejected + completion input IDs. + - `"prompt_pixel_values"` (optional): Tensor for pixel values, if available. + - `"prompt_pixel_attention_mask"` (optional): Tensor for pixel attention masks, if available. + + padding_value (`int`): + The padding value to use for the concatenated completion sequences (`chosen_input_ids` and + `rejected_input_ids`). + + Returns: + `dict[str, torch.LongTensor]`: A dictionary containing: + + - `"prompt_input_ids"`: Concatenated prompt input IDs of shape `(2 * batch_size, prompt_length)`. + - `"completion_input_ids"`: Concatenated chosen and rejected completion input IDs of shape `(2 * + batch_size, max_completion_length)`. + - `"prompt_attention_mask"`: Concatenated prompt attention masks of shape `(2 * batch_size, + prompt_length)`. + - `"completion_attention_mask"`: Concatenated chosen and rejected attention masks of shape `(2 * + batch_size, max_completion_length)`. + - `"pixel_values"` (optional): Concatenated pixel values if `"prompt_pixel_values"` are present. + - `"pixel_attention_mask"` (optional): Concatenated pixel attention masks if + `"prompt_pixel_attention_mask"` are present. + + Notes: + The completion input IDs and attention masks are padded to the maximum completion length of the chosen or + rejected sequences. + """ + output = {} + + # For the prompt, the input_ids are the same for both the chosen and rejected responses + output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0) + output["prompt_attention_mask"] = torch.cat( + [batch["prompt_attention_mask"], batch["prompt_attention_mask"]], dim=0 + ) + if "pixel_values" in batch: + output["pixel_values"] = torch.cat([batch["pixel_values"], batch["pixel_values"]], dim=0) + + if "pixel_attention_mask" in batch: + output["pixel_attention_mask"] = torch.cat( + [batch["pixel_attention_mask"], batch["pixel_attention_mask"]], dim=0 + ) + if "image_sizes" in batch: + output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0) + if "token_type_ids" in batch: + output["token_type_ids"] = torch.cat((batch["token_type_ids"], batch["token_type_ids"])) + + # Concatenate the chosen and rejected completions + max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + output["completion_input_ids"] = torch.cat( + ( + pad_to_length(batch["chosen_input_ids"], max_completion_length, pad_value=padding_value), + pad_to_length(batch["rejected_input_ids"], max_completion_length, pad_value=padding_value), + ), + ) + output["completion_attention_mask"] = torch.cat( + ( + pad_to_length(batch["chosen_attention_mask"], max_completion_length, pad_value=0), + pad_to_length(batch["rejected_attention_mask"], max_completion_length, pad_value=0), + ), + ) + + return output + + def dpo_loss( + self, + chosen_logps: torch.FloatTensor, + rejected_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, + loss_type: str = "sigmoid", + model_output: dict[str, torch.FloatTensor] = None, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """ + Compute the DPO loss for a batch of policy and reference model log probabilities. + + Args: + chosen_logps (`torch.FloatTensor`): + Log probabilities of the model for the chosen responses. Shape: `(batch_size,)`. + rejected_logps (`torch.FloatTensor`): + Log probabilities of the model for the rejected responses. Shape: `(batch_size,)`. + ref_chosen_logps (`torch.FloatTensor`): + Log probabilities of the reference model for the chosen responses. Shape: `(batch_size,)`. + ref_rejected_logps (`torch.FloatTensor`): + Log probabilities of the reference model for the rejected responses. Shape: `(batch_size,)`. + loss_type (`str`, defaults to `"sigmoid"`): + The type of loss to compute. One of: + - `"sigmoid"`: Sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. + - `"hinge"`: Hinge loss on the normalized likelihood from the + [SLiC](https://huggingface.co/papers/2305.10425) paper. + - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. + - `"exo_pair"`: Pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper. + - `"nca_pair"`: Pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper. + - `"robust"`: Unbiased estimate of the DPO loss that is robust to preference noise from the [Robust + DPO](https://huggingface.co/papers/2403.00409) paper. + - `"bco_pair"`: Pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper. + - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675) + paper. + - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. + - `"aot_unpaired"`: AOT loss for unpaired datasets from the + [AOT](https://huggingface.co/papers/2406.05882) paper. + - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the + [DiscoPOP](https://huggingface.co/papers/2406.08414) paper. + - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. + - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. + - `"sft"`: Negative log-likelihood loss (standard supervised fine-tuning loss). + model_output (`dict[str, torch.FloatTensor]`, *optional*): + The output of the model's forward pass. This is used to compute auxiliary losses if enabled. + + Returns: + A tuple of three tensors: `(losses, chosen_rewards, rejected_rewards)`. The losses tensor contains the DPO + loss for each example in the batch. The `chosen_rewards` and `rejected_rewards` tensors contain the rewards + for the chosen and rejected responses, respectively. + """ + device = self.accelerator.device + + # Get the log ratios for the chosen and rejected responses + chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device) + rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device) + + if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE: + # The alpha-divergence formula: (1 - u^-alpha) / alpha + # The divergence difference between the chosen and rejected sample is: + # (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha + # = (u[l]^-alpha - u[w]^-alpha) / alpha + # where u[w] and u[l] are the policy/reference probability ratios + # for the chosen and rejected samples, respectively. + alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT + if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params: + alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY]) + logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef + else: + logratios = chosen_logps - rejected_logps + if self.reference_free: + ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device) + else: + ref_logratios = ref_chosen_logps - ref_rejected_logps + + logratios = logratios.to(self.accelerator.device) + ref_logratios = ref_logratios.to(self.accelerator.device) + logits = logratios - ref_logratios + + if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE: + # The js-divergence formula: log(2 * u / (1 + u)) + # The divergence difference between the chosen and rejected sample is: + # log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l])) + # = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l])) + # where u[w] and u[l] are the policy/reference probability ratios + # for the chosen and rejected samples, respectively. + logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios) + + # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. + # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the + # labels and calculates a conservative DPO loss. + if loss_type == "sigmoid": + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + + elif loss_type == "robust": + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + + F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) / (1 - 2 * self.label_smoothing) + + elif loss_type == "exo_pair": + # eqn (16) of the EXO paper: https://huggingface.co/papers/2402.00856 + import math + + if self.label_smoothing == 0: + self.label_smoothing = 1e-3 + losses = (self.beta * logits).sigmoid() * ( + F.logsigmoid(self.beta * logits) - math.log(1 - self.label_smoothing) + ) + (-self.beta * logits).sigmoid() * (F.logsigmoid(-self.beta * logits) - math.log(self.label_smoothing)) + + elif loss_type == "hinge": + losses = torch.relu(1 - self.beta * logits) + + elif loss_type == "ipo": + # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. + losses = (logits - 1 / (2 * self.beta)) ** 2 + + elif loss_type == "bco_pair": + chosen_logratios = chosen_logps - ref_chosen_logps + rejected_logratios = rejected_logps - ref_rejected_logps + chosen_rewards = self.beta * chosen_logratios + rejected_rewards = self.beta * rejected_logratios + rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach() + self.running.update(rewards) + delta = self.running.mean + losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid( + -(self.beta * rejected_logratios - delta) + ) + + elif loss_type == "sppo_hard": + # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach, + # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. + # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is + # set to 1 for the winner and 0 for the loser. + a = chosen_logps - ref_chosen_logps + b = rejected_logps - ref_rejected_logps + losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2 + + elif loss_type == "nca_pair": + chosen_rewards = (chosen_logps - ref_chosen_logps) * self.beta + rejected_rewards = (rejected_logps - ref_rejected_logps) * self.beta + losses = ( + -F.logsigmoid(chosen_rewards) + - 0.5 * F.logsigmoid(-chosen_rewards) + - 0.5 * F.logsigmoid(-rejected_rewards) + ) + + elif loss_type == "aot_unpaired": + chosen_logratios = chosen_logps - ref_chosen_logps + rejected_logratios = rejected_logps - ref_rejected_logps + chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0) + rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0) + delta = chosen_logratios_sorted - rejected_logratios_sorted + losses = ( + -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * delta) * self.label_smoothing + ) + + elif loss_type == "aot": + logratios = chosen_logps - rejected_logps + ref_logratios = ref_chosen_logps - ref_rejected_logps + logratios_sorted, _ = torch.sort(logratios, dim=0) + ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0) + delta = logratios_sorted - ref_logratios_sorted + losses = ( + -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * delta) * self.label_smoothing + ) + + elif loss_type == "apo_zero": + # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) + # Use this loss when you believe the chosen outputs are better than your model's default output + losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood + losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood + losses = losses_chosen + losses_rejected + + elif loss_type == "apo_down": + # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266) + # Use this loss when you believe the chosen outputs are worse than your model's default output. + # Decrease chosen likelihood and decrease rejected likelihood more + losses_chosen = F.sigmoid(self.beta * chosen_logratios) + losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) + losses = losses_chosen + losses_rejected + + elif loss_type == "discopop": + # Eqn (5) of the DiscoPOP paper (https://huggingface.co/papers/2406.08414) + # This loss was discovered with LLM discovery + logratios = chosen_logps - rejected_logps + ref_logratios = ref_chosen_logps - ref_rejected_logps + logits = logratios - ref_logratios + logits = logits * self.beta + # Modulate the mixing coefficient based on the log ratio magnitudes + log_ratio_modulation = torch.sigmoid(logits / self.args.discopop_tau) + logistic_component = -F.logsigmoid(logits) + exp_component = torch.exp(-logits) + # Blend between logistic and exponential component based on log ratio modulation + losses = logistic_component * (1 - log_ratio_modulation) + exp_component * log_ratio_modulation + + elif loss_type == "sft": + # SFT loss is the negative log likelihood loss on chosen responses + # This acts as the generation loss component in MPO + sft_loss = model_output["nll_loss"] + # Create losses tensor with same shape as other losses (per-sample) + batch_size = chosen_logps.shape[0] + losses = sft_loss.expand(batch_size) + # For SFT, we don't have preference rewards, so use zeros + chosen_rewards = torch.zeros_like(chosen_logps) + rejected_rewards = torch.zeros_like(rejected_logps) + + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', " + "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_unpaired', 'discopop', 'apo_zero', " + "'apo_down', 'sft']" + ) + + chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach() + rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach() + + return losses, chosen_rewards, rejected_rewards + + def _compute_loss_liger( + self, model: nn.Module, batch: dict[str, list | torch.LongTensor] + ) -> dict[str, torch.Tensor]: + unwrapped_model = self.accelerator.unwrap_model(model) + concatenated_batch = self.concatenated_inputs(batch, padding_value=self.pad_token_id) + + model_kwargs = {} + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + # Add the pixel values and attention masks for vision models + if "pixel_values" in concatenated_batch: + model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] + if "pixel_attention_mask" in concatenated_batch: + model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] + if "image_sizes" in concatenated_batch: + model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] + + prompt_attention_mask = concatenated_batch["prompt_attention_mask"] + completion_attention_mask = concatenated_batch["completion_attention_mask"] + + if self.is_encoder_decoder: + # 1. Get encoder outputs + encoder_outputs = unwrapped_model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + # 2. Prepare decoder inputs + decoder_input_ids = shift_tokens_right( + concatenated_batch["completion_input_ids"], + unwrapped_model.config.decoder_start_token_id, + ) + # 3. Get decoder outputs + decoder_outputs = unwrapped_model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + hidden_states = decoder_outputs.last_hidden_state + + ref_hidden_states = None + if not self.reference_free and self.ref_model is not None: + unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) + ref_encoder_outputs = unwrapped_ref_model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + ref_decoder_outputs = unwrapped_ref_model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=ref_encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + ref_hidden_states = ref_decoder_outputs.last_hidden_state + elif not self.reference_free: + with self.null_ref_context(): + ref_encoder_outputs = unwrapped_model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + ref_decoder_outputs = unwrapped_model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=ref_encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + ref_hidden_states = ref_decoder_outputs.last_hidden_state + + labels = concatenated_batch["completion_input_ids"] + loss_mask = completion_attention_mask.bool() + else: + # For decoder-only models + input_ids = torch.cat( + (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1 + ) + attention_mask = torch.cat( + (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]), + dim=1, + ) + # Mask the prompt but not the completion for the loss + loss_mask = torch.cat( + (torch.zeros_like(prompt_attention_mask), completion_attention_mask), + dim=1, + ) + + # Flush and truncate + if self.max_length is not None and self.max_length < attention_mask.size(1): + if self.truncation_mode == "keep_start": + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) + attention_mask = attention_mask[:, : self.max_length] + input_ids = input_ids[:, : self.max_length] + loss_mask = loss_mask[:, : self.max_length] + elif self.truncation_mode == "keep_end": + # Flush right before truncating left, then flush left + # [[0, 0, x, x, x, x], -> [[0, 0, x, x], + # [0, x, x, x, 0, 0]] [0, x, x, x]] + attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) + input_ids = input_ids[:, -self.max_length :] + attention_mask = attention_mask[:, -self.max_length :] + loss_mask = loss_mask[:, -self.max_length :] + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) + else: + raise ValueError( + f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " + "'keep_start']." + ) + else: + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) + + # Add logits_to_keep optimization + if self.use_logits_to_keep: + first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() + logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 + model_kwargs["logits_to_keep"] = logits_to_keep + + model_kwargs["output_hidden_states"] = True + + # Add padding-free training support + if self.padding_free: + input_ids = input_ids[attention_mask.bool()].unsqueeze(0) + loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) + position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 + model_kwargs["position_ids"] = position_ids + else: + model_kwargs["attention_mask"] = attention_mask + + # Get the base model outputs (before LM head) + if hasattr(unwrapped_model, "get_decoder") and unwrapped_model.get_decoder() is not None: + base_model = unwrapped_model.get_decoder() + else: + base_attr = getattr(unwrapped_model, "base_model_prefix", self.args.base_model_attribute_name) + base_model = getattr(unwrapped_model, base_attr, unwrapped_model) + + outputs = base_model( + input_ids, + use_cache=False, + **model_kwargs, + ) + hidden_states = outputs.last_hidden_state[:, :-1] + + # Get reference hidden states if needed + ref_hidden_states = None + if not self.reference_free and self.ref_model is not None: + unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) + if hasattr(unwrapped_ref_model, "get_decoder") and unwrapped_ref_model.get_decoder() is not None: + ref_base_model = unwrapped_ref_model.get_decoder() + else: + ref_attr = getattr(unwrapped_ref_model, "base_model_prefix", self.args.base_model_attribute_name) + ref_base_model = getattr(unwrapped_ref_model, ref_attr, unwrapped_ref_model) + + ref_outputs = ref_base_model( + input_ids, + use_cache=False, + **model_kwargs, + ) + ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + elif not self.reference_free: + if hasattr(unwrapped_model, "get_decoder") and unwrapped_model.get_decoder() is not None: + ref_base_model = unwrapped_model.get_decoder() + else: + ref_attr = getattr(unwrapped_model, "base_model_prefix", self.args.base_model_attribute_name) + ref_base_model = getattr(unwrapped_model, ref_attr, unwrapped_model) + with self.null_ref_context(): + ref_outputs = ref_base_model( + input_ids, + use_cache=False, + **model_kwargs, + ) + ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + + masked_input_ids = torch.where(loss_mask != 0, input_ids, self.label_pad_token_id) + labels = masked_input_ids[:, 1:] # Shift right for casual LM + + # Get the LM head + lm_head = unwrapped_model.get_output_embeddings() + + # Get reference model weights if needed + ref_weight = None + ref_bias = None + if not self.reference_free: + if self.ref_model is not None: + unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) + ref_lm_head = unwrapped_ref_model.get_output_embeddings() + else: + with self.null_ref_context(): + ref_lm_head = unwrapped_model.get_output_embeddings() + ref_weight = ref_lm_head.weight + ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None + + # Compute loss using Liger kernel + loss_output = self.dpo_loss_fn( + lm_head.weight, + hidden_states, + labels, + bias=lm_head.bias if hasattr(lm_head, "bias") else None, + ref_input=ref_hidden_states if not self.reference_free else None, + ref_weight=ref_weight if not self.reference_free else None, + ref_bias=ref_bias if not self.reference_free else None, + ) + ( + loss, + (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs), + ) = loss_output + + output = { + "loss": loss, + "chosen_logps": chosen_logps, + "rejected_logps": rejected_logps, + "mean_chosen_logits": chosen_logits_mean, + "mean_rejected_logits": rejected_logits_mean, + "nll_loss": nll_loss, + "chosen_rewards": aux_outputs[0], + "rejected_rewards": aux_outputs[1], + } + if self.aux_loss_enabled: + output["aux_loss"] = outputs.aux_loss + + return output + + def concatenated_forward( + self, model: nn.Module, batch: dict[str, list | torch.LongTensor], is_ref_model: bool = False + ) -> dict[str, torch.Tensor]: + """ + Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + + Args: + model: + Model to run the forward pass on. + batch: + Batch of input data. + is_ref_model: + Whether this method is being called for the reference model. If `True`, length desensitization is not + applied. + """ + num_examples = batch["prompt_input_ids"].shape[0] + + concatenated_batch = self.concatenated_inputs(batch, padding_value=self.pad_token_id) + + model_kwargs = {"use_cache": False} + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + # Add the pixel values and attention masks for vision models + if "pixel_values" in concatenated_batch: + model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] + if "pixel_attention_mask" in concatenated_batch: + model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] + if "image_sizes" in concatenated_batch: + model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] + + prompt_input_ids = concatenated_batch["prompt_input_ids"] + prompt_attention_mask = concatenated_batch["prompt_attention_mask"] + completion_input_ids = concatenated_batch["completion_input_ids"] + completion_attention_mask = concatenated_batch["completion_attention_mask"] + if self.is_encoder_decoder: + labels = completion_input_ids + labels[completion_attention_mask == 0] = self.label_pad_token_id + outputs = model( + input_ids=prompt_input_ids, + attention_mask=prompt_attention_mask, + labels=labels, # we need the labels for the logits to be returned + **model_kwargs, + ) + logits = outputs.logits + loss_mask = completion_attention_mask.bool() + else: + # Concatenate the prompt and completion inputs + input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1) + attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1) + if "token_type_ids" in concatenated_batch: + prompt_token_type_ids = concatenated_batch["token_type_ids"] + token_type_ids = pad_to_length(prompt_token_type_ids, input_ids.shape[1], 0) + # Mask the prompt but not the completion for the loss + loss_mask = torch.cat( + (torch.zeros_like(prompt_attention_mask), completion_attention_mask), + dim=1, + ) + + # Flush and truncate + if self.max_length is not None and self.max_length < attention_mask.size(1): + if self.truncation_mode == "keep_start": + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + if "token_type_ids" in concatenated_batch: + attention_mask, input_ids, loss_mask, token_type_ids = flush_left( + attention_mask, input_ids, loss_mask, token_type_ids + ) + else: + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) + attention_mask = attention_mask[:, : self.max_length] + input_ids = input_ids[:, : self.max_length] + loss_mask = loss_mask[:, : self.max_length] + elif self.truncation_mode == "keep_end": + # Flush right before truncating left, then flush left + # [[0, 0, x, x, x, x], -> [[0, 0, x, x], + # [0, x, x, x, 0, 0]] [0, x, x, x]] + if "token_type_ids" in concatenated_batch: + attention_mask, input_ids, loss_mask, token_type_ids = flush_left( + attention_mask, input_ids, loss_mask, token_type_ids + ) + token_type_ids = token_type_ids[:, -self.max_length :] + else: + attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) + input_ids = input_ids[:, -self.max_length :] + attention_mask = attention_mask[:, -self.max_length :] + loss_mask = loss_mask[:, -self.max_length :] + if "token_type_ids" in concatenated_batch: + attention_mask, input_ids, loss_mask, token_type_ids = flush_left( + attention_mask, input_ids, loss_mask, token_type_ids + ) + else: + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) + else: + raise ValueError( + f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " + "'keep_start']." + ) + else: + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + if "token_type_ids" in concatenated_batch: + attention_mask, input_ids, loss_mask, token_type_ids = flush_left( + attention_mask, input_ids, loss_mask, token_type_ids + ) + else: + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) + + if "token_type_ids" in concatenated_batch: + model_kwargs["token_type_ids"] = token_type_ids + + if self.use_logits_to_keep: + # Compute logits_to_keep based on loss_mask pattern: + # [[0, 0, 0, x, x, x, x], + # [0, 0, 0, x, x, x, 0]] + # ^ start computing logits from here ([:, -(7-3+1):]) + first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() + logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label + model_kwargs["logits_to_keep"] = logits_to_keep + + model_kwargs["output_hidden_states"] = True + + if self.padding_free: + # Flatten the input_ids, position_ids, and loss_mask + # input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]] + # [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]] + input_ids = input_ids[attention_mask.bool()].unsqueeze(0) + loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) + position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 + model_kwargs["position_ids"] = position_ids + else: + model_kwargs["attention_mask"] = attention_mask + + outputs = model(input_ids, **model_kwargs) + logits = outputs.logits + + # Offset the logits by one to align with the labels + labels = torch.roll(input_ids, shifts=-1, dims=1) + loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() + + if self.use_logits_to_keep: + # Align labels with logits + # logits: -, -, [x2, x3, x4, x5, x6] + # ^ --------- ^ after logits[:, :-1, :] + # labels: [y0, y1, y2, y3, y4, y5, y6] + # ^ --------- ^ with logits_to_keep=4, [:, -4:] + # loss_mask: [0, 0, 0, 1, 1, 1, 1] + labels = labels[:, -logits_to_keep:] + loss_mask = loss_mask[:, -logits_to_keep:] + + if logits.shape[:2] != labels.shape[:2]: + # for LLaVA, the returned logits include the image tokens (placed before the text tokens) + seq_len = labels.shape[1] + logits = logits[:, -seq_len:] + + # Compute the log probabilities of the labels + labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later + per_token_logps = selective_log_softmax(logits, labels) + per_token_logps[~loss_mask] = 0 + per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) + + if self.padding_free: + # Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len]) + batch_size, seq_len = attention_mask.shape + per_token_logps_ = torch.zeros( + batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype + ) + per_token_logps_[attention_mask.bool()] = per_token_logps + per_token_logps = per_token_logps_ + + all_logps = per_token_logps[:, 1:].sum(-1) + + output = {} + + if self.use_weighting: + with torch.no_grad(): + # Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827 + logprobs = F.log_softmax(logits, dim=-1) + weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1) # same as sum(probs**2) in log space + per_token_logps_adjusted = per_token_logps - weights_adjustment_factor + all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1) + chosen_weights = all_weights[:num_examples] + rejected_weights = all_weights[num_examples:] + output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) + + if self.args.rpo_alpha is not None or "sft" in self.loss_type: + # Only use the chosen logits for the RPO loss or SFT loss + chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples] + chosen_labels = labels[:num_examples, :-1] if not self.is_encoder_decoder else labels[:num_examples] + + # Compute the log probabilities of the labels + output["nll_loss"] = F.cross_entropy( + torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0 + ) + + if "ipo" in self.loss_type: + all_logps = all_logps / loss_mask.sum(-1) + + if self.args.ld_alpha is not None and not is_ref_model: + # Compute response lengths based on loss_mask + completion_lengths = loss_mask.sum(dim=1) + + chosen_lengths = completion_lengths[:num_examples] + rejected_lengths = completion_lengths[num_examples:] + public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper + public_lengths = torch.cat([public_lengths, public_lengths], dim=0) + + seq_len = per_token_logps.size(1) + position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps) + + ld_mask = position_ids < public_lengths.unsqueeze(1) + mask = position_ids < completion_lengths.unsqueeze(1) + + front_mask = (ld_mask & mask).float() + rear_mask = (~ld_mask & mask).float() + front_logps = (per_token_logps * front_mask).sum(dim=1) + rear_logps = (per_token_logps * rear_mask).sum(dim=1) + + all_logps = front_logps + self.args.ld_alpha * rear_logps + + output["chosen_logps"] = all_logps[:num_examples] + output["rejected_logps"] = all_logps[num_examples:] + + # Compute the mean logits + if self.padding_free: + # position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]). + # There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens, + # and the second half to the rejected tokens. + # To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id. + split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples] + mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean() + mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean() + else: + mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() + mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean() + + output["mean_chosen_logits"] = mean_chosen_logits + output["mean_rejected_logits"] = mean_rejected_logits + + if self.aux_loss_enabled: + output["aux_loss"] = outputs.aux_loss + + return output + + def get_batch_loss_metrics( + self, + model: PreTrainedModel | nn.Module, + batch: dict[str, list | torch.LongTensor], + train_eval: Literal["train", "eval"] = "train", + ) -> tuple[torch.Tensor, dict[str, float]]: + """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + if self.args.use_liger_kernel: + model_output = self._compute_loss_liger(model, batch) + losses = model_output["loss"] + chosen_rewards = model_output["chosen_rewards"] + rejected_rewards = model_output["rejected_rewards"] + else: + model_output = self.concatenated_forward(model, batch) + + # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model + if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: + ref_chosen_logps = batch["ref_chosen_logps"] + ref_rejected_logps = batch["ref_rejected_logps"] + else: + ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) + + # Initialize combined losses + losses = 0 + chosen_rewards = 0 + rejected_rewards = 0 + + # Compute losses for each loss type + for idx, loss_type in enumerate(self.loss_type): + # Compute individual loss using standard DPO loss function + _losses, _chosen_rewards, _rejected_rewards = self.dpo_loss( + model_output["chosen_logps"], + model_output["rejected_logps"], + ref_chosen_logps, + ref_rejected_logps, + loss_type, + model_output, + ) + + # Add weighted contributions + weight = self.loss_weights[idx] if self.loss_weights else 1.0 + losses = losses + _losses * weight + chosen_rewards = chosen_rewards + _chosen_rewards * weight + rejected_rewards = rejected_rewards + _rejected_rewards * weight + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + if self.args.rpo_alpha is not None: + losses = losses + self.args.rpo_alpha * model_output["nll_loss"] # RPO loss from V3 of the paper + + if self.use_weighting: + losses = losses * model_output["policy_weights"] + + if self.aux_loss_enabled: + losses = losses + self.aux_loss_coef * model_output["aux_loss"] + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() + metrics[f"{prefix}rewards/margins"] = ( + self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() + ) + metrics[f"{prefix}logps/chosen"] = ( + self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item() + ) + metrics[f"{prefix}logps/rejected"] = ( + self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item() + ) + metrics[f"{prefix}logits/chosen"] = ( + self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item() + ) + metrics[f"{prefix}logits/rejected"] = ( + self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item() + ) + if self.args.rpo_alpha is not None or "sft" in self.loss_type: + metrics[f"{prefix}nll_loss"] = ( + self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item() + ) + if self.aux_loss_enabled: + metrics[f"{prefix}aux_loss"] = ( + self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item() + ) + + return losses.mean(), metrics + + def compute_loss( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + return_outputs=False, + num_items_in_batch=None, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]: + compute_loss_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + with compute_loss_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + + # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: + loss = loss.to(self.args.device) + # force log the metrics + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return loss, metrics + + return loss + + def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with generate_context_manager: + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.pad_token_id, + ) + + # if ref_output in batch use that otherwise use the reference model + if "ref_output" in batch: + ref_output = batch["ref_output"] + else: + if self.ref_model is None: + with self.null_ref_context(): + ref_output = self.model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.pad_token_id, + ) + else: + ref_output = self.ref_model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.pad_token_id) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + ref_output = pad_to_length(ref_output, self.max_length, self.pad_token_id) + ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) + + return policy_output_decoded, ref_output_decoded + + def prediction_step( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + prediction_loss_only: bool, + ignore_keys: list[str] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with torch.no_grad(), prediction_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") + + # force log the metrics + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return loss.detach(), None, None + + # logits for the chosen and rejected samples from model + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + logits = [v for k, v in logits_dict.items() if k not in ignore_keys] + logits = torch.tensor(logits, device=self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: bool | None = None, + ignore_keys: list[str] | None = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by + `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch) + + table = pd.DataFrame( + columns=["Prompt", "Policy", "Ref Model"], + data=[ + [prompt, pol[len(prompt) :], ref[len(prompt) :]] + for prompt, pol, ref in zip( + random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded, strict=True + ) + ], + ) + if "wandb" in self.args.report_to and self.accelerator.is_main_process: + wandb.log({"game_log": wandb.Table(data=table)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="game_log.csv", + table=table, + ) + + if "mlflow" in self.args.report_to and self.accelerator.is_main_process: + mlflow.log_table(data=table, artifact_file="game_log.json") + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`dict[str, float]`): + The values to log. + start_time (`float`, *optional*): + Start time of the training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/ICL/RL/trl_source/trl/trainer/grpo_config.py b/ICL/RL/trl_source/trl/trainer/grpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..38dc4d5d115722794b748e7dc5843c6b4f4bfb49 --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/grpo_config.py @@ -0,0 +1,899 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from transformers import TrainingArguments + + +@dataclass +class GRPOConfig(TrainingArguments): + r""" + Configuration class for the [`GRPOTrainer`]. + + This class includes only the parameters that are specific to GRPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model and reference model + + model_init_kwargs (`str`, `dict[str, Any]`, *optional*): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`GRPOTrainer`] is provided as a string. + disable_dropout (`bool`, *optional*, defaults to `False`): + Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents + the model from generating different logprobs for the same input. + cast_lm_head_to_fp32 (`bool`, *optional*, defaults to `False`): + Whether to cast the language modeling head of the policy and reference models to float32. As recommended by + the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe. This flag is only supported when the model + has untied word embedding and language modeling head layers i.e. `tie_word_embeddings` in the model config + is False. + + > Parameters that control the data preprocessing + + remove_unused_columns (`bool`, *optional*, defaults to `False`): + Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that + requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`. + num_generations (`int`, *optional*, defaults to `8`): + Number of generations per prompt to sample. The effective batch size (num_processes * per_device_batch_size + * gradient_accumulation_steps) must be evenly divisible by this value. + num_generations_eval (`int` or `None`, *optional*): + Number of generations to sample during evaluation. This allows using fewer generations during evaluation to + save computation. If `None`, uses the value of `num_generations`. + max_completion_length (`int` or `None`, *optional*, defaults to `256`): + Maximum length of the generated completion. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible + with vLLM generation. + shuffle_dataset (`bool`, *optional*, defaults to `True`): + Whether to shuffle the training dataset. + + > Parameters that control generation + + generation_batch_size: (`int`, *optional*): + Batch size to use for generation. If `None`, it defaults to the effective training batch size: + `per_device_train_batch_size * num_processes * steps_per_generation`. In other words, there is one + generation batch processed per optimization step. Mutually exclusive with `steps_per_generation`. + steps_per_generation: (`int`, *optional*): + Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`. Mutually exclusive + with `generation_batch_size`. + temperature (`float`, defaults to `1.0`): + Temperature for sampling. The higher the temperature, the more random the completions. + top_p (`float`, *optional*, defaults to `1.0`): + Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to + `1.0` to consider all tokens. + top_k (`int`, *optional*, defaults to `0`): + Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, top-k-filtering is + disabled and all tokens are considered. + min_p (`float`, *optional*): + Minimum token probability, which will be scaled by the probability of the most likely token. It must be a + value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. + generation_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or + `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the + generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict + with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. + chat_template_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to the `apply_chat_template` function when generating completions. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. + Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat + tokens. + use_transformers_paged (`bool`, *optional*, defaults to `False`): + Whether to use the `transformers` paged implementation for generation. If set to `True`, the `transformers` + paged implementation will be used for generation instead of the default padded implementation. This + parameter is only effective when `use_vllm` is set to `False`. + cache_implementation (`str`, *optional*): + Implementation of the cache method for faster generation when `use_vllm` is set to `False`. + + > Parameters that control generation acceleration powered by vLLM + + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation + instead of the default model.generate(). Requires `vllm` to be installed. + vllm_mode (`str`, *optional*, defaults to `"server"`): + Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or + `"colocate"`. + + - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM + server is running (start with `trl vllm-serve`). + - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a + separate server but may cause resource contention with training. + vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): + Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use + the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model + implementation. + vllm_structured_outputs_regex (`str`, *optional*): + Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled. + + > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) + + vllm_server_base_url (`str`, *optional*): + Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and + `vllm_server_port` are ignored. + vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. + vllm_server_port (`int`, *optional*, defaults to `8000`): + Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. + vllm_server_timeout (`float`, *optional*, defaults to `240.0`): + Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the + timeout, a `ConnectionError` is raised. + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port + is occupied, there is no need to change it. + + > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) + + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`): + Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to + `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when + launching the vLLM server via the `--vllm_gpu_memory_utilization` flag. + vllm_max_model_length (`int`, *optional*): + Context window for vLLM. Set it to at least the maximum prompt length in the dataset plus + `max_completion_length`; if omitted, it is inferred from the model config. + vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): + Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to + `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when + launching the vLLM server via the `--vllm_tensor_parallel_size` flag. + vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): + Enable vLLM sleep mode to offload weights/cache during the optimizer step. Keeps GPU memory usage low, but + waking the engine adds hostโ€“device transfer latency. + + > Parameters that control the training + + beta (`float`, *optional*, defaults to `0.0`): + KL coefficient. If `0.0` (default), the reference model is not loaded, reducing memory usage and improving + training speed. [DeepSeek-R1 incentivizes reasoning in LLMs through reinforcement + learning](https://huggingface.co/papers/2501.12948) use a value of `0.001`. + num_iterations (`int`, *optional*, defaults to `1`): + Number of iterations per batch (denoted as ฮผ in the algorithm). + epsilon (`float`, *optional*, defaults to `0.2`): + Epsilon value for clipping. + delta (`float`, *optional*): + Enables the upper clipping bound in two-sided GRPO loss when set to a float. If `None` (default), standard + GRPO clipping is used. Recommended to be greater than `1 + ฮต` when enabled. This method is introduced in + the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). + epsilon_high (`float`, *optional*): + Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound + specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. + When used with `loss_type='cispo'`, this corresponds to the ฮต_max param specified in the [ScaleRL + paper](https://huggingface.co/papers/2510.13786) and the recommended value is `5.0`. + sapo_temperature_neg (`float`, *optional*, defaults to `1.05`): + Temperature for tokens with non-positive advantage scores used in the `sapo` loss function. This parameter + is introduced in the [Soft Adaptive Policy Optimization paper](https://huggingface.co/papers/2511.20347). + sapo_temperature_pos (`float`, *optional*, defaults to `1.0`): + Temperature for tokens with positive advantage scores used in the `sapo` loss function. This parameter is + introduced in the [Soft Adaptive Policy Optimization paper](https://huggingface.co/papers/2511.20347). + importance_sampling_level (`str`, *optional*, defaults to `"token"`): + Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"` + keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the + log-probability ratios across valid tokens to produce a single ratio per sequence. The [GSPO + paper](https://huggingface.co/papers/2507.18071) shows that sequence-level sampling often yields more + stable training and better alignment with sequence-level rewards. + reward_weights (`list[float]`, *optional*): + Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are + weighted equally with weight `1.0`. + multi_objective_aggregation (`str`, *optional*, defaults to `"sum_then_normalize"`): + Method to aggregate multiple reward functions. Supported values are: + + - `"sum_then_normalize"` (default): First sums the weighted rewards from each reward function, then applies + reward scaling/normalization as specified by `scale_rewards` (see `scale_rewards` for details). + - `"normalize_then_sum"`: First normalizes/scales each reward function across generations (within each + group), then sums the normalized rewards using the specified weights. The aggregated reward is then + normalized at the batch level when forming advantages. This is the suggested approach from the paper + [GDPO: Group reward-Decoupled Normalization Policy Optimization for Multi-reward RL + Optimization](https://huggingface.co/papers/2601.05242). + scale_rewards (`str` or `bool`, *optional*, defaults to `"group"`): + Specifies the scaling strategy for rewards. Supported values are: + + - `True` or `"group"` (default): rewards are scaled by the standard deviation within each group, ensuring + unit variance within a group. + - `"batch"`: rewards are scaled by the standard deviation across the entire batch, as recommended in the + [PPO Lite paper](https://huggingface.co/papers/2508.08221). + - `False` or `"none"`: no scaling is applied. The [Dr. GRPO + paper](https://huggingface.co/papers/2503.20783) recommends not scaling rewards, as scaling by the + standard deviation introduces a question-level difficulty bias. + loss_type (`str`, *optional*, defaults to `"dapo"`): + Specifies the loss formulation to use. Supported values are: + + - `"grpo"`: Aggregates token-level losses by normalizing over sequence length. Not recommended due to + length biasโ€”this approach tends to prefer shorter completions with positive advantages and longer ones + with negative advantages. + - `"dr_grpo"`: Aggregates token-level losses by normalizing with a global constant. This method was + introduced in the [Dr. GRPO paper](https://huggingface.co/papers/2503.20783) to eliminate length bias. + The value of the constant corresponds to `max_completion_length`. + - `"dapo"` (default): Aggregates token-level losses by normalizing with the number of active token in the + global accumulated batch. This method was introduced in the [DAPO + paper](https://huggingface.co/papers/2503.14476) to eliminate length bias. + - `"bnpo"`: Aggregates token-level losses by normalizing with the number of active token in the local + batch. Note that normalization is performed over the local batch only, so results may slightly vary + depending on the local batch size, despite a constant effective batch size. When using + `per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. + - `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. The + clipped weights are then multiplied with the advantages and policy model's log probs. Individual token + losses are aggregated by normalizing with the number of active tokens in the global accumulated batch. + This method was introduced in the [MiniMax-M1 paper](https://huggingface.co/papers/2506.13585). + - `"sapo"`: Soft Adaptive Policy Optimization loss, as introduced in the [Soft Adaptive Policy Optimization + paper](https://huggingface.co/papers/2511.20347). Replaces hard clipping with a smooth, + temperature-controlled gate that adaptively attenuates off-policy updates while preserving useful + learning signals. + - `"luspo"`: Length-Unbiased Sequence Policy Optimization loss. A sequence-level loss that scales each + sequence's loss by its length. This is a modification of GSPO and requires + `importance_sampling_level="sequence"`. Introduced in the [LUSPO + paper](https://huggingface.co/papers/2602.05261). + mask_truncated_completions (`bool`, *optional*, defaults to `False`): + When enabled, truncated completions are excluded from the loss calculation, preventing them from being + incorrectly penalized and introducing noise during training. According to the + [DAPO](https://huggingface.co/papers/2503.14476) paper, this is a good practice for training stability. + sync_ref_model (`bool`, *optional*, defaults to `False`): + Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using + the `ref_model_mixup_alpha` parameter. This synchronization originates from the + [TR-DPO](https://huggingface.co/papers/2404.09656) paper. + ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): + ฮฑ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix + between the current policy and the previous reference policy during updates. The reference policy is + updated according to the equation: `ฯ€_ref = ฮฑ * ฯ€_ฮธ + (1 - ฮฑ) * ฯ€_ref_prev`. To use this parameter, you + must set `sync_ref_model=True`. + ref_model_sync_steps (`int`, *optional*, defaults to `512`): + ฯ„ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how + frequently the current policy is synchronized with the reference policy. To use this parameter, you must + set `sync_ref_model=True`. + top_entropy_quantile (`float`, *optional*, defaults to `1.0`): + ฯ parameter from [Beyond the 80/20 Rule](https://huggingface.co/papers/2506.01939). Keeps in the policy + loss term only the top-ฯ quantile of tokens by entropy of the probability distribution at each sequence + position, improving results. Range: `[0.0-1.0]`. A value of `0.0` masks all but the highest entropy token; + `1.0` keeps all tokens. The paper recommends a value of `0.2`. If used with + `mask_truncated_completions=True`, only tokens from non-truncated completions are considered. + max_tool_calling_iterations (`int`, *optional*): + Maximum number of tool-calling turns when training an agent. If `None`, there is no limit and generation + stops when the model generates a response turn with no tool calls or when the total response length reaches + `max_model_length`. + vllm_importance_sampling_correction (`bool`, *optional*, defaults to `True`): + Whether to apply Importance Sampling (IS) to correct for the mismatch between vLLM completion logprobs and + recomputed training logprobs. If set to `False`, no IS is applied regardless of + `vllm_importance_sampling_mode`. When `True`, the selected mode determines how the IS ratios are computed + and constrained. + vllm_importance_sampling_mode (`str`, *optional*, defaults to `"sequence_mask"`): + Specifies how Importance Sampling is performed when `vllm_importance_sampling_correction=True`. Possible + values are: + + - `"token_truncate"`: Token-level truncated IS (default). Per-token ratios are clipped from above at C. + - `"token_mask"`: Token-level masked IS. Per-token ratios above C are set to zero. + - `"sequence_truncate"`: Sequence-level truncated IS. A single sequence ratio is clipped from above at + C and applied to all tokens in the sequence. + - `"sequence_mask"`: Sequence-level masked IS. Sequences with ratios above C are masked out. + vllm_importance_sampling_cap (`float`, *optional*, defaults to `3.0`): + Importance sampling cap C used by `vllm_importance_sampling_mode`. For `*_truncate` modes, importance + ratios are clipped from above at C. For `*_mask` modes, ratios larger than C are set to zero. + off_policy_mask_threshold (`float`, *optional*): + Threshold for off-policy sequence masking. If `None`, off-policy sequence masking is disabled. When set, + sequences with negative advantages and high KL divergence are masked out to stabilize training. This + parameter corresponds to the `delta` threshold in Equation 9 of the [DeepSeek-V3.2 + paper](https://huggingface.co/papers/2512.02556). It expects a positive value (e.g., 0.5). + use_bias_correction_kl (`bool`, *optional*, defaults to `False`): + Whether to use the unbiased KL divergence estimator with importance sampling correction. This corrects the + KL divergence estimate by multiplying it with the importance sampling ratio. This is described in the + [DeepSeek-V3.2 paper](https://huggingface.co/papers/2512.02556). + + > Parameters that control the logging + + log_completions (`bool`, *optional*, defaults to `False`): + Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed, + it prints the sample. If `wandb` and/or `trackio` logging is enabled, it logs it to `wandb` and/or + `trackio`. + num_completions_to_print (`int`, *optional*): + Number of completions to print with `rich`. If `None`, all completions are logged. + log_unique_prompts (`bool`, *optional*, defaults to `False`): + Whether to log unique prompts. If `True`, only unique prompts are logged. If `False`, all prompts are + logged. + log_completions_hub_repo (`str`, *optional*): + Hugging Face Hub repository to save the completions. Should be a complete repository name like + `'username/reponame'` or `'orgname/reponame'`, or just `'reponame'` in which case the repository will be + created in the currently-logged-in Hugging Face user's namespace. Note that this repository will be public + unless you set `hub_private_repo=True` or your organization's default is to create private repositories." + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue + # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary + # workaround here, which can be removed once we drop support for versions older than 4.57.5. + lr_scheduler_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " + "restarts." + }, + ) + + # Parameters that control the model and reference model + model_init_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` " + "argument of the `GRPOTrainer` is provided as a string." + }, + ) + disable_dropout: bool = field( + default=False, + metadata={ + "help": "Whether to disable dropout in the model. This is useful for training with a reference model, as " + "it prevents the model from generating different logprobs for the same input." + }, + ) + cast_lm_head_to_fp32: bool = field( + default=False, + metadata={ + "help": "Whether to cast the language modeling head of the policy and reference, models to float32." + "As recommended by the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe. This flag is only " + "supported when the model has untied word embedding and language modeling head layers i.e. " + "`tie_word_embeddings` in the model config is False." + }, + ) + + # Parameters that control the data preprocessing + # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on + # additional columns to compute the reward + remove_unused_columns: bool | None = field( + default=False, + metadata={ + "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function " + "that requires any column other than 'prompts' and 'completions', you should keep this to `False`." + }, + ) + num_generations: int | None = field( + default=8, + metadata={ + "help": "Number of generations to sample. The effective batch size (num_processes * per_device_batch_size " + "* gradient_accumulation_steps) must be evenly divisible by this value." + }, + ) + num_generations_eval: int | None = field( + default=None, + metadata={ + "help": "Number of generations to sample during evaluation. This allows using fewer generations during " + "evaluation to save computation. If `None`, uses the value of `num_generations`." + }, + ) + max_completion_length: int | None = field( + default=256, + metadata={"help": "Maximum length of the generated completion."}, + ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option " + "is not compatible with vLLM generation." + }, + ) + shuffle_dataset: bool | None = field( + default=True, + metadata={"help": "Whether to shuffle the training dataset."}, + ) + + # Parameters that control generation + generation_batch_size: int | None = field( + default=None, + metadata={ + "help": "Batch size to use for generation. If `None`, it defaults to the effective training batch size: " + "`per_device_train_batch_size * num_processes * steps_per_generation`." + }, + ) + steps_per_generation: int | None = field( + default=None, + metadata={"help": "Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`."}, + ) + temperature: float = field( + default=1.0, + metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, + ) + top_p: float = field( + default=1.0, + metadata={ + "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. " + "Set to 1.0 to consider all tokens." + }, + ) + top_k: int = field( + default=0, + metadata={ + "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, " + "top-k-filtering is disabled and all tokens are considered." + }, + ) + min_p: float | None = field( + default=None, + metadata={ + "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It " + "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range." + }, + ) + generation_kwargs: dict | None = field( + default=None, + metadata={ + "help": "Additional keyword arguments to pass to `GenerationConfig` (if using transformers) or " + "`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the " + "generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that " + "conflict with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them." + }, + ) + chat_template_kwargs: dict | None = field( + default=None, + metadata={ + "help": "Additional keyword arguments to pass to the `apply_chat_template` function when generating " + "completions." + }, + ) + repetition_penalty: float = field( + default=1.0, + metadata={ + "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated " + "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model " + "to repeat tokens." + }, + ) + use_transformers_paged: bool = field( + default=False, + metadata={ + "help": "Whether to use the `transformers` paged implementation for generation. If set to `True`, the " + "`transformers` paged implementation will be used for generation instead of the default padded " + "implementation. This parameter is only effective when `use_vllm` is set to `False`." + }, + ) + cache_implementation: str | None = field( + default=None, + metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."}, + ) + + # Parameters that control generation acceleration powered by vLLM + use_vllm: bool = field( + default=False, + metadata={ + "help": "Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for " + "generation instead of the default model.generate(). Requires `vllm` to be installed." + }, + ) + vllm_mode: str = field( + default="server", + metadata={ + "help": "Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `'server'` or " + "`'colocate'`. `'server'`: The trainer will send generation requests to a separate vLLM server. Make sure " + "a TRL vLLM server is running (start with `trl vllm-serve`). `'colocate'`: vLLM will run in the same " + "process and share the training GPUs. This avoids the need for a separate server but may cause resource " + "contention with training." + }, + ) + vllm_model_impl: str = field( + default="vllm", + metadata={ + "help": "Model implementation to use for vLLM. Must be one of `transformers` or `vllm`. `transformers`: " + "Use the `transformers` backend for model implementation. `vllm`: Use the `vllm` library for " + "model implementation." + }, + ) + vllm_enable_sleep_mode: bool = field( + default=False, + metadata={ + "help": "Enable vLLM sleep mode to offload weights/cache during the optimizer step. Keeps GPU memory " + "usage low, but waking the engine adds hostโ€“device transfer latency." + }, + ) + vllm_structured_outputs_regex: str | None = field( + default=None, + metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."}, + ) + + # Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) + vllm_server_base_url: str | None = field( + default=None, + metadata={ + "help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " + "and `vllm_server_port` are ignored." + }, + ) + vllm_server_host: str = field( + default="0.0.0.0", + metadata={"help": "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."}, + ) + vllm_server_port: int = field( + default=8000, + metadata={"help": "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."}, + ) + vllm_server_timeout: float = field( + default=240.0, + metadata={ + "help": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up " + "after the timeout, a `ConnectionError` is raised." + }, + ) + vllm_group_port: int = field( + default=51216, + metadata={ + "help": "Port number for the weight update group. This is used to communicate with the vLLM server. " + "Unless the port is occupied, there is no need to change it.", + }, + ) + + # Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) + vllm_gpu_memory_utilization: float = field( + default=0.3, + metadata={ + "help": "Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set " + "to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when " + "launching the vLLM server via the `--vllm_gpu_memory_utilization` flag." + }, + ) + vllm_max_model_length: int | None = field( + default=None, + metadata={ + "help": "Context window for vLLM. Set it to at least the maximum prompt length in the dataset plus " + "`max_completion_length`; if omitted, it is inferred from the model config." + }, + ) + vllm_tensor_parallel_size: int = field( + default=1, + metadata={ + "help": "Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set " + "to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when " + "launching the vLLM server via the `--vllm_tensor_parallel_size` flag." + }, + ) + + # Parameters that control the training + beta: float = field( + default=0.0, + metadata={ + "help": "KL coefficient. If `0.0` (default), the reference model is not loaded, reducing memory usage and " + "improving training speed. [DeepSeek-R1 incentivizes reasoning in LLMs through reinforcement " + "learning](https://huggingface.co/papers/2501.12948) use a value of `0.001`." + }, + ) + num_iterations: int = field( + default=1, + metadata={"help": "Number of iterations per batch (denoted as ฮผ in the algorithm)."}, + ) + epsilon: float = field( + default=0.2, + metadata={"help": "Epsilon value for clipping."}, + ) + delta: float | None = field( + default=None, + metadata={ + "help": "Enables the upper clipping bound in two-sided GRPO loss when set to a float. If `None` " + "(default), standard GRPO clipping is used. Recommended to be greater than `1 + ฮต` when enabled. This " + "method is introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)." + }, + ) + epsilon_high: float | None = field( + default=None, + metadata={ + "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the " + "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`. " + "When used with `loss_type='cispo'`, this corresponds to the ฮต_max param specified in the" + "[ScaleRL paper]https://huggingface.co/papers/2510.13786) and the recommended value is `5.0`." + }, + ) + sapo_temperature_neg: float = field( + default=1.05, + metadata={ + "help": "Temperature for tokens with non-positive advantage scores used in the `sapo` loss function. " + "This parameter is introduced in the [Soft Adaptive Policy Optimization " + "paper](https://huggingface.co/papers/2511.20347)." + }, + ) + sapo_temperature_pos: float = field( + default=1.0, + metadata={ + "help": "Temperature for tokens with positive advantage scores used in the `sapo` loss function. " + "This parameter is introduced in the [Soft Adaptive Policy Optimization " + "paper](https://huggingface.co/papers/2511.20347)." + }, + ) + importance_sampling_level: str = field( + default="token", + metadata={ + "help": "Controls whether importance sampling ratios are computed at the `'token'` or `'sequence'` level. " + "`'token'` keeps the raw per-token log-probability ratios (one weight per token). `'sequence'` averages " + "the log-probability ratios across valid tokens to produce a single ratio per sequence. The GSPO paper " + "shows that sequence-level sampling often yields more stable training and better alignment with " + "sequence-level rewards." + }, + ) + reward_weights: list[float] | None = field( + default=None, + metadata={ + "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all " + "rewards are weighted equally with weight `1.0`." + }, + ) + multi_objective_aggregation: str = field( + default="sum_then_normalize", + metadata={ + "help": "Method to aggregate multiple reward functions. Supported values are: " + "`'sum_then_normalize'` (default): First sums the weighted rewards from each reward function, then " + "applies reward scaling/normalization as specified by `scale_rewards` (see `scale_rewards` for details). " + "`'normalize_then_sum'`: First normalizes/scales each reward function across generations (within each " + "group), then sums the normalized rewards using the specified weights. The aggregated reward is then " + "normalized at the batch level when forming advantages. This is the suggested approach from the paper " + "GDPO: Group reward-Decoupled Normalization Policy Optimization for Multi-reward RL Optimization." + }, + ) + scale_rewards: str = field( + default="group", + metadata={ + "help": "Specifies the scaling strategy for rewards. Supported values are: " + "`True` or `group'` (default): rewards are scaled by the standard deviation within each group, ensuring " + "unit variance within a group. " + "`'batch'`: rewards are scaled by the standard deviation across the entire batch, as recommended in the " + "PPO Lite paper. " + "`False` or `'none'`: no scaling is applied. The Dr. GRPO paper recommends not scaling rewards, as " + "scaling by the standard deviation introduces a question-level difficulty bias." + }, + ) + loss_type: str = field( + default="dapo", + metadata={ + "help": "Specifies the loss formulation to use. Supported values are 'grpo', 'dapo', 'bnpo', and " + "'dr_grpo'. " + "'grpo': Aggregates token-level losses by normalizing over sequence length. Not recommended due to length " + "biasโ€”this approach tends to prefer shorter completions with positive advantages and longer ones with " + "negative advantages. " + "'dapo' (default): Aggregates token-level losses by normalizing with the number of active token in the " + "global accumulated batch. This method was introduced in the DAPO paper to eliminate length bias. " + "'dr_grpo': Aggregates token-level losses by normalizing with a global constant. This method was " + "introduced in the Dr. GRPO paper to eliminate length bias. The value of the constant corresponds to " + "`max_completion_length`. " + "'bnpo': Aggregates token-level losses by normalizing with the number of active token in the local batch. " + "Note that normalization is performed over the local batch only, so results may slightly vary depending " + "on the local batch size, despite a constant effective batch size. When using " + "`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss." + "'cispo': Clips the importance sampling weights instead of the advantage scaled importance weights. " + "The clipped weights are then multiplied with the advantages and policy model's log probs. " + "Individual token losses are aggregated by normalizing with the number of active tokens in " + "the global accumulated batch. This method was introduced in the " + "[MiniMax-M1 paper](https://huggingface.co/papers/2506.13585). " + "'sapo': Soft Adaptive Policy Optimization loss, as introduced in the " + "[Soft Adaptive Policy Optimization paper](https://huggingface.co/papers/2511.20347). " + "Replaces hard clipping with a smooth, temperature-controlled gate that adaptively attenuates " + "off-policy updates while preserving useful learning signals." + "'luspo': Length-Unbiased Sequence Policy Optimization loss. A sequence-level loss that scales each " + "sequence's loss by its length. This is a modification of GSPO and requires " + "`importance_sampling_level='sequence'`. Introduced in the [LUSPO " + "paper](https://huggingface.co/papers/2602.05261)." + }, + ) + mask_truncated_completions: bool = field( + default=False, + metadata={ + "help": "When enabled, truncated completions are excluded from the loss calculation, preventing them from " + "being incorrectly penalized and introducing noise during training. According to the DAPO paper, this is " + "a good practice for training stability." + }, + ) + sync_ref_model: bool = field( + default=False, + metadata={ + "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " + "steps, using the `ref_model_mixup_alpha` parameter." + }, + ) + ref_model_mixup_alpha: float = field( + default=0.6, + metadata={ + "help": "ฮฑ parameter from the TR-DPO paper, which controls the mix between the current policy and the " + "previous reference policy during updates. The reference policy is updated according to the equation: " + "`ฯ€_ref = ฮฑ * ฯ€_ฮธ + (1 - ฮฑ) * ฯ€_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + ref_model_sync_steps: int = field( + default=512, + metadata={ + "help": "ฯ„ parameter from the TR-DPO paper, which determines how frequently the current policy is " + "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + top_entropy_quantile: float = field( + default=1.0, + metadata={ + "help": "ฯ parameter from Beyond the 80/20 Rule. Keeps in the policy loss term only the top-ฯ quantile of " + "tokens by entropy of the probability distribution at each sequence position, improving results. Range: " + "[0.0-1.0]. A value of `0.0` masks all but the highest entropy token; `1.0` keeps all tokens. The paper " + "recommends a value of `0.2`. If used with `mask_truncated_completions=True`, only tokens from " + "non-truncated completions are considered." + }, + ) + max_tool_calling_iterations: int | None = field( + default=None, + metadata={ + "help": "Maximum number of tool-calling turns when training an agent. If `None`, there is no limit and " + "generation stops when the model generates a response turn with no tool calls or when the total " + "response length reaches `max_model_length`." + }, + ) + vllm_importance_sampling_correction: bool = field( + default=True, + metadata={ + "help": "Whether to apply Importance Sampling (IS) to correct for the mismatch between vLLM " + "completion logprobs and recomputed training logprobs. If set to `False`, no IS is applied " + "regardless of `vllm_importance_sampling_mode`. When `True`, the selected mode determines how " + "IS ratios are computed and constrained." + }, + ) + + vllm_importance_sampling_mode: str = field( + default="sequence_mask", + metadata={ + "help": "Specifies how Importance Sampling (IS) is performed when " + "vllm_importance_sampling_correction=True. Modes are defined along two orthogonal " + "dimensions: (1) constraint, which determines how to handle ratios above " + "vllm_importance_sampling_cap (C)โ€”either truncation (clip from above, ฯ โ† min(ฯ, C)) or " + "masking (set ratios above C to zero); and (2) granularity, which determines whether " + "ratios are computed per token or as a single sequence-level ratio applied to all tokens. " + "Supported options are: 'token_truncate', 'token_mask', 'sequence_truncate', and " + "'sequence_mask'." + }, + ) + + vllm_importance_sampling_cap: float = field( + default=3.0, + metadata={ + "help": "Importance sampling cap C used by `vllm_importance_sampling_mode`. For '*_truncate' modes, " + "ratios are clipped from above at C. For '*_mask' modes, ratios larger than C are set to zero." + }, + ) + off_policy_mask_threshold: float | None = field( + default=None, + metadata={ + "help": "Threshold for off-policy sequence masking. If `None`, off-policy sequence masking is disabled. " + "When set, sequences with negative advantages and high KL divergence are masked out to stabilize " + "training. This parameter corresponds to the `delta` threshold in Equation 9 of the [DeepSeek-V3.2 " + "paper](https://huggingface.co/papers/2512.02556). It expects a positive value (e.g., 0.5)." + }, + ) + use_bias_correction_kl: bool = field( + default=False, + metadata={ + "help": "Whether to use the unbiased KL divergence estimator with importance sampling correction. This " + "corrects the KL divergence estimate by multiplying it with the importance sampling ratio. " + "This is described in the [DeepSeek-V3.2 paper](https://huggingface.co/papers/2512.02556)." + }, + ) + + # Parameters that control the logging + log_completions: bool = field( + default=False, + metadata={ + "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is " + "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`." + }, + ) + num_completions_to_print: int | None = field( + default=None, + metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."}, + ) + log_unique_prompts: bool = field( + default=False, + metadata={ + "help": "Whether to log unique prompts. If `True`, only unique prompts are logged. If `False`, all " + "prompts are logged." + }, + ) + log_completions_hub_repo: str | None = field( + default=None, + metadata={ + "help": "Hugging Face Hub repository to save the completions. Should be a complete repository name like " + "`'username/reponame'` or `'orgname/reponame'`, or just `'reponame'` in which case the repository will " + "be created in the currently-logged-in Hugging Face user's namespace. Note that this repository will be " + "public unless you set `hub_private_repo=True` or your organization's default is to create private " + "repositories." + }, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + super().__post_init__() + + self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) + + if self.log_completions_hub_repo is not None and not self.log_completions: + raise ValueError( + "log_completions_hub_repo is set, but log_completions is False. Enable log_completions to upload " + "completions to the Hub, or unset log_completions_hub_repo." + ) + + num_processes = self.world_size + # The current default effective batch size + if self.generation_batch_size is None and self.steps_per_generation is None: + self.steps_per_generation = self.gradient_accumulation_steps + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + elif self.generation_batch_size is not None and self.steps_per_generation is None: + # Just ensure the value is divisible by the global batch size + if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size " + f"({self.per_device_train_batch_size * num_processes})." + ) + self.steps_per_generation = self.generation_batch_size // ( + self.per_device_train_batch_size * num_processes + ) + elif self.generation_batch_size is None and self.steps_per_generation is not None: + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + else: + raise ValueError( + "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time" + ) + + if self.do_eval and self.eval_strategy != "no": + # Determine the number of generations to use for evaluation + num_generations = self.num_generations_eval or self.num_generations + + # Just ensure the value is divisible by the global batch size + if (self.per_device_eval_batch_size * num_processes) % num_generations != 0: + raise ValueError( + f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be " + f"divisible by the number of generations used for evaluation ({num_generations})." + ) + + # The generation batch must contain full prompt groups (no partials), so it must be divisible by + # num_generations. + if self.generation_batch_size % self.num_generations != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations " + f"({self.num_generations})." + ) + + if self.num_generations < 2: + raise ValueError( + "GRPO requires at least 2 generations per prompt to calculate the advantages. You provided " + f"{self.num_generations}, which is less than the minimum required." + ) + + if self.delta is not None and self.use_liger_kernel: + raise ValueError("Liger kernel does not support two-sided GRPO loss yet.") diff --git a/ICL/RL/trl_source/trl/trainer/grpo_trainer.py b/ICL/RL/trl_source/trl/trainer/grpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e2a6d41a14261ac5a45d5f8d8c7d1d1ad34c62e9 --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/grpo_trainer.py @@ -0,0 +1,2215 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import atexit +import copy +import importlib.resources as pkg_resources +import inspect +import os +import sys +import textwrap +import time +import warnings +from collections import defaultdict, deque +from collections.abc import Callable +from contextlib import nullcontext +from functools import partial +from pathlib import Path +from typing import Any + +import datasets +import pandas as pd +import torch +import torch.utils.data +import transformers +from accelerate.logging import get_logger +from accelerate.utils import gather, gather_object, is_peft_model, set_seed +from datasets import Dataset, IterableDataset +from huggingface_hub import CommitScheduler, DatasetCard, DatasetCardData, create_repo +from packaging.version import Version +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import DataLoader, Sampler +from transformers import ( + AutoModelForSequenceClassification, + AutoProcessor, + AutoTokenizer, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, + is_trackio_available, + is_wandb_available, +) +from transformers.trainer_utils import seed_worker +from transformers.utils import is_datasets_available, is_peft_available, is_rich_available + +from ..chat_template_utils import add_response_schema, get_training_chat_template, parse_response +from ..data_utils import ( + apply_chat_template, + is_conversational, + prepare_multimodal_messages, +) +from ..extras.profiling import profiling_context, profiling_decorator +from ..generation.vllm_generation import VLLMGeneration +from ..import_utils import is_jmespath_available, is_liger_kernel_available +from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation +from ..models.utils import _ForwardRedirection, disable_gradient_checkpointing +from .base_trainer import BaseTrainer +from .callbacks import SyncRefModelCallback +from .grpo_config import GRPOConfig +from .utils import ( + RepeatSampler, + create_model_from_path, + disable_dropout_in_model, + entropy_from_logits, + get_config_model_id, + identity, + nanmax, + nanmin, + nanstd, + pad, + print_prompt_completions_sample, + selective_log_softmax, + shuffle_sequence_dict, + shutdown_event_loop_in_daemon, + split_pixel_values_by_grid, + split_tensor_dict, + start_event_loop_in_daemon, + unsplit_pixel_values_by_grid, + use_adapter, +) + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model + +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss + + +if is_wandb_available(): + import wandb + +if is_trackio_available(): + import trackio + +logger = get_logger(__name__) + +# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of +# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. +RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]] + +# What we call a rollout function is a callable that takes prompts (list) and the trainer instance as parameters and +# returns a dict of generation results. Those results must include "prompt_ids", "completion_ids", and "logprobs" +# fields. Any extra fields (per-completion) are forwarded to the reward functions. +RolloutFunc = Callable[[list[str], "GRPOTrainer"], dict[str, Any]] + + +class GRPOTrainer(BaseTrainer): + """ + Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the + paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language + Models](https://huggingface.co/papers/2402.03300). + + Example: + + ```python + from trl import GRPOTrainer + from trl.rewards import accuracy_reward + from datasets import load_dataset + + dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=accuracy_reward, + train_dataset=dataset, + ) + trainer.train() + ``` + + Args: + model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using `.from_pretrained` (where `` is derived from the model + config) with the keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + - A [`~peft.PeftModel`] object. Only causal language models are supported. + reward_funcs (`RewardFunc | list[RewardFunc]`): + Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward + functions with the prompts and completions and sum the rewards. Can be either: + + - A single reward function, such as: + - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the + keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. + - A custom reward function: The function is provided with the prompts and the generated completions, + plus any additional columns in the dataset. It should return a list of rewards. Custom reward + functions can be either synchronous or asynchronous and can also return `None` when the reward is + not applicable to those samples. This is useful for multi-task training where different reward + functions apply to different types of samples. When a reward function returns `None` for a sample, + that reward function is excluded from the reward calculation for that sample. For more details, see + [Using a custom reward + function](#using-a-custom-reward-function). + + The trainer's state is also passed to the reward function. The trainer's state is an instance of + [`~transformers.TrainerState`] and can be accessed by accessing the `trainer_state` argument to the + reward function's signature. + - A list of reward functions, where each item can independently be any of the above types. Mixing different + types within the list (e.g., a string model ID and a custom reward function) is allowed. + args ([`GRPOConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is + ignored. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. The padding side must be set to "left". If `None`, the + processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A + padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token, + `tokenizer.eos_token` will be used as the default. + reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*): + Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: + + - A single processing class: Used when `reward_funcs` contains only one reward function. + - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. + If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is + `None`, the tokenizer for the model is automatically loaded using + [`~transformers.AutoTokenizer.from_pretrained`]. For elements in `reward_funcs` that are custom reward + functions (not [`~transformers.PreTrainedModel`]), the corresponding entries in `reward_processing_classes` + are ignored. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your + model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + tools (list of `Callable`, *optional*): + A list of callable tool functions (sync or async) that the model can invoke during generation. Each tool + should be a standard Python function with properly type-hinted arguments and return values, and a + Google-style docstring describing its purpose, arguments, and return value. For more details, see: + https://huggingface.co/docs/transformers/en/chat_extras#passing-tools. The model uses the function's name, + type hints, and docstring to determine how to call it. Ensure that the model's chat template supports tool + use and that it has been fine-tuned for tool calling. + rollout_func (`RolloutFunc`, *optional*): + Function to use for generating completions. It receives the list of prompts allocated to the current + process and the trainer instance. It must return a dict with `"prompt_ids"`, `"completion_ids"`, and + `"logprobs"` fields. Any other fields are forwarded to the reward functions. This feature is experimental + and may change or be removed at any time without prior notice. + """ + + _tag_names = ["trl", "grpo"] + _name = "GRPO" + _paper = { + "title": "DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", + "id": "2402.03300", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @article{shao2024deepseekmath, + title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, + author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, + year = 2024, + eprint = {arXiv:2402.03300}, + } + """), + } + + def __init__( + self, + model: "str | PreTrainedModel | PeftModel", + reward_funcs: RewardFunc | list[RewardFunc], + args: GRPOConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + peft_config: "PeftConfig | None" = None, + tools: list[Callable] | None = None, + rollout_func: RolloutFunc | None = None, + ): + # Args + if args is None: + model_name = model if isinstance(model, str) else get_config_model_id(model.config) + model_name = model_name.split("/")[-1] + args = GRPOConfig(f"{model_name}-GRPO") + + # Model + if isinstance(model, str): + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) + else: + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + + # Some models (SmolVLM/Idefics3) don't support `logits_to_keep` argument and error out if we pass it + # Inspect the forward method before we wrap the model with PEFT + self.model_kwarg_keys = ( + inspect.signature(model.forward).parameters.keys() + if not hasattr(model, "get_base_model") + else inspect.signature(model.get_base_model().forward).parameters.keys() + ) + + # Processing class + if processing_class is None: + processing_class = AutoProcessor.from_pretrained( + get_config_model_id(model.config), truncation_side="left", padding_side="left" + ) + + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + self.pad_token = tokenizer.pad_token + self.pad_token_id = tokenizer.pad_token_id + self.eos_token_id = tokenizer.eos_token_id + + if is_peft_available() and is_peft_model(model) and peft_config is not None: + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) + + if is_peft_available() and is_peft_model(model) and args.beta != 0.0: + # If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy + # of the "default" adapter, so that we can use it as the reference model during GRPO training. + model.add_adapter("ref", model.peft_config["default"]) + for name, param in model.named_parameters(): + if ".default." in name: + ref_name = name.replace(".default.", ".ref.") + ref_param = model.get_parameter(ref_name) + ref_param.data.copy_(param.data) + + # Create PEFT model + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally + # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 + if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing: + model.enable_input_require_grads() + + # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the + # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by + # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for + # quantized models. See: https://github.com/huggingface/peft/issues/2889 + # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do + if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): + for param in model.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.bfloat16) + + # Reward functions + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + self.reward_func_names = [] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, num_labels=1, **model_init_kwargs + ) + if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models + self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) + else: + self.reward_func_names.append(reward_funcs[i].__name__) + self.reward_funcs = reward_funcs + + # Reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError( + f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " + f"functions ({len(reward_funcs)})" + ) + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + + # Reward processing class + if reward_processing_classes is None: + reward_processing_classes = [None] * len(reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + if len(reward_processing_classes) != len(reward_funcs): + raise ValueError( + f"The number of reward processing classes ({len(reward_processing_classes)}) must match the number of " + f"reward functions ({len(reward_funcs)})." + ) + + for i, (reward_processing_class, reward_func) in enumerate( + zip(reward_processing_classes, reward_funcs, strict=True) + ): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config)) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = reward_processing_class.eos_token + # The reward model computes the reward for the latest non-padded token in the input sequence. + # So it's important to set the pad token ID to the padding token ID of the processing class. + reward_func.config.pad_token_id = reward_processing_class.pad_token_id + reward_processing_classes[i] = reward_processing_class + + self.reward_processing_classes = reward_processing_classes + + # Rollout function + if rollout_func is not None and os.environ.get("TRL_EXPERIMENTAL_SILENCE", "0") != "1": + warnings.warn( + "You are importing from 'rollout_func', which is an experimental feature. This API may change or be " + "removed at any time without prior notice. Silence this warning by setting environment variable " + "TRL_EXPERIMENTAL_SILENCE=1.", + UserWarning, + stacklevel=2, + ) + self.rollout_func = rollout_func + + # Tools + if tools: + if not Version(transformers.__version__) >= Version("5.0.0"): + raise ImportError( + "Using tools with GRPOTrainer requires transformers version 5.0.0 or higher. Please use " + "transformers with `pip install --pre transformers` to use this feature." + ) + if not is_jmespath_available(): + raise ImportError( + "Using tools with GRPOTrainer requires the jmespath library for response parsing. Please install " + "it with `pip install jmespath` to use this feature." + ) + self.tools = tools or [] + self._sync_tool_dict = {} + self._async_tool_dict = {} + if self.tools: + for tool in self.tools: + if asyncio.iscoroutinefunction(tool): + self._async_tool_dict[tool.__name__] = tool + else: + self._sync_tool_dict[tool.__name__] = tool + + # Check for async functions to start an event loop on a daemon thread + self._has_async_funcs = any(asyncio.iscoroutinefunction(func) for func in self.reward_funcs + self.tools) + + if self._has_async_funcs: + self.async_loop_thread, self.async_loop, self.async_loop_ready_event = start_event_loop_in_daemon( + name="GRPOTrainer-AsyncLoop" + ) + # wait until the event loop is running in the daemon thread + self.async_loop_ready_event.wait() + atexit.register(shutdown_event_loop_in_daemon, self.async_loop_thread, self.async_loop) + + # At the time of initial implementation, most tokenizers do not have built-in support for response schemas. + # While waiting for broader adoption, we provide this utility function to manually set the response schema for + # known chat templates. + # We need `getattr`` until the base class sets a default None value for response_schema + if tools and not getattr(processing_class, "response_schema", None): + processing_class = add_response_schema(processing_class) + # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template + # isn't, we replace it at initialization with a training-safe, prefix-preserving template. + if tools: + self.chat_template = get_training_chat_template(processing_class) + else: + self.chat_template = None + + # Training arguments + self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper + self.num_generations = args.num_generations # = G in the GRPO paper + self.max_tool_calling_iterations = args.max_tool_calling_iterations or sys.maxsize + self.num_generations_eval = args.num_generations_eval or self.num_generations + self.chat_template_kwargs = args.chat_template_kwargs or {} + self.temperature = args.temperature + self.top_p = args.top_p + self.top_k = args.top_k + self.min_p = args.min_p + self.repetition_penalty = args.repetition_penalty + self.use_transformers_paged = args.use_transformers_paged + self.use_vllm = args.use_vllm + self.vllm_mode = args.vllm_mode + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode + self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction + self.vllm_importance_sampling_mode = args.vllm_importance_sampling_mode + self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap + self.use_liger_kernel = args.use_liger_kernel + self.loss_type = args.loss_type + self.multi_objective_aggregation = args.multi_objective_aggregation + self.scale_rewards = args.scale_rewards + self.importance_sampling_level = args.importance_sampling_level + self.off_policy_mask_threshold = args.off_policy_mask_threshold + if self.use_liger_kernel and self.off_policy_mask_threshold is not None: + raise ValueError("Liger kernel does not support off-policy sequence masking yet.") + self.mask_truncated_completions = args.mask_truncated_completions + self.top_entropy_quantile = args.top_entropy_quantile + if self.use_liger_kernel and self.top_entropy_quantile < 1.0: + raise NotImplementedError( + "Liger Kernels don't currently support masking token positions based on entropy." + ) + if self.use_liger_kernel and not self.importance_sampling_level == "token": + raise NotImplementedError( + "Liger Kernels currently only support token-level importance sampling. Please set" + "`importance_sampling_level` to 'token'." + ) + + # Datasets + self.shuffle_dataset = args.shuffle_dataset + + if ( + isinstance(train_dataset, IterableDataset) + or isinstance(eval_dataset, IterableDataset) + or ( + isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) + ) + ): + # See https://github.com/huggingface/trl/issues/3213 + raise NotImplementedError( + "Iterable datasets are not yet supported in GRPOTrainer. Please use a standard dataset instead." + ) + + if args.loss_type == "luspo" and args.importance_sampling_level != "sequence": + logger.warning( + "When using `'luspo'` loss, `importance_sampling_level` should be set to `'sequence'` to mirror the " + "paper's setup." + ) + + # Multi-step + self.num_iterations = args.num_iterations # = ๐œ‡ in the GRPO paper + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle + self._step = 0 + # Buffer the batch to reuse generated outputs across multiple updates. For more details, see + # `_get_train_sampler` and `_prepare_inputs`. + self._buffered_inputs = None + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model=model, + args=args, + data_collator=identity, # No data collation is needed in GRPO + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + # In Trainer, `training_step` scales the loss by `gradient_accumulation_steps` only if `compute_loss_func` + # is None. For DAPO, loss scaling instead depends on the total number of completions tokens across the + # global accumulated batch. To control scaling ourselves, we must disable Trainerโ€™s built-in scaling. The + # simplest (though a bit hacky) way is to set `compute_loss_func` to any non-None value, which bypasses + # that behavior without rewriting `training_step`. + compute_loss_func="non-None value to disable scaling", + ) + + # Reference model + self.beta = args.beta + if self.beta == 0.0: + # If beta is 0.0, the reference model is not needed + self.ref_model = None + elif is_peft_model(model): + # If PEFT is used, the reference model is not needed since the adapter can be disabled + # to revert to the initial model. + self.ref_model = None + else: + # For deepspeed, fsdp or non-distributed models, create a reference model from scratch + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if self.args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + self.ref_model = create_model_from_path(get_config_model_id(self.model.config), **model_init_kwargs) + + # Disable dropout in the models + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + # Cast LM Head To FP32 + if args.cast_lm_head_to_fp32: + + def _cast_lm_head_to_fp32(target_model: PreTrainedModel): + """Cast lm_head to fp32 while preserving embedding output dtype if tied.""" + + def cast_inputs_to_fp32(module, inputs): + # Preserve other positional args and kwargs untouched + if not inputs: + return inputs + return (inputs[0].to(torch.float32),) + inputs[1:] + + original_dtype_local = target_model.lm_head.weight.dtype + target_model.lm_head = target_model.lm_head.float() + target_model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32) + + if target_model.config.tie_word_embeddings: + + def cast_outputs_to_original_dtype(module, args, output): + return output.to(original_dtype_local) + + # Only cast activations; weights are now fp32 (intentional for numerical stability of logits) + target_model.model.embed_tokens.register_forward_hook(cast_outputs_to_original_dtype) + + _cast_lm_head_to_fp32(model) + if self.ref_model is not None: + _cast_lm_head_to_fp32(self.ref_model) + + # Liger loss + if self.use_liger_kernel: + if not is_liger_kernel_available(): + raise ImportError( + "Liger is required to use `use_liger_kernel` as the GRPO loss. Run `pip install liger-kernel`." + ) + # redirect the model.module forward to the model forward to ensure pre-forward hooks are called + self._forward_redirection = _ForwardRedirection() + + self.liger_grpo_loss = LigerFusedLinearGRPOLoss( + beta=self.beta, + epsilon_low=self.epsilon_low, + epsilon_high=self.epsilon_high, + temperature=self.temperature, + use_ref_model=self.beta != 0.0, + loss_type=self.loss_type, + max_completion_length=self.max_completion_length, + ) + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + self._current_train_step_time = 0.0 + self.log_completions = args.log_completions + self.log_unique_prompts = args.log_unique_prompts + self.num_completions_to_print = args.num_completions_to_print + # Keep logs sized to the generation batch to record only outputs from the latest model update. + self._logs = { + "images": deque(maxlen=args.generation_batch_size), + "prompt": deque(maxlen=args.generation_batch_size), + "completion": deque(maxlen=args.generation_batch_size), + "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), + "advantages": deque(maxlen=args.generation_batch_size), + } + + # Ensure each process receives a unique seed to prevent duplicate completions when generating with + # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but + # it's safer to set it in all cases. + set_seed(args.seed, device_specific=True) + + if self.use_vllm: + # Initialize vLLM generation backend + # Wrap rollout_func to capture trainer context if provided + rollout_func = None + if self.rollout_func is not None: + + def rollout_func(prompts): + return self.rollout_func(prompts, self) + + self.vllm_generation = VLLMGeneration( + model=self.model, + accelerator=self.accelerator, + is_fsdp_enabled=self.is_fsdp_enabled, + processing_class=self.processing_class, + # vLLM configuration + mode=args.vllm_mode, + structured_outputs_regex=args.vllm_structured_outputs_regex, + # Server mode configuration + server_base_url=args.vllm_server_base_url, + server_host=args.vllm_server_host, + server_port=args.vllm_server_port, + group_port=args.vllm_group_port, + server_timeout=args.vllm_server_timeout, + # Colocate mode configuration + tensor_parallel_size=args.vllm_tensor_parallel_size, + gpu_memory_utilization=args.vllm_gpu_memory_utilization, + max_model_length=args.vllm_max_model_length, + max_num_seqs=args.per_device_train_batch_size + * args.vllm_tensor_parallel_size + * args.steps_per_generation, + enable_sleep_mode=args.vllm_enable_sleep_mode, + model_impl=args.vllm_model_impl, + # Generation configuration + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + max_completion_length=self.max_completion_length, + generation_kwargs=args.generation_kwargs, + # Chat/tool configuration + chat_template=self.chat_template, + chat_template_kwargs=self.chat_template_kwargs, + tools=self.tools, + rollout_func=rollout_func, + ) + self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation + else: + generation_kwargs = { + "max_new_tokens": self.max_completion_length, + "do_sample": True, + "pad_token_id": tokenizer.pad_token_id, + "bos_token_id": tokenizer.bos_token_id, + "eos_token_id": tokenizer.eos_token_id, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "min_p": self.min_p, + "repetition_penalty": self.repetition_penalty, + "cache_implementation": args.cache_implementation, + } + if args.generation_kwargs is not None: + generation_kwargs.update(args.generation_kwargs) + self.generation_config = GenerationConfig(**generation_kwargs) + # Keep training-specific generation kwargs to overwrite model's original generation config + self.generation_kwargs = generation_kwargs + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags to the model + self.model.add_model_tags(self._tag_names) + + if self.ref_model is not None: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + elif self.is_fsdp_enabled: + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + if args.sync_ref_model: + if self.beta == 0.0: + raise ValueError( + "You passed `sync_ref_model=True` while `beta=0.0`, which means the reference model is not used " + "during training. Consequently, GRPOTrainer does not create a `ref_model` instance, and there is " + "nothing to synchronize. Please set `sync_ref_model=False`, or set `beta` to a non-zero value." + ) + if is_peft_model(model): + raise NotImplementedError( + "You passed `sync_ref_model=True` while using a PEFT model, which is currently not supported. " + "With PEFT, GRPOTrainer does not keep a separate reference model in memory; instead, it recovers " + "reference behavior by temporarily disabling the adapter. As a result, there is no standalone " + "`ref_model` instance to synchronize. Use `sync_ref_model=False`, or opt for full fine-tuning if " + "you need a synced reference model. If you need `sync_ref_model` to work with PEFT, please open a " + "feature request at https://github.com/huggingface/trl/issues." + ) + self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + else: + # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp + self.reward_funcs[i] = self.accelerator.prepare_model( + reward_func, evaluation_mode=True, device_placement=True + ) + + if self.accelerator.is_main_process and self.log_completions: + os.makedirs(os.path.join(self.args.output_dir, "completions"), exist_ok=True) + if self.args.log_completions_hub_repo is not None: + repo_id = self.args.log_completions_hub_repo + create_repo(repo_id, private=self.args.hub_private_repo, repo_type="dataset", exist_ok=True) + template_path = pkg_resources.files("trl").joinpath("templates/completions_dataset_card.md") + card_data = DatasetCardData( + pretty_name="TRL Completion logs", + tags=["trl", "trl-logs", "completions"], + ) + card = DatasetCard.from_template( + card_data=card_data, + template_path=str(template_path), + repo_id=repo_id, + hub_model_id=self.args.hub_model_id, + ) + card.push_to_hub(repo_id) + self.commit_scheduler = CommitScheduler( + repo_id=repo_id, + repo_type="dataset", + folder_path=f"{self.args.output_dir}/completions", + every=2, # minutes + allow_patterns=["*.parquet"], + ) + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" + # and "attention_mask"). In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't + # work. Instead, we set them to the columns expected by the `training_step` method, hence the override. + if self._signature_columns is None: + self._signature_columns = ["prompt", "image", "images"] + + # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. + # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an + # *generation* batch (i.e., `per_device_batch_size ร— steps_per_generation`). This allows us to generate completions + # once every steps_per_generation stepโ€”rather than once per accumulation stepโ€”which is significantly more + # efficient. The only change from the original implementation is multiplying the batch size by + # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the + # splitting internally. + # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line + # modification. As a result, some parts of the method aren't relevant to GRPO, but we keep them to stay one line + # apart from the super method, ensuring easier maintenance in the future. + def get_train_dataloader(self): + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size * self.args.steps_per_generation, # < this is the change + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = partial( + seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index + ) + + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: + # Returns a sampler that + # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are + # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt + # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies + # in group formation. + # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to + # _prepare_inputs to see how the generations are stored and reused. + + # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the + # second row shows the second sampled batch, and so on. + # + # | GPU 0 | GPU 1 | + # + # global_step step <-โ”€โ”€โ”€> num_generations=2 + # <-โ”€โ”€โ”€โ”€โ”€โ”€โ”€> per_device_train_batch_size=3 + # grad_accum โ–ฒ โ–ฒ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss + # =2 โ–ผ | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss + # | + # | 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss + # steps_per_gen=4 โ–ผ 1 3 9 9 10 10 11 11 <- Take the stored generations and use the fourth slice to compute the loss + # + # 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss + # 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss + # ... + if dataset is None: + dataset = self.train_dataset + return RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size // self.num_generations, + repeat_count=self.num_iterations * self.args.steps_per_generation, + shuffle=self.shuffle_dataset, + seed=self.args.seed, + ) + + def _get_eval_sampler(self, eval_dataset) -> Sampler: + # See _get_train_sampler for an explanation of the sampler. + return RepeatSampler( + data_source=eval_dataset, + mini_repeat_count=self.num_generations_eval, + seed=self.args.seed, + ) + + @profiling_decorator + def _get_last_hidden_state( + self, + unwrapped_model, + input_ids, + attention_mask, + logits_to_keep, + pixel_values=None, + image_grid_thw=None, + pixel_attention_mask=None, + image_sizes=None, + ): + if is_peft_model(unwrapped_model): + unwrapped_model = unwrapped_model.base_model.model + + # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} + + # For Qwen models: + if image_grid_thw is not None and pixel_values is not None: + model_inputs["image_grid_thw"] = image_grid_thw + # For Gemma, SmolVLM2, LLaVa-Next etc.: + if pixel_values is not None: + model_inputs["pixel_values"] = pixel_values + # For SmolVLM2 + if pixel_attention_mask is not None: + model_inputs["pixel_attention_mask"] = pixel_attention_mask + # For LLaVa-Next + if image_sizes is not None: + model_inputs["image_sizes"] = image_sizes + + # Only add logits_to_keep if the model supports it + if "logits_to_keep" in self.model_kwarg_keys: + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings + + last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state + # Exclude the last value: it corresponds to the next token pred + last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H) + return last_hidden_state + + def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: + """ + Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold. + + Args: + entropies (`torch.Tensor`): + Tensor of shape (batch_size, seq_len) with per-token entropy values. + mask (`torch.Tensor`): + Binary mask of the same shape as `entropies`, where `1` indicates valid tokens and `0` padding. + threshold (`float`): + Quantile threshold between `0.0` and `1.0` to select high-entropy tokens. + + Returns: + `torch.Tensor`: + Boolean mask of shape (batch_size, seq_len), where `True` indicates tokens with entropy >= threshold + and `False` otherwise. + """ + local = entropies[mask.bool()].float() + + # Use a negative pad_value as a sentinel because entropy values are always >= 0. + # This guarantees that the sentinel cannot collide with any real entropy value. + pad_value = -1e9 + + # Pad across processes so that every rank has the same tensor length + padded = self.accelerator.pad_across_processes(local, dim=0, pad_index=pad_value) + gathered = self.accelerator.gather(padded) + + # Drop sentinel values (safe because no entropy can be negative) + gathered = gathered[gathered != pad_value] + + if gathered.numel() == 0: + return torch.zeros_like(entropies, dtype=torch.bool) + + entropy_threshold = torch.quantile(gathered, threshold) + masked_entropies = entropies * mask.float() + entropy_mask = masked_entropies >= entropy_threshold + return entropy_mask & mask.bool() # ensure padding tokens are always masked out + + @profiling_decorator + def _get_per_token_logps_and_entropies( + self, + model, + input_ids, + attention_mask, + logits_to_keep, + batch_size=None, + compute_entropy=False, + pixel_values=None, + image_grid_thw=None, + num_images=None, + pixel_attention_mask=None, + image_sizes=None, + token_type_ids=None, + ) -> dict[str, torch.Tensor | None]: + """Compute log-probs and (optionally) entropies for each token.""" + batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak + all_logps = [] + all_entropies = [] + for start in range(0, input_ids.size(0), batch_size): + input_ids_batch = input_ids[start : start + batch_size] + attention_mask_batch = attention_mask[start : start + batch_size] + + # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) + model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} + if image_grid_thw is not None and pixel_values is not None: + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) + row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() + model_inputs["pixel_values"] = pixel_values[row_start:row_end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + if pixel_attention_mask is not None: + model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] + if image_sizes is not None: + model_inputs["image_sizes"] = image_sizes[start : start + batch_size] + if token_type_ids is not None: + model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size] + + # Only add logits_to_keep if the model supports it + if "logits_to_keep" in self.model_kwarg_keys: + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings + + logits = model(**model_inputs).logits + # Exclude the last value: it corresponds to the next token pred + logits = logits[:, :-1, :] # (B, L-1, H) + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H) + # Divide logits by sampling temperature. + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + logits = logits / self.temperature + completion_ids = input_ids_batch[:, -logits_to_keep:] + logps = selective_log_softmax(logits, completion_ids) # compute logprobs + all_logps.append(logps) + + if compute_entropy: + with torch.no_grad(): + entropies = entropy_from_logits(logits) + all_entropies.append(entropies) + + logps = torch.cat(all_logps, dim=0) + entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None + return logps, entropies + + def training_step(self, model, inputs, num_items_in_batch): + time_before = time.perf_counter() + output = super().training_step(model, inputs, num_items_in_batch) + self._step += 1 + time_after = time.perf_counter() + self._current_train_step_time += time_after - time_before + if self._step % self.current_gradient_accumulation_steps == 0: + self._metrics["train"]["step_time"].append(self._current_train_step_time) + self._current_train_step_time = 0.0 + return output + + @profiling_decorator + def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: + # Prepares inputs for model training/evaluation by managing completion generation and batch handling. + # During training: + # - Receives the local generation batch (Per-GPU batch size ร— steps per generation) + # from the modified training dataloader instead of the standard local batch + # - Generates completions once for the entire generation batch and splits it into batches of size + # `per_device_train_batch_size` + # - Buffers these completions and returns the appropriate slice for the current accumulation step + # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations) + # During evaluation: + # - The input is treated as a standard local batch (no accumulation, no multiple iterations) + # - Completions are generated for each batch without buffering or reuse + # Returns a single local batch in both cases. + + mode = "train" if self.model.training else "eval" + if mode == "train": + generate_every = self.args.steps_per_generation * self.num_iterations + if self._step % generate_every == 0 or self._buffered_inputs is None: + # self._buffered_inputs=None can occur when resuming from a checkpoint + generation_batch = self._generate_and_score_completions(generation_batch) + generation_batch = split_pixel_values_by_grid(generation_batch) + generation_batch = shuffle_sequence_dict(generation_batch) + generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation) + self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches] + inputs = self._buffered_inputs[self._step % self.args.steps_per_generation] + else: + # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence + # local generation batch == local eval batch + inputs = self._generate_and_score_completions(generation_batch) + return inputs + + @profiling_decorator + def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): + device = self.accelerator.device + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + + # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations + keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + + # This allows for dynamic reward shaping based on training progress. + reward_kwargs["trainer_state"] = self.state + + async_funcs_info = [] # async custom functions for asyncio.gather + + for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names, strict=True) + ): + if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models + with profiling_context(self, reward_func_name): + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] + texts = [ + apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] + for x in messages + ] + else: + texts = [p + c for p, c in zip(prompts, completions, strict=True)] + reward_inputs = reward_processing_class( + text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + reward_inputs = super()._prepare_inputs(reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + elif asyncio.iscoroutinefunction(reward_func): # Separate async reward funcs to run them in parallel later + async_funcs_info.append((i, reward_func, reward_func_name)) + else: + # Run synchronous reward function + with profiling_context(self, reward_func_name): + output_reward_func = reward_func( + prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs + ) + # Convert None values to NaN + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # Execute async custom functions in parallel using asyncio.gather + if async_funcs_info: + + async def _invoke_async(index, func, func_name): + with profiling_context(self, func_name): + output = await func( + prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs + ) + output = [r if r is not None else torch.nan for r in output] + return index, output + + async def _run_async_funcs(): + coros = [_invoke_async(i, func, func_name) for (i, func, func_name) in async_funcs_info] + return await asyncio.gather(*coros) + + async_results = asyncio.run_coroutine_threadsafe(_run_async_funcs(), self.async_loop).result() + for idx, output_reward_func in async_results: + rewards_per_func[:, idx] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # If all reward functions return None for a given row, issue a detailed warning + if torch.isnan(rewards_per_func).all(dim=1).any(): + nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] + row_reward_kwargs = { + key: value[nan_row_idx] for key, value in reward_kwargs.items() if key != "trainer_state" + } + row_reward_kwargs["prompt"] = prompts[nan_row_idx] + row_reward_kwargs["completion"] = completions[nan_row_idx] + logger.warning( + f"All reward functions returned None for the following kwargs:\n{row_reward_kwargs}\n" + "Please ensure that at least one reward function returns a valid reward." + ) + + # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the + # completions may be distributed across processes + rewards_per_func = gather(rewards_per_func) + return rewards_per_func + + def _generate_single_turn(self, prompts: list): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + # Generate completions using either vLLM or regular generation + if self.use_vllm: + # Sync weights if training step changed + if self.state.global_step != self._last_loaded_step: + with profiling_context(self, "sync_weights"): + self.vllm_generation.sync_weights() + self._last_loaded_step = self.state.global_step + + # Generate using vLLM + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + prompt_ids, completion_ids, logprobs, extra_fields = self.vllm_generation.generate( + prompts=prompts, num_generations=num_generations, profiler=profiling_context(self, "vLLM.generate") + ) + + elif self.use_transformers_paged: + if is_conversational({"prompt": prompts[0]}): + processor_outputs = self.processing_class.apply_chat_template( + conversation=prompts, + tools=self.tools, + chat_template=self.chat_template, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + **self.chat_template_kwargs, + ) + else: + processor_outputs = self.processing_class(text=prompts) + + with ( + profiling_context(self, "transformers.generate_batch"), + unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + # Cast to the appropriate dtype based on training configuration + if self.args.bf16: + unwrapped_model.to(torch.bfloat16) + elif self.args.fp16: + unwrapped_model.to(torch.float16) + if self.args.cast_lm_head_to_fp32: + unwrapped_model.lm_head.to(torch.float32) + with torch.inference_mode(): + # Continuous batching API expects 'inputs' arg only + all_outputs = unwrapped_model.generate_batch( + processor_outputs["input_ids"], generation_config=self.generation_config, progress_bar=False + ) + unwrapped_model.train() # restore training mode, as generate_batch forces eval mode + completion_ids = [output.generated_tokens for output in all_outputs.values()] + prompt_ids = processor_outputs["input_ids"] + logprobs = None # not used in this case + extra_fields = {} # No extra fields for paged mode + + else: + # Regular generation path + if is_conversational({"prompt": prompts[0]}): + generate_inputs = self.processing_class.apply_chat_template( + conversation=prompts, + tools=self.tools, + chat_template=self.chat_template, + add_generation_prompt=True, + tokenize=True, + padding=True, + padding_side="left", + return_tensors="pt", + return_dict=True, + **self.chat_template_kwargs, + ) + else: + generate_inputs = self.processing_class( + text=prompts, padding=True, padding_side="left", return_tensors="pt" + ) + generate_inputs = super()._prepare_inputs(generate_inputs) + + with ( + profiling_context(self, "transformers.generate"), + unwrap_model_for_generation( + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + prompt_completion_ids = unwrapped_model.generate( + **generate_inputs, generation_config=self.generation_config, disable_compile=True + ) + # Compute prompt length and extract completion ids + prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] + prompt_length = prompt_ids.size(1) + completion_ids = prompt_completion_ids[:, prompt_length:] + + # Mask everything after the first EOS token + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=True)] + completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=True)] + logprobs = None # not used in this case + extra_fields = {} # No extra fields for non-rollout_func paths + + return prompt_ids, completion_ids, logprobs, extra_fields + + def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs): + # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt + tool_calls = [completion[0].get("tool_calls") for completion in completions] + idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] + tool_calls = [tool_calls[idx] for idx in idxs_with_tool] + tool_mask = [[1] * len(ids) for ids in completion_ids] # 0 for tool result tokens, 1 elsewhere + tool_call_count = 0 + tool_failure_count = 0 + iteration_num = 0 + while idxs_with_tool and iteration_num < self.max_tool_calling_iterations: + prompt_completion_tools = [prompts[i] for i in idxs_with_tool] # select only prompts that need tool calls + + # Call the tools, and build the new prompt for generation + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + tool_call_list = tool_calls[idx] + prompt_completion_tool = prompt_completion_tools[idx] + # Append the last assistant message (which triggered tool_calls) to the prompt + prompt_completion_tool.append(completions[idx_with_tool][-1]) + async_coros = [] + tool_call_results = [] + for tool_call in tool_call_list: + tool_call_count += 1 + if tool_call["type"] == "function": + function = tool_call["function"] + name = function["name"] + try: + if name in self._sync_tool_dict: + tool_call_results.append((name, self._sync_tool_dict[name](**function["arguments"]))) + elif name in self._async_tool_dict: + async_coros.append((name, self._async_tool_dict[name](**function["arguments"]))) + except Exception as e: + tool_failure_count += 1 + result = {"error": str(e)} + tool_call_results.append((name, result)) + else: + tool_failure_count += 1 + name = tool_call.get("name", "unknown") + tool_call_results.append((name, {"error": f"Unsupported tool call type: {tool_call['type']}"})) + + if async_coros: + + async def _run_async_tools(async_coros): + coros = [coro for _, coro in async_coros] + results = await asyncio.gather(*coros, return_exceptions=True) + return [(name, result) for (name, _), result in zip(async_coros, results, strict=False)] + + async_results = asyncio.run_coroutine_threadsafe( + _run_async_tools(async_coros), self.async_loop + ).result() + + for name, result in async_results: + if isinstance(result, Exception): + tool_failure_count += 1 + tool_call_results.append((name, {"error": str(result)})) + else: + tool_call_results.append((name, result)) + + for name, result in tool_call_results: + tool_message = {"role": "tool", "name": name, "content": str(result)} + prompt_completion_tool.append(tool_message) + completions[idx_with_tool].append(tool_message) + + # Tokenize and filter samples whose length exceeds max allowed length. This is important, because both + # vLLM and transformers will error out if the input is longer than the model's max length. + pct_ids = self.processing_class.apply_chat_template( + prompt_completion_tools, + tools=self.tools, + chat_template=self.chat_template, + add_generation_prompt=True, + tokenize=True, + return_dict=False, + **self.chat_template_kwargs, + ) + if self.use_vllm and self.vllm_mode == "colocate": + max_model_len = self.llm.llm_engine.model_config.max_model_len + elif not self.use_vllm: + max_model_len = self.model.config.max_position_embeddings + else: + raise NotImplementedError( + f"Unsupported mode detected: use_vllm={self.use_vllm}, vllm_mode={self.vllm_mode}" + ) + overlong = [len(pct) >= max_model_len for pct in pct_ids] + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + if overlong[idx]: + prompt_length = len(prompt_ids[idx_with_tool]) + ct = pct_ids[idx][prompt_length : prompt_length + self.max_completion_length] + completion_ids[idx_with_tool] = ct + tool_mask[idx_with_tool] += [1] * (len(ct) - len(tool_mask[idx_with_tool])) + if logprobs is not None: + logprobs[idx_with_tool] += [0.0] * (len(ct) - len(logprobs[idx_with_tool])) + # Keep only non-overlong items for further processing + idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o] + prompt_completion_tools = [pct for pct, o in zip(prompt_completion_tools, overlong, strict=True) if not o] + if not idxs_with_tool: + break # all overlong, exit tool loop + + # Generate new completions after tool execution + prompt_completion_tool_ids, post_tool_ids, post_tool_logprobs, _ = self._generate_single_turn( + prompt_completion_tools + ) + + # Sanity check: from experience, this is useful to catch bugs in the chat template + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool + if prompt_ids[idx_with_tool] != pct[: len(prompt_ids[idx_with_tool])]: + raise ValueError( + "The chat template is not prefix-preserving. Please update it to use a prefix-preserving " + "format." + ) + + # Truncate so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + prompt_len = len(prompt_ids[idx_with_tool]) + completion_tool_ids = prompt_completion_tool_ids[idx][prompt_len:] + excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length + if excess_length > 0: + # If exceeding max length, truncate post_tool_ids + post_tool_ids[idx] = post_tool_ids[idx][:-excess_length] + if logprobs is not None: + post_tool_logprobs[idx] = post_tool_logprobs[idx][:-excess_length] + excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length + if excess_length > 0: + # If still exceeding max length, truncate completion_tool_ids as well + prompt_completion_tool_ids[idx] = prompt_completion_tool_ids[idx][:-excess_length] + + # Update tool_mask: the tool result should be 0 and the post-tool 1 + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + prompt_completion_tool_length = len(prompt_completion_tool_ids[idx]) + prompt_length = len(prompt_ids[idx_with_tool]) + completion_length = len(completion_ids[idx_with_tool]) + post_tool_length = len(post_tool_ids[idx]) + tool_length = prompt_completion_tool_length - prompt_length - completion_length + tool_mask[idx_with_tool] += [0] * tool_length + [1] * post_tool_length + if logprobs is not None: + logprobs[idx_with_tool] += [0.0] * tool_length + post_tool_logprobs[idx] + + # Update completion_ids with the new completions (after tool execution) + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + prompt_length = len(prompt_ids[idx_with_tool]) + pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool + completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] + + # Decode post-tool completions + post_tool_completions = [ + parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids + ] + + # Add post-tool completions to the existing completions + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + if post_tool_completions[idx]: # {} if post-tool completions completely truncated + completions[idx_with_tool].append(post_tool_completions[idx]) + + # Check for further tool calls + tool_calls = [completion.get("tool_calls") for completion in post_tool_completions] + idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call] + tool_calls = [tool_call for tool_call in tool_calls if tool_call] + iteration_num += 1 + return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count + + def _generate(self, prompts: list): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + # Copy the prompts to avoid modifying the original list + prompts = copy.deepcopy(prompts) + + prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) + + # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. + if is_conversational({"prompt": prompts[0]}): + if ( + Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 + and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors + and hasattr(self.processing_class, "response_schema") # attribute not set by default for now + and self.processing_class.response_schema is not None # only works if the tokenizer has a schema + ): + completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] + else: + contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + completions = [[{"role": "assistant", "content": content}] for content in contents] + else: + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + + # Extract tool calls from the completions and (possibly) execute them + if self.tools: + ( + tool_mask, + completions, + completion_ids, + logprobs, + tool_call_count, + tool_failure_count, + ) = self._tool_call_loop(prompts, prompt_ids, completion_ids, completions, logprobs) + else: + # Support custom env_mask from rollout_func (e.g., for environment feedback masking) + # Internally treated as tool_mask - marks model tokens (1) vs external tokens (0) + tool_mask = extra_fields.pop("env_mask", None) + + # Get completion length per sequence, used for logging + prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) + if tool_mask is not None: # count only model-generated tokens (tool_mask=1) + completion_lengths = torch.tensor([sum(mask) for mask in tool_mask], device=device) + else: + completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) + agg_prompt_lengths = self.accelerator.gather(prompt_lengths) + agg_completion_lengths = self.accelerator.gather(completion_lengths) + total_prompt_tokens = agg_prompt_lengths.sum() + total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss + + # Log the metrics + if mode == "train": + self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Log completion lengths, mean, min, max + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Identify sequences that terminated with EOS and log their lengths + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] + if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + if self.tools: + agg_tool_call_count = self.accelerator.gather(torch.tensor(tool_call_count, device=device)).sum() + tool_call_frequency = (agg_tool_call_count / len(agg_prompt_lengths)).item() + self._metrics[mode]["tools/call_frequency"].append(tool_call_frequency) + agg_tool_failure_count = self.accelerator.gather(torch.tensor(tool_failure_count, device=device)).sum() + failure_frequency = ( + (agg_tool_failure_count / agg_tool_call_count).item() if agg_tool_call_count > 0 else 0.0 + ) + self._metrics[mode]["tools/failure_frequency"].append(failure_frequency) + + return ( + prompt_ids, + completion_ids, + tool_mask, + completions, + total_completion_tokens, + logprobs, + extra_fields, + ) + + def _generate_and_score_completions( + self, inputs: list[dict[str, torch.Tensor | Any]] + ) -> dict[str, torch.Tensor | Any]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + # Transformers requires at least one image in the batch, otherwise it throws an error + if images is not None and all(img_list == [] for img_list in images): + images = None + + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What color is the sky?"}]}] + if images is not None: + prompts = [ + prepare_multimodal_messages(prompt, image_list) + for prompt, image_list in zip(prompts, images, strict=True) + ] + + ( + prompt_ids_list, + completion_ids_list, + tool_mask_list, + completions, + num_items_in_batch, + sampling_per_token_logps_list, + extra_fields, + ) = self._generate(prompts) + + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + if sampling_per_token_logps_list is not None: + sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] + sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") + else: + sampling_per_token_logps = None + if tool_mask_list is not None: + tool_mask = [torch.tensor(mask, device=device) for mask in tool_mask_list] + tool_mask = pad(tool_mask, padding_value=1, padding_side="right") + else: + tool_mask = None + + # If mask_truncated_completions is enabled, zero out truncated completions for attention and loss masking + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + # Mask completion_mask for attention masking + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() + # Also mask tool_mask for consistency in multi-turn training + if tool_mask is not None: + tool_mask = tool_mask * (~is_truncated).unsqueeze(1).int() + + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + num_images = [len(img_list) for img_list in images] if images is not None else None + + # Get forward_kwargs for models with multimodal inputs + if images is not None: + prompts_text = [ + apply_chat_template( + {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs + )["prompt"] + for prompt in prompts + ] + prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + + # If token_type_ids are used, extend them with zeros for the completion part + if "token_type_ids" in forward_kwargs: + token_type_ids = forward_kwargs["token_type_ids"] + forward_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 + ) + + # When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a + # torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True"). + # Temporarily disable checkpointing to avoid this warning during inference. + with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): + # If the generation and optimization steps are misalignedโ€”i.e., if generation does not occur at the end of + # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)โ€”then the + # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps + # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set + # old_per_token_logps to None. + # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the + # distribution mismatch between vLLM and the training model can be large and harm the training. + generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency + if self.args.gradient_accumulation_steps % generate_every != 0 or ( + self.use_vllm and self.vllm_importance_sampling_correction + ): + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + old_per_token_logps = None + + # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch + if self.use_vllm and self.vllm_importance_sampling_correction: + mask = completion_mask if tool_mask is None else completion_mask * tool_mask + per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * mask + + sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] + if sequence_level_is: + per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) + logps_diff = per_sequence_logps_diff + else: + logps_diff = per_token_logps_diff + + vllm_importance_sampling_ratio = torch.exp(logps_diff) + + # vllm_importance_sampling_ratio.shape: + # token_* modes: (B, T) (per-token ratio) + # sequence_* modes: (B, 1) (per-sequence ratio) + + if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]: + vllm_importance_sampling_ratio = torch.clamp( + vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap + ) + elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]: + vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill( + vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0 + ) + else: + raise ValueError( + f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'." + ) + + # Compute the per-token log probabilities for the reference model + if self.beta != 0.0: + if self.ref_model is not None: + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + # When training a PEFT adapter, how we obtain the reference depends on the setup: + # - New adapter: disabling adapters yields the base model. + # - Re-training an existing adapter: an initial copy is loaded under the name "ref". + model = self.accelerator.unwrap_model(self.model) + with use_adapter(model, adapter_name="ref" if "ref" in model.peft_config else None): + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + ref_per_token_logps = None + + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + + # Merge extra_fields from rollout_func into inputs for reward functions + if extra_fields: + for i, inp in enumerate(inputs): + for key, values in extra_fields.items(): + if isinstance(values, list) and i < len(values): + inp[key] = values[i] + elif not isinstance(values, list): + inp[key] = values + + # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is + # important because rewards will be normalized per group, and completions are distributed. We will later slice + # rewards_per_func to extract each process's subset. + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + + if self.multi_objective_aggregation == "sum_then_normalize": + # Apply weights to each reward function's output and sum + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + mean_grouped_rewards = rewards.view(-1, num_generations).mean(dim=1) + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(num_generations, dim=0) + if self.scale_rewards in ["group", "none"]: + # If self.scale_rewards = "none", we'll only use std_rewards to check for zero std for logging + if num_generations > 1: + std_rewards = rewards.view(-1, num_generations).std(dim=1) + std_rewards = std_rewards.repeat_interleave(num_generations, dim=0) + else: # doesn't occur during training, but could occur in eval when num_generations_eval=1 + std_rewards = torch.zeros_like(rewards) + elif self.scale_rewards == "batch": + # Compute global std + if rewards.numel() > 1: + std_rewards = rewards.std().expand_as(rewards) + else: # doesn't occur during training, but could occur in eval when num_generations_eval=batch_size=1 + std_rewards = torch.zeros_like(rewards) + else: + raise ValueError( + f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'." + ) + + advantages = rewards - mean_grouped_rewards + if self.scale_rewards != "none": + advantages = advantages / (std_rewards + 1e-4) + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging + + elif self.multi_objective_aggregation == "normalize_then_sum": + grouped = rewards_per_func.view(-1, num_generations, len(self.reward_funcs)) + mean_k = torch.nanmean(grouped, dim=1, keepdim=True) + std_k = nanstd(grouped, dim=1, keepdim=True) if num_generations > 1 else torch.zeros_like(mean_k) + reward_k = (grouped - mean_k) / (std_k + 1e-4) + reward_k = reward_k.view(-1, len(self.reward_funcs)) + rewards = (reward_k * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + std_rewards = rewards.std().expand_as(rewards) if rewards.numel() > 1 else torch.zeros_like(rewards) + advantages = (rewards - rewards.mean()) / (std_rewards + 1e-4) + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging + + else: + raise ValueError( + f"Invalid multi_objective_aggregation: {self.multi_objective_aggregation}. Must be " + "'sum_then_normalize' or 'normalize_then_sum'." + ) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + all_process_advantages = advantages.clone() # keep the aggregated advantages for logging + advantages = advantages[process_slice] + + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) + std_func_rewards = nanstd(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards) + rewards = rewards_per_func.nansum(dim=1) + self._metrics[mode]["reward"].append(rewards.mean().item()) + self._metrics[mode]["reward_std"].append(rewards.std().item()) + self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) + + # Log prompt and completion texts + self._logs["prompt"].extend(gather_object(prompts_text)) + self._logs["completion"].extend(gather_object(completions_text)) + for i, name in enumerate(self.reward_func_names): + self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) + self._logs["advantages"].extend(all_process_advantages.tolist()) + + if images is not None: + self._logs["images"].extend(gather_object(images)) + + if self.use_vllm and self.vllm_importance_sampling_correction: + delta = torch.abs(old_per_token_logps - sampling_per_token_logps) + mask = completion_mask.bool() if tool_mask is None else (completion_mask * tool_mask).bool() + delta = delta[mask] + mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( + self.accelerator.gather(mean_delta).mean().item() + ) + self._metrics[mode]["sampling/sampling_logp_difference/max"].append( + self.accelerator.gather(max_delta).max().item() + ) + if sequence_level_is: + flat_is_ratio = vllm_importance_sampling_ratio.flatten() + else: + flat_is_ratio = vllm_importance_sampling_ratio[mask] + + min_importance_sampling_ratio = ( + torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + mean_importance_sampling_ratio = ( + torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + max_importance_sampling_ratio = ( + torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( + nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( + self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( + nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() + ) + + output = { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "advantages": advantages, + "num_items_in_batch": num_items_in_batch, + } + if old_per_token_logps is not None: + output["old_per_token_logps"] = old_per_token_logps + if self.use_vllm and self.vllm_importance_sampling_correction: + output["importance_sampling_ratio"] = vllm_importance_sampling_ratio + if sampling_per_token_logps is not None: + output["sampling_per_token_logps"] = sampling_per_token_logps + if ref_per_token_logps is not None: + output["ref_per_token_logps"] = ref_per_token_logps + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] + if "token_type_ids" in forward_kwargs: + output["token_type_ids"] = forward_kwargs["token_type_ids"] + if images is not None: + output["num_images"] = num_images + if tool_mask is not None: + output["tool_mask"] = tool_mask + return output + + def compute_liger_loss(self, unwrapped_model, inputs): + # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + # Get the last hidden state of the model + last_hidden_state = self._get_last_hidden_state( + unwrapped_model, + input_ids, + attention_mask, + logits_to_keep, + inputs.get("pixel_values"), + inputs.get("image_grid_thw"), + inputs.get("pixel_attention_mask"), + inputs.get("image_sizes"), + ) + + # Apply tool_mask (from env_mask) for loss computation in multi-turn training scenarios + loss_mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"] + # Compute loss and metrics using liger grpo loss + loss, metrics = self.liger_grpo_loss( + _input=last_hidden_state, + lin_weight=unwrapped_model.lm_head.weight, + selected_token_ids=completion_ids, + # The attention_mask parameter in liger loss is actually used as a loss mask (not model attention) + attention_mask=loss_mask, + advantages=inputs["advantages"], + bias=unwrapped_model.lm_head.bias, + old_per_token_logps=inputs.get("old_per_token_logps"), + ref_per_token_logps=inputs.get("ref_per_token_logps"), + ) + # Extract metrics from the liger_grpo_loss output + # KL divergence is the first metric when beta is non-zero + mean_kl = metrics[0] if self.beta != 0.0 else None + clip_ratio = metrics[-1] + + mode = "train" if self.model.training else "eval" + if self.beta != 0.0: + self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item()) + self._metrics[mode]["clip_ratio"].append(self.accelerator.gather(clip_ratio).mean().item()) + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + return loss / normalizer + + @profiling_decorator + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The GRPOTrainer does not support returning outputs") + if self.use_liger_kernel: + # Compute the loss using the liger grpo loss + unwrapped_model = self.accelerator.unwrap_model(model) + return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs) + else: + return self._compute_loss(model, inputs) + + @staticmethod + def get_off_policy_mask( + advantages: torch.Tensor, + per_token_logps: torch.Tensor, + sampling_per_token_logps: torch.Tensor, + mask: torch.Tensor, + off_policy_threshold: float, + ) -> torch.Tensor: + """ + Computes the Off-Policy Sequence Mask from DeepSeek-V3.2 paper. Returns a (B, 1) tensor where 1.0 indicates + "Keep" and 0.0 indicates "Drop". + """ + # forward KL div: log(pi_old) - log(pi_theta) + kl_div = sampling_per_token_logps - per_token_logps.detach() + # Sequence-level Mean KL (ignoring prompt+padding) + seq_kl_sum = (kl_div * mask).sum(dim=1, keepdim=True) + avg_seq_kl = seq_kl_sum / mask.sum(dim=1, keepdim=True).clamp(min=1.0) + # Keep if (Advantage >= 0) OR (KL <= delta) + is_pos_adv = advantages >= 0 + is_low_kl = avg_seq_kl <= off_policy_threshold + return (is_pos_adv | is_low_kl).to(dtype=mask.dtype) # (B, 1) + + def _compute_loss(self, model, inputs): + # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + mask = completion_mask if "tool_mask" not in inputs else completion_mask * inputs["tool_mask"] + + # Compute the per_token_logps and the entropy at each position in the completion + per_token_logps, entropies = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=True, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), + pixel_attention_mask=inputs.get("pixel_attention_mask"), + image_sizes=inputs.get("image_sizes"), + token_type_ids=inputs.get("token_type_ids"), + ) + + if self.top_entropy_quantile < 1.0: + entropy_mask = self.get_high_entropy_mask(entropies, mask, 1 - self.top_entropy_quantile) + else: + entropy_mask = None + + # Compute the loss + advantages = inputs["advantages"] + # In the base GRPO implementation, advantages are expected to have shape (B,). To support subclasses that + # provide advantages with shape (B, T) (e.g., MiniLLM), we *conditionally* unsqueeze the tensor. + if advantages.dim() == 1: + advantages = advantages.unsqueeze(1) + # When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps, + # old_per_token_logps == per_token_logps. In this case we can skip its computation + # (see _generate_and_score_completions) and instead use per_token_logps.detach(). + # The exception is when using vLLM, where we always compute old_per_token_logps + # for importance sampling + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps + + if self.off_policy_mask_threshold is not None: + # OPSM should use inference-time logprobs to detect both sources of off-policyness: + # 1. Drift from gradient updates (always present) + # 2. Drift from training-inference mismatch (when using vLLM) + # When using vLLM, prioritize sampling_per_token_logps, otherwise use old_per_token_logps + sampling_per_token_logps = inputs.get("sampling_per_token_logps", old_per_token_logps) + + off_policy_mask = self.get_off_policy_mask( + advantages=advantages, + per_token_logps=per_token_logps, + sampling_per_token_logps=sampling_per_token_logps, + mask=mask, + off_policy_threshold=self.off_policy_mask_threshold, + ) + + log_ratio = per_token_logps - old_per_token_logps + if self.importance_sampling_level == "token": + log_importance_weights = log_ratio + elif self.importance_sampling_level == "sequence": + log_importance_weights = (log_ratio * mask).sum(-1) / mask.sum(-1).clamp(min=1.0) + log_importance_weights = log_importance_weights.unsqueeze(-1) + else: + raise ValueError( + f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " + "and 'sequence'." + ) + + coef_1 = torch.exp(log_importance_weights) + + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) + # Importance sampling correction for the KL divergence + if self.args.use_bias_correction_kl: + per_token_kl = per_token_kl * coef_1 + + # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on + # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) + if self.loss_type == "cispo": + clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() + per_token_loss = -clamped_ratios * advantages * per_token_logps + elif self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]: + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + # Two-sided clipping + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) + + per_token_loss1 = coef_1 * advantages + per_token_loss2 = coef_2 * advantages + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + elif self.loss_type == "sapo": + temperatures = torch.where(advantages > 0, self.args.sapo_temperature_pos, self.args.sapo_temperature_neg) + soft_coef_1 = torch.sigmoid(temperatures * (coef_1 - 1)) * 4 / temperatures + per_token_loss = -soft_coef_1 * advantages + else: + raise ValueError(f"Unknown loss type: {self.loss_type}") + + if self.off_policy_mask_threshold is not None: + per_token_loss = per_token_loss * off_policy_mask + + if entropy_mask is not None: + per_token_loss = per_token_loss * entropy_mask + + if self.use_vllm and self.vllm_importance_sampling_correction: + per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] + + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + mode = "train" if self.model.training else "eval" + if self.loss_type in ["grpo", "sapo"]: + loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + loss = loss / normalizer + elif self.loss_type == "bnpo": + loss = (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + loss = loss / normalizer + elif self.loss_type == "dr_grpo": + loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + loss = loss / normalizer + elif self.loss_type in ["cispo", "dapo"]: + normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes + loss = (per_token_loss * mask).sum() / normalizer + elif self.loss_type == "luspo": + # Unless importance_sampling_level="token" (not recommended here), per_token_loss is expected to be (B, 1) + loss = (per_token_loss * mask.sum(1, keepdim=True)).mean() + normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 + loss = loss / normalizer + else: + raise ValueError(f"Unknown loss type: {self.loss_type}") + + # Log the metrics + completion_token_count = mask.sum().clamp(min=1.0) + + def masked_batch_mean(x): + if x.shape[1] == 1: # when importance_sampling_level == "sequence" + return x.mean() + else: + return (x * mask).sum() / completion_token_count + + if self.beta != 0.0: + mean_kl = masked_batch_mean(per_token_kl) + self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) + + mean_entropy = masked_batch_mean(entropies) + self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) + + if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]: + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = masked_batch_mean(is_low_clipped.float()) + high_clip = masked_batch_mean(is_high_clipped.float()) + clip_ratio = masked_batch_mean(is_region_clipped.float()) + + gathered_low_clip = self.accelerator.gather(low_clip) + self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) + gathered_high_clip = self.accelerator.gather(high_clip) + self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) + gathered_clip_ratio = self.accelerator.gather(clip_ratio) + self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + elif self.loss_type == "cispo": + is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages > 0) + cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) + gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) + self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) + + return loss + + # During eval, Trainer calls prediction_step. If no labels are present in the inputs, it only runs forward and + # returns logits. We override prediction_step to force compute_loss, because this trainer doesn't involve labels. + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): + inputs = self._prepare_inputs(inputs) + with torch.no_grad(): + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + loss = loss.mean().detach() + return loss, None, None + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + if self.accelerator.is_main_process and self.log_completions: + if is_rich_available(): + print_prompt_completions_sample( + self._logs["prompt"], + self._logs["completion"], + self._logs["rewards"], + self._logs["advantages"], + self.state.global_step, + self.num_completions_to_print, + ) + + logging_backends = [] + if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: + logging_backends.append(wandb) + if self.args.report_to and "trackio" in self.args.report_to: + logging_backends.append(trackio) + + table = { + "step": [self.state.global_step] * len(self._logs["prompt"]), + "prompt": self._logs["prompt"], + "completion": self._logs["completion"], + **self._logs["rewards"], + "advantage": self._logs["advantages"], + } + + df_base = pd.DataFrame(table) + df_base.to_parquet( + os.path.join( + self.args.output_dir, + "completions", + f"completions_{self.state.global_step:05d}.parquet", + ) + ) + + images_raw = self._logs["images"] or [] + + for logging_backend in logging_backends: + if images_raw: + images = [] + for image_list in self._logs["images"]: + images.append([logging_backend.Image(image) for image in image_list]) + df = pd.concat( + [df_base, pd.Series(images, name="image")], + axis=1, + copy=False, + ) + else: + df = df_base + + if self.log_unique_prompts: + df = df.drop_duplicates(subset=["prompt"]) + + logging_backend.log({"completions": logging_backend.Table(dataframe=df)}) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/ICL/RL/trl_source/trl/trainer/kto_config.py b/ICL/RL/trl_source/trl/trainer/kto_config.py new file mode 100644 index 0000000000000000000000000000000000000000..849648a8f74f245df98cab016a7b87fa1b628f93 --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/kto_config.py @@ -0,0 +1,36 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass + +from ..import_utils import suppress_experimental_warning + + +with suppress_experimental_warning(): + from ..experimental.kto import KTOConfig as _KTOConfig + + +@dataclass +class KTOConfig(_KTOConfig): + def __post_init__(self): + warnings.warn( + "The `KTOConfig` is now located in `trl.experimental`. Please update your imports to " + "`from trl.experimental.kto import KTOConfig`. For more information, see " + "https://github.com/huggingface/trl/issues/4223. Promoting KTO to the stable API is a high-priority task. " + "Until then, this current path (`from trl import KTOConfig`) will remain, but API changes may occur.", + FutureWarning, + stacklevel=3, + ) + super().__post_init__() diff --git a/ICL/RL/trl_source/trl/trainer/kto_trainer.py b/ICL/RL/trl_source/trl/trainer/kto_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..110e29f0a8062efe09a87cddf2a7574d4f8d999e --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/kto_trainer.py @@ -0,0 +1,36 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass + +from ..import_utils import suppress_experimental_warning + + +with suppress_experimental_warning(): + from ..experimental.kto import KTOTrainer as _KTOTrainer + + +@dataclass +class KTOTrainer(_KTOTrainer): + def __init__(self, *args, **kwargs): + warnings.warn( + "The `KTOTrainer` is now located in `trl.experimental`. Please update your imports to " + "`from trl.experimental.kto import KTOTrainer`. For more information, see " + "https://github.com/huggingface/trl/issues/4223. Promoting KTO to the stable API is a high-priority task. " + "Until then, this current path (`from trl import KTOTrainer`) will remain, but API changes may occur.", + FutureWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) diff --git a/ICL/RL/trl_source/trl/trainer/model_config.py b/ICL/RL/trl_source/trl/trainer/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..29860a629a650c8574a19319b128a35d21c69a01 --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/model_config.py @@ -0,0 +1,188 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + + +@dataclass +class ModelConfig: + """ + Configuration class for the models. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + model_name_or_path (`str`, *optional*): + Model checkpoint for weights initialization. + model_revision (`str`, *optional*, defaults to `"main"`): + Specific model version to use. It can be a branch name, a tag name, or a commit id. + dtype (`Literal["auto", "bfloat16", "float16", "float32"]`, *optional*, defaults to `"float32"`): + Override the default `torch.dtype` and load the model under this dtype. Possible values are + + - `"bfloat16"`: `torch.bfloat16` + - `"float16"`: `torch.float16` + - `"float32"`: `torch.float32` + - `"auto"`: Automatically derive the dtype from the model's weights. + + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether to allow for custom models defined on the Hub in their own modeling files. This option should only + be set to `True` for repositories you trust and in which you have read the code, as it will execute code + present on the Hub on your local machine. + attn_implementation (`str`, *optional*): + Which attention implementation to use. More information in the [Kernels Hub Integrations + Guide](kernels_hub). + use_peft (`bool`, *optional*, defaults to `False`): + Whether to use PEFT for training. + lora_r (`int`, *optional*, defaults to `16`): + LoRA R value. + lora_alpha (`int`, *optional*, defaults to `32`): + LoRA alpha. + lora_dropout (`float`, *optional*, defaults to `0.05`): + LoRA dropout. + lora_target_modules (`str | list[str]`, *optional*): + LoRA target modules. + lora_target_parameters (`str | list[str]`, *optional*): + List of target parameters for LoRA. + lora_modules_to_save (`list[str]`, *optional*): + Model layers to unfreeze & train. + lora_task_type (`str`, *optional*, defaults to `"CAUSAL_LM"`): + Task type to pass for LoRA (use `"SEQ_CLS"` for reward modeling). + use_rslora (`bool`, *optional*, defaults to `False`): + Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor to `lora_alpha/โˆšr`, instead of + the original default value of `lora_alpha/r`. + use_dora (`bool`, *optional*, defaults to `False`): + Enable [Weight-Decomposed Low-Rank Adaptation (DoRA)](https://huggingface.co/papers/2402.09353). This + technique decomposes the updates of the weights into two parts, magnitude and direction. Direction is + handled by normal LoRA, whereas the magnitude is handled by a separate learnable parameter. This can + improve the performance of LoRA, especially at low ranks. Right now, DoRA only supports linear and Conv2D + layers. DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for + inference. + load_in_8bit (`bool`, *optional*, defaults to `False`): + Whether to use 8 bit precision for the base model. Works only with LoRA. + load_in_4bit (`bool`, *optional*, defaults to `False`): + Whether to use 4 bit precision for the base model. Works only with LoRA. + bnb_4bit_quant_type (`str`, *optional*, defaults to `"nf4"`): + Quantization type (`"fp4"` or `"nf4"`). + use_bnb_nested_quant (`bool`, *optional*, defaults to `False`): + Whether to use nested quantization. + """ + + model_name_or_path: str | None = field( + default=None, + metadata={"help": "Model checkpoint for weights initialization."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "Specific model version to use. It can be a branch name, a tag name, or a commit id."}, + ) + dtype: str | None = field( + default="float32", + metadata={ + "help": "Override the default `torch.dtype` and load the model under this dtype. It defaults to `'float32'`.", + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": "Whether to allow for custom models defined on the Hub in their own modeling files. This option " + "should only be set to `True` for repositories you trust and in which you have read the code, as it will " + "execute code present on the Hub on your local machine." + }, + ) + attn_implementation: str | None = field( + default=None, + metadata={ + "help": "Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in " + "which case you must install this manually by running `pip install flash-attn --no-build-isolation`." + }, + ) + use_peft: bool = field( + default=False, + metadata={"help": "Whether to use PEFT for training."}, + ) + lora_r: int = field( + default=16, + metadata={"help": "LoRA R value."}, + ) + lora_alpha: int = field( + default=32, + metadata={"help": "LoRA alpha."}, + ) + lora_dropout: float = field( + default=0.05, + metadata={"help": "LoRA dropout."}, + ) + lora_target_modules: list[str] | None = field( + default=None, + metadata={"help": "LoRA target modules."}, + ) + lora_target_parameters: list[str] | None = field( + default=None, + metadata={"help": "List of target parameters for LoRA."}, + ) + lora_modules_to_save: list[str] | None = field( + default=None, + metadata={"help": "Model layers to unfreeze & train."}, + ) + lora_task_type: str = field( + default="CAUSAL_LM", + metadata={"help": "Task type to pass for LoRA (use 'SEQ_CLS' for reward modeling)."}, + ) + use_rslora: bool = field( + default=False, + metadata={ + "help": "Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor to `lora_alpha/โˆšr`, " + "instead of the original default value of `lora_alpha/r`." + }, + ) + use_dora: bool = field( + default=False, + metadata={ + "help": "Enable Weight-Decomposed Low-Rank Adaptation (DoRA). This technique decomposes the updates of " + "the weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the " + "magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, " + "especially at low ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a " + "bigger overhead than pure LoRA, so it is recommended to merge weights for inference." + }, + ) + load_in_8bit: bool = field( + default=False, + metadata={"help": "Whether to use 8 bit precision for the base model. Works only with LoRA."}, + ) + load_in_4bit: bool = field( + default=False, + metadata={"help": "Whether to use 4 bit precision for the base model. Works only with LoRA."}, + ) + bnb_4bit_quant_type: str = field( + default="nf4", + metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]}, + ) + use_bnb_nested_quant: bool = field( + default=False, + metadata={"help": "Whether to use nested quantization."}, + ) + bnb_4bit_quant_storage: str | None = field( + default=None, + metadata={"help": "Quantization storage dtype"}, + ) + + def __post_init__(self): + if self.load_in_8bit and self.load_in_4bit: + raise ValueError("You can't use 8 bit and 4 bit precision at the same time") + + if hasattr(self.lora_target_modules, "__len__") and len(self.lora_target_modules) == 1: + self.lora_target_modules = self.lora_target_modules[0] diff --git a/ICL/RL/trl_source/trl/trainer/reward_config.py b/ICL/RL/trl_source/trl/trainer/reward_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b90665d46ccf4d0726fddec90f9470bdd0e90df1 --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/reward_config.py @@ -0,0 +1,184 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from transformers import TrainingArguments + + +@dataclass +class RewardConfig(TrainingArguments): + r""" + Configuration class for the [`RewardTrainer`]. + + This class includes only the parameters that are specific to Reward training. For a full list of training + arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this + class may differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model + + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`RewardTrainer`] is provided as a string. If you're training a MoE architecture and want + to include the load balancing/auxiliary loss as a part of the final loss, remember to set + `output_router_logits=True` in this dictionary. + chat_template_path (`str`, *optional*): + If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory + or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must + ensure that any special tokens referenced in the template are added to the tokenizer and that the model's + embedding layer is resized accordingly. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + + > Parameters that control the data preprocessing + + dataset_num_proc (`int`, *optional*): + Number of processes to use for processing the dataset. + eos_token (`str`, *optional*): + Token used to indicate the end of a turn or sequence. If `None`, it defaults to + `processing_class.eos_token`. + pad_token (`str`, *optional*): + Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, + it falls back to `processing_class.eos_token`. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the tokenized sequence. Samples are filtered out if either chosen or rejected sequence + exceeds this value. If `None`, no filtering is applied. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + + > Parameters that control the training + + center_rewards_coefficient (`float`, *optional*): + Coefficient to incentivize the reward model to output mean-zero rewards (proposed by + https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`. + activation_offloading (`bool`, *optional*, defaults to `False`): + Whether to offload the activations to the CPU. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-4, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue + # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary + # workaround here, which can be removed once we drop support for versions older than 4.57.5. + lr_scheduler_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " + "restarts." + }, + ) + + # Parameters that control the model + model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `RewardTrainer` is provided as a string. If you're training a MoE architecture and want to include " + "the load balancing/auxiliary loss as a part of the final loss, remember to set " + "`output_router_logits=True` in this dictionary." + }, + ) + chat_template_path: str | None = field( + default=None, + metadata={ + "help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local " + "directory or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, " + "you must ensure that any special tokens referenced in the template are added to the tokenizer and " + "that the model's embedding layer is resized accordingly." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) + + # Parameters that control the data preprocessing + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + eos_token: str | None = field( + default=None, + metadata={ + "help": "Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`." + }, + ) + pad_token: str | None = field( + default=None, + metadata={ + "help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that " + "is also `None`, it falls back to `processing_class.eos_token`." + }, + ) + max_length: int | None = field( + default=1024, + metadata={ + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from " + "the right. If `None`, no truncation is applied." + }, + ) + pad_to_multiple_of: int | None = field( + default=None, + metadata={"help": "If set, the sequences will be padded to a multiple of this value."}, + ) + + # Parameters that control the training + center_rewards_coefficient: float | None = field( + default=None, + metadata={ + "help": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by " + "https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`." + }, + ) + activation_offloading: bool = field( + default=False, + metadata={"help": "Whether to offload the activations to the CPU."}, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + super().__post_init__() diff --git a/ICL/RL/trl_source/trl/trainer/reward_trainer.py b/ICL/RL/trl_source/trl/trainer/reward_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..9c309060e3c732e18ad7d3bc4a93e32914a65614 --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/reward_trainer.py @@ -0,0 +1,680 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import logging +import os +import re +from collections import defaultdict +from collections.abc import Callable +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +import torch.nn as nn +import transformers +from accelerate import PartialState +from accelerate.logging import get_logger +from accelerate.utils import is_peft_model +from datasets import Dataset, IterableDataset +from packaging.version import Version +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollator, + PreTrainedModel, + PreTrainedTokenizerBase, + TrainerCallback, + set_seed, +) +from transformers.data.data_collator import DataCollatorMixin +from transformers.modeling_layers import GenericForSequenceClassification +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_peft_available + +from ..chat_template_utils import clone_chat_template +from ..data_utils import is_conversational +from ..models import get_act_offloading_ctx_manager +from .base_trainer import BaseTrainer +from .reward_config import RewardConfig +from .utils import create_model_from_path, disable_dropout_in_model, get_config_model_id, pad, remove_none_values + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model + + +logger = get_logger(__name__) + + +# AutoModelForSequenceClassification adds a new classification head when loading a CausalLM. That head is randomly +# initialized and triggers a harmless warning about uninitialized weights. We suppress just that specific warning to +# avoid confusing users. + + +# Old approach using logging filter (for transformers < 4.57.0) +@contextmanager +def suppress_from_pretrained_warning(logger: logging.Logger): + pattern = re.compile( + r"^Some weights of \S+ were not initialized from the model checkpoint at \S+ and are newly initialized: " + r"\[.*\]\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and " + r"inference\.$" + ) + + class _Filter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + return not pattern.search(record.getMessage()) + + f = _Filter() + logger.addFilter(f) + try: + yield + finally: + logger.removeFilter(f) + + +# New approach using scoped override (for transformers >= 4.57.0) +@contextmanager +def ignore_seqcls_score_missing_key(): + # Scoped override: ignore only the expected seq-clf head key. + old = getattr(GenericForSequenceClassification, "_keys_to_ignore_on_load_missing", None) + merged = list(old) if old is not None else [] + pattern = r"^score\.weight$" + if pattern not in merged: + merged.append(pattern) + GenericForSequenceClassification._keys_to_ignore_on_load_missing = merged + try: + yield + finally: + GenericForSequenceClassification._keys_to_ignore_on_load_missing = old + + +# Version-aware wrapper that chooses the appropriate approach +@contextmanager +def suppress_seqcls_warning(): + # Use the new approach for transformers >= 4.57.0, old approach for earlier versions + # The old approach is needed for 4.56.2 to avoid meta tensor issues with device_map=None + if Version(transformers.__version__) >= Version("4.57.0"): + with ignore_seqcls_score_missing_key(): + yield + else: + # Get the transformers logger + transformers_logger = logging.getLogger("transformers.modeling_utils") + with suppress_from_pretrained_warning(transformers_logger): + yield + + +def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list[str]: + return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names + + +@dataclass +class DataCollatorForPreference(DataCollatorMixin): + """ + Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch. + + This collator expects each example in the input list to be a dictionary containing the `"chosen_input_ids"` and + `"rejected_input_ids"` keys. The collator returns a dictionary containing the following keys: + - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. The first half of the batch + corresponds to the `"chosen_input_ids"` and the second half to the `"rejected_input_ids"`. + - `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch. + + Optionally, the examples can contain a `"margin"` key, in which case the returned dictionary will also contain a + `"margin"` key with a tensor of margins. + + Args: + pad_token_id (`int`): + Token ID to use for padding. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + return_tensors (`str`, *optional*, defaults to `"pt"`): + Type of Tensor to return. Only `"pt"` is currently supported. + + Examples: + ```python + >>> from trl.trainer.reward_trainer import DataCollatorForPreference + + >>> collator = DataCollatorForPreference(pad_token_id=0) + >>> examples = [ + ... {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}, + ... {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]}, + ... ] + >>> collator(examples) + {'input_ids': tensor([[1, 2, 3], + [6, 7, 0], + [4, 5, 0], + [8, 0, 0]]), + 'attention_mask': tensor([[1, 1, 1], + [1, 1, 0], + [1, 1, 0], + [1, 0, 0]])} + + >>> examples = [ + ... {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.5}, + ... {"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.0}, + ... ] + >>> collator(examples) + {'input_ids': tensor([[1, 2, 3], + [6, 7, 0], + [4, 5, 0], + [8, 0, 0]]), + 'attention_mask': tensor([[1, 1, 1], + [1, 1, 0], + [1, 1, 0], + [1, 0, 0]]), + 'margin': tensor([0.5, 0.0])} + ``` + """ + + pad_token_id: int + pad_to_multiple_of: int | None = None + return_tensors: str = "pt" + + def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: + # Convert to tensor + chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples] + rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples] + if "margin" in examples[0]: + margins = torch.tensor([example["margin"] for example in examples], dtype=torch.float) + input_ids = chosen_input_ids + rejected_input_ids + attention_mask = [torch.ones_like(ids) for ids in input_ids] + + output = {} + + # Pad + output["input_ids"] = pad( + input_ids, + padding_value=self.pad_token_id, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + output["attention_mask"] = pad( + attention_mask, + padding_value=0, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + if "margin" in examples[0]: + output["margin"] = margins + return output + + +class RewardTrainer(BaseTrainer): + """ + Trainer for Outcome-supervised Reward Models (ORM). + + This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods. + + Example: + + ```python + from trl import RewardTrainer + from datasets import load_dataset + + dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + + trainer = RewardTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + train_dataset=dataset, + ) + trainer.train() + ``` + + Args: + model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using `AutoModelForSequenceClassification.from_pretrained` with the keyword arguments in + `args.model_init_kwargs`. + - A sequence classification [`~transformers.PreTrainedModel`] object. + - A sequence classification [`~peft.PeftModel`] object. + args ([`RewardConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. + data_collator ([`~transformers.DataCollator`], *optional*): + Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. + Will default to [`~trainer.reward_trainer.DataCollatorForPreference`]. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. This trainer supports [preference](#preference) type (both implicit and + explicit prompt). The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + + The trainer also supports processed datasets (tokenized) as long as they contain an `chosen_input_ids` and + `rejected_input_ids` fields. + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*): + Tokenizer used to process the data. If `None`, the tokenizer is loaded from the model's name with + [`~transformers.AutoTokenizer.from_pretrained`]. A padding token, `processing_class.pad_token`, must be + set. If the processing class has not set a padding token, `processing_class.eos_token` will be used as the + default. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a + [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing + [`RewardConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a + boolean `compute_result` argument. This will be triggered after the last eval batch to signal that the + function needs to calculate and return the global summary statistics rather than accumulating the + batch-level statistics. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your + model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*): + A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in + `args`. Incompatible with the `optimizers` argument. + + Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before + initializing the Trainer. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. Note that if the loaded + model is a causal LM, it's highly recommended to set `modules_to_save=["score"]` in the PEFT configuration + to ensure that the reward head is properly trained. + """ + + _tag_names = ["trl", "reward-trainer"] + _name = "Reward" + _template_file = "rm_model_card.md" + + def __init__( + self, + model: "str | PreTrainedModel | PeftModel", + args: RewardConfig | None = None, + data_collator: DataCollator | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | None = None, + compute_metrics: Callable[[EvalPrediction], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None, + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: "PeftConfig | None" = None, + ): + # Args + if args is None: + model_name = model if isinstance(model, str) else get_config_model_id(model.config) + model_name = model_name.split("/")[-1] + args = RewardConfig(f"{model_name}-Reward") + + # IterableDataset requires dispatch_batches=False because Accelerate's dispatch mode may try to concatenate + # batches from multiple processes, leading to mismatch errors. + if isinstance(train_dataset, IterableDataset): + if args.accelerator_config.dispatch_batches is True: + logger.warning( + "You are using an `IterableDataset` for training with `dispatch_batches=True`. `dispatch_batches` " + "is forced to `False` when using an `IterableDataset`. To remove this warning, unset " + "`dispatch_batches` in `RewardConfig` or set it to `False`." + ) + args.accelerator_config.dispatch_batches = False + + # Model + # As AutoModelForSequenceClassification.from_pretrained() will add a random head for the model, set_seed must + # be done before loading the model to ensure reproducibility. + set_seed(args.seed) + if isinstance(model, str): + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + model_init_kwargs["num_labels"] = 1 # the only output of the model is the reward score + with suppress_seqcls_warning(): + model = create_model_from_path(model, AutoModelForSequenceClassification, **model_init_kwargs) + else: + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + # Validate that the model has num_labels = 1 (required for reward models) + if getattr(model.config, "num_labels", None) != 1: + raise ValueError( + f"The model has `num_labels={model.config.num_labels}`, but reward models require `num_labels=1` " + "to output a single scalar reward per sequence. Please instantiate your model with `num_labels=1` " + "or pass a model name as a string to have it configured automatically." + ) + + # Processing class + if processing_class is None: + processing_class = AutoTokenizer.from_pretrained(get_config_model_id(model.config)) + + # Handle pad token for processors or tokenizers + if args.eos_token is not None: + eos_token = args.eos_token + eos_token_id = processing_class.convert_tokens_to_ids(eos_token) + if eos_token_id is None: + raise ValueError( + f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " + "in the vocabulary before using it as an EOS token." + ) + processing_class.eos_token_id = eos_token_id + + if args.chat_template_path is not None: + if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): + with open(args.chat_template_path, encoding="utf-8") as chat_template_file: + processing_class.chat_template = chat_template_file.read() + added_tokens = [] + else: + model, processing_class, added_tokens = clone_chat_template( + model, processing_class, args.chat_template_path + ) + else: + added_tokens = [] + + # PEFT configuration and model wrapping + if peft_config is not None: + if added_tokens: + # Ensure that the added tokens are trainable + if peft_config.trainable_token_indices is None: + peft_config.trainable_token_indices = {"embed_tokens": added_tokens} + elif "embed_tokens" not in peft_config.trainable_token_indices: + peft_config.trainable_token_indices["embed_tokens"] = added_tokens + else: + peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens) + + # Ensure that the lm_head is trainable + if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save: + logger.warning( + "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's " + "`modules_to_save`. As a result, the model may not learn to generate outputs with these new " + "tokens, leading to degraded generation quality. To fix this, add " + "`modules_to_save=['lm_head']` to your PEFT configuration." + ) + + if peft_config.modules_to_save is None: + peft_config.modules_to_save = ["lm_head"] + else: + peft_config.modules_to_save.append("lm_head") + + if is_peft_available() and is_peft_model(model) and peft_config is not None: + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) + + # Create PEFT model + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally + # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 + if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing: + model.enable_input_require_grads() + + # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the + # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by + # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for + # quantized models. See: https://github.com/huggingface/peft/issues/2889 + # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do + if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): + for param in model.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.bfloat16) + + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(model) + + # Pad token (needed for SequenceClassification models) + # If not provided, use the one from the processing class or the eos token if the processing class does not have + # a pad token. + pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token + pad_token_id = processing_class.convert_tokens_to_ids(pad_token) + if pad_token_id is None: + raise ValueError( + f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " + "in the vocabulary before using it as a padding token." + ) + model.config.pad_token_id = pad_token_id + processing_class.pad_token_id = pad_token_id + + # Data collator + if data_collator is None: + data_collator = DataCollatorForPreference( + pad_token_id=pad_token_id, + pad_to_multiple_of=args.pad_to_multiple_of, + ) + + # Dataset + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # During evaluation, Trainer calls compute_loss() only if can_return_loss is True and label_names is empty. + self.can_return_loss = True + self.label_names = [] + + # Initialize activation offloading context + if self.args.activation_offloading: + self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) + else: + self.maybe_activation_offload_context = contextlib.nullcontext() + + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + + # Add tags to the model + self.model.add_model_tags(self._tag_names) + + def _prepare_dataset( + self, + dataset: Dataset | IterableDataset, + processing_class: PreTrainedTokenizerBase, + args: RewardConfig, + dataset_name: str, + ) -> Dataset | IterableDataset: + # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from + # sampled data. + if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform` + dataset = dataset.with_transform(remove_none_values) + + # If the dataset is already preprocessed (tokenized), skip the processing steps. + column_names = get_dataset_column_names(dataset) + is_processed = "chosen_input_ids" in column_names and "rejected_input_ids" in column_names + + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc + + with PartialState().main_process_first(): + if not is_processed: + # Add EOS token to the end of the sequences if needed + first_example = next(iter(dataset)) + if not is_conversational(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" + + def add_eos(example, eos_token): + if not example["chosen"].endswith(eos_token): + example["chosen"] = example["chosen"] + eos_token + if "rejected" in example and not example["rejected"].endswith(eos_token): + example["rejected"] = example["rejected"] + eos_token + return example + + dataset = dataset.map( + add_eos, + fn_kwargs={"eos_token": processing_class.eos_token}, + **map_kwargs, + ) + + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + def tokenize_fn(example, processing_class): + if "prompt" in example: # explicit prompt case + example["chosen"] = example["prompt"] + example["chosen"] + example["rejected"] = example["prompt"] + example["rejected"] + + if is_conversational(example): + chosen_input_ids = processing_class.apply_chat_template( + example["chosen"], + tools=example.get("tools"), + return_dict=True, + **example.get("chat_template_kwargs", {}), + )["input_ids"] + rejected_input_ids = processing_class.apply_chat_template( + example["rejected"], + tools=example.get("tools"), + return_dict=True, + **example.get("chat_template_kwargs", {}), + )["input_ids"] + output = {"chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids} + else: + output = { + "chosen_input_ids": processing_class(text=example["chosen"])["input_ids"], + "rejected_input_ids": processing_class(text=example["rejected"])["input_ids"], + } + return output + + dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs) + + # Filter samples that are longer than `max_length` + if args.max_length is not None: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Filtering {dataset_name} >{args.max_length} tokens" + dataset = dataset.filter( + lambda example: len(example["chosen_input_ids"]) <= args.max_length + and len(example["rejected_input_ids"]) <= args.max_length, + **map_kwargs, + ) + + return dataset + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" + # and "attention_mask"). + if self._signature_columns is None: + self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "margin"] + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + mode = "train" if self.model.training else "eval" + + # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing + inputs["use_cache"] = False + outputs = model(**inputs) + + # Split the rewards into chosen and rejected + rewards_chosen, rewards_rejected = torch.chunk(outputs.logits.squeeze(-1), chunks=2) + + # Calculate loss, optionally modulate with margin + if "margin" in inputs: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean() + else: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() + + if self.args.center_rewards_coefficient is not None: + loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2) + + if mode == "train": + num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() + self._total_train_tokens += num_tokens_in_batch + self._metrics[mode]["num_tokens"] = [self._total_train_tokens] + + # Compute min, mean, max, accuracy and margin + with torch.no_grad(): + all_rewards = self.accelerator.gather(outputs.logits) + self._metrics[mode]["min_reward"].append(all_rewards.min().item()) + self._metrics[mode]["mean_reward"].append(all_rewards.mean().item()) + self._metrics[mode]["max_reward"].append(all_rewards.max().item()) + + mean_accuracy = (rewards_chosen > rewards_rejected).float().mean() + mean_accuracy = self.accelerator.gather_for_metrics(mean_accuracy).mean().item() + self._metrics[mode]["accuracy"].append(mean_accuracy) + + mean_margin = (rewards_chosen - rewards_rejected).mean() + mean_margin = self.accelerator.gather_for_metrics(mean_margin).mean() + self._metrics[mode]["margin"].append(mean_margin.item()) + + return (loss, outputs) if return_outputs else loss + + # Override training step to add activation offloading context. + def training_step(self, *args, **kwargs): + with self.maybe_activation_offload_context: + return super().training_step(*args, **kwargs) + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/ICL/RL/trl_source/trl/trainer/rloo_config.py b/ICL/RL/trl_source/trl/trainer/rloo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..8e55b7ff4bbc70b573b99d74d63f5a92a0f03274 --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/rloo_config.py @@ -0,0 +1,613 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from transformers import TrainingArguments + + +@dataclass +class RLOOConfig(TrainingArguments): + r""" + Configuration class for the [`RLOOTrainer`]. + + This class includes only the parameters that are specific to RLOO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model and reference model + + model_init_kwargs (`str`, `dict[str, Any]`, *optional*): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`RLOOTrainer`] is provided as a string. + disable_dropout (`bool`, *optional*, defaults to `False`): + Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents + the model from generating different logprobs for the same input. + + > Parameters that control the data preprocessing + + remove_unused_columns (`bool`, *optional*, defaults to `False`): + Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that + requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`. + num_generations (`int`, *optional*, defaults to `2`): + Number of generations per prompt to sample. The effective batch size (num_processes * per_device_batch_size + * gradient_accumulation_steps) must be evenly divisible by this value. + num_generations_eval (`int` or `None`, *optional*): + Number of generations to sample during evaluation. This allows using fewer generations during evaluation to + save computation. If `None`, uses the value of `num_generations`. + max_completion_length (`int` or `None`, *optional*, defaults to `256`): + Maximum length of the generated completion. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible + with vLLM generation. + shuffle_dataset (`bool`, *optional*, defaults to `True`): + Whether to shuffle the training dataset. + + > Parameters that control generation + + generation_batch_size: (`int`, *optional*): + Batch size to use for generation. If `None`, it defaults to the effective training batch size: + `per_device_train_batch_size * num_processes * steps_per_generation`. In other words, there is one + generation batch processed per optimization step. Mutually exclusive with `steps_per_generation`. + steps_per_generation: (`int`, *optional*): + Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`. Mutually exclusive + with `generation_batch_size`. + temperature (`float`, defaults to `1.0`): + Temperature for sampling. The higher the temperature, the more random the completions. + top_p (`float`, *optional*, defaults to `1.0`): + Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to + `1.0` to consider all tokens. + top_k (`int`, *optional*, defaults to `0`): + Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, top-k-filtering is + disabled and all tokens are considered. + min_p (`float`, *optional*): + Minimum token probability, which will be scaled by the probability of the most likely token. It must be a + value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. + generation_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to [`~transformers.GenerationConfig`] (if using transformers) or + `SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the + generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that conflict + with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. + chat_template_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to the `apply_chat_template` function when generating completions. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. + Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat + tokens. + use_transformers_paged (`bool`, *optional*, defaults to `False`): + Whether to use the `transformers` paged implementation for generation. If set to `True`, the `transformers` + paged implementation will be used for generation instead of the default padded implementation. This + parameter is only effective when `use_vllm` is set to `False`. + cache_implementation (`str`, *optional*): + Implementation of the cache method for faster generation when `use_vllm` is set to `False`. + + > Parameters that control generation acceleration powered by vLLM + + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation + instead of the default model.generate(). Requires `vllm` to be installed. + vllm_mode (`str`, *optional*, defaults to `"server"`): + Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or + `"colocate"`. + + - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM + server is running (start with `trl vllm-serve`). + - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a + separate server but may cause resource contention with training. + vllm_model_impl (`str`, *optional*, defaults to `"vllm"`): + Model implementation to use for vLLM. Must be one of `"transformers"` or `"vllm"`. `"transformers"`: Use + the `transformers` backend for model implementation. `"vllm"`: Use the `vllm` library for model + implementation. + vllm_structured_outputs_regex (`str`, *optional*): + Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled. + + > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) + + vllm_server_base_url (`str`, *optional*): + Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and + `vllm_server_port` are ignored. + vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. + vllm_server_port (`int`, *optional*, defaults to `8000`): + Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. + vllm_server_timeout (`float`, *optional*, defaults to `240.0`): + Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the + timeout, a `ConnectionError` is raised. + vllm_group_port (`int`, *optional*, defaults to `51216`): + Port number for the weight update group. This is used to communicate with the vLLM server. Unless the port + is occupied, there is no need to change it. + + > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) + + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`): + Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to + `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when + launching the vLLM server via the `--vllm_gpu_memory_utilization` flag. + vllm_max_model_length (`int`, *optional*): + Context window for vLLM. Set it to at least the maximum prompt length in the dataset plus + `max_completion_length`; if omitted, it is inferred from the model config. + vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): + Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to + `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when + launching the vLLM server via the `--vllm_tensor_parallel_size` flag. + vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`): + Enable vLLM sleep mode to offload weights/cache during the optimizer step. Keeps GPU memory usage low, but + waking the engine adds hostโ€“device transfer latency. + + > Parameters that control the training + + beta (`float`, *optional*, defaults to `0.05`): + KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training + speed. + num_iterations (`int`, *optional*, defaults to `1`): + Number of iterations per batch (denoted as ฮผ in the algorithm). + epsilon (`float`, *optional*, defaults to `0.2`): + Epsilon value for clipping. + epsilon_high (`float`, *optional*): + Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound + specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. + reward_weights (`list[float]`, *optional*): + Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are + weighted equally with weight `1.0`. + normalize_advantages (`bool`, *optional*, defaults to `False`): + Whether to normalize advantages. Normalization is done per generation batch to have mean `0.0` and standard + deviation of `1.0`. + reward_clip_range (`tuple[float, float]`, *optional*): + Clip range for rewards as (min, max). If `None`, no clipping is applied. + mask_truncated_completions (`bool`, *optional*, defaults to `False`): + When enabled, truncated completions are excluded from the loss calculation, preventing them from being + incorrectly penalized and introducing noise during training. According to the + [DAPO](https://huggingface.co/papers/2503.14476) paper, this is a good practice for training stability. + sync_ref_model (`bool`, *optional*, defaults to `False`): + Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using + the `ref_model_mixup_alpha` parameter. This synchronization originates from the + [TR-DPO](https://huggingface.co/papers/2404.09656) paper. + ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): + ฮฑ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix + between the current policy and the previous reference policy during updates. The reference policy is + updated according to the equation: `ฯ€_ref = ฮฑ * ฯ€_ฮธ + (1 - ฮฑ) * ฯ€_ref_prev`. To use this parameter, you + must set `sync_ref_model=True`. + ref_model_sync_steps (`int`, *optional*, defaults to `512`): + ฯ„ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how + frequently the current policy is synchronized with the reference policy. To use this parameter, you must + set `sync_ref_model=True`. + + > Parameters that control the logging + + log_completions (`bool`, *optional*, defaults to `False`): + Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed, + it prints the sample. If `wandb` and/or `trackio` logging is enabled, it logs it to `wandb` and/or + `trackio`. + num_completions_to_print (`int`, *optional*): + Number of completions to print with `rich`. If `None`, all completions are logged. + log_unique_prompts (`bool`, *optional*, defaults to `False`): + Whether to log unique prompts. If `True`, only unique prompts are logged. If `False`, all prompts are + logged. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue + # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary + # workaround here, which can be removed once we drop support for versions older than 4.57.5. + lr_scheduler_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " + "restarts." + }, + ) + + # Parameters that control the model and reference model + model_init_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` " + "argument of the `RLOOTrainer` is provided as a string." + }, + ) + disable_dropout: bool = field( + default=False, + metadata={ + "help": "Whether to disable dropout in the model. This is useful for training with a reference model, as " + "it prevents the model from generating different logprobs for the same input." + }, + ) + + # Parameters that control the data preprocessing + # The default value remove_unused_columns is overwritten from the parent class, because in RLOO we usually rely on + # additional columns to compute the reward + remove_unused_columns: bool | None = field( + default=False, + metadata={ + "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function " + "that requires any column other than 'prompts' and 'completions', you should keep this to `False`." + }, + ) + num_generations: int | None = field( + default=2, + metadata={ + "help": "Number of generations to sample. The effective batch size (num_processes * per_device_batch_size " + "* gradient_accumulation_steps) must be evenly divisible by this value." + }, + ) + num_generations_eval: int | None = field( + default=None, + metadata={ + "help": "Number of generations to sample during evaluation. This allows using fewer generations during " + "evaluation to save computation. If `None`, uses the value of `num_generations`." + }, + ) + max_completion_length: int | None = field( + default=256, + metadata={"help": "Maximum length of the generated completion."}, + ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option " + "is not compatible with vLLM generation." + }, + ) + shuffle_dataset: bool | None = field( + default=True, + metadata={"help": "Whether to shuffle the training dataset."}, + ) + + # Parameters that control generation + generation_batch_size: int | None = field( + default=None, + metadata={ + "help": "Batch size to use for generation. If `None`, it defaults to the effective training batch size: " + "`per_device_train_batch_size * num_processes * steps_per_generation`." + }, + ) + steps_per_generation: int | None = field( + default=None, + metadata={"help": "Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`."}, + ) + temperature: float = field( + default=1.0, + metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, + ) + top_p: float = field( + default=1.0, + metadata={ + "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. " + "Set to 1.0 to consider all tokens." + }, + ) + top_k: int = field( + default=0, + metadata={ + "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, " + "top-k-filtering is disabled and all tokens are considered." + }, + ) + min_p: float | None = field( + default=None, + metadata={ + "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It " + "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range." + }, + ) + generation_kwargs: dict | None = field( + default=None, + metadata={ + "help": "Additional keyword arguments to pass to `GenerationConfig` (if using transformers) or " + "`SamplingParams` (if using vLLM) when sampling completions. This can be used to further customize the " + "generation behavior, such as setting `suppress_tokens`, `num_beams`, etc. If it contains keys that " + "conflict with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them." + }, + ) + chat_template_kwargs: dict | None = field( + default=None, + metadata={ + "help": "Additional keyword arguments to pass to the `apply_chat_template` function when generating " + "completions." + }, + ) + repetition_penalty: float = field( + default=1.0, + metadata={ + "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated " + "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model " + "to repeat tokens." + }, + ) + use_transformers_paged: bool = field( + default=False, + metadata={ + "help": "Whether to use the `transformers` paged implementation for generation. If set to `True`, the " + "`transformers` paged implementation will be used for generation instead of the default padded " + "implementation. This parameter is only effective when `use_vllm` is set to `False`." + }, + ) + cache_implementation: str | None = field( + default=None, + metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."}, + ) + + # Parameters that control generation acceleration powered by vLLM + use_vllm: bool = field( + default=False, + metadata={ + "help": "Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for " + "generation instead of the default model.generate(). Requires `vllm` to be installed." + }, + ) + vllm_mode: str = field( + default="server", + metadata={ + "help": "Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `'server'` or " + "`'colocate'`. `'server'`: The trainer will send generation requests to a separate vLLM server. Make sure " + "a TRL vLLM server is running (start with `trl vllm-serve`). `'colocate'`: vLLM will run in the same " + "process and share the training GPUs. This avoids the need for a separate server but may cause resource " + "contention with training." + }, + ) + vllm_model_impl: str = field( + default="vllm", + metadata={ + "help": "Model implementation to use for vLLM. Must be one of `transformers` or `vllm`. `transformers`: " + "Use the `transformers` backend for model implementation. `vllm`: Use the `vllm` library for " + "model implementation." + }, + ) + vllm_enable_sleep_mode: bool = field( + default=False, + metadata={ + "help": "Enable vLLM sleep mode to offload weights/cache during the optimizer step. Keeps GPU memory " + "usage low, but waking the engine adds hostโ€“device transfer latency." + }, + ) + vllm_structured_outputs_regex: str | None = field( + default=None, + metadata={"help": "Regex for vLLM structured outputs. If `None` (default), structured outputs is disabled."}, + ) + + # Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) + vllm_server_base_url: str | None = field( + default=None, + metadata={ + "help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " + "and `vllm_server_port` are ignored." + }, + ) + vllm_server_host: str = field( + default="0.0.0.0", + metadata={"help": "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."}, + ) + vllm_server_port: int = field( + default=8000, + metadata={"help": "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."}, + ) + vllm_server_timeout: float = field( + default=240.0, + metadata={ + "help": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up " + "after the timeout, a `ConnectionError` is raised." + }, + ) + vllm_group_port: int = field( + default=51216, + metadata={ + "help": "Port number for the weight update group. This is used to communicate with the vLLM server. " + "Unless the port is occupied, there is no need to change it.", + }, + ) + + # Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) + vllm_gpu_memory_utilization: float = field( + default=0.3, + metadata={ + "help": "Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set " + "to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when " + "launching the vLLM server via the `--vllm_gpu_memory_utilization` flag." + }, + ) + vllm_max_model_length: int | None = field( + default=None, + metadata={ + "help": "Context window for vLLM. Set it to at least the maximum prompt length in the dataset plus " + "`max_completion_length`; if omitted, it is inferred from the model config." + }, + ) + vllm_tensor_parallel_size: int = field( + default=1, + metadata={ + "help": "Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set " + "to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when " + "launching the vLLM server via the `--vllm_tensor_parallel_size` flag." + }, + ) + + # Parameters that control the training + beta: float = field( + default=0.05, + metadata={ + "help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving " + "training speed." + }, + ) + num_iterations: int = field( + default=1, + metadata={"help": "Number of iterations per batch (denoted as ฮผ in the algorithm)."}, + ) + epsilon: float = field( + default=0.2, + metadata={"help": "Epsilon value for clipping."}, + ) + epsilon_high: float | None = field( + default=None, + metadata={ + "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the " + "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`." + }, + ) + reward_weights: list[float] | None = field( + default=None, + metadata={ + "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all " + "rewards are weighted equally with weight `1.0`." + }, + ) + normalize_advantages: bool = field( + default=False, + metadata={ + "help": "Whether to normalize advantages. Normalization is done per generation batch to have mean `0.0` " + "and standard deviation of `1.0`." + }, + ) + reward_clip_range: tuple[float, float] | None = field( + default=None, + metadata={"help": "Clip range for rewards as (min, max). If None, no clipping is applied."}, + ) + mask_truncated_completions: bool = field( + default=False, + metadata={ + "help": "When enabled, truncated completions are excluded from the loss calculation, preventing them from " + "being incorrectly penalized and introducing noise during training. According to the DAPO paper, this is " + "a good practice for training stability." + }, + ) + sync_ref_model: bool = field( + default=False, + metadata={ + "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " + "steps, using the `ref_model_mixup_alpha` parameter." + }, + ) + ref_model_mixup_alpha: float = field( + default=0.6, + metadata={ + "help": "ฮฑ parameter from the TR-DPO paper, which controls the mix between the current policy and the " + "previous reference policy during updates. The reference policy is updated according to the equation: " + "`ฯ€_ref = ฮฑ * ฯ€_ฮธ + (1 - ฮฑ) * ฯ€_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + ref_model_sync_steps: int = field( + default=512, + metadata={ + "help": "ฯ„ parameter from the TR-DPO paper, which determines how frequently the current policy is " + "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + + # Parameters that control the logging + log_completions: bool = field( + default=False, + metadata={ + "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is " + "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`." + }, + ) + num_completions_to_print: int | None = field( + default=None, + metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."}, + ) + log_unique_prompts: bool = field( + default=False, + metadata={ + "help": "Whether to log unique prompts. If `True`, only unique prompts are logged. If `False`, all " + "prompts are logged." + }, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + super().__post_init__() + + num_processes = self.world_size + # The current default effective batch size + if self.generation_batch_size is None and self.steps_per_generation is None: + self.steps_per_generation = self.gradient_accumulation_steps + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + elif self.generation_batch_size is not None and self.steps_per_generation is None: + # Just ensure the value is divisible by the global batch size + if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size " + f"({self.per_device_train_batch_size * num_processes})." + ) + self.steps_per_generation = self.generation_batch_size // ( + self.per_device_train_batch_size * num_processes + ) + elif self.generation_batch_size is None and self.steps_per_generation is not None: + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + else: + raise ValueError( + "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time" + ) + + if self.do_eval and self.eval_strategy != "no": + # Determine the number of generations to use for evaluation + num_generations = self.num_generations_eval or self.num_generations + + # Just ensure the value is divisible by the global batch size + if (self.per_device_eval_batch_size * num_processes) % num_generations != 0: + raise ValueError( + f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be " + f"divisible by the number of generations used for evaluation ({num_generations})." + ) + + # The generation batch must contain full prompt groups (no partials), so it must be divisible by + # num_generations. + if self.generation_batch_size % self.num_generations != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations " + f"({self.num_generations})." + ) + + if self.num_generations < 2: + raise ValueError( + "RLOO requires at least 2 generations per prompt to calculate the advantages. You provided " + f"{self.num_generations}, which is less than the minimum required." + ) diff --git a/ICL/RL/trl_source/trl/trainer/rloo_trainer.py b/ICL/RL/trl_source/trl/trainer/rloo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8aa2d4e7c99bdb04f2f44dd721e5bdd7a8d725 --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/rloo_trainer.py @@ -0,0 +1,1394 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import atexit +import copy +import inspect +import textwrap +import time +from collections import defaultdict, deque +from collections.abc import Callable +from contextlib import nullcontext +from functools import partial +from pathlib import Path +from typing import Any + +import datasets +import pandas as pd +import torch +import torch.utils.data +import transformers +from accelerate.logging import get_logger +from accelerate.utils import gather, gather_object, is_peft_model, set_seed +from datasets import Dataset, IterableDataset +from packaging.version import Version +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import DataLoader, Sampler +from transformers import ( + AutoModelForSequenceClassification, + AutoProcessor, + AutoTokenizer, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, + is_trackio_available, + is_wandb_available, +) +from transformers.trainer_utils import seed_worker +from transformers.utils import is_datasets_available, is_peft_available, is_rich_available + +from ..data_utils import ( + apply_chat_template, + is_conversational, + prepare_multimodal_messages, +) +from ..extras.profiling import profiling_context, profiling_decorator +from ..generation.vllm_generation import VLLMGeneration +from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation +from ..models.utils import disable_gradient_checkpointing +from .base_trainer import BaseTrainer +from .callbacks import SyncRefModelCallback +from .rloo_config import RLOOConfig +from .utils import ( + RepeatSampler, + create_model_from_path, + disable_dropout_in_model, + entropy_from_logits, + get_config_model_id, + identity, + nanmax, + nanmin, + nanstd, + pad, + print_prompt_completions_sample, + selective_log_softmax, + shuffle_sequence_dict, + shutdown_event_loop_in_daemon, + split_pixel_values_by_grid, + split_tensor_dict, + start_event_loop_in_daemon, + unsplit_pixel_values_by_grid, + use_adapter, +) + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model + + +if is_wandb_available(): + import wandb + +if is_trackio_available(): + import trackio + + +logger = get_logger(__name__) + +# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of +# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. +RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]] + + +class RLOOTrainer(BaseTrainer): + """ + Trainer for the Reinforce Leave One Out (RLOO) method. This algorithm was initially proposed in the paper [Back to + Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in + LLMs](https://huggingface.co/papers/2402.14740). + + Example: + + ```python + from trl import RLOOTrainer + from trl.rewards import accuracy_reward + from datasets import load_dataset + + dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + + trainer = RLOOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=accuracy_reward, + train_dataset=dataset, + ) + trainer.train() + ``` + + Args: + model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using `.from_pretrained` (where `` is derived from the model + config) with the keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + - A [`~peft.PeftModel`] object. Only causal language models are supported. + reward_funcs (`RewardFunc | list[RewardFunc]`): + Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward + functions with the prompts and completions and sum the rewards. Can be either: + + - A single reward function, such as: + - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the + keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. + - A custom reward function: The function is provided with the prompts and the generated completions, + plus any additional columns in the dataset. It should return a list of rewards. Custom reward + functions can be either synchronous or asynchronous and can also return `None` when the reward is + not applicable to those samples. This is useful for multi-task training where different reward + functions apply to different types of samples. When a reward function returns `None` for a sample, + that reward function is excluded from the reward calculation for that sample. For more details, see + [Using a custom reward + function](#using-a-custom-reward-function). + + The trainer's state is also passed to the reward function. The trainer's state is an instance of + [`~transformers.TrainerState`] and can be accessed by accessing the `trainer_state` argument to the + reward function's signature. + - A list of reward functions, where each item can independently be any of the above types. Mixing different + types within the list (e.g., a string model ID and a custom reward function) is allowed. + args ([`RLOOConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is + ignored. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. The padding side must be set to "left". If `None`, the + processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A + padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token, + `tokenizer.eos_token` will be used as the default. + reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*): + Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: + + - A single processing class: Used when `reward_funcs` contains only one reward function. + - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. + If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is + `None`, the tokenizer for the model is automatically loaded using + [`~transformers.AutoTokenizer.from_pretrained`]. For elements in `reward_funcs` that are custom reward + functions (not [`~transformers.PreTrainedModel`]), the corresponding entries in `reward_processing_classes` + are ignored. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your + model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + """ + + _tag_names = ["trl", "rloo"] + _name = "RLOO" + _paper = { + "title": "Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs", + "id": "2402.14740", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @inproceedings{ahmadian2024back, + title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}}, + author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker}, + year = 2024, + booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024}, + pages = {12248--12267}, + publisher = {Association for Computational Linguistics}, + editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar}, + }"""), + } + + def __init__( + self, + model: "str | PreTrainedModel | PeftModel", + reward_funcs: RewardFunc | list[RewardFunc], + args: RLOOConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + peft_config: "PeftConfig | None" = None, + ): + # Args + if args is None: + model_name = model if isinstance(model, str) else get_config_model_id(model.config) + model_name = model_name.split("/")[-1] + args = RLOOConfig(f"{model_name}-RLOO") + + # Model + if isinstance(model, str): + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) + else: + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `RLOOConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + + # Some models (SmolVLM/Idefics3) don't support `logits_to_keep` argument and error out if we pass it + # Inspect the forward method before we wrap the model with PEFT + self.model_kwarg_keys = ( + inspect.signature(model.forward).parameters.keys() + if not hasattr(model, "get_base_model") + else inspect.signature(model.get_base_model().forward).parameters.keys() + ) + + # Processing class + if processing_class is None: + processing_class = AutoProcessor.from_pretrained( + get_config_model_id(model.config), truncation_side="left", padding_side="left" + ) + + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + self.pad_token = tokenizer.pad_token + self.pad_token_id = tokenizer.pad_token_id + self.eos_token_id = tokenizer.eos_token_id + + if is_peft_available() and is_peft_model(model) and peft_config is not None: + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) + if is_peft_available() and is_peft_model(model): + # If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy + # of the "default" adapter, so that we can use it as the reference model during the training. + model.add_adapter("ref", model.peft_config["default"]) + for name, param in model.named_parameters(): + if ".default." in name: + ref_name = name.replace(".default.", ".ref.") + ref_param = model.get_parameter(ref_name) + ref_param.data.copy_(param.data) + + # Create PEFT model + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally + # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 + if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing: + model.enable_input_require_grads() + + # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the + # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by + # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for + # quantized models. See: https://github.com/huggingface/peft/issues/2889 + # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do + if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): + for param in model.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.bfloat16) + + # Reward functions + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + self.reward_func_names = [] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, num_labels=1, **model_init_kwargs + ) + if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models + self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1]) + else: + self.reward_func_names.append(reward_funcs[i].__name__) + self.reward_funcs = reward_funcs + + self._has_async_reward_funcs = any(asyncio.iscoroutinefunction(func) for func in self.reward_funcs) + if self._has_async_reward_funcs: + self.async_reward_loop_thread, self.async_reward_loop, self.async_reward_loop_ready_event = ( + start_event_loop_in_daemon(name="RLOOTrainer-AsyncRewardLoop") + ) + # wait until the event loop is running in the daemon thread + self.async_reward_loop_ready_event.wait() + atexit.register(shutdown_event_loop_in_daemon, self.async_reward_loop_thread, self.async_reward_loop) + + # Reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError( + f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " + f"functions ({len(reward_funcs)})" + ) + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + + # Reward processing class + if reward_processing_classes is None: + reward_processing_classes = [None] * len(reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + if len(reward_processing_classes) != len(reward_funcs): + raise ValueError( + f"The number of reward processing classes ({len(reward_processing_classes)}) must match the number of " + f"reward functions ({len(reward_funcs)})." + ) + + for i, (reward_processing_class, reward_func) in enumerate( + zip(reward_processing_classes, reward_funcs, strict=True) + ): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config)) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = reward_processing_class.eos_token + # The reward model computes the reward for the latest non-padded token in the input sequence. + # So it's important to set the pad token ID to the padding token ID of the processing class. + reward_func.config.pad_token_id = reward_processing_class.pad_token_id + reward_processing_classes[i] = reward_processing_class + + self.reward_processing_classes = reward_processing_classes + + # Training arguments + self.max_completion_length = args.max_completion_length + self.num_generations = args.num_generations + self.num_generations_eval = args.num_generations_eval or self.num_generations + self.chat_template_kwargs = args.chat_template_kwargs or {} + self.temperature = args.temperature + self.top_p = args.top_p + self.top_k = args.top_k + self.min_p = args.min_p + self.repetition_penalty = args.repetition_penalty + self.use_transformers_paged = args.use_transformers_paged + self.use_vllm = args.use_vllm + self.vllm_mode = args.vllm_mode + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode + self.normalize_advantages = args.normalize_advantages + self.mask_truncated_completions = args.mask_truncated_completions + self.reward_clip_range = args.reward_clip_range + + # Datasets + self.shuffle_dataset = args.shuffle_dataset + + if ( + isinstance(train_dataset, IterableDataset) + or isinstance(eval_dataset, IterableDataset) + or ( + isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) + ) + ): + # See https://github.com/huggingface/trl/issues/3213 + raise NotImplementedError( + "Iterable datasets are not yet supported in RLOOTrainer. Please use a standard dataset instead." + ) + + # Multi-step + self.num_iterations = args.num_iterations + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle + self._step = 0 + # Buffer the batch to reuse generated outputs across multiple updates. For more details, see + # `_get_train_sampler` and `_prepare_inputs`. + self._buffered_inputs = None + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model=model, + args=args, + data_collator=identity, # No data collation is needed in RLOO + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + ) + + # Reference model + self.beta = args.beta + if self.beta == 0.0: + # If beta is 0.0, the reference model is not needed + self.ref_model = None + elif is_peft_model(model): + # If PEFT is used, the reference model is not needed since the adapter can be disabled + # to revert to the initial model. + self.ref_model = None + else: + # For deepspeed, fsdp or non-distributed models, create a reference model from scratch + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if self.args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + self.ref_model = create_model_from_path(get_config_model_id(self.model.config), **model_init_kwargs) + + # Disable dropout in the models + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + self._current_train_step_time = 0.0 + self.log_completions = args.log_completions + self.log_unique_prompts = args.log_unique_prompts + self.num_completions_to_print = args.num_completions_to_print + # Keep logs sized to the generation batch to record only outputs from the latest model update. + self._logs = { + "images": deque(maxlen=args.generation_batch_size), + "prompt": deque(maxlen=args.generation_batch_size), + "completion": deque(maxlen=args.generation_batch_size), + "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), + "advantages": deque(maxlen=args.generation_batch_size), + } + + # Ensure each process receives a unique seed to prevent duplicate completions when generating with + # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but + # it's safer to set it in all cases. + set_seed(args.seed, device_specific=True) + + if self.use_vllm: + # Initialize vLLM generation backend + self.vllm_generation = VLLMGeneration( + model=self.model, + accelerator=self.accelerator, + is_fsdp_enabled=self.is_fsdp_enabled, + processing_class=self.processing_class, + # vLLM configuration + mode=args.vllm_mode, + structured_outputs_regex=args.vllm_structured_outputs_regex, + # Server mode configuration + server_base_url=args.vllm_server_base_url, + server_host=args.vllm_server_host, + server_port=args.vllm_server_port, + group_port=args.vllm_group_port, + server_timeout=args.vllm_server_timeout, + # Colocate mode configuration + tensor_parallel_size=args.vllm_tensor_parallel_size, + gpu_memory_utilization=args.vllm_gpu_memory_utilization, + max_model_length=args.vllm_max_model_length, + max_num_seqs=args.per_device_train_batch_size + * args.vllm_tensor_parallel_size + * args.steps_per_generation, + enable_sleep_mode=args.vllm_enable_sleep_mode, + model_impl=args.vllm_model_impl, + # Generation configuration + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + max_completion_length=self.max_completion_length, + generation_kwargs=args.generation_kwargs, + # Chat/tool configuration + chat_template_kwargs=self.chat_template_kwargs, + ) + self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation + else: + generation_kwargs = { + "max_new_tokens": self.max_completion_length, + "do_sample": True, + "pad_token_id": tokenizer.pad_token_id, + "bos_token_id": tokenizer.bos_token_id, + "eos_token_id": tokenizer.eos_token_id, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "min_p": self.min_p, + "repetition_penalty": self.repetition_penalty, + "cache_implementation": args.cache_implementation, + } + if args.generation_kwargs is not None: + generation_kwargs.update(args.generation_kwargs) + self.generation_config = GenerationConfig(**generation_kwargs) + # Keep training-specific generation kwargs to overwrite model's original generation config + self.generation_kwargs = generation_kwargs + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags to the model + self.model.add_model_tags(self._tag_names) + + if self.ref_model is not None: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + elif self.is_fsdp_enabled: + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + if args.sync_ref_model: + if self.beta == 0.0: + raise ValueError( + "You passed `sync_ref_model=True` while `beta=0.0`, which means the reference model is not used " + "during training. Consequently, RLOOTrainer does not create a `ref_model` instance, and there is " + "nothing to synchronize. Please set `sync_ref_model=False`, or set `beta` to a non-zero value." + ) + if is_peft_model(model): + raise NotImplementedError( + "You passed `sync_ref_model=True` while using a PEFT model, which is currently not supported. " + "With PEFT, RLOOTrainer does not keep a separate reference model in memory; instead, it recovers " + "reference behavior by temporarily disabling the adapter. As a result, there is no standalone " + "`ref_model` instance to synchronize. Use `sync_ref_model=False`, or opt for full fine-tuning if " + "you need a synced reference model. If you need `sync_ref_model` to work with PEFT, please open a " + "feature request at https://github.com/huggingface/trl/issues." + ) + self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + else: + # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp + self.reward_funcs[i] = self.accelerator.prepare_model( + reward_func, evaluation_mode=True, device_placement=True + ) + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" + # and "attention_mask"). In RLOOTrainer, we preprocess data, so using the model's signature columns doesn't + # work. Instead, we set them to the columns expected by the `training_step` method, hence the override. + if self._signature_columns is None: + self._signature_columns = ["prompt", "image", "images"] + + # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. + # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an + # *generation* batch (i.e., `per_device_batch_size ร— steps_per_generation`). This allows us to generate completions + # once every steps_per_generation stepโ€”rather than once per accumulation stepโ€”which is significantly more + # efficient. The only change from the original implementation is multiplying the batch size by + # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the + # splitting internally. + # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line + # modification. As a result, some parts of the method aren't relevant to RLOO, but we keep them to stay one line + # apart from the super method, ensuring easier maintenance in the future. + def get_train_dataloader(self): + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size * self.args.steps_per_generation, # < this is the change + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = partial( + seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index + ) + + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler: + # Returns a sampler that + # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are + # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt + # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies + # in group formation. + # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to + # _prepare_inputs to see how the generations are stored and reused. + + # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the + # second row shows the second sampled batch, and so on. + # + # | GPU 0 | GPU 1 | + # + # global_step step <-โ”€โ”€โ”€> num_generations=2 + # <-โ”€โ”€โ”€โ”€โ”€โ”€โ”€> per_device_train_batch_size=3 + # grad_accum โ–ฒ โ–ฒ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss + # =2 โ–ผ | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss + # | + # | 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss + # steps_per_gen=4 โ–ผ 1 3 9 9 10 10 11 11 <- Take the stored generations and use the fourth slice to compute the loss + # + # 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss + # 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss + # ... + if dataset is None: + dataset = self.train_dataset + return RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size // self.num_generations, + repeat_count=self.num_iterations * self.args.steps_per_generation, + shuffle=self.shuffle_dataset, + seed=self.args.seed, + ) + + def _get_eval_sampler(self, eval_dataset) -> Sampler: + # See _get_train_sampler for an explanation of the sampler. + return RepeatSampler( + data_source=eval_dataset, + mini_repeat_count=self.num_generations_eval, + seed=self.args.seed, + ) + + @profiling_decorator + def _get_per_token_logps_and_entropies( + self, + model, + input_ids, + attention_mask, + logits_to_keep, + batch_size=None, + compute_entropy=False, + pixel_values=None, + image_grid_thw=None, + num_images=None, + pixel_attention_mask=None, + image_sizes=None, + token_type_ids=None, + ) -> dict[str, torch.Tensor | None]: + """Compute log-probs and (optionally) entropies for each token.""" + batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak + all_logps = [] + all_entropies = [] + for start in range(0, input_ids.size(0), batch_size): + input_ids_batch = input_ids[start : start + batch_size] + attention_mask_batch = attention_mask[start : start + batch_size] + + # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) + model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} + if image_grid_thw is not None and pixel_values is not None: + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) + row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() + model_inputs["pixel_values"] = pixel_values[row_start:row_end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + if pixel_attention_mask is not None: + model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] + if image_sizes is not None: + model_inputs["image_sizes"] = image_sizes[start : start + batch_size] + if token_type_ids is not None: + model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size] + + # Only add logits_to_keep if the model supports it + if "logits_to_keep" in self.model_kwarg_keys: + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings + + logits = model(**model_inputs).logits + # Exclude the last value: it corresponds to the next token pred + logits = logits[:, :-1, :] # (B, L-1, H) + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H) + # Divide logits by sampling temperature. + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + logits = logits / self.temperature + completion_ids = input_ids_batch[:, -logits_to_keep:] + logps = selective_log_softmax(logits, completion_ids) # compute logprobs + all_logps.append(logps) + + if compute_entropy: + with torch.no_grad(): + entropies = entropy_from_logits(logits) + all_entropies.append(entropies) + + logps = torch.cat(all_logps, dim=0) + entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None + return logps, entropies + + def training_step(self, model, inputs, num_items_in_batch): + time_before = time.perf_counter() + output = super().training_step(model, inputs, num_items_in_batch) + self._step += 1 + time_after = time.perf_counter() + self._current_train_step_time += time_after - time_before + if self._step % self.current_gradient_accumulation_steps == 0: + self._metrics["train"]["step_time"].append(self._current_train_step_time) + self._current_train_step_time = 0.0 + return output + + @profiling_decorator + def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]: + # Prepares inputs for model training/evaluation by managing completion generation and batch handling. + # During training: + # - Receives the local generation batch (Per-GPU batch size ร— steps per generation) + # from the modified training dataloader instead of the standard local batch + # - Generates completions once for the entire generation batch and splits it into batches of size + # `per_device_train_batch_size` + # - Buffers these completions and returns the appropriate slice for the current accumulation step + # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations) + # During evaluation: + # - The input is treated as a standard local batch (no accumulation, no multiple iterations) + # - Completions are generated for each batch without buffering or reuse + # Returns a single local batch in both cases. + + mode = "train" if self.model.training else "eval" + if mode == "train": + generate_every = self.args.steps_per_generation * self.num_iterations + if self._step % generate_every == 0 or self._buffered_inputs is None: + # self._buffered_inputs=None can occur when resuming from a checkpoint + generation_batch = self._generate_and_score_completions(generation_batch) + generation_batch = split_pixel_values_by_grid(generation_batch) + generation_batch = shuffle_sequence_dict(generation_batch) + generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation) + self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches] + inputs = self._buffered_inputs[self._step % self.args.steps_per_generation] + else: + # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence + # local generation batch == local eval batch + inputs = self._generate_and_score_completions(generation_batch) + return inputs + + @profiling_decorator + def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): + device = self.accelerator.device + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + + # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations + keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + + # This allows for dynamic reward shaping based on training progress. + reward_kwargs["trainer_state"] = self.state + + async_funcs_info = [] # async custom functions for asyncio.gather + + for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names, strict=True) + ): + if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models + with profiling_context(self, reward_func_name): + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)] + texts = [ + apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"] + for x in messages + ] + else: + texts = [p + c for p, c in zip(prompts, completions, strict=True)] + reward_inputs = reward_processing_class( + text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + reward_inputs = super()._prepare_inputs(reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + elif asyncio.iscoroutinefunction(reward_func): # Separate async reward funcs to run them in parallel later + async_funcs_info.append((i, reward_func, reward_func_name)) + else: + # Run synchronous reward function + with profiling_context(self, reward_func_name): + output_reward_func = reward_func( + prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs + ) + # Convert None values to NaN + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # Execute async custom functions in parallel using asyncio.gather + if async_funcs_info: + + async def _invoke_async_reward(index, func, func_name): + with profiling_context(self, func_name): + output = await func( + prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs + ) + output = [r if r is not None else torch.nan for r in output] + return index, output + + async def _run_async_funcs(): + coros = [_invoke_async_reward(i, func, func_name) for (i, func, func_name) in async_funcs_info] + return await asyncio.gather(*coros) + + async_results = asyncio.run_coroutine_threadsafe(_run_async_funcs(), self.async_reward_loop).result() + for idx, output_reward_func in async_results: + rewards_per_func[:, idx] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # If all reward functions return None for a given row, issue a detailed warning + if torch.isnan(rewards_per_func).all(dim=1).any(): + nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] + row_reward_kwargs = { + key: value[nan_row_idx] for key, value in reward_kwargs.items() if key != "trainer_state" + } + row_reward_kwargs["prompt"] = prompts[nan_row_idx] + row_reward_kwargs["completion"] = completions[nan_row_idx] + logger.warning( + f"All reward functions returned None for the following kwargs:\n{row_reward_kwargs}\n" + "Please ensure that at least one reward function returns a valid reward." + ) + + # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the + # completions may be distributed across processes + rewards_per_func = gather(rewards_per_func) + return rewards_per_func + + def _generate_single_turn(self, prompts: list): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + # Generate completions using either vLLM or regular generation + if self.use_vllm: + # Sync weights if training step changed + if self.state.global_step != self._last_loaded_step: + with profiling_context(self, "sync_weights"): + self.vllm_generation.sync_weights() + self._last_loaded_step = self.state.global_step + + # Generate using vLLM (note: RLOO doesn't use logprobs from generation, so we ignore them) + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + prompt_ids, completion_ids, _, _ = self.vllm_generation.generate( + prompts=prompts, num_generations=num_generations, profiler=profiling_context(self, "vLLM.generate") + ) + + elif self.use_transformers_paged: + if is_conversational({"prompt": prompts[0]}): + processor_outputs = self.processing_class.apply_chat_template( + conversation=prompts, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + **self.chat_template_kwargs, + ) + else: + processor_outputs = self.processing_class(text=prompts) + + with ( + profiling_context(self, "transformers.generate_batch"), + unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + # Cast to the appropriate dtype based on training configuration + if self.args.bf16: + unwrapped_model.to(torch.bfloat16) + elif self.args.fp16: + unwrapped_model.to(torch.float16) + with torch.inference_mode(): + # Continuous batching API expects 'inputs' arg only + all_outputs = unwrapped_model.generate_batch( + processor_outputs["input_ids"], generation_config=self.generation_config, progress_bar=False + ) + unwrapped_model.train() # restore training mode, as generate_batch forces eval mode + completion_ids = [output.generated_tokens for output in all_outputs.values()] + prompt_ids = processor_outputs["input_ids"] + + else: + # Regular generation path + if is_conversational({"prompt": prompts[0]}): + generate_inputs = self.processing_class.apply_chat_template( + conversation=prompts, + add_generation_prompt=True, + tokenize=True, + padding=True, + padding_side="left", + return_tensors="pt", + return_dict=True, + **self.chat_template_kwargs, + ) + else: + generate_inputs = self.processing_class( + text=prompts, padding=True, padding_side="left", return_tensors="pt" + ) + generate_inputs = super()._prepare_inputs(generate_inputs) + + with ( + profiling_context(self, "transformers.generate"), + unwrap_model_for_generation( + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + prompt_completion_ids = unwrapped_model.generate( + **generate_inputs, generation_config=self.generation_config, disable_compile=True + ) + # Compute prompt length and extract completion ids + prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] + prompt_length = prompt_ids.size(1) + completion_ids = prompt_completion_ids[:, prompt_length:] + + # Mask everything after the first EOS token + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=True)] + completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=True)] + + return prompt_ids, completion_ids + + def _generate(self, prompts: list): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + # Copy the prompts to avoid modifying the original list + prompts = copy.deepcopy(prompts) + + prompt_ids, completion_ids = self._generate_single_turn(prompts) + + # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. + if is_conversational({"prompt": prompts[0]}): + contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + completions = [[{"role": "assistant", "content": content}] for content in contents] + else: + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + + # Get completion length per sequence, used for logging + prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) + completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) + agg_prompt_lengths = self.accelerator.gather(prompt_lengths) + agg_completion_lengths = self.accelerator.gather(completion_lengths) + total_prompt_tokens = agg_prompt_lengths.sum() + total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss + + # Log the metrics + if mode == "train": + self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Log completion lengths, mean, min, max + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Identify sequences that terminated with EOS and log their lengths + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] + if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + return prompt_ids, completion_ids, completions + + def _generate_and_score_completions( + self, inputs: list[dict[str, torch.Tensor | Any]] + ) -> dict[str, torch.Tensor | Any]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + # Transformers requires at least one image in the batch, otherwise it throws an error + if images is not None and all(img_list == [] for img_list in images): + images = None + + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What color is the sky?"}]}] + if images is not None: + prompts = [ + prepare_multimodal_messages(prompt, image_list) + for prompt, image_list in zip(prompts, images, strict=True) + ] + + prompt_ids_list, completion_ids_list, completions = self._generate(prompts) + + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() + + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + num_images = [len(img_list) for img_list in images] if images is not None else None + + # Get forward_kwargs for models with multimodal inputs + if images is not None: + prompts_text = [ + apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] + for prompt in prompts + ] + prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + + # If token_type_ids are used, extend them with zeros for the completion part + if "token_type_ids" in forward_kwargs: + token_type_ids = forward_kwargs["token_type_ids"] + forward_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 + ) + + # When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a + # torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True"). + # Temporarily disable checkpointing to avoid this warning during inference. + with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): + # Compute the per-token log probabilities for the current model + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS + + # Compute the per-token log probabilities for the reference model + if self.beta != 0.0: + if self.ref_model is not None: + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + # When training a PEFT adapter, how we obtain the reference depends on the setup: + # - New adapter: disabling adapters yields the base model. + # - Re-training an existing adapter: an initial copy is loaded under the name "ref". + model = self.accelerator.unwrap_model(self.model) + with use_adapter(model, adapter_name="ref" if "ref" in model.peft_config else None): + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + ref_per_token_logps = None + + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + + # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is + # important because rewards will be normalized per group, and completions are distributed. We will later slice + # rewards_per_func to extract each process's subset. + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + num_generations = self.num_generations if mode == "train" else self.num_generations_eval + + # Apply weights to each reward function's output and sum + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + # Apply reward clipping if specified + if self.reward_clip_range: + rewards = rewards.clamp(min=self.reward_clip_range[0], max=self.reward_clip_range[1]) + + # Include the KL penalty in the reward + if self.beta != 0.0: + per_token_kl = old_per_token_logps - ref_per_token_logps + # Apply sequence-level KL penalty to rewards (sum KL across tokens first, then apply to each sequence) + kl = (per_token_kl * completion_mask).sum(-1) + kl = gather(kl) # rewards are gathered, so kl must be too + rewards = rewards - self.beta * kl + + grouped_rewards = rewards.view(-1, num_generations) + mean_grouped_rewards = grouped_rewards.mean(dim=1) + if num_generations > 1: + std_rewards = grouped_rewards.std(dim=1) + else: # doesn't occur during training, but could occur in eval when num_generations_eval=1 + std_rewards = torch.zeros_like(mean_grouped_rewards) + + # RLOO advantages computation + grouped_sum = grouped_rewards.sum(dim=1, keepdim=True) # (num_prompts, 1) + if num_generations > 1: + baselines = (grouped_sum - grouped_rewards) / (num_generations - 1) # (num_prompts, num_generations) + baselines = baselines.view(-1) # Flatten back to match rewards shape + advantages = rewards - baselines + else: # this case doesn't occur during training, but could in eval when num_generations_eval=1 + advantages = torch.zeros_like(rewards) + + # Normalize advantages + if self.normalize_advantages: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-4) + + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + all_process_advantages = advantages.clone() # keep the aggregated advantages for logging + advantages = advantages[process_slice] + + # Calculate and log the mean KL divergence between current and reference model + if self.beta != 0.0: + mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) + + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) + std_func_rewards = nanstd(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards) + rewards = rewards_per_func.nansum(dim=1) + self._metrics[mode]["reward"].append(rewards.mean().item()) + self._metrics[mode]["reward_std"].append(rewards.std().item()) + self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) + + # Log prompt and completion texts + self._logs["prompt"].extend(gather_object(prompts_text)) + self._logs["completion"].extend(gather_object(completions_text)) + for i, name in enumerate(self.reward_func_names): + self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) + self._logs["advantages"].extend(all_process_advantages.tolist()) + + if images is not None: + self._logs["images"].extend(gather_object(images)) + + output = { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "old_logps": old_logps, + "advantages": advantages, + } + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] + if "token_type_ids" in forward_kwargs: + output["token_type_ids"] = forward_kwargs["token_type_ids"] + if images is not None: + output["num_images"] = num_images + return output + + @profiling_decorator + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The RLOOTrainer does not support returning outputs") + return self._compute_loss(model, inputs) + + def _compute_loss(self, model, inputs): + # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + # Compute the per_token_logps and the entropy at each position in the completion + per_token_logps, entropies = self._get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=True, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), + pixel_attention_mask=inputs.get("pixel_attention_mask"), + image_sizes=inputs.get("image_sizes"), + token_type_ids=inputs.get("token_type_ids"), + ) + + logps = (per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS + old_logps = inputs["old_logps"] + log_ratio = logps - old_logps + + # Compute the loss + advantages = inputs["advantages"] + coef_1 = torch.exp(log_ratio) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + per_sequence_loss1 = coef_1 * advantages + per_sequence_loss2 = coef_2 * advantages + per_sequence_loss = -torch.min(per_sequence_loss1, per_sequence_loss2) + loss = per_sequence_loss.mean() + + # Log the metrics + mode = "train" if self.model.training else "eval" + + # Entropy + mean_entropy = (entropies * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) + + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0) + is_region_clipped = is_low_clipped | is_high_clipped + gathered_low_clip = self.accelerator.gather(is_low_clipped.float().mean()) + self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) + gathered_high_clip = self.accelerator.gather(is_high_clipped.float().mean()) + self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) + gathered_clip_ratio = self.accelerator.gather(is_region_clipped.float().mean()) + self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + return loss + + # During eval, Trainer calls prediction_step. If no labels are present in the inputs, it only runs forward and + # returns logits. We override prediction_step to force compute_loss, because this trainer doesn't involve labels. + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): + inputs = self._prepare_inputs(inputs) + with torch.no_grad(): + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + loss = loss.mean().detach() + return loss, None, None + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + if self.accelerator.is_main_process and self.log_completions: + if is_rich_available(): + print_prompt_completions_sample( + self._logs["prompt"], + self._logs["completion"], + self._logs["rewards"], + self._logs["advantages"], + self.state.global_step, + self.num_completions_to_print, + ) + + logging_backends = [] + if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: + logging_backends.append(wandb) + if self.args.report_to and "trackio" in self.args.report_to: + logging_backends.append(trackio) + + table = { + "step": [str(self.state.global_step)] * len(self._logs["prompt"]), + "prompt": self._logs["prompt"], + "completion": self._logs["completion"], + **self._logs["rewards"], + "advantage": self._logs["advantages"], + } + + df_base = pd.DataFrame(table) + images_raw = self._logs["images"] or [] + + for logging_backend in logging_backends: + if images_raw: + images = [] + for image_list in self._logs["images"]: + images.append([logging_backend.Image(image) for image in image_list]) + df = pd.concat( + [df_base, pd.Series(images, name="image")], + axis=1, + copy=False, + ) + else: + df = df_base + + if self.log_unique_prompts: + df = df.drop_duplicates(subset=["prompt"]) + + logging_backend.log({"completions": logging_backend.Table(dataframe=df)}) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/ICL/RL/trl_source/trl/trainer/sft_config.py b/ICL/RL/trl_source/trl/trainer/sft_config.py new file mode 100644 index 0000000000000000000000000000000000000000..a0cdce445228daae1622d1a018c7aa81cd8c631d --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/sft_config.py @@ -0,0 +1,282 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from transformers import TrainingArguments + + +@dataclass +class SFTConfig(TrainingArguments): + r""" + Configuration class for the [`SFTTrainer`]. + + This class includes only the parameters that are specific to SFT training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model + + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`SFTTrainer`] is provided as a string. If you're training a MoE architecture and want to + include the load balancing/auxiliary loss as a part of the final loss, remember to set + `output_router_logits=True` in this dictionary. + chat_template_path (`str`, *optional*): + If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory + or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must + ensure that any special tokens referenced in the template are added to the tokenizer and that the model's + embedding layer is resized accordingly. + + > Parameters that control the data preprocessing + + dataset_text_field (`str`, *optional*, defaults to `"text"`): + Name of the column that contains text data in the dataset. + dataset_kwargs (`dict[str, Any]`, *optional*): + Dictionary of optional keyword arguments for the dataset preparation. The only supported key is + `skip_prepare_dataset`. When the model is a VLM, `skip_prepare_dataset` is automatically treated as `True` + regardless of the provided value, since preprocessing is done on the fly. + dataset_num_proc (`int`, *optional*): + Number of processes to use for processing the dataset. + eos_token (`str`, *optional*): + Token used to indicate the end of a turn or sequence. If `None`, it defaults to + `processing_class.eos_token`. + pad_token (`str`, *optional*): + Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, + it falls back to `processing_class.eos_token`. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. + If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. + shuffle_dataset (`bool`, *optional*, defaults to `False`): + Whether to shuffle the dataset. + packing (`bool`, *optional*, defaults to `False`): + Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce + padding. Uses `max_length` to define sequence length. + packing_strategy (`str`, *optional*, defaults to `"bfd"`): + Strategy for packing sequences. Can be `"bfd"` (best-fit decreasing, truncates overflow), `"bfd-requeue"` + (best-fit decreasing, re-queues overflow tokens), or `"wrapped"` (aggressive, cuts mid-sequence). + padding_free (`bool`, *optional*, defaults to `False`): + Whether to perform forward passes without padding by flattening all sequences in the batch into a single + continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only + supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch structure. When + packing is enabled with strategy `"bfd"`, padding-free is enabled, regardless of the value of this + parameter. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + eval_packing (`bool`, *optional*): + Whether to pack the eval dataset. If `None`, uses the same value as `packing`. + + > Parameters that control the training + + completion_only_loss (`bool`, *optional*): + Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is computed + only on the completion, which is supported only for [prompt-completion](#prompt-completion) datasets. If + `False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset: + loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on the full + sequence for [language modeling](#language-modeling) datasets. + assistant_only_loss (`bool`, *optional*, defaults to `False`): + Whether to compute loss only on the assistant part of the sequence. If set to `True`, loss is computed only + on the assistant responses, which is supported only for [conversational](#conversational) datasets. If + `False`, loss is computed on the entire sequence. + loss_type (`str`, *optional*, defaults to `"nll"`): + Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` (Dynamic + Fine-Tuning, as described in [this paper](https://huggingface.co/papers/2508.05629)). + activation_offloading (`bool`, *optional*, defaults to `False`): + Whether to offload the activations to the CPU. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=2e-5, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue + # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary + # workaround here, which can be removed once we drop support for versions older than 4.57.5. + lr_scheduler_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " + "restarts." + }, + ) + + # Parameters that control the model + model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `SFTTrainer` is provided as a string. If you're training a MoE architecture and want to include the " + "load balancing/auxiliary loss as a part of the final loss, remember to set `output_router_logits=True` " + "in this dictionary." + }, + ) + chat_template_path: str | None = field( + default=None, + metadata={ + "help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local " + "directory or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, " + "you must ensure that any special tokens referenced in the template are added to the tokenizer and " + "that the model's embedding layer is resized accordingly." + }, + ) + + # Parameters that control the data preprocessing + dataset_text_field: str = field( + default="text", + metadata={"help": "Name of the column that contains text data in the dataset."}, + ) + dataset_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is " + "`skip_prepare_dataset`. If the model is a VLM, `skip_prepare_dataset` value is ignored. When the model " + "is a VLM, `skip_prepare_dataset` is automatically treated as `True` regardless of the provided value, " + "since preprocessing is done on the fly." + }, + ) + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + eos_token: str | None = field( + default=None, + metadata={ + "help": "Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`." + }, + ) + pad_token: str | None = field( + default=None, + metadata={ + "help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that " + "is also `None`, it falls back to `processing_class.eos_token`." + }, + ) + max_length: int | None = field( + default=1024, + metadata={ + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from " + "the right. If `None`, no truncation is applied. When packing is enabled, this value sets the " + "sequence length." + }, + ) + shuffle_dataset: bool = field( + default=False, + metadata={"help": "Whether to shuffle the dataset."}, + ) + packing: bool = field( + default=False, + metadata={ + "help": "Whether to group multiple sequences into fixed-length blocks to improve computational efficiency " + "and reduce padding. Uses `max_length` to define sequence length." + }, + ) + packing_strategy: str = field( + default="bfd", + metadata={ + "help": "Strategy for packing sequences. Can be `'bfd'` (best-fit decreasing, truncates overflow), " + "`'bfd-requeue'` (best-fit decreasing, re-queues overflow tokens), or `'wrapped'` (aggressive, cuts " + "mid-sequence).", + "choices": ["bfd", "bfd-requeue", "wrapped"], + }, + ) + padding_free: bool = field( + default=False, + metadata={ + "help": "Whether to perform forward passes without padding by flattening all sequences in the batch into " + "a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this " + "is only supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch " + "structure. When packing is enabled with strategy `'bfd'`, padding-free is enabled, regardless of the " + "value of this parameter." + }, + ) + pad_to_multiple_of: int | None = field( + default=None, + metadata={"help": "If set, the sequences will be padded to a multiple of this value."}, + ) + eval_packing: bool | None = field( + default=None, + metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."}, + ) + + # Parameters that control the training + completion_only_loss: bool | None = field( + default=None, + metadata={ + "help": ( + "Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is " + "computed only on the completion, which is supported only for prompt-completion datasets. If `False`, " + "loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset: " + "loss is computed on the completion for prompt-completion datasets, and on the full sequence for " + "language modeling datasets." + ) + }, + ) + assistant_only_loss: bool = field( + default=False, + metadata={ + "help": ( + "Whether to compute loss only on the assistant part of the sequence. If set to `True`, loss is " + "computed only on the assistant responses, which is supported only for conversational datasets. If `False`, " + "loss is computed on the entire sequence." + ) + }, + ) + loss_type: str = field( + default="nll", + metadata={ + "help": ( + 'Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` ' + "(Dynamic Fine-Tuning, as described in https://huggingface.co/papers/2508.05629)." + ) + }, + ) + activation_offloading: bool = field( + default=False, + metadata={"help": "Whether to offload the activations to the CPU."}, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + super().__post_init__() diff --git a/ICL/RL/trl_source/trl/trainer/sft_trainer.py b/ICL/RL/trl_source/trl/trainer/sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..bc53510dc10431ecfb217cf47b4487d82d87ea73 --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/sft_trainer.py @@ -0,0 +1,1335 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import os +import warnings +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +import torch.nn as nn +import transformers +from accelerate import PartialState +from accelerate.logging import get_logger +from accelerate.utils import is_peft_model +from datasets import Dataset, IterableDataset +from packaging.version import Version +from transformers import ( + AutoProcessor, + DataCollator, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainingArguments, +) +from transformers.data.data_collator import DataCollatorMixin +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_peft_available + +from ..chat_template_utils import clone_chat_template +from ..data_utils import ( + apply_chat_template, + is_conversational, + is_conversational_from_value, + maybe_convert_to_chatml, + pack_dataset, + prepare_multimodal_messages, + truncate_dataset, +) +from ..models import get_act_offloading_ctx_manager +from .base_trainer import BaseTrainer +from .sft_config import SFTConfig +from .utils import ( + create_model_from_path, + entropy_from_logits, + flush_left, + get_config_model_id, + pad, + remove_none_values, + selective_log_softmax, +) + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, PeftType, get_peft_model + + +logger = get_logger(__name__) + + +FLASH_ATTENTION_VARIANTS = { + "flash_attention_2", + "flash_attention_3", + "kernels-community/flash-attn2", + "kernels-community/flash-attn3", + "kernels-community/vllm-flash-attn3", +} + + +def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list[str]: + return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names + + +@dataclass +class DataCollatorForLanguageModeling(DataCollatorMixin): + """ + Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch. + + This collator expects each example in the input list to be a dictionary containing at least the `"input_ids"` key. + If the input contains a `"completion_mask"`, it is used to set the labels to `-100` for tokens that are not in the + completion. If `"assistant_masks"` are present, they are used to set the labels to `-100` for tokens that are not + in the assistant part of the sequence. The collator returns a dictionary containing the following keys: + - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. + - `"labels"`: Tensor of labels, padded to the maximum length of the batch. If `completion_only_loss` is set to + `True`, tokens that are not in the completion are set to -100. If `assistant_masks` are present, tokens that are + not in the assistant part of the sequence are set to -100. If `padding_free` is set to `False`, the following key + is also returned: + - `"attention_mask"`: Tensor of attention masks, padded to the maximum length of the batch. + If `padding_free` is set to `True`, the following key is also returned: + - `"position_ids"`: Tensor of position IDs, padded to the maximum length of the batch. + + Args: + pad_token_id (`int`): + Token ID to use for padding. + completion_only_loss (`bool`, *optional*, defaults to `True`): + When the input contains a completion mask (`completion_mask`), the labels are set to -100 for the tokens + that are no in the completion. + padding_free (`bool`, *optional*, defaults to `False`): + If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be + generated accordingly and returned instead of the attention mask. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + return_tensors (`str`, *optional*, defaults to `"pt"`): + Type of Tensor to return. Only `"pt"` is currently supported. + + Examples: + ```python + >>> from trl.trainer.sft_trainer import DataCollatorForLanguageModeling + + >>> collator = DataCollatorForLanguageModeling(pad_token_id=0) + >>> examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3], + [ 4, 5, 0]]), + 'attention_mask': tensor([[ 1, 1, 1], + [ 1, 1, 0]]), + 'labels': tensor([[ 1, 2, 3], + [ 4, 5, -100]])} + + >>> # With completion mask + >>> examples = [ + ... {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, + ... {"input_ids": [4, 5], "completion_mask": [0, 1]}, + ... ] + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3], + [ 4, 5, 0]]), + 'attention_mask': tensor([[ 1, 1, 1], + [ 1, 1, 0]]), + 'labels': tensor([[-100, 2, 3], + [-100, 5, -100]])} + + >>> # With padding_free + >>> collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3, 4, 5]]), + 'position_ids': tensor([[0, 1, 2, 0, 1]]), + 'labels': tensor([[1, 2, 3, 4, 5]])} + ``` + """ + + pad_token_id: int + completion_only_loss: bool = True + padding_free: bool = False + pad_to_multiple_of: int | None = None + return_tensors: str = "pt" + + def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: + # Convert to tensor + input_ids = [torch.tensor(example["input_ids"]) for example in examples] + if "labels" in examples[0]: + labels = [torch.tensor(example["labels"]) for example in examples] + else: + labels = [torch.tensor(example["input_ids"]) for example in examples] + + # For padding-free, we should NOT create attention_mask as it causes FlashAttention to ignore position_ids and + # compute wrong cu_seq_lens from the all-1s mask + if self.padding_free: + if "seq_lengths" in examples[0]: + position_ids = self.get_position_ids_from_packed_seq_lengths( + [example["seq_lengths"] for example in examples] + ) + else: + position_ids = [torch.arange(len(ids)) for ids in input_ids] + else: + attention_mask = [torch.ones_like(ids) for ids in input_ids] + if self.completion_only_loss and "completion_mask" in examples[0]: + completion_mask = [torch.tensor(example["completion_mask"]) for example in examples] + if "assistant_masks" in examples[0]: + assistant_masks = [torch.tensor(example["assistant_masks"]) for example in examples] + + # If padding_free, flatten everything into a single sequence + output = {} + if self.padding_free: + input_ids = [torch.cat(input_ids, dim=0)] + labels = [torch.cat(labels, dim=0)] + position_ids = [torch.cat(position_ids, dim=0)] + if self.completion_only_loss and "completion_mask" in examples[0]: + completion_mask = [torch.cat(completion_mask, dim=0)] + if "assistant_masks" in examples[0]: + assistant_masks = [torch.cat(assistant_masks, dim=0)] + + # Pad + output["input_ids"] = pad( + input_ids, + padding_value=self.pad_token_id, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + output["labels"] = pad( + labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + if self.padding_free: + output["position_ids"] = pad( + position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["labels"][output["position_ids"] == 0] = -100 + else: + output["attention_mask"] = pad( + attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + if self.completion_only_loss and "completion_mask" in examples[0]: + completion_mask = pad( + completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion + if "assistant_masks" in examples[0]: + assistant_masks = pad( + assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["labels"][assistant_masks == 0] = -100 + return output + + @staticmethod + def get_position_ids_from_packed_seq_lengths(batch_seq_lengths: list[list[int]]) -> list[torch.Tensor]: + """ + Get position IDs for packed sequences. + + Args: + batch_seq_lengths (`list[list[int]]`): + A list of lists containing the lengths of each individual document in the packed batch. + + Return: + `list[torch.Tensor]`: + A list of tensors containing the position IDs for each packed sequence. + """ + # Get lengths per row + example_lengths = [sum(seq_lengths) for seq_lengths in batch_seq_lengths] + # Flat list of lengths + batch_seq_lengths = torch.tensor( + [seq_length for seq_lengths in batch_seq_lengths for seq_length in seq_lengths] + ) + position_ids = torch.ones(sum(example_lengths), dtype=batch_seq_lengths.dtype) + position_ids[0] = 0 + # Reset position ids to 0 at the start of each sequence + position_ids[batch_seq_lengths[:-1].cumsum(0)] = -(batch_seq_lengths[:-1] - 1) + position_ids = position_ids.cumsum(0) + # Split back into one tensor per example + return list(position_ids.split(example_lengths)) + + +@dataclass +class DataCollatorForVisionLanguageModeling(DataCollatorMixin): + """ + Data collator for vision-language modeling tasks. + + Unlike text-only datasetsโ€”where the collator typically receives pre-tokenized inputs ready for batching, + vision-language data processing involves converting images into pixel values. This conversion is disk-intensive, + making upfront preprocessing of the entire dataset impractical. Therefore, this collator performs tokenization and + image processing on-the-fly to efficiently prepare batches. + + Each input example should be a dictionary containing at least: + - An `"images"` key holding the image data. + - [language modeling](#language-modeling) type: either a `"messages"` key for conversational inputs or a `"text"` + key for standard text inputs. + - [prompt-completion](#prompt-completion) type: keys `"prompt"` and `"completion"` for the prompt and completion. + + The collator outputs a dictionary including: + - `"input_ids"`: Tensor of token IDs. + - `"attention_mask"`: Tensor indicating attention mask. + - `"pixel_values"`: Tensor representing image pixel values. + - `"labels"`: Tensor for training labels. + + Additional keys may be present depending on the processor, such as `"image_grid_thw"`. + + Args: + processor ([`~transformers.ProcessorMixin`]): + The processor used to tokenize text and process images. It must be a subclass of + [`~transformers.ProcessorMixin`] and include a `tokenizer` with a defined `pad_token_id`. + max_length (`int` or `None`, optional, defaults to `None`): + Maximum sequence length for input tokens. If `None`, no truncation is applied. + completion_only_loss (`bool`, *optional*, defaults to `False`): + Whether to compute loss only on the completion part of the sequence. When `True`, the labels for the prompt + part are set to -100. It requires the dataset type to be prompt-completion. + pad_to_multiple_of (`int` or `None`, optional, defaults to `None`): + If set, the sequences will be padded to a multiple of this value. + dataset_text_field (`str`, optional, defaults to `"text"`): + Name of the column that contains text data in the dataset. This parameter is only relevant for [standard + datasets format](dataset_formats#standard). + return_tensors (`str`, optional, defaults to `"pt"`): + The tensor type to return. Currently, only `"pt"` (PyTorch tensors) is supported. + + Example: + ```python + >>> from trl.trainer.sft_trainer import DataCollatorForVisionLanguageModeling + >>> from transformers import AutoProcessor + + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> collator = DataCollatorForVisionLanguageModeling(processor) + >>> examples = [ + ... {"images": [Image.open("image_0.png")], "messages": [{"role": "user", "content": "What is this?"}]}, + ... {"images": [Image.open("image_1.png")], "messages": [{"role": "user", "content": "Describe this image."}]}, + ... ] + >>> collator(examples) + {'input_ids': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374, + 419, 30, 151645, 198], + [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419, + 2168, 13, 151645, 198]]), + 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), + 'pixel_values': tensor([[-0.9893, 0.1785, 1.5362, ..., -0.0582, 0.8661, -0.2431], + [-0.2302, 0.9522, -1.1061, ..., 0.0555, 1.3354, -0.6412], + [ 1.2150, 0.9084, 0.7041, ..., 0.2404, -0.8403, -0.5133], + ..., + [ 0.6895, 0.2807, 0.2515, ..., -0.2004, -1.2100, 0.0555], + [ 0.8209, -0.9748, 1.5654, ..., 1.6055, -0.4706, 0.5817], + [-1.0915, 0.4559, 0.9230, ..., 0.5106, 0.0982, -0.1720]]), + 'image_grid_thw': tensor([[1, 4, 4], + [1, 4, 4]]), + 'labels': tensor([[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 3838, 374, + 419, 30, 151645, 198], + [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, + 151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151653, 74785, 419, + 2168, 13, 151645, 198]])} + ``` + """ + + processor: ProcessorMixin + max_length: int | None = None + completion_only_loss: bool = False # default not used in practice; SFTTrainer always passes the relevant value + pad_to_multiple_of: int | None = None + dataset_text_field: str = "text" + return_tensors: str = "pt" + + def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: + if "messages" in examples[0] or self.dataset_text_field in examples[0]: + if self.completion_only_loss: + raise ValueError( + "The `completion_only_loss` argument is not supported for language modeling datasets." + ) + return self._collate_language_modeling(examples) + elif "prompt" in examples[0] and "completion" in examples[0]: + return self._collate_prompt_completion(examples) + else: + raise KeyError(f"Unexpected input keys in examples: {list(examples[0].keys())}.") + + def _collate_language_modeling(self, examples: list[dict[str, Any]]) -> dict[str, Any]: + images = [example["images"] for example in examples] + # Transformers requires at least one image in the batch, otherwise it throws an error + if all(img_list == [] for img_list in images): + images = None + + if "messages" in examples[0]: # conversational case + messages = [prepare_multimodal_messages(example["messages"], example["images"]) for example in examples] + texts = self.processor.apply_chat_template(messages) + elif self.dataset_text_field in examples[0]: # standard case + texts = [example[self.dataset_text_field] for example in examples] + else: + raise KeyError( + "The input examples must contain either 'messages' for conversational data or 'text' for standard " + "data." + ) + + output = self.processor( + images=images, + text=texts, + padding=True, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + truncation=self.max_length is not None, + max_length=self.max_length, + return_tensors=self.return_tensors, + add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens + ) + labels = output["input_ids"].clone() + labels[output["attention_mask"] == 0] = -100 + # We mask only padding tokens (-100) in the labels. Vision tokens are left unchanged because their handling in + # loss computation has to be done by the model, and masking them here would be infeasible in practice as vision + # token definitions vary across architectures. + output["labels"] = labels + return output + + def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str, Any]: + if self.pad_to_multiple_of is not None: + raise NotImplementedError( + "Padding to a multiple of a value is not yet implemented for vision-language modeling and " + "prompt-completion data yet." + ) + images = [example["images"] for example in examples] + # Transformers requires at least one image in the batch, otherwise it throws an error + if all(img_list == [] for img_list in images): + images = None + if is_conversational(examples[0]): # conversational case + for example in examples: + example["prompt"] = prepare_multimodal_messages(example["prompt"], images=example["images"]) + example["completion"] = prepare_multimodal_messages(example["completion"], images=[]) + examples = [apply_chat_template(example, self.processor) for example in examples] + + prompts = [example["prompt"] for example in examples] + completions = [example["completion"] for example in examples] + + processed_prompts = self.processor( + images=images, + text=prompts, + padding=True, + padding_side="left", + return_tensors=self.return_tensors, + add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens + ) + processed_completions = self.processor( + text=completions, + padding=True, + padding_side="right", + return_tensors=self.return_tensors, + add_special_tokens=False, # to avoid adding the BOS, twice see https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens + ) + + # Concatenate prompts and completions + prompt_ids, prompt_mask = processed_prompts["input_ids"], processed_prompts["attention_mask"] + completion_ids, completion_mask = processed_completions["input_ids"], processed_completions["attention_mask"] + input_ids = torch.cat((prompt_ids, completion_ids), dim=1) + attention_mask = torch.cat((prompt_mask, completion_mask), dim=1) + completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1) + if "token_type_ids" in processed_prompts: # special case for Gemma + prompt_token_type_ids = processed_prompts["token_type_ids"] + completion_token_type_ids = processed_completions["token_type_ids"] + token_type_ids = torch.cat((prompt_token_type_ids, completion_token_type_ids), dim=1) + + # Flush left to reduce padding + if "token_type_ids" in processed_prompts: + attention_mask, input_ids, completion_mask, token_type_ids = flush_left( + attention_mask, input_ids, completion_mask, token_type_ids + ) + else: + attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask) + + # Truncate if necessary + if self.max_length is not None: + input_ids = input_ids[:, : self.max_length] + attention_mask = attention_mask[:, : self.max_length] + completion_mask = completion_mask[:, : self.max_length] + if "token_type_ids" in processed_prompts: + token_type_ids = token_type_ids[:, : self.max_length] + + # Create labels and mask padding tokens + labels = input_ids.clone() + labels[attention_mask == 0] = -100 + if self.completion_only_loss: + labels[completion_mask == 0] = -100 + + # Build the output dictionary + output = processed_prompts # we take processed_prompts because it contains the images + output["input_ids"] = input_ids + output["attention_mask"] = attention_mask + output["labels"] = labels + if "token_type_ids" in processed_prompts: + output["token_type_ids"] = token_type_ids + return output + + +def dft_loss(outputs, labels, num_items_in_batch=None): + """ + DFT loss function, as presented in [On the Generalization of SFT: A Reinforcement Learning Perspective with Reward + Rectification](https://huggingface.co/papers/2508.05629) + """ + labels = nn.functional.pad(labels, (0, 1), value=-100) + shift_labels = labels[..., 1:].contiguous() + loss_mask = shift_labels != -100 + shift_labels[~loss_mask] = 0 + logprobs = selective_log_softmax(outputs.logits, shift_labels) + per_token_loss = -logprobs.exp().detach() * logprobs + if num_items_in_batch is None: + num_items_in_batch = loss_mask.sum() + loss = (per_token_loss * loss_mask).sum() / num_items_in_batch + return loss + + +class SFTTrainer(BaseTrainer): + """ + Trainer for Supervised Fine-Tuning (SFT) method. + + This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods. + + Example: + + ```python + from trl import SFTTrainer + from datasets import load_dataset + + dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") + + trainer = SFTTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + train_dataset=dataset, + ) + trainer.train() + ``` + + Args: + model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using `.from_pretrained` (where `` is derived from the model + config) with the keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + - A [`~peft.PeftModel`] object. Only causal language models are supported. + If you're training a model with an MoE architecture and want to include the load balancing/auxiliary loss + as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`. + args ([`SFTConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. + data_collator ([`~transformers.DataCollator`], *optional*): + Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. + Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model + and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. This trainer supports both [language modeling](#language-modeling) type and + [prompt-completion](#prompt-completion) type. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + + The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field. + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If `None`, the processing class is loaded from the model's name + with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set. + If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default. + compute_loss_func (`Callable`, *optional*): + A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated + batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss + function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) + used by [`Trainer`]. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a + [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing + [`SFTConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a boolean + `compute_result` argument. This will be triggered after the last eval batch to signal that the function + needs to calculate and return the global summary statistics rather than accumulating the batch-level + statistics. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your + model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*): + A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in + `args`. Incompatible with the `optimizers` argument. + + Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before + initializing the Trainer. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + formatting_func (`Callable`, *optional*): + Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly + converts the dataset into a [language modeling](#language-modeling) type. + """ + + _tag_names = ["trl", "sft"] + _name = "SFT" + + def __init__( + self, + model: "str | PreTrainedModel | PeftModel", + args: SFTConfig | TrainingArguments | None = None, + data_collator: DataCollator | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, + compute_loss_func: Callable | None = None, + compute_metrics: Callable[[EvalPrediction], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None, + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: "PeftConfig | None" = None, + formatting_func: Callable[[dict], str] | None = None, + ): + # Args + if args is None: + model_name = model if isinstance(model, str) else get_config_model_id(model.config) + model_name = model_name.split("/")[-1] + args = SFTConfig(f"{model_name}-SFT") + elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig): + dict_args = args.to_dict() + dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token + if Version(transformers.__version__) < Version("5.0.0"): + dict_args.pop("push_to_hub_token") + args = SFTConfig(**dict_args) + + # IterableDataset requires dispatch_batches=False because Accelerate's dispatch mode may try to concatenate + # batches from multiple processes, leading to mismatch errors. + if isinstance(train_dataset, IterableDataset): + if args.accelerator_config.dispatch_batches is True: + logger.warning( + "You are using an `IterableDataset` for training with `dispatch_batches=True`. `dispatch_batches` " + "is forced to `False` when using an `IterableDataset`. To remove this warning, unset " + "`dispatch_batches` in `SFTConfig` or set it to `False`." + ) + args.accelerator_config.dispatch_batches = False + + # Model + if isinstance(model, str): + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) + else: + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `SFTConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + + # Processing class + if processing_class is None: + processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config)) + + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + self._is_vlm = True + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + self._is_vlm = False + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + if args.eos_token is not None: + eos_token = args.eos_token + eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) + if eos_token_id is None: + raise ValueError( + f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " + "in the vocabulary before using it as an EOS token." + ) + tokenizer.eos_token_id = eos_token_id + + if args.chat_template_path is not None: + if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): + with open(args.chat_template_path, encoding="utf-8") as chat_template_file: + processing_class.chat_template = chat_template_file.read() + added_tokens = [] + else: + model, processing_class, added_tokens = clone_chat_template( + model, processing_class, args.chat_template_path + ) + else: + added_tokens = [] + + # Catch some wrong configurations related to VLMs + if self._is_vlm and args.packing: + raise ValueError( + "Packing is not supported for vision-language models. Please set `packing=False` in the SFTConfig." + ) + if self._is_vlm and args.padding_free: + raise ValueError( + "Padding-free training is yet not supported for vision-language models. Please set " + "`padding_free=False` in the `SFTConfig`." + ) + if self._is_vlm and args.assistant_only_loss: + raise ValueError( + "Assistant-only loss is not yet supported for vision-language models. Please set " + "`assistant_only_loss=False` in the `SFTConfig`." + ) + + # PEFT configuration and model wrapping + if peft_config is not None: + if added_tokens: + # Ensure that the added tokens are trainable + if peft_config.trainable_token_indices is None: + peft_config.trainable_token_indices = {"embed_tokens": added_tokens} + elif "embed_tokens" not in peft_config.trainable_token_indices: + peft_config.trainable_token_indices["embed_tokens"] = added_tokens + else: + peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens) + + # Ensure that the lm_head is trainable + if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save: + logger.warning( + "Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's " + "`modules_to_save`. As a result, the model may not learn to generate outputs with these new " + "tokens, leading to degraded generation quality. To fix this, add " + "`modules_to_save=['lm_head']` to your PEFT configuration." + ) + + if peft_config.modules_to_save is None: + peft_config.modules_to_save = ["lm_head"] + else: + peft_config.modules_to_save.append("lm_head") + + if is_peft_available() and is_peft_model(model) and peft_config is not None: + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) + + # Create PEFT model + if peft_config is not None: + model = get_peft_model(model, peft_config) + + # PEFT + DeepSpeed ZeRO-3 requires reentrant checkpointing. For more details, see + # https://github.com/huggingface/trl/issues/2514#issuecomment-2692152703 + if ( + is_peft_model(model) + and args.deepspeed_plugin is not None + and args.deepspeed_plugin.zero_stage == 3 + and args.gradient_checkpointing + ): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + use_reentrant = args.gradient_checkpointing_kwargs.get("use_reentrant") + if use_reentrant is False: + logger.warning( + "You are using PEFT with DeepSpeed ZeRO-3 and gradient checkpointing with `use_reentrant=False`. " + "`use_reentrant` is forced to `True` in this configuration to ensure correct training. To remove " + "this warning, unset `use_reentrant` in `gradient_checkpointing_kwargs` or set it to `True`." + ) + args.gradient_checkpointing_kwargs["use_reentrant"] = True + + # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally + # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 + if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing: + model.enable_input_require_grads() + + # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the + # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by + # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for + # quantized models. See: https://github.com/huggingface/peft/issues/2889 + # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do + if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): + for param in model.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.bfloat16) + + # In Prompt Tuning a small set of trainable virtual tokens (continuous prompt embeddings) is prepended to the + # input. We store the number of these tokens so we can account for them correctly when calculating accuracy. + self.num_virtual_tokens = 0 + if is_peft_available() and is_peft_model(model): + if model.active_adapter in model.peft_config: + peft_model_config = model.peft_config[model.active_adapter] + self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) + + # Data collator + # BFD packing requires padding-free mode; otherwise, the collator outputs padded attention masks, causing + # FlashAttention to ignore position_ids and recompute them incorrectly from the padded attention mask. + self.padding_free = args.padding_free or (args.packing and args.packing_strategy == "bfd") + use_flash_attention = model.config._attn_implementation in FLASH_ATTENTION_VARIANTS + if self.padding_free: + if data_collator is not None: + raise ValueError("Passing a custom data collator is not supported when using padding-free.") + if args.packing and args.packing_strategy == "wrapped": + logger.warning( + "You are passing `padding_free=True` with the 'wrapped' packing strategy, which is not " + "recommended. Please refer to the documentation to understand why this is not recommended." + ) + if not use_flash_attention: + logger.warning( + "Padding-free training is enabled, but the attention implementation is not set to a supported " + "flash attention variant. Padding-free training flattens batches into a single sequence, and only " + "the following implementations are known to reliably support this: " + f"{', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. Using other implementations may lead to " + "unexpected behavior. To ensure compatibility, set `attn_implementation` in the model " + "configuration to one of these supported options or verify that your attention mechanism can " + "handle flattened sequences." + ) + + if args.per_device_train_batch_size == 1 and not args.packing: + logger.warning( + "You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size " + "of 1 annihilate the benefits of padding-free training. Please consider increasing the batch size " + "to at least 2." + ) + + # Decide whether to use completion-only loss: if not specified, then it is set to True if the dataset format + # is prompt-completion, and False if the dataset format is language modeling. + dataset_sample = next(iter(train_dataset)) + if args.completion_only_loss is None: + self.completion_only_loss = "prompt" in dataset_sample and "completion" in dataset_sample + else: + self.completion_only_loss = args.completion_only_loss + + self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample + if self._is_vision_dataset and not self._is_vlm: + raise ValueError( + "The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided " + "model does not seem to be a vision-language model. Please check your model and dataset." + ) + + if data_collator is None and not self._is_vision_dataset: + # Get the pad token: if not provided, use the one from the processing class or the eos token + # if the processing class does not have a pad token. + pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token + pad_token_id = tokenizer.convert_tokens_to_ids(pad_token) + if pad_token_id is None: + raise ValueError( + f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " + "in the vocabulary before using it as a padding token." + ) + data_collator = DataCollatorForLanguageModeling( + pad_token_id=pad_token_id, + completion_only_loss=self.completion_only_loss, + padding_free=self.padding_free, + pad_to_multiple_of=args.pad_to_multiple_of, + ) + elif data_collator is None and self._is_vision_dataset: + data_collator = DataCollatorForVisionLanguageModeling( + processor=processing_class, + max_length=args.max_length, + completion_only_loss=self.completion_only_loss, + pad_to_multiple_of=args.pad_to_multiple_of, + dataset_text_field=args.dataset_text_field, + ) + + if args.packing and args.packing_strategy == "bfd" and not use_flash_attention: + logger.warning( + "You are using packing, but the attention implementation is not set to a supported flash attention " + "variant. Packing gathers multiple samples into a single sequence, and only the following " + f"implementations are known to reliably support this: {', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. " + "Using other implementations may lead to cross-contamination between samples. To avoid this, either " + "disable packing by setting `packing=False`, or set `attn_implementation` in the model configuration " + "to one of these supported options." + ) + if args.assistant_only_loss and not is_conversational(dataset_sample): + raise ValueError( + "You set `assistant_only_loss=True`, but the dataset is not conversational. This option is only " + "supported for conversational datasets." + ) + + # Dataset + # Skip dataset preparation if `skip_prepare_dataset=True` in `dataset_kwargs`, or if it's a VLM, where + # preprocessing (e.g., image-to-pixel conversion) is too costly and done on the fly instead. + skip_prepare_dataset = ( + args.dataset_kwargs is not None + and args.dataset_kwargs.get("skip_prepare_dataset", False) + or self._is_vision_dataset + ) + if not skip_prepare_dataset: + if self.completion_only_loss and formatting_func: + raise ValueError( + "A formatting function was provided while `completion_only_loss=True`, which is incompatible. " + "Using a formatter converts the dataset to a language modeling type, conflicting with " + "completion-only loss. To resolve this, apply your formatting function before passing the " + "dataset, or disable `completion_only_loss` in `SFTConfig`." + ) + train_dataset = self._prepare_dataset( + train_dataset, processing_class, args, args.packing, formatting_func, "train" + ) + if eval_dataset is not None: + packing = args.packing if args.eval_packing is None else args.eval_packing + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset( + eval_dataset, processing_class, args, packing, formatting_func, "eval" + ) + + # Loss function + if not args.use_liger_kernel: # liger supports dft loss by just passing use_token_scaling=True + if args.loss_type == "nll": + pass # use the default loss + elif args.loss_type == "dft": + if compute_loss_func is not None: + raise ValueError( + "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. " + "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so " + "passing a `compute_loss_func` is not allowed." + ) + compute_loss_func = dft_loss + else: + raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.") + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_loss_func=compute_loss_func, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Initialize activation offloading context + if self.args.activation_offloading: + self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) + else: + self.maybe_activation_offload_context = contextlib.nullcontext() + + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + + # Add tags to the model + self.model.add_model_tags(self._tag_names) + + def _prepare_dataset( + self, + dataset: Dataset | IterableDataset, + processing_class: PreTrainedTokenizerBase | ProcessorMixin, + args: SFTConfig, + packing: bool, + formatting_func: Callable[[dict], str] | None, + dataset_name: str, + ) -> Dataset | IterableDataset: + # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from + # sampled data. + if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform` + dataset = dataset.with_transform(remove_none_values) + + # If the dataset is already preprocessed (tokenized), skip the processing steps. + column_names = get_dataset_column_names(dataset) + is_processed = "input_ids" in column_names + + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc + + with PartialState().main_process_first(): + # Apply the formatting function if any + if formatting_func is not None and is_processed: + logger.warning( + "You passed a dataset that is already processed (contains an `input_ids` field) together with a " + "formatting function. Therefore `formatting_func` will be ignored. Either remove the " + "`formatting_func` or pass a dataset that is not already processed.", + ) + + if formatting_func is not None and not is_processed: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset" + + def _func(example): + return {"text": formatting_func(example)} + + dataset = dataset.map(_func, batched=False, **map_kwargs) + + if not is_processed: + # Convert the dataset to ChatML if needed + first_example = next(iter(dataset)) + if is_conversational_from_value(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML" + column_names = get_dataset_column_names(dataset) + dataset = dataset.map( + maybe_convert_to_chatml, + remove_columns="conversations" if "conversations" in column_names else None, + **map_kwargs, + ) + + # Apply the chat template if needed + first_example = next(iter(dataset)) + if not is_conversational(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" + + def add_eos(example, eos_token): + if "text" in example and not example["text"].endswith(eos_token): # language modeling case + example["text"] = example["text"] + eos_token + elif "completion" in example and not example["completion"].endswith(eos_token): + example["completion"] = example["completion"] + eos_token + return example + + eos_token = processing_class.tokenizer.eos_token if self._is_vlm else processing_class.eos_token + dataset = dataset.map( + add_eos, + fn_kwargs={"eos_token": eos_token}, + remove_columns="messages" if "messages" in column_names else None, # renamed to "text" + **map_kwargs, + ) + + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_loss): + if "prompt" in example: # prompt-completion case + output = {} + if is_conversational(example): + if self._is_vlm: + prompt = prepare_multimodal_messages(example["prompt"], images=[]) + completion = prepare_multimodal_messages(example["completion"], images=[]) + else: + prompt = example["prompt"] + completion = example["completion"] + prompt_ids = processing_class.apply_chat_template( + prompt, + tools=example.get("tools"), + add_generation_prompt=True, + tokenize=True, + return_dict=False, + **example.get("chat_template_kwargs", {}), + ) + # Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists + # even for single examples, while for LLMs it returns lists of ints. + prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids + prompt_completion_processed = processing_class.apply_chat_template( + prompt + completion, + tools=example.get("tools"), + tokenize=True, + return_dict=True, + return_assistant_tokens_mask=assistant_only_loss, + **example.get("chat_template_kwargs", {}), + ) + # Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists + # even for single examples, while for LLMs it returns lists of ints. + prompt_completion_processed = { + k: v[0] if isinstance(v[0], list) else v + for k, v in prompt_completion_processed.items() + } + prompt_completion_ids = prompt_completion_processed["input_ids"] + if "assistant_masks" in prompt_completion_processed: + output["assistant_masks"] = prompt_completion_processed["assistant_masks"] + else: + prompt_ids = processing_class(text=example["prompt"])["input_ids"] + prompt_completion_ids = processing_class(text=example["prompt"] + example["completion"])[ + "input_ids" + ] + # Fix transformers inconsistency: for VLMs, processing_class returns lists of lists + # even for single examples, while for LLMs it returns lists of ints. + prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids + prompt_completion_ids = ( + prompt_completion_ids[0] + if isinstance(prompt_completion_ids[0], list) + else prompt_completion_ids + ) + + # Check if the tokenized prompt starts with the tokenized prompt+completion + if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids: + logger.warning( + "Mismatch between tokenized prompt and the start of tokenized prompt+completion. " + "This may be due to unexpected tokenizer behavior, whitespace issues, or special " + "token handling. Verify that the tokenizer is processing text consistently." + ) + + # Create completion mask + completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids)) + output["input_ids"] = prompt_completion_ids + output["completion_mask"] = completion_mask + + else: # language modeling case + if is_conversational(example): + if self._is_vlm: + messages = prepare_multimodal_messages(example["messages"], images=[]) + else: + messages = example["messages"] + processed = processing_class.apply_chat_template( + messages, + tools=example.get("tools"), + tokenize=True, + return_dict=True, + return_assistant_tokens_mask=assistant_only_loss, + **example.get("chat_template_kwargs", {}), + ) + # Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists + # even for single examples, while for LLMs it returns lists of ints. + processed = {k: v[0] if isinstance(v[0], list) else v for k, v in processed.items()} + output = {k: processed[k] for k in ("input_ids", "assistant_masks") if k in processed} + else: + output = {"input_ids": processing_class(text=example[dataset_text_field])["input_ids"]} + + if "assistant_masks" in output and 1 not in output["assistant_masks"]: + raise RuntimeError( + "You're using `assistant_only_loss=True`, but at least one example has no assistant " + "tokens. This usually means the tokenizer's chat template doesn't generate assistant " + "masks โ€” it may be missing the `{% generation %}` keyword. Please check the template and " + "ensure it's correctly configured to support assistant masking." + ) + return output + + dataset = dataset.map( + tokenize_fn, + fn_kwargs={ + "processing_class": processing_class, + "dataset_text_field": args.dataset_text_field, + "assistant_only_loss": args.assistant_only_loss, + }, + **map_kwargs, + ) + + # Pack or truncate + if packing: + if args.max_length is None: + raise ValueError("When packing is enabled, `max_length` can't be `None`.") + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Packing {dataset_name} dataset" + + columns = ["input_ids"] + if "completion_mask" in get_dataset_column_names(dataset): + columns.append("completion_mask") + if "assistant_masks" in get_dataset_column_names(dataset): + columns.append("assistant_masks") + + dataset = dataset.select_columns(columns) + + # Shuffle the dataset before packing. When using wrapped packing, it's important to shuffle before + # packing as well to avoid correlations between sequences packed together. + if args.shuffle_dataset: + dataset = dataset.shuffle(seed=args.seed) + + # Packing adds new column "seq_lengths" needed for document aware FlashAttention + dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs) + elif args.max_length is not None: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Truncating {dataset_name} dataset" + dataset = truncate_dataset(dataset, args.max_length, map_kwargs) + # For Liger kernel, ensure only the essential columns + if args.use_liger_kernel: + collator_expected_keys = {"input_ids", "seq_lengths", "completion_mask", "assistant_masks"} + column_names = get_dataset_column_names(dataset) + dataset = dataset.select_columns(collator_expected_keys.intersection(column_names)) + + if args.shuffle_dataset: + dataset = dataset.shuffle(seed=args.seed) + + return dataset + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" + # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the + # dataset. So we need to override the default signature columns to include "completion_mask" as well. + if self._signature_columns is None: + if self._is_vision_dataset: + self._signature_columns = ["messages", "prompt", "completion", "images"] + else: + self._signature_columns = ["input_ids", "labels", "seq_lengths", "completion_mask", "assistant_masks"] + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + mode = "train" if self.model.training else "eval" + prediction_loss_only = inputs.pop("_prediction_loss_only", None) + + # Set aside labels as it will be dropped by super().compute_loss() if a custom `compute_loss_func` is used. + # This can be removed when this issue is fixed. + # When using CP or SP, labels are pre-shifted, we must use shift_labels instead. + labels = inputs["labels"] if "shift_labels" not in inputs else None + + # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing + inputs["use_cache"] = False + + # Request token accuracy from Liger kernel and set token scaling if using DFT loss + if self.args.use_liger_kernel: + # Avoid materializing full logits during eval unless explicitly needed. + # By default, liger kernel only skips logits during training (self.training=True). + # When only loss is needed for eval (no compute_metrics), we can safely skip logits. + # prediction_step communicates whether logits are expected via `_prediction_loss_only`; + # this prevents skipping logits during `predict()` where outputs are requested. + # Keep logits when preprocess_logits_for_metrics is set, even if compute_metrics is None. + # to prevent massive vRAM spikes from the lm_head projection. + # See: https://github.com/huggingface/trl/issues/4679 + inputs["skip_logits"] = ( + self.model.training + or self.args.prediction_loss_only + or ( + self.compute_metrics is None + and self.preprocess_logits_for_metrics is None + and prediction_loss_only is not False + ) + ) + inputs["return_token_accuracy"] = True + inputs["use_token_scaling"] = self.args.loss_type == "dft" + + (loss, outputs) = super().compute_loss( + model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch + ) + + # Compute entropy + if not self.args.use_liger_kernel: # liger doesn't return logits + with torch.no_grad(): + per_token_entropy = entropy_from_logits(outputs.logits) + # When using Prompt Tuning, skip the virtual tokens in logits before entropy computation, since they + # do not correspond to actual input tokens. + if ( + self.num_virtual_tokens > 0 + and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING + ): + per_token_entropy = per_token_entropy[:, self.num_virtual_tokens :] + if "attention_mask" in inputs: + attention_mask = inputs["attention_mask"] + entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum() + elif "position_ids" in inputs: + entropy = torch.mean(per_token_entropy) + else: + raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.") + entropy = self.accelerator.gather_for_metrics(entropy).mean().item() + self._metrics[mode]["entropy"].append(entropy) + + if mode == "train": + # When using padding-free, the attention_mask is not present in the inputs, instead we have cu_seq_lens_q, + # cu_seq_lens_k, and max_length_k, max_length_q and position_ids. + if "attention_mask" in inputs: + num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() + elif "position_ids" in inputs: + local_num_tokens = torch.tensor(inputs["position_ids"].size(1), device=inputs["position_ids"].device) + num_tokens_in_batch = self.accelerator.gather_for_metrics(local_num_tokens).sum().item() + else: + raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.") + self._total_train_tokens += num_tokens_in_batch + self._metrics[mode]["num_tokens"] = [self._total_train_tokens] + + if self.args.use_liger_kernel: + if hasattr(outputs, "token_accuracy") and outputs.token_accuracy is not None: + token_accuracy = self.accelerator.gather_for_metrics(outputs.token_accuracy).mean().item() + self._metrics[mode]["mean_token_accuracy"].append(token_accuracy) + else: + # liger-kernel<=0.6.4 can omit token_accuracy even when requested; fixed for Gemma3 in + # https://github.com/linkedin/Liger-Kernel/pull/1010 + warnings.warn( + "liger-kernel did not return token_accuracy when requested. The mean_token_accuracy metric will " + "not be logged. This may indicate an outdated liger-kernel version. Consider upgrading to the " + "latest version. If the issue persists after upgrading, please report it to the liger-kernel " + "repository.", + stacklevel=2, + ) + else: + # Compute accuracy from logits using argmax (traditional method) + with torch.no_grad(): + if "shift_labels" in inputs: + # When using CP or SP, labels are pre-shifted. We must use these (and cannot manually shift) because: + # - The first discarded token from inputs["labels"] actually belongs to process n-1 + # - The last logits require the label from process n+1 + shift_logits = outputs.logits.contiguous() + shift_labels = inputs["shift_labels"] + else: + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Prompt Tuning and P-Tuning output logits for virtual tokens but Prefix-Tuning does not. + if ( + self.num_virtual_tokens > 0 + and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING + ): + shift_logits = shift_logits[:, self.num_virtual_tokens :, :] + + # Get predictions + predictions = shift_logits.argmax(dim=-1) + + # Create mask for non-padding tokens (assuming ignore_index is -100) + mask = shift_labels != -100 + + # Calculate accuracy only on non-padding tokens + correct_predictions = (predictions == shift_labels) & mask + total_tokens = mask.sum() + correct_tokens = correct_predictions.sum() + + # Gather the correct_tokens and total_tokens across all processes + correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) + total_tokens = self.accelerator.gather_for_metrics(total_tokens) + + # Compute the mean token accuracy and log it + total_sum = total_tokens.sum() + accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 + self._metrics[mode]["mean_token_accuracy"].append(accuracy) + + # Log auxiliary loss if enabled (applies to both Liger and non-Liger) + if self.aux_loss_enabled: + aux_loss = outputs.aux_loss + aux_loss = self.accelerator.gather_for_metrics(aux_loss).mean().item() + self._metrics[mode]["aux_loss"].append(aux_loss) + + return (loss, outputs) if return_outputs else loss + + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): + # Preserve the eval loop intent so compute_loss can decide whether logits are needed. + inputs["_prediction_loss_only"] = prediction_loss_only + return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + + # Override training step to add activation offloading context. + def training_step(self, *args, **kwargs): + with self.maybe_activation_offload_context: + return super().training_step(*args, **kwargs) + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/ICL/RL/trl_source/trl/trainer/utils.py b/ICL/RL/trl_source/trl/trainer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..399f5d579d79c5c3700c029c32e2737e21ac3c4d --- /dev/null +++ b/ICL/RL/trl_source/trl/trainer/utils.py @@ -0,0 +1,1336 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import dataclasses +import importlib.resources as pkg_resources +import json +import os +import random +import socket +import threading +from collections.abc import Mapping, Sequence, Sized +from contextlib import contextmanager +from dataclasses import dataclass +from importlib.metadata import version +from itertools import accumulate +from typing import TypeVar + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +import transformers +from accelerate import Accelerator, PartialState, logging +from accelerate.state import AcceleratorState +from huggingface_hub import ModelCard, ModelCardData +from torch.utils.data import Sampler +from transformers import ( + AutoConfig, + BitsAndBytesConfig, + PretrainedConfig, + PreTrainedModel, + is_comet_available, +) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.models.auto.auto_factory import _BaseAutoModelClass +from transformers.utils import ( + is_peft_available, + is_rich_available, + is_torch_mlu_available, + is_torch_npu_available, + is_torch_xpu_available, +) + +from ..trainer.model_config import ModelConfig + + +if is_rich_available(): + from rich.console import Console + from rich.panel import Panel + from rich.table import Table + from rich.text import Text + +if is_comet_available(): + import comet_ml + +if is_peft_available(): + from peft import LoraConfig, PeftConfig, PeftModel + + +logger = logging.get_logger(__name__) + + +def _is_port_free(port: int, host: str = "127.0.0.1") -> bool: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind((host, port)) + return True + except OSError: + return False + + +def _find_free_port() -> int: + candidates = (29500, 23456, 12355, 12345) + for p in candidates: + if _is_port_free(p): + return p + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def ensure_master_addr_port(addr: str | None = None, port: int | None = None) -> None: + """ + Ensure `MASTER_ADDR`/`MASTER_PORT` are set safely. + + - Respects existing environment variables. + - Defaults `MASTER_ADDR` to localhost if unset. + - Chooses a free TCP port if `MASTER_PORT` is unset to avoid collisions. + - If `MASTER_PORT` is set to `"0"` or `"auto"`, it is resolved to a free port. + """ + os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR") or addr or "localhost" + + env_port = os.environ.get("MASTER_PORT", "").strip().lower() + if port is None and env_port not in {"", "0", "auto"}: + try: + port = int(env_port) + except ValueError: + pass + + os.environ["MASTER_PORT"] = str(_find_free_port() if port in (None, 0) else port) + + +def pad( + tensors: list[torch.Tensor], + padding_value: int = 0, + padding_side: str = "right", + pad_to_multiple_of: int | None = None, +) -> torch.Tensor: + """ + Pads a list of tensors to the same shape along the first dimension. + + Args: + tensors (`list[torch.Tensor]`): + List of input tensors to pad. + padding_value (`int`): + Value to use for padding. Default is 0. + padding_side (`str`): + Side on which to add padding. Must be 'left' or 'right'. Default is 'right'. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + Returns: + `torch.Tensor`: + A single tensor containing the padded tensors. + + Examples: + ```python + >>> import torch + + >>> pad([torch.tensor([1, 2, 3]), torch.tensor([4, 5])]) + tensor([[1, 2, 3], + [4, 5, 0]]) + + >>> pad([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]])]) + tensor([[[1, 2], + [3, 4]], + [[5, 6], + [0, 0]]]) + ``` + """ + # Determine the maximum shape for each dimension + output_shape = np.max([t.shape for t in tensors], 0).tolist() + + # Apply pad_to_multiple_of to the first (sequence) dimension + if pad_to_multiple_of is not None: + remainder = output_shape[0] % pad_to_multiple_of + if remainder != 0: + output_shape[0] += pad_to_multiple_of - remainder + + # Create an output tensor filled with the padding value + output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device) + + for i, t in enumerate(tensors): + if padding_side == "left": + seq_start = output_shape[0] - t.shape[0] + elif padding_side == "right": + seq_start = 0 + else: + raise ValueError("padding_side must be 'left' or 'right'") + + # Define the slices + seq_slice = slice(seq_start, seq_start + t.shape[0]) + slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:]) + output[i][slices] = t + + return output + + +@dataclass +class RunningMoments: + """ + Calculates the running mean and standard deviation of a data stream. Reference: + https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75 + """ + + accelerator: Accelerator + mean: float = 0 + std: float = 1 + var: float = 1 + count: float = 1e-24 + + @torch.no_grad() + def update(self, xs: torch.Tensor) -> tuple[float, float]: + """ + Updates running moments from batch's moments computed across ranks + """ + if self.accelerator.use_distributed: + xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs) + else: + xs_count = xs.numel() + xs_var, xs_mean = torch.var_mean(xs, unbiased=False) + xs_mean, xs_var = xs_mean.float(), xs_var.float() + + delta = xs_mean - self.mean + tot_count = self.count + xs_count + + new_sum = xs_var * xs_count + # correct old_sum deviation accounting for the new mean + old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count + tot_sum = old_sum + new_sum + + self.mean += (delta * xs_count / tot_count).item() + new_var = tot_sum / tot_count + self.std = (new_var * tot_count / (tot_count - 1)).float().sqrt().item() + self.var = new_var.item() + self.count = tot_count + + return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item() + + def save_to_json(self, json_path: str): + """Save the content of this instance in JSON format inside `json_path`.""" + # save everything except accelerator + if self.accelerator.is_main_process: + save_dict = dataclasses.asdict(self, dict_factory=lambda x: {k: v for (k, v) in x if k != "accelerator"}) + json_string = json.dumps(save_dict, indent=2, sort_keys=True) + "\n" + with open(json_path, "w", encoding="utf-8") as f: + f.write(json_string) + + @classmethod + def load_from_json(cls, accelerator: Accelerator, json_path: str): + """Create an instance from the content of `json_path`.""" + # load everything except accelerator + with open(json_path, encoding="utf-8") as f: + text = f.read() + return cls(accelerator=accelerator, **json.loads(text)) + + +@torch.no_grad() +def get_global_statistics( + accelerator, xs: torch.Tensor, mask=None, device="cpu" +) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Computes element-wise mean and variance of the tensor across processes. Reference: + https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75 + """ + xs = xs.to(accelerator.device) + sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device) + sum_and_count = accelerator.reduce(sum_and_count) + global_sum, count = sum_and_count + global_mean = global_sum / count + + sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask)) + sum_var = accelerator.reduce(sum_var) + global_var = sum_var / count + + return global_mean.to(device), global_var.to(device), count.item() + + +def pad_to_length(tensor: torch.Tensor, length: int, pad_value: int | float, dim: int = -1) -> torch.Tensor: + if tensor.size(dim) >= length: + return tensor + else: + pad_size = list(tensor.shape) + pad_size[dim] = length - tensor.size(dim) + return torch.cat( + [ + tensor, + pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), + ], + dim=dim, + ) + + +def disable_dropout_in_model(model: torch.nn.Module) -> None: + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0 + + +def get_quantization_config(model_args: ModelConfig) -> BitsAndBytesConfig | None: + if model_args.load_in_4bit: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.dtype, # For consistency with model weights, we use the same value as `dtype` + bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, + bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, + bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage, + ) + elif model_args.load_in_8bit: + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + ) + else: + quantization_config = None + + return quantization_config + + +def get_kbit_device_map() -> dict[str, int] | None: + if torch.cuda.is_available() or is_torch_xpu_available(): + return {"": PartialState().local_process_index} + else: + return None + + +def get_peft_config(model_args: ModelConfig) -> "PeftConfig | None": + if model_args.use_peft is False: + return None + + if not is_peft_available(): + raise ValueError( + "You need to have PEFT library installed in your environment, make sure to install `peft`. " + "Make sure to run `pip install -U peft`." + ) + + peft_config = LoraConfig( + task_type=model_args.lora_task_type, + r=model_args.lora_r, + target_modules=model_args.lora_target_modules, + target_parameters=model_args.lora_target_parameters, + lora_alpha=model_args.lora_alpha, + lora_dropout=model_args.lora_dropout, + bias="none", + use_rslora=model_args.use_rslora, + use_dora=model_args.use_dora, + modules_to_save=model_args.lora_modules_to_save, + ) + + return peft_config + + +def get_exp_cap(value, decimal=4): + """ + Get the exponent cap of a value. This is used to cap the exponent of a value to avoid overflow. The formula is : + log(value.dtype.max) E.g. for float32 data type, the maximum exponent value is 88.7228 to 4 decimal points. + + Args: + value (`torch.Tensor`): + The input tensor to obtain the data type + decimal (`int`): + The number of decimal points of the output exponent cap. eg: direct calling exp(log(torch.float32.max)) + will result in inf so we cap the exponent to 88.7228 to avoid overflow. + """ + vdtype_max = torch.zeros([1]).to(value.dtype) + torch.finfo(value.dtype).max + vdtype_log_max = torch.log(vdtype_max).to(value.device) + return torch.floor(vdtype_log_max * 10**decimal) / 10**decimal if decimal > 0 else vdtype_log_max + + +def cap_exp(value, cap=-1): + # Cap the exponent value below the upper-bound to avoid overflow, before calling torch.exp + cap = get_exp_cap(value) if cap < 0 else cap + return torch.exp(torch.clamp(value, max=cap)) + + +def prepare_deepspeed( + model: torch.nn.Module, per_device_train_batch_size: int, fp16: bool = False, bf16: bool = False +) -> torch.nn.Module: + """ + Prepares the model for training with DeepSpeed (both for stage 2 and 3), configuring the appropriate settings based + on the model and batch size. + + Args: + model (`torch.nn.Module`): + The model to be prepared for DeepSpeed training. + per_device_train_batch_size (`int`): + The training batch size per device. + fp16 (`bool`, defaults to `False`): + Whether to use FP16 precision. + bf16 (`bool`, defaults to `False`): + Whether to use BF16 precision. + + Returns: + `torch.nn.Module`: + The model initialized and configured with DeepSpeed for training. + """ + import deepspeed + + deepspeed_plugin = AcceleratorState().deepspeed_plugin + config_kwargs = deepspeed_plugin.deepspeed_config + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["train_micro_batch_size_per_gpu"] = per_device_train_batch_size + config_kwargs = { + "train_micro_batch_size_per_gpu": config_kwargs["train_micro_batch_size_per_gpu"], + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + if bf16: + config_kwargs["bf16"] = {"enabled": True} + elif fp16: + config_kwargs["fp16"] = {"enabled": True} + else: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0, + } + ) + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + + +def empty_cache() -> None: + """Empties the cache of the available torch device. + + This function checks for the availability of different torch devices (XPU, MLU, NPU, CUDA) and empties the cache of + the first available device it finds. + + If none of the specific devices are available, it defaults to emptying the CUDA cache. + """ + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_mlu_available(): + torch.mlu.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + else: + torch.cuda.empty_cache() + + +def generate_model_card( + base_model: str | None, + model_name: str, + hub_model_id: str, + dataset_name: str | None, + tags: list[str], + wandb_url: str | None, + trainer_name: str, + trainer_citation: str | None = None, + template_file: str | None = None, + paper_title: str | None = None, + paper_id: str | None = None, + comet_url: str | None = None, +) -> ModelCard: + """ + Generate a [`~huggingface_hub.ModelCard`] from a template. + + Args: + base_model (`str` or `None`): + Base model name. + model_name (`str`): + Model name. + hub_model_id (`str`): + Hub model ID as `username/model_id`. + dataset_name (`str` or `None`): + Dataset name. + tags (`list[str]`): + Tags. + wandb_url (`str` or `None`): + Weights & Biases run URL. + comet_url (`str` or `None`): + Comet experiment URL. + trainer_name (`str`): + Trainer name. + trainer_citation (`str` or `None`, defaults to `None`): + Trainer citation as a BibTeX entry. + template_file (`str` *optional*): + Template file name located in the `trl/templates` directory. Defaults to `lm_model_card.md`. + paper_title (`str` or `None`, defaults to `None`): + Paper title. + paper_id (`str` or `None`, defaults to `None`): + ArXiv paper ID as `YYMM.NNNNN`. + + Returns: + [`~huggingface_hub.ModelCard`]: + A ModelCard object. + """ + card_data = ModelCardData( + base_model=base_model, + datasets=dataset_name, + library_name="transformers", + licence="license", + model_name=model_name, + tags=["generated_from_trainer", *tags], + ) + template_file = template_file or "lm_model_card.md" + card = ModelCard.from_template( + card_data, + template_path=str(pkg_resources.files("trl").joinpath(f"templates/{template_file}")), + base_model=base_model, + model_name=model_name, + hub_model_id=hub_model_id, + dataset_name=dataset_name, + wandb_url=wandb_url, + comet_url=comet_url, + trainer_name=trainer_name, + trainer_citation=trainer_citation, + paper_title=paper_title, + paper_id=paper_id, + trl_version=version("trl"), + transformers_version=version("transformers"), + pytorch_version=version("torch"), + datasets_version=version("datasets"), + tokenizers_version=version("tokenizers"), + ) + return card + + +def get_comet_experiment_url() -> str | None: + """ + If Comet integration is enabled, return the URL of the current Comet experiment; otherwise, return `None`. + """ + if not is_comet_available(): + return None + + if comet_ml.get_running_experiment() is not None: + return comet_ml.get_running_experiment().url + + return None + + +def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None: + """ + If Comet integration is enabled logs a table to the Comet experiment if it is currently running. + + Args: + name (`str`): + Table name. + table (`pandas.DataFrame`): + The Pandas DataFrame containing the table to log. + """ + if not is_comet_available(): + raise ModuleNotFoundError("The comet-ml is not installed. Please install it first: pip install comet-ml") + + experiment = comet_ml.get_running_experiment() + if experiment is not None: + experiment.log_table(tabular_data=table, filename=name) + + +def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, ...]: + """ + Shift non-zero elements in the mask and corresponding tensors to the left. + + This function operates on a binary mask and any number of additional tensors with the same dimensions as the mask. + For each row, non-zero values are shifted to the leftmost positions. Then, columns that contain only zeros across + all rows are truncated from the mask and tensors. Visually, this operation can be represented as follows: + + ``` + [[0, 0, x, x, x, x], -> [[x, x, x, x], + [0, x, x, x, 0, 0]] [x, x, x, 0]] + ``` + + Args: + mask (`torch.Tensor`): + 2D tensor (binary mask) with shape `(N, M)`. + *tensors (`torch.Tensor`): + One or more 2D tensors with the same shape as `mask`. These tensors will be processed alongside `mask`, + with non-zero values shifted and excess zero columns truncated in the same manner. + + Returns: + `torch.Tensor`: + Updated binary mask with non-zero values flushed to the left and trailing zero columns removed. + `*torch.Tensor` + Updated tensors, processed in the same way as the mask. + + Example: + ```python + >>> mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]]) + >>> tensor = torch.tensor([[9, 9, 2, 3, 4], [9, 5, 6, 9, 9]]) + >>> new_mask, new_tensor = flush_left(mask, tensor) + >>> print(new_mask) + tensor([[1, 1, 1], + [1, 1, 0]]) + + >>> print(new_tensor) + tensor([[2, 3, 4], + [5, 6, 0]]) + ``` + """ + _, M = mask.shape + + # Create copy of mask and tensors + mask_copy = mask.clone() + tensors = [t.clone() for t in tensors] + + # Shift non-zero values to the left + first_non_zero = mask_copy.argmax(dim=1) + pos = torch.arange(M, device=mask_copy.device).unsqueeze(0) + idx_roll = (pos + first_non_zero.unsqueeze(1)) % M + mask_roll = mask_copy.gather(1, idx_roll) + rolled_tensors = [t.gather(1, idx_roll) for t in tensors] + + # Truncate trailing columns that are all zeros in mask_roll + col_sums = mask_roll.sum(dim=0) + empty_cols = col_sums == 0 + first_empty_col = int(empty_cols.to(torch.int8).argmax()) if empty_cols.any() else M + flushed_mask = mask_roll[:, :first_empty_col] + flushed_tensors = [t[:, :first_empty_col] for t in rolled_tensors] + + if not flushed_tensors: + return flushed_mask + return flushed_mask, *flushed_tensors + + +def flush_right(mask: torch.Tensor, *tensors: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, ...]: + """ + Shift non-zero elements in the mask and corresponding tensors to the right. See `flush_left` for details. + """ + _, M = mask.shape + + # Create copy of mask and tensors + mask_copy = mask.clone() + tensors = [t.clone() for t in tensors] + + # Shift non-zero values to the right + flipped_mask = torch.fliplr(mask_copy) + first_non_zero = flipped_mask.argmax(dim=1) + pos = torch.arange(M, device=mask_copy.device).unsqueeze(0) + idx_roll = (pos - first_non_zero.unsqueeze(1)) % M + mask_roll = mask_copy.gather(1, idx_roll) + rolled_tensors = [t.gather(1, idx_roll) for t in tensors] + + # Truncate leading columns that are all zeros in mask_roll + col_sums = mask_roll.sum(dim=0) + non_empty_cols = col_sums != 0 + first_non_empty_col = int(non_empty_cols.to(torch.int8).argmax()) if non_empty_cols.any() else M + flushed_mask = mask_roll[:, first_non_empty_col:] + flushed_tensors = [t[:, first_non_empty_col:] for t in rolled_tensors] + + if not flushed_tensors: + return flushed_mask + return flushed_mask, *flushed_tensors + + +def selective_log_softmax(logits, index) -> torch.Tensor: + """ + A memory-efficient implementation of the common `log_softmax -> gather` operation. + + This function is equivalent to the following naive implementation: + ```python + logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) + ``` + + Args: + logits (`torch.Tensor`): + Logits tensor of shape `(..., num_classes)`. + index (`torch.Tensor`): + Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. + + Returns: + `torch.Tensor`: + Gathered log probabilities with the same shape as `index`. + """ + if logits.dtype in [torch.float32, torch.float64]: + selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) + # loop to reduce peak mem consumption + logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) + else: + # logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach + per_token_logps = [] + for row_logits, row_labels in zip(logits, index, strict=True): # loop to reduce peak mem consumption + row_logps = F.log_softmax(row_logits, dim=-1) + row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) + per_token_logps.append(row_per_token_logps) + per_token_logps = torch.stack(per_token_logps) + return per_token_logps + + +def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Tensor: + """ + Compute the Shannon entropy (in nats) for each row of *logits* in a memory-efficient way. + + Instead of materializing the full softmax for all rows at once, the logits are flattened to shape (N, num_classes), + where N is the product of all leading dimensions. Computation is then performed in chunks of size `chunk_size` + along this flattened dimension, reducing peak memory usage. The result is reshaped back to match the input's + leading dimensions. + + Args: + logits (`torch.Tensor`): + Logits tensor of shape `(..., num_classes)`. Entropy is taken along the last axis; all leading dimensions + are preserved in the output. + chunk_size (`int`, *optional*, defaults to `128`): + Number of rows from the flattened logits to process per iteration. Smaller values reduce memory usage at + the cost of more iterations. + + Returns: + `torch.Tensor`: + Entropy values with shape `logits.shape[:-1]`. + """ + original_shape = logits.shape[:-1] # all dims except num_classes + num_classes = logits.shape[-1] + + # Flatten all leading dimensions into one + flat_logits = logits.reshape(-1, num_classes) + + entropies = [] + for chunk in flat_logits.split(chunk_size, dim=0): + logps = F.log_softmax(chunk, dim=-1) + chunk_entropy = -(torch.exp(logps) * logps).sum(-1) + entropies.append(chunk_entropy) + + entropies = torch.cat(entropies, dim=0) + return entropies.reshape(original_shape) + + +def print_prompt_completions_sample( + prompts: list, + completions: list, + rewards: dict[str, list[float]], + advantages: list[float], + step: int, + num_samples: int = None, +) -> None: + """ + Print out a sample of model completions to the console with multiple reward metrics. + + This function creates a nicely formatted table showing prompt-completion pairs, useful for monitoring model outputs + during training. It requires the `rich` library to be installed. + + Args: + prompts (`list`): + List of prompts. Can be either strings or lists of messages. + completions (`list`): + List of completions corresponding to the prompts. Can be either strings or lists of messages. + rewards (`dict[str, list[float]]`): + Dictionary where keys are reward names and values are lists of rewards. + advantages (`list[float]`): + List of advantages corresponding to the prompts and completions. + step (`int`): + Current training step number, used in the output title. + num_samples (`int`, *optional*): + Number of random samples to display. If `None` (default), all items will be displayed. + + Example: + ```python + >>> from trl.trainer.utils import print_prompt_completions_sample + + >>> prompts = ["The sky is", "The sun is"] + >>> completions = [" blue.", " in the sky."] + >>> rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]} + >>> advantages = [0.987, 0.654] + >>> print_prompt_completions_sample(prompts, completions, rewards, advantages, 42) + โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Step 42 โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ + โ”‚ โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”“ โ”‚ + โ”‚ โ”ƒ Prompt โ”ƒ Completion โ”ƒ Correctness โ”ƒ Format โ”ƒ Advantage โ”ƒ โ”‚ + โ”‚ โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ฉ โ”‚ + โ”‚ โ”‚ The sky is โ”‚ blue. โ”‚ 0.12 โ”‚ 0.79 โ”‚ 0.99 โ”‚ โ”‚ + โ”‚ โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค โ”‚ + โ”‚ โ”‚ The sun is โ”‚ in the sky. โ”‚ 0.46 โ”‚ 0.10 โ”‚ 0.65 โ”‚ โ”‚ + โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ + โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ + ``` + """ + if not is_rich_available(): + raise ImportError( + "The function `print_prompt_completions_sample` requires the `rich` library. Please install it with " + "`pip install rich`." + ) + console = Console() + table = Table(show_header=True, header_style="bold white", expand=True) + + # Add columns + table.add_column("Prompt", style="bright_yellow") + table.add_column("Completion", style="bright_green") + for reward_name in rewards.keys(): + table.add_column(reward_name, style="bold cyan", justify="right") + table.add_column("Advantage", style="bold magenta", justify="right") + + def format_entry(entry) -> Text: + t = Text() + if isinstance(entry, list) and all(isinstance(m, dict) for m in entry): + for j, msg in enumerate(entry): + role = msg.get("role", "") + if "content" in msg: + # Chat message + t.append(f"{role.upper()}\n", style="bold red") + t.append(msg["content"]) + elif "name" in msg and "args" in msg: + # Tool call + t.append(f"{role.upper()}\n", style="bold red") + t.append(f"{msg['name']}({msg['args']})") + else: + # Fallback + t.append(str(msg)) + if j < len(entry) - 1: + t.append("\n\n") + else: + t.append(str(entry)) + return t + + # Some basic input validation + if num_samples is not None: + if num_samples >= len(prompts): + num_samples = None + elif num_samples <= 0: + return + + # Subsample data if num_samples is specified + if num_samples is not None: + indices = random.sample(range(len(prompts)), num_samples) + prompts = [prompts[i] for i in indices] + completions = [completions[i] for i in indices] + rewards = {key: [val[i] for i in indices] for key, val in rewards.items()} + advantages = [advantages[i] for i in indices] + + for i in range(len(prompts)): + reward_values = [f"{rewards[key][i]:.2f}" for key in rewards.keys()] # 2 decimals + table.add_row( + format_entry(prompts[i]), + format_entry(completions[i]), + *reward_values, + f"{advantages[i]:.2f}", + ) + table.add_section() # Adds a separator between rows + + panel = Panel(table, expand=False, title=f"Step {step}", border_style="bold white") + console.print(panel) + + +class RepeatSampler(Sampler): + """ + Sampler that repeats the indices of a dataset in a structured manner. + + Args: + data_source (`Sized`): + Dataset to sample from. + mini_repeat_count (`int`): + Number of times to repeat each index per batch. + batch_size (`int`, *optional*, defaults to `1`): + Number of unique indices per batch. + repeat_count (`int`, *optional*, defaults to `1`): + Number of times to repeat the full sampling process. + shuffle (`bool`, *optional*, defaults to `True`): + Whether to shuffle the dataset. + seed (`int`, *optional*): + Random seed for reproducibility (only affects this sampler). + + Example: + ```python + >>> sampler = RepeatSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4) + >>> list(sampler) + [4, 4, 3, 3, 0, 0, + 4, 4, 3, 3, 0, 0, + 4, 4, 3, 3, 0, 0, + 4, 4, 3, 3, 0, 0, + 1, 1, 2, 2, 6, 6, + 1, 1, 2, 2, 6, 6, + 1, 1, 2, 2, 6, 6, + 1, 1, 2, 2, 6, 6] + ``` + + ```txt + mini_repeat_count = 3 + - - - + [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, | + 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, | + 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, | + repeat_count = 2 + 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, | + 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, | + 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] | + --------- --------- --------- --------- + --------- --------- --------- --------- + --------- --------- --------- --------- + batch_size = 12 + ``` + """ + + def __init__( + self, + data_source: Sized, + mini_repeat_count: int, + batch_size: int = 1, + repeat_count: int = 1, + shuffle: bool = True, + seed: int | None = None, + ): + self.data_source = data_source + self.mini_repeat_count = mini_repeat_count + self.batch_size = batch_size + self.repeat_count = repeat_count + self.num_samples = len(data_source) + self.shuffle = shuffle + self.seed = seed + + if shuffle: + self.generator = torch.Generator() # Create a local random generator + if seed is not None: + self.generator.manual_seed(seed) + + def __iter__(self): + if self.shuffle: + # E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7) + indexes = torch.randperm(self.num_samples, generator=self.generator).tolist() + else: + indexes = list(range(self.num_samples)) + + # [2, 4, 3, 1, 0, 6, 5] + # -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3) + indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)] + + # [[2, 4, 3], [1, 0, 6], [5]] + # -> [[2, 4, 3], [1, 0, 6]] + indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size] + + for chunk in indexes: + for _ in range(self.repeat_count): + for index in chunk: + for _ in range(self.mini_repeat_count): + yield index + + def __len__(self) -> int: + return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count + + +# torch.nanstd doesn't exist, so we define it here +def nanstd(tensor: torch.Tensor, dim: int | tuple[int, ...] | None = None, keepdim: bool = False) -> torch.Tensor: + """ + Compute the standard deviation of a tensor, ignoring NaNs. + + Args: + tensor (`torch.Tensor`): + Input tensor. + dim (`int` or `tuple[int, ...]`, *optional*): + Dimension(s) to reduce. Defaults to all dimensions. + keepdim (`bool`, *optional*, defaults to `False`): + Whether to keep reduced dimensions. + + Returns: + `torch.Tensor`: + Standard deviation of the tensor, ignoring NaNs. + """ + # Compute variance ignoring NaNs + mean = torch.nanmean(tensor, dim=dim, keepdim=True) + variance = torch.nanmean((tensor - mean) ** 2, dim=dim, keepdim=True) + count = torch.sum(~torch.isnan(tensor), dim=dim, keepdim=True) # count of non-NaN values + correction = count / (count - 1) + correction = torch.where(count > 1, correction, torch.full_like(correction, float("nan"))) + variance *= correction # Bessel's correction + std = torch.sqrt(variance) + if keepdim: + return std + if dim is None: + return std.squeeze() + if isinstance(dim, int): + return std.squeeze(dim) + dims = [(d if d >= 0 else d + std.ndim) for d in dim] + for d in sorted(dims, reverse=True): + std = std.squeeze(d) + return std + + +def split_tensor_dict( + tensor_dict: dict[str, torch.Tensor | None], num_chunks: int +) -> list[dict[str, torch.Tensor | None]]: + """ + Splits a dictionary of tensors along the first dimension into `num_chunks` equal parts. + + Example: + ```python + >>> x = torch.arange(12).reshape(6, 2) + >>> y = torch.arange(6).reshape(6, 1) + >>> tensor_dict = {"x": x, "y": y} + >>> split_tensor_dict(tensor_dict, 3) + [ + {"x": tensor([[0, 1], [2, 3]]), "y": tensor([[0], [1]])}, + {"x": tensor([[4, 5], [6, 7]]), "y": tensor([[2], [3]])}, + {"x": tensor([[ 8, 9], [10, 11]]), "y": tensor([[4], [5]])} + ] + ``` + """ + first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None) + chunk_size = first_tensor.shape[0] // num_chunks + chunks = [] + for i in range(num_chunks): + chunk_dict = {} + for key, tensor in tensor_dict.items(): + if tensor is not None and (isinstance(tensor, list) or tensor.ndim > 0): + chunk_dict[key] = tensor[i * chunk_size : (i + 1) * chunk_size] + elif tensor is not None and tensor.ndim == 0: + chunk_dict[key] = tensor + else: + chunk_dict[key] = None + chunks.append(chunk_dict) + return chunks + + +def shuffle_sequence_dict(seq_dict: dict[str, Sequence | None]) -> dict[str, Sequence | None]: + """ + Shuffles all sequence-like values in a dictionary along the first dimension in unison. + + Example: + ```python + >>> x = torch.arange(6).reshape(3, 2) + >>> y = ["a", "b", "c"] + >>> seq_dict = {"x": x, "y": y} + >>> shuffle_sequence_dict(seq_dict) + {'x': tensor([[2, 3], + [0, 1], + [4, 5]]), + 'y': ['b', 'a', 'c']} + ``` + """ + # Determine batch size from the first non-None sequence + batch_size = len(next(v for v in seq_dict.values() if v is not None)) + permutation = torch.randperm(batch_size) + + def permute(v: Sequence | None) -> Sequence | None: + if v is None: + return None + if isinstance(v, torch.Tensor) and v.ndim == 0: + return v + if isinstance(v, torch.Tensor) and v.ndim >= 1: + return v[permutation] + return [v[i] for i in permutation] + + return {key: permute(val) for key, val in seq_dict.items()} + + +def nanmin(tensor: torch.Tensor) -> torch.Tensor: + """ + Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors. + + Args: + tensor (`torch.Tensor`): Input tensor of shape `(N,)`. + + Returns: + `torch.Tensor`: Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN. + """ + if torch.isnan(tensor).all(): + return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) + return torch.min(tensor[~torch.isnan(tensor)]) + + +def nanmax(tensor: torch.Tensor) -> torch.Tensor: + """ + Compute the maximum value of a tensor, ignoring NaNs. This function only supports 1D tensors. + + Args: + tensor (`torch.Tensor`): Input tensor of shape `(N,)`. + + Returns: + `torch.Tensor`: Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN. + """ + if torch.isnan(tensor).all(): + return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) + return torch.max(tensor[~torch.isnan(tensor)]) + + +def identity(x): + """Do we really need docs for this?""" + return x + + +def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor | list[torch.Tensor]]: + """ + Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in `batch["image_grid_thw"]` + and batch["num_images"] while keeping other entries unchanged. + """ + if "image_grid_thw" not in batch or "pixel_values" not in batch or "num_images" not in batch: + return batch + + lengths = batch["image_grid_thw"].prod(-1).tolist() # [num_images] + pixel_values = batch["pixel_values"] # [total, feature_dim] + + if sum(lengths) != pixel_values.size(0): + raise ValueError(f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}") + + boundaries = [0, *accumulate(batch["num_images"])] # [3, 4, 5] -> [0, 3, 7, 12] + sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(batch["num_images"]))] + split_values = list(torch.split(batch["pixel_values"], sections, dim=0)) + image_grid_thw = list(torch.split(batch["image_grid_thw"], batch["num_images"], dim=0)) + return {**batch, "pixel_values": split_values, "image_grid_thw": image_grid_thw} + + +def unsplit_pixel_values_by_grid(batch: dict[str, torch.Tensor | list[torch.Tensor]]) -> dict[str, torch.Tensor]: + """ + Opposite of `split_pixel_values_by_grid`. Merges a list of tensors in `batch["pixel_values"]` back into a single + tensor along the first dimension. + """ + pixel_values = batch.get("pixel_values") + if isinstance(pixel_values, list): + merged = torch.cat(pixel_values, dim=0) + batch = {**batch, "pixel_values": merged} + + image_grid_thw = batch.get("image_grid_thw") + if isinstance(image_grid_thw, list): + merged = torch.cat(image_grid_thw, dim=0) + batch = {**batch, "image_grid_thw": merged} + + return batch + + +TListOrMapping = TypeVar("TListOrMapping", list, Mapping) + + +def remove_none_values(example: TListOrMapping) -> TListOrMapping: + """ + Recursively removes entries with `None` values from a nested structure (list or dictionary). + + Args: + example (`list` or `Mapping`): + Input nested structure (list or dictionary) from which to remove `None`. + + Example: + ```python + >>> [ + ... { + ... "a": {"aa": None, "ab": 1}, + ... "b": "my_string", + ... } + ... ] + >>> remove_none_values(example) + [{'a': {'ab': 1}, 'b': 'my_string'}] + ``` + """ + if isinstance(example, list): + return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example] + elif isinstance(example, Mapping): + return { + key: remove_none_values(value) if isinstance(value, (dict, list)) else value + for key, value in example.items() + if value is not None + } + else: + raise TypeError("Input must be a list or a dictionary.") + + +def create_model_from_path( + model_id: str, architecture: _BaseAutoModelClass | None = None, **kwargs +) -> PreTrainedModel: + """ + Create a model from a given path using the specified initialization arguments. + + Args: + model_id (`str`): + Path to the model. Can be either a local directory or a model identifier from the Hugging Face Hub. + architecture (`_BaseAutoModelClass` or `None`, *optional*): + Model architecture class to instantiate. The model is initialized using the `from_pretrained` method of + this class. If `None`, the architecture will be inferred from the model's configuration. + kwargs (`dict`): + Initialization keyword arguments to pass to the model's `from_pretrained` method. When `'dtype'` is + specified, it can be either a `torch.dtype` or one of the strings: `'bfloat16'`, `'float16'`, `'float32'`, + or `'auto'`. If not explicitly set, `dtype` defaults to `'float32'`. + + Returns: + [`~transformers.PreTrainedModel`]: + The instantiated model. + """ + dtype = kwargs.get("dtype", "float32") + if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None: + pass # dtype is already a torch.dtype or "auto" or None + elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]: + kwargs["dtype"] = getattr(torch, dtype) + else: + raise ValueError( + "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing " + f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}." + ) + kwargs["device_map"] = kwargs.get("device_map", "auto") + if architecture is None: + config = AutoConfig.from_pretrained(model_id) + architecture = getattr(transformers, config.architectures[0]) + model = architecture.from_pretrained(model_id, **kwargs) + return model + + +def get_config_model_id(config: PretrainedConfig) -> str: + """ + Retrieve the model identifier from a given model configuration. + + Args: + config ([`~transformers.PreTrainedConfig`]): + Configuration from which to extract the model identifier. + + Returns: + `str`: + The model identifier associated with the model configuration. + """ + return getattr(config, "_name_or_path", "") + + +@dataclass +class CausalLMOutputWithPastAndFlatLogits(CausalLMOutputWithPast): + flat_logits: torch.Tensor | None = None + + +def forward_masked_logits( + model: PreTrainedModel, logits_mask: torch.LongTensor, **kwargs +) -> CausalLMOutputWithPastAndFlatLogits: + """ + Run a Causal LM forward pass while computing logits only for masked positions to reduce memory usage. + + These are always equal: + + ```python + full_outputs = model(input_ids=input_ids) + masked_outputs = forward_masked_logits(model, mask, input_ids=input_ids) + + assert torch.equal( + masked_outputs.flat_logits, + full_outputs.logits[mask.bool()], + ) + ``` + + Args: + model ([`~transformers.PreTrainedModel`]): + A causal language model. + logits_mask (`torch.LongTensor`): + Boolean-like tensor indicating which token positions should have logits computed. Shape should match the + input sequence shape in `kwargs` (typically `[batch, seq_len]`). + **kwargs: + Keyword arguments forwarded to the inner decoder (e.g., `input_ids`, `attention_mask`, `past_key_values`). + + Returns: + `CausalLMOutputWithPastAndFlatLogits`: Output containing logits only for the unmasked positions. + + Raises: + ValueError: If `logits_to_keep` or `labels` are provided in `kwargs`. + """ + if kwargs.get("logits_to_keep") is not None: + raise ValueError("`logits_to_keep` is not supported by this forward helper.") + if kwargs.get("labels") is not None: + raise ValueError("`labels` is not yet supported by this forward helper.") + + outputs: BaseModelOutputWithPast = model.get_decoder()(**kwargs) + hidden_states = outputs.last_hidden_state + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + flat_logits = model.lm_head(hidden_states[logits_mask.bool()]) + if hasattr(model, "logit_scale"): # CohereForCausalLM has this attribute + flat_logits = flat_logits * model.logit_scale + + return CausalLMOutputWithPastAndFlatLogits( + flat_logits=flat_logits, + # We use .get(...) because some models like FalconMambaForCausalLM don't return past_key_values or attentions + past_key_values=outputs.get("past_key_values"), + hidden_states=outputs.hidden_states, + attentions=outputs.get("attentions"), + ) + + +@contextmanager +def use_adapter(model: "PeftModel", adapter_name: str | None): + """ + Context manager to temporarily set and reset the active adapter in a PEFT model. + + Args: + model ([`~peft.PeftModel`]): + PEFT model to manage. + adapter_name (`str` or `None`): + Name of the adapter to set as active. If `None`, the context manager will disable all adapters. + + Example: + ```python + >>> from trl.trainer.utils import use_adapter + >>> from peft import AutoPeftModelForCausalLM + >>> import torch + + >>> model = AutoPeftModelForCausalLM.from_pretrained("path/to/model") + >>> input_ids = torch.tensor([[1, 2, 3]]) + >>> with use_adapter(model, "adapter_name"): + ... outputs = model(input_ids) + ``` + """ + + if not is_peft_available(): + raise ImportError( + "You're trying to use a PEFT adapter but PEFT is not installed. Please install it with `pip install peft`." + ) + if adapter_name is None: + with model.disable_adapter(): + yield + else: + previous_adapter = model.active_adapter + model.set_adapter(adapter_name) + try: + yield + finally: + model.set_adapter(previous_adapter) + + +def start_event_loop_in_daemon( + name: str | None = None, +) -> tuple[threading.Thread, asyncio.AbstractEventLoop, threading.Event]: + """ + This function creates a new daemon thread that runs the provided event loop. + + Args: + name (`str`, *optional*): + Name of the thread. If `None`, the default thread naming will be used. + + Returns: + `threading.Thread`: + The thread running the event loop. + `asyncio.AbstractEventLoop`: + The event loop being run in the thread. + `threading.Event`: + An event that is set when the loop is ready. + """ + loop = asyncio.new_event_loop() + loop_ready_event = threading.Event() + + def run_loop(): + asyncio.set_event_loop(loop) + loop_ready_event.set() + loop.run_forever() + + thread = threading.Thread(target=run_loop, name=name, daemon=True) + thread.start() + return thread, loop, loop_ready_event + + +def shutdown_event_loop_in_daemon( + thread: threading.Thread | None, + loop: asyncio.AbstractEventLoop | None, +) -> None: + """ + Shutdown an asyncio event loop running in a separate thread. + + This function stops the event loop and waits for the associated thread to finish execution. + + Args: + thread (`threading.Thread`): + The thread running the event loop. + loop (`asyncio.AbstractEventLoop`): + The asyncio event loop to shut down. + """ + if loop is None or thread is None: + return + loop.call_soon_threadsafe(loop.stop) + thread.join(timeout=5)