Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- ICL/RL/trl_source/.github/workflows/tests-experimental.yml +70 -0
- ICL/RL/trl_source/.github/workflows/tests_transformers_branch.yml +121 -0
- ICL/RL/trl_source/examples/scripts/evals/judge_tldr.py +108 -0
- ICL/RL/trl_source/examples/scripts/nemo_gym/deepspeed_zero3.yaml +22 -0
- ICL/RL/trl_source/examples/scripts/nemo_gym/submit.sh +112 -0
- ICL/RL/trl_source/examples/scripts/online_dpo.py +159 -0
- ICL/RL/trl_source/examples/scripts/openenv/browsergym_llm.py +506 -0
- ICL/RL/trl_source/examples/scripts/openenv/echo.py +248 -0
- ICL/RL/trl_source/examples/scripts/openenv/wordle.py +607 -0
- ICL/RL/trl_source/examples/scripts/openenv/wordle_prompt.txt +105 -0
- ICL/RL/trl_source/examples/scripts/ppo/ppo.py +180 -0
- ICL/RL/trl_source/examples/scripts/reward_modeling.py +136 -0
- ICL/RL/trl_source/examples/scripts/sft_vlm_gemma3.py +194 -0
- ICL/RL/trl_source/trl/__pycache__/__init__.cpython-313.pyc +0 -0
- ICL/RL/trl_source/trl/__pycache__/_compat.cpython-313.pyc +0 -0
- ICL/RL/trl_source/trl/__pycache__/chat_template_utils.cpython-313.pyc +0 -0
- ICL/RL/trl_source/trl/__pycache__/data_utils.cpython-313.pyc +0 -0
- ICL/RL/trl_source/trl/__pycache__/import_utils.cpython-313.pyc +0 -0
- ICL/RL/trl_source/trl/accelerate_configs/fsdp1.yaml +28 -0
- ICL/RL/trl_source/trl/accelerate_configs/fsdp2.yaml +25 -0
- ICL/RL/trl_source/trl/accelerate_configs/multi_gpu.yaml +16 -0
- ICL/RL/trl_source/trl/accelerate_configs/single_gpu.yaml +16 -0
- ICL/RL/trl_source/trl/accelerate_configs/zero1.yaml +20 -0
- ICL/RL/trl_source/trl/accelerate_configs/zero2.yaml +21 -0
- ICL/RL/trl_source/trl/accelerate_configs/zero3.yaml +22 -0
- ICL/RL/trl_source/trl/experimental/__init__.py +36 -0
- ICL/RL/trl_source/trl/experimental/bco/__init__.py +16 -0
- ICL/RL/trl_source/trl/experimental/bema_for_ref_model/__init__.py +16 -0
- ICL/RL/trl_source/trl/experimental/bema_for_ref_model/dpo_trainer.py +30 -0
- ICL/RL/trl_source/trl/experimental/cpo/__init__.py +19 -0
- ICL/RL/trl_source/trl/experimental/cpo/cpo_config.py +207 -0
- ICL/RL/trl_source/trl/experimental/cpo/cpo_trainer.py +1057 -0
- ICL/RL/trl_source/trl/experimental/gfpo/gfpo_config.py +35 -0
- ICL/RL/trl_source/trl/experimental/gkd/__init__.py +19 -0
- ICL/RL/trl_source/trl/experimental/gkd/gkd_config.py +112 -0
- ICL/RL/trl_source/trl/experimental/gold/__init__.py +19 -0
- ICL/RL/trl_source/trl/experimental/gold/gold.py +155 -0
- ICL/RL/trl_source/trl/experimental/gold/gold_config.py +419 -0
- ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/__init__.py +16 -0
- ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py +34 -0
- ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +731 -0
- ICL/RL/trl_source/trl/experimental/gspo_token/__init__.py +15 -0
- ICL/RL/trl_source/trl/experimental/gspo_token/grpo_trainer.py +157 -0
- ICL/RL/trl_source/trl/experimental/judges/__init__.py +36 -0
- ICL/RL/trl_source/trl/experimental/judges/judges.py +482 -0
- ICL/RL/trl_source/trl/experimental/kto/__init__.py +19 -0
- ICL/RL/trl_source/trl/experimental/kto/kto_config.py +171 -0
- ICL/RL/trl_source/trl/experimental/kto/kto_trainer.py +1511 -0
- ICL/RL/trl_source/trl/experimental/merge_model_callback.py +352 -0
- ICL/RL/trl_source/trl/experimental/minillm/__init__.py +19 -0
ICL/RL/trl_source/.github/workflows/tests-experimental.yml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Tests (experimental)
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
pull_request:
|
| 5 |
+
paths:
|
| 6 |
+
# Run only when relevant files are modified
|
| 7 |
+
- "trl/experimental/**"
|
| 8 |
+
- "tests/experimental/**"
|
| 9 |
+
|
| 10 |
+
env:
|
| 11 |
+
TQDM_DISABLE: 1
|
| 12 |
+
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
|
| 13 |
+
TRL_EXPERIMENTAL_SILENCE: 1
|
| 14 |
+
|
| 15 |
+
jobs:
|
| 16 |
+
check_code_quality:
|
| 17 |
+
name: Check code quality
|
| 18 |
+
runs-on: ubuntu-latest
|
| 19 |
+
if: github.event.pull_request.draft == false
|
| 20 |
+
steps:
|
| 21 |
+
- uses: actions/checkout@v6
|
| 22 |
+
- name: Set up Python 3.13
|
| 23 |
+
uses: actions/setup-python@v6
|
| 24 |
+
with:
|
| 25 |
+
python-version: 3.13
|
| 26 |
+
- uses: pre-commit/action@v3.0.1
|
| 27 |
+
with:
|
| 28 |
+
extra_args: --all-files
|
| 29 |
+
|
| 30 |
+
tests:
|
| 31 |
+
name: Tests (experimental)
|
| 32 |
+
runs-on:
|
| 33 |
+
group: aws-g4dn-2xlarge
|
| 34 |
+
container:
|
| 35 |
+
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
| 36 |
+
options: --gpus all
|
| 37 |
+
defaults:
|
| 38 |
+
run:
|
| 39 |
+
shell: bash
|
| 40 |
+
steps:
|
| 41 |
+
- name: Git checkout
|
| 42 |
+
uses: actions/checkout@v6
|
| 43 |
+
|
| 44 |
+
- name: Set up Python 3.13
|
| 45 |
+
uses: actions/setup-python@v6
|
| 46 |
+
with:
|
| 47 |
+
python-version: 3.13
|
| 48 |
+
|
| 49 |
+
- name: Install Make and Git
|
| 50 |
+
run: |
|
| 51 |
+
apt-get update && apt-get install -y make git curl
|
| 52 |
+
|
| 53 |
+
- name: Install uv
|
| 54 |
+
run: |
|
| 55 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 56 |
+
|
| 57 |
+
- name: Create Python virtual environment
|
| 58 |
+
run: |
|
| 59 |
+
uv venv
|
| 60 |
+
uv pip install --upgrade setuptools wheel
|
| 61 |
+
|
| 62 |
+
- name: Install dependencies
|
| 63 |
+
run: |
|
| 64 |
+
source .venv/bin/activate
|
| 65 |
+
uv pip install ".[dev]"
|
| 66 |
+
|
| 67 |
+
- name: Test with pytest
|
| 68 |
+
run: |
|
| 69 |
+
source .venv/bin/activate
|
| 70 |
+
make test_experimental
|
ICL/RL/trl_source/.github/workflows/tests_transformers_branch.yml
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Tests against Transformers branch
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
workflow_dispatch:
|
| 5 |
+
inputs:
|
| 6 |
+
transformers_ref:
|
| 7 |
+
description: "Transformers git ref (branch, tag, or commit SHA)"
|
| 8 |
+
required: true
|
| 9 |
+
default: "main"
|
| 10 |
+
|
| 11 |
+
env:
|
| 12 |
+
TQDM_DISABLE: 1
|
| 13 |
+
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
|
| 14 |
+
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
|
| 15 |
+
|
| 16 |
+
jobs:
|
| 17 |
+
tests_transformers_branch:
|
| 18 |
+
name: Tests with Transformers ${{ inputs.transformers_ref }}
|
| 19 |
+
runs-on:
|
| 20 |
+
group: aws-g4dn-2xlarge
|
| 21 |
+
container:
|
| 22 |
+
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
| 23 |
+
options: --gpus all
|
| 24 |
+
defaults:
|
| 25 |
+
run:
|
| 26 |
+
shell: bash
|
| 27 |
+
steps:
|
| 28 |
+
- name: Git checkout
|
| 29 |
+
uses: actions/checkout@v6
|
| 30 |
+
|
| 31 |
+
- name: Set up Python 3.12
|
| 32 |
+
uses: actions/setup-python@v6
|
| 33 |
+
with:
|
| 34 |
+
python-version: '3.12'
|
| 35 |
+
|
| 36 |
+
- name: Install Make and Git
|
| 37 |
+
run: |
|
| 38 |
+
apt-get update && apt-get install -y make git curl
|
| 39 |
+
|
| 40 |
+
- name: Install uv
|
| 41 |
+
run: |
|
| 42 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 43 |
+
|
| 44 |
+
- name: Create Python virtual environment
|
| 45 |
+
run: |
|
| 46 |
+
uv venv
|
| 47 |
+
uv pip install --upgrade setuptools wheel
|
| 48 |
+
|
| 49 |
+
- name: Install dependencies
|
| 50 |
+
run: |
|
| 51 |
+
source .venv/bin/activate
|
| 52 |
+
uv pip install ".[dev]"
|
| 53 |
+
uv pip install -U git+https://github.com/huggingface/transformers.git@${{ inputs.transformers_ref }}
|
| 54 |
+
|
| 55 |
+
- name: Test with pytest
|
| 56 |
+
run: |
|
| 57 |
+
source .venv/bin/activate
|
| 58 |
+
make test
|
| 59 |
+
|
| 60 |
+
- name: Post to Slack
|
| 61 |
+
if: github.ref == 'refs/heads/main' && always()
|
| 62 |
+
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
| 63 |
+
with:
|
| 64 |
+
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
| 65 |
+
title: Results with Transformers ${{ inputs.transformers_ref }}
|
| 66 |
+
status: ${{ job.status }}
|
| 67 |
+
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
| 68 |
+
|
| 69 |
+
distributed_smoke:
|
| 70 |
+
name: Distributed smoke tests with Transformers ${{ inputs.transformers_ref }}
|
| 71 |
+
runs-on:
|
| 72 |
+
group: aws-g5-12xlarge-cache
|
| 73 |
+
container:
|
| 74 |
+
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
| 75 |
+
options: --gpus all
|
| 76 |
+
defaults:
|
| 77 |
+
run:
|
| 78 |
+
shell: bash
|
| 79 |
+
env:
|
| 80 |
+
CUDA_VISIBLE_DEVICES: "0,1"
|
| 81 |
+
steps:
|
| 82 |
+
- name: Git checkout
|
| 83 |
+
uses: actions/checkout@v6
|
| 84 |
+
|
| 85 |
+
- name: Set up Python 3.12
|
| 86 |
+
uses: actions/setup-python@v6
|
| 87 |
+
with:
|
| 88 |
+
python-version: '3.12'
|
| 89 |
+
|
| 90 |
+
- name: Install Make and Git
|
| 91 |
+
run: |
|
| 92 |
+
apt-get update && apt-get install -y make git curl
|
| 93 |
+
|
| 94 |
+
- name: Install uv
|
| 95 |
+
run: |
|
| 96 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 97 |
+
|
| 98 |
+
- name: Create Python virtual environment
|
| 99 |
+
run: |
|
| 100 |
+
uv venv
|
| 101 |
+
uv pip install --upgrade setuptools wheel
|
| 102 |
+
|
| 103 |
+
- name: Install dependencies
|
| 104 |
+
run: |
|
| 105 |
+
source .venv/bin/activate
|
| 106 |
+
uv pip install ".[dev]"
|
| 107 |
+
uv pip install -U git+https://github.com/huggingface/transformers.git@${{ inputs.transformers_ref }}
|
| 108 |
+
|
| 109 |
+
- name: Run distributed smoke tests
|
| 110 |
+
run: |
|
| 111 |
+
source .venv/bin/activate
|
| 112 |
+
pytest -v tests/distributed/test_distributed.py
|
| 113 |
+
|
| 114 |
+
- name: Post to Slack
|
| 115 |
+
if: github.ref == 'refs/heads/main' && always()
|
| 116 |
+
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
| 117 |
+
with:
|
| 118 |
+
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
| 119 |
+
title: Results of distributed smoke tests with Transformers ${{ inputs.transformers_ref }}
|
| 120 |
+
status: ${{ job.status }}
|
| 121 |
+
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
ICL/RL/trl_source/examples/scripts/evals/judge_tldr.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# /// script
|
| 16 |
+
# dependencies = [
|
| 17 |
+
# "trl[vllm]",
|
| 18 |
+
# ]
|
| 19 |
+
# ///
|
| 20 |
+
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
|
| 23 |
+
from datasets import load_dataset
|
| 24 |
+
from transformers import HfArgumentParser
|
| 25 |
+
from vllm import LLM, SamplingParams
|
| 26 |
+
|
| 27 |
+
from trl.experimental.judges import HfPairwiseJudge, OpenAIPairwiseJudge
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
Examples:
|
| 32 |
+
|
| 33 |
+
python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --num_examples 1000
|
| 34 |
+
Model win rate: 31.40%
|
| 35 |
+
|
| 36 |
+
python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-3.5-turbo-0125 --num_examples 1000
|
| 37 |
+
Model win rate: 51.60%
|
| 38 |
+
|
| 39 |
+
python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-4o-mini --num_examples 1000
|
| 40 |
+
Model win rate: 51.20%
|
| 41 |
+
|
| 42 |
+
python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --num_examples 1000
|
| 43 |
+
Model win rate: 46.30%
|
| 44 |
+
|
| 45 |
+
python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-3.5-turbo-0125 --num_examples 1000
|
| 46 |
+
Model win rate: 52.50%
|
| 47 |
+
|
| 48 |
+
python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-4o-mini --num_examples 1000
|
| 49 |
+
Model win rate: 63.00%
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class ScriptArguments:
|
| 55 |
+
r"""
|
| 56 |
+
Arguments for the script.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
model_name_or_path (`str`):
|
| 60 |
+
Model name or path to the model to evaluate.
|
| 61 |
+
judge_model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3-70B-Instruct"`):
|
| 62 |
+
Model name or path to the model to use as a judge. E.g., 'gpt-3.5-turbo-0125' or
|
| 63 |
+
'meta-llama/Meta-Llama-3-70B-Instruct'.
|
| 64 |
+
num_examples (`int`, *optional*):
|
| 65 |
+
Number of examples to evaluate.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
model_name_or_path: str = field(metadata={"help": "Model name or path to the model to evaluate."})
|
| 69 |
+
judge_model: str = field(
|
| 70 |
+
default="meta-llama/Meta-Llama-3-70B-Instruct",
|
| 71 |
+
metadata={
|
| 72 |
+
"help": "Model name or path to the model to use as a judge. E.g., 'gpt-3.5-turbo-0125' or "
|
| 73 |
+
"'meta-llama/Meta-Llama-3-70B-Instruct'."
|
| 74 |
+
},
|
| 75 |
+
)
|
| 76 |
+
num_examples: int | None = field(default=None, metadata={"help": "Number of examples to evaluate."})
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
# Parse the arguments
|
| 81 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 82 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
| 83 |
+
|
| 84 |
+
# Load the dataset
|
| 85 |
+
dataset = load_dataset("trl-lib/tldr", split="validation")
|
| 86 |
+
if script_args.num_examples is not None:
|
| 87 |
+
dataset = dataset.select(range(script_args.num_examples))
|
| 88 |
+
|
| 89 |
+
# Extract the prompts and reference completions
|
| 90 |
+
prompts = dataset["prompt"]
|
| 91 |
+
reference_completions = dataset["completion"]
|
| 92 |
+
|
| 93 |
+
# Generate the model completions
|
| 94 |
+
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=200) # very generous max token length
|
| 95 |
+
llm = LLM(model=script_args.model_name_or_path, tensor_parallel_size=1)
|
| 96 |
+
outputs = llm.generate(prompts, sampling_params)
|
| 97 |
+
model_completions = [output.outputs[0].text.strip() for output in outputs]
|
| 98 |
+
|
| 99 |
+
# Judge the outputs
|
| 100 |
+
if "gpt" in script_args.judge_model:
|
| 101 |
+
judge = OpenAIPairwiseJudge(script_args.judge_model)
|
| 102 |
+
else:
|
| 103 |
+
judge = HfPairwiseJudge(script_args.judge_model)
|
| 104 |
+
|
| 105 |
+
completions = [[c0, c1] for c0, c1 in zip(reference_completions, model_completions, strict=True)]
|
| 106 |
+
best_idxs = judge.judge(prompts, completions)
|
| 107 |
+
model_win_rate = best_idxs.count(1) / len(best_idxs)
|
| 108 |
+
print(f"Model win rate: {model_win_rate * 100:.2f}%")
|
ICL/RL/trl_source/examples/scripts/nemo_gym/deepspeed_zero3.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_multinode_launcher: standard
|
| 5 |
+
offload_optimizer_device: none
|
| 6 |
+
offload_param_device: none
|
| 7 |
+
zero3_init_flag: true
|
| 8 |
+
zero3_save_16bit_model: true
|
| 9 |
+
zero_stage: 3
|
| 10 |
+
distributed_type: DEEPSPEED
|
| 11 |
+
downcast_bf16: 'no'
|
| 12 |
+
machine_rank: 0
|
| 13 |
+
main_training_function: main
|
| 14 |
+
mixed_precision: bf16
|
| 15 |
+
num_machines: 4
|
| 16 |
+
num_processes: 32
|
| 17 |
+
rdzv_backend: static
|
| 18 |
+
same_network: true
|
| 19 |
+
tpu_env: []
|
| 20 |
+
tpu_use_cluster: false
|
| 21 |
+
tpu_use_sudo: false
|
| 22 |
+
use_cpu: false
|
ICL/RL/trl_source/examples/scripts/nemo_gym/submit.sh
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH -A account
|
| 3 |
+
#SBATCH -p partition
|
| 4 |
+
#SBATCH -N 5
|
| 5 |
+
#SBATCH --gres gpu:8
|
| 6 |
+
#SBATCH --ntasks-per-node=1
|
| 7 |
+
#SBATCH --cpus-per-task=16
|
| 8 |
+
#SBATCH --time=4:00:00
|
| 9 |
+
#SBATCH --job-name=trl_nemo_gym
|
| 10 |
+
#SBATCH --output=logs/%j/slurm.out
|
| 11 |
+
#SBATCH --error=logs/%j/slurm.err
|
| 12 |
+
|
| 13 |
+
CONTAINER_IMAGE="nvcr.io/nvidia/pytorch:25.12-py3"
|
| 14 |
+
MOUNTS="/path/to/mounts:/path/to/mounts"
|
| 15 |
+
|
| 16 |
+
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
|
| 17 |
+
|
| 18 |
+
TRAIN_NODE_0="${NODELIST[0]}"
|
| 19 |
+
TRAIN_NODE_1="${NODELIST[1]}"
|
| 20 |
+
TRAIN_NODE_2="${NODELIST[2]}"
|
| 21 |
+
TRAIN_NODE_3="${NODELIST[3]}"
|
| 22 |
+
VLLM_NODE="${NODELIST[4]}"
|
| 23 |
+
|
| 24 |
+
echo "Training Nodes: $TRAIN_NODE_0, $TRAIN_NODE_1, $TRAIN_NODE_2, $TRAIN_NODE_3"
|
| 25 |
+
echo "vLLM Node: $VLLM_NODE"
|
| 26 |
+
echo "Main process IP: $TRAIN_NODE_0"
|
| 27 |
+
|
| 28 |
+
LOG_DIR="logs/${SLURM_JOB_ID}"
|
| 29 |
+
mkdir -p ${LOG_DIR}
|
| 30 |
+
|
| 31 |
+
echo "Starting ng_run and vLLM on ${VLLM_NODE}..."
|
| 32 |
+
echo "Logs will be saved to: ${LOG_DIR}"
|
| 33 |
+
|
| 34 |
+
# NOTE: If you have already set up your TRL venv, you can remove all of the pip installs and uv venv related commands below!
|
| 35 |
+
|
| 36 |
+
srun --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \
|
| 37 |
+
--container-image="${CONTAINER_IMAGE}" \
|
| 38 |
+
--container-mounts="${MOUNTS}" \
|
| 39 |
+
--container-mount-home \
|
| 40 |
+
bash -c "
|
| 41 |
+
LOG_DIR=/path/to/logs
|
| 42 |
+
mkdir -p \${LOG_DIR}
|
| 43 |
+
|
| 44 |
+
# Install uv if not already installed
|
| 45 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 46 |
+
source \$HOME/.local/bin/env
|
| 47 |
+
|
| 48 |
+
# Start nemo gym servers
|
| 49 |
+
(set -x && \
|
| 50 |
+
export HOME=/path/to/user && \
|
| 51 |
+
export PATH=\$HOME/.local/bin:\$PATH && \
|
| 52 |
+
cd /path/to/user/Gym && \
|
| 53 |
+
uv venv --python 3.12 && \
|
| 54 |
+
source .venv/bin/activate && \
|
| 55 |
+
uv sync && \
|
| 56 |
+
ray stop --force && \
|
| 57 |
+
ng_run +config_paths=[responses_api_models/vllm_model/configs/vllm_model.yaml,resources_servers/workplace_assistant/configs/workplace_assistant.yaml] +head_server.host=0.0.0.0 +head_server.port=11000) > \${LOG_DIR}/ng_run.log 2>&1 &
|
| 58 |
+
|
| 59 |
+
sleep 10
|
| 60 |
+
|
| 61 |
+
# Start trl vllm server
|
| 62 |
+
(set -x && \
|
| 63 |
+
export HOME=/path/to/user && \
|
| 64 |
+
export HF_HOME=/path/to/user/hf_home && \
|
| 65 |
+
cd /path/to/user/trl && \
|
| 66 |
+
rm -rf .venv && uv venv && source .venv/bin/activate && uv sync && uv pip install -e .[vllm] && uv pip install fastapi uvicorn && \
|
| 67 |
+
python -m trl.scripts.vllm_serve \
|
| 68 |
+
--model Qwen/Qwen3-4B-Instruct-2507 \
|
| 69 |
+
--host 0.0.0.0 \
|
| 70 |
+
--tensor-parallel-size 8 \
|
| 71 |
+
--data-parallel-size 1 \
|
| 72 |
+
--max-model-len 16384 \
|
| 73 |
+
--gpu-memory-utilization 0.7 \
|
| 74 |
+
--port 8000) > \${LOG_DIR}/vllm_serve.log 2>&1 &
|
| 75 |
+
|
| 76 |
+
wait
|
| 77 |
+
" &
|
| 78 |
+
|
| 79 |
+
echo "Waiting for nemo gym and vllm to start..."
|
| 80 |
+
sleep 120
|
| 81 |
+
|
| 82 |
+
echo "Launching training on 4 nodes..."
|
| 83 |
+
|
| 84 |
+
TRAIN_NODES_LIST="${TRAIN_NODE_0},${TRAIN_NODE_1},${TRAIN_NODE_2},${TRAIN_NODE_3}"
|
| 85 |
+
|
| 86 |
+
srun --nodes=4 --ntasks=4 --nodelist="${TRAIN_NODES_LIST}" \
|
| 87 |
+
--container-image="${CONTAINER_IMAGE}" \
|
| 88 |
+
--container-mounts="${MOUNTS}" \
|
| 89 |
+
--container-mount-home \
|
| 90 |
+
bash -c "
|
| 91 |
+
set -x && \
|
| 92 |
+
export HOME=/path/to/user && \
|
| 93 |
+
export HF_HOME=/path/to/user/hf_home && \
|
| 94 |
+
cd /path/to/user/trl && \
|
| 95 |
+
source .venv/bin/activate && uv pip install accelerate deepspeed wandb omegaconf && \
|
| 96 |
+
cd examples/scripts/nemo_gym && \
|
| 97 |
+
export WANDB_API_KEY=<your wandb api key> && \
|
| 98 |
+
accelerate launch \
|
| 99 |
+
--config_file deepspeed_zero3.yaml \
|
| 100 |
+
--num_processes 32 \
|
| 101 |
+
--num_machines 4 \
|
| 102 |
+
--machine_rank \$SLURM_PROCID \
|
| 103 |
+
--main_process_ip ${TRAIN_NODE_0} \
|
| 104 |
+
--main_process_port 29500 \
|
| 105 |
+
--rdzv_backend c10d \
|
| 106 |
+
train_multi_environment.py \
|
| 107 |
+
--config config.yaml \
|
| 108 |
+
--vllm_server_host ${VLLM_NODE} \
|
| 109 |
+
--head_server_host ${VLLM_NODE}" &
|
| 110 |
+
|
| 111 |
+
wait
|
| 112 |
+
|
ICL/RL/trl_source/examples/scripts/online_dpo.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# /// script
|
| 16 |
+
# dependencies = [
|
| 17 |
+
# "trl",
|
| 18 |
+
# "peft",
|
| 19 |
+
# "trackio",
|
| 20 |
+
# "kernels",
|
| 21 |
+
# ]
|
| 22 |
+
# ///
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
Usage:
|
| 26 |
+
|
| 27 |
+
python examples/scripts/online_dpo.py \
|
| 28 |
+
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
|
| 29 |
+
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
|
| 30 |
+
--dataset_name trl-lib/tldr \
|
| 31 |
+
--learning_rate 5.0e-7 \
|
| 32 |
+
--output_dir pythia-1b-tldr-online-dpo \
|
| 33 |
+
--per_device_train_batch_size 8 \
|
| 34 |
+
--gradient_accumulation_steps 16 \
|
| 35 |
+
--warmup_steps 0.1 \
|
| 36 |
+
--missing_eos_penalty 1.0
|
| 37 |
+
|
| 38 |
+
With LoRA:
|
| 39 |
+
python examples/scripts/online_dpo.py \
|
| 40 |
+
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
|
| 41 |
+
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
|
| 42 |
+
--dataset_name trl-lib/tldr \
|
| 43 |
+
--learning_rate 5.0e-6 \
|
| 44 |
+
--output_dir pythia-1b-tldr-online-dpo \
|
| 45 |
+
--per_device_train_batch_size 16 \
|
| 46 |
+
--gradient_accumulation_steps 8 \
|
| 47 |
+
--warmup_steps 0.1 \
|
| 48 |
+
--missing_eos_penalty 1.0 \
|
| 49 |
+
--use_peft
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
import os
|
| 53 |
+
|
| 54 |
+
import torch
|
| 55 |
+
from datasets import load_dataset
|
| 56 |
+
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig
|
| 57 |
+
|
| 58 |
+
from trl import (
|
| 59 |
+
LogCompletionsCallback,
|
| 60 |
+
ModelConfig,
|
| 61 |
+
ScriptArguments,
|
| 62 |
+
TrlParser,
|
| 63 |
+
get_kbit_device_map,
|
| 64 |
+
get_peft_config,
|
| 65 |
+
get_quantization_config,
|
| 66 |
+
)
|
| 67 |
+
from trl.experimental.judges import HfPairwiseJudge, OpenAIPairwiseJudge, PairRMJudge
|
| 68 |
+
from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Enable logging in a Hugging Face Space
|
| 72 |
+
os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
JUDGES = {"pair_rm": PairRMJudge, "openai": OpenAIPairwiseJudge, "hf": HfPairwiseJudge}
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig))
|
| 79 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 80 |
+
training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
|
| 81 |
+
|
| 82 |
+
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
|
| 83 |
+
model_kwargs = dict(
|
| 84 |
+
revision=model_args.model_revision,
|
| 85 |
+
attn_implementation=model_args.attn_implementation,
|
| 86 |
+
dtype=dtype,
|
| 87 |
+
use_cache=False if training_args.gradient_checkpointing else True,
|
| 88 |
+
)
|
| 89 |
+
quantization_config = get_quantization_config(model_args)
|
| 90 |
+
if quantization_config is not None:
|
| 91 |
+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
|
| 92 |
+
model_kwargs["device_map"] = get_kbit_device_map()
|
| 93 |
+
model_kwargs["quantization_config"] = quantization_config
|
| 94 |
+
|
| 95 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 96 |
+
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if training_args.reward_model_path is not None:
|
| 100 |
+
reward_model = AutoModelForSequenceClassification.from_pretrained(
|
| 101 |
+
training_args.reward_model_path,
|
| 102 |
+
num_labels=1,
|
| 103 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 104 |
+
**model_kwargs,
|
| 105 |
+
)
|
| 106 |
+
reward_tokenizer = AutoTokenizer.from_pretrained(
|
| 107 |
+
training_args.reward_model_path,
|
| 108 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 109 |
+
truncation=True,
|
| 110 |
+
truncation_side="left", # since we judge the completion, truncating left is more appropriate
|
| 111 |
+
)
|
| 112 |
+
if reward_tokenizer.pad_token_id is None:
|
| 113 |
+
reward_tokenizer.pad_token = reward_tokenizer.eos_token
|
| 114 |
+
else:
|
| 115 |
+
reward_model = None
|
| 116 |
+
reward_tokenizer = None
|
| 117 |
+
|
| 118 |
+
if training_args.judge is not None:
|
| 119 |
+
judge_cls = JUDGES[training_args.judge]
|
| 120 |
+
judge = judge_cls()
|
| 121 |
+
else:
|
| 122 |
+
judge = None
|
| 123 |
+
|
| 124 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 125 |
+
model_args.model_name_or_path,
|
| 126 |
+
padding_side="left",
|
| 127 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 128 |
+
**model_kwargs,
|
| 129 |
+
)
|
| 130 |
+
if tokenizer.pad_token_id is None:
|
| 131 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 132 |
+
|
| 133 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 134 |
+
|
| 135 |
+
trainer = OnlineDPOTrainer(
|
| 136 |
+
model=model,
|
| 137 |
+
reward_funcs=reward_model,
|
| 138 |
+
judge=judge,
|
| 139 |
+
args=training_args,
|
| 140 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 141 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 142 |
+
processing_class=tokenizer,
|
| 143 |
+
reward_processing_classes=reward_tokenizer,
|
| 144 |
+
peft_config=get_peft_config(model_args),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
if training_args.eval_strategy != "no":
|
| 148 |
+
generation_config = GenerationConfig(
|
| 149 |
+
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
|
| 150 |
+
)
|
| 151 |
+
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
|
| 152 |
+
trainer.add_callback(completions_callback)
|
| 153 |
+
|
| 154 |
+
trainer.train()
|
| 155 |
+
|
| 156 |
+
# Save and push to hub
|
| 157 |
+
trainer.save_model(training_args.output_dir)
|
| 158 |
+
if training_args.push_to_hub:
|
| 159 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
ICL/RL/trl_source/examples/scripts/openenv/browsergym_llm.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# /// script
|
| 16 |
+
# dependencies = [
|
| 17 |
+
# "trl[vllm]",
|
| 18 |
+
# "peft",
|
| 19 |
+
# "trackio",
|
| 20 |
+
# "kernels",
|
| 21 |
+
# "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env",
|
| 22 |
+
# ]
|
| 23 |
+
# ///
|
| 24 |
+
|
| 25 |
+
"""
|
| 26 |
+
Simple script to run GRPO training with OpenEnv's BrowserGym environment and vLLM for LLMs.
|
| 27 |
+
|
| 28 |
+
This script is optimized for text-only Language Models (LLMs). It uses the accessibility
|
| 29 |
+
tree text from BrowserGym, making it memory-efficient.
|
| 30 |
+
|
| 31 |
+
The environment runs on a Hugging Face Space by default.
|
| 32 |
+
|
| 33 |
+
Setup (Option A - Install from HF Space, recommended):
|
| 34 |
+
|
| 35 |
+
```sh
|
| 36 |
+
uv pip install git+https://huggingface.co/spaces/openenv/browsergym_env
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
Setup (Option B - Clone OpenEnv repo, for development):
|
| 40 |
+
|
| 41 |
+
```sh
|
| 42 |
+
git clone https://github.com/meta-pytorch/OpenEnv.git
|
| 43 |
+
cd OpenEnv/envs/browsergym_env
|
| 44 |
+
uv pip install -e .
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
# Option 1: HF Spaces + Colocated vLLM (1 GPU required)
|
| 48 |
+
```sh
|
| 49 |
+
python examples/scripts/openenv/browsergym_llm.py --vllm-mode colocate
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
# Option 2: HF Spaces + Separate vLLM server (2 GPUs required)
|
| 53 |
+
|
| 54 |
+
# Spin up vLLM server (Terminal 1)
|
| 55 |
+
```sh
|
| 56 |
+
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-0.6B --host 0.0.0.0 --port 8001
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
# Run training (Terminal 2)
|
| 60 |
+
```sh
|
| 61 |
+
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/browsergym_llm.py --vllm-mode server --vllm-server-url http://localhost:8001
|
| 62 |
+
```
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
from __future__ import annotations
|
| 66 |
+
|
| 67 |
+
import argparse
|
| 68 |
+
from datetime import datetime
|
| 69 |
+
from pathlib import Path
|
| 70 |
+
|
| 71 |
+
from browsergym_env import BrowserGymAction, BrowserGymEnv
|
| 72 |
+
from datasets import Dataset
|
| 73 |
+
from transformers import AutoTokenizer
|
| 74 |
+
|
| 75 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 76 |
+
from trl.experimental.openenv import generate_rollout_completions
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def parse_args() -> argparse.Namespace:
|
| 80 |
+
parser = argparse.ArgumentParser(description="Run GRPO training for BrowserGym MiniWoB using OpenEnv environment.")
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--model-id",
|
| 83 |
+
default="Qwen/Qwen3-0.6B",
|
| 84 |
+
help="Model identifier passed to GRPOTrainer for fine-tuning.",
|
| 85 |
+
)
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--space-url",
|
| 88 |
+
type=str,
|
| 89 |
+
default="https://openenv-browsergym-env.hf.space",
|
| 90 |
+
help="URL for the Hugging Face Space running the BrowserGym environment.",
|
| 91 |
+
)
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--benchmark",
|
| 94 |
+
default="miniwob",
|
| 95 |
+
help="BrowserGym benchmark to use (miniwob, webarena, etc.).",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--task-name",
|
| 99 |
+
default="click-test",
|
| 100 |
+
help="Specific task within the benchmark (e.g., click-test, click-button).",
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--dataset-prompt",
|
| 104 |
+
default="Complete the web task successfully.",
|
| 105 |
+
help="Prompt text used to seed the training dataset.",
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--dataset-size",
|
| 109 |
+
type=int,
|
| 110 |
+
default=1000,
|
| 111 |
+
help="Number of entries to include in the synthetic training dataset.",
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--max-steps",
|
| 115 |
+
type=int,
|
| 116 |
+
default=10,
|
| 117 |
+
help="Maximum number of steps per episode.",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--max-new-tokens",
|
| 121 |
+
type=int,
|
| 122 |
+
default=32,
|
| 123 |
+
help="Maximum number of new tokens to request from vLLM for each action.",
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--temperature",
|
| 127 |
+
type=float,
|
| 128 |
+
default=0.7,
|
| 129 |
+
help="Sampling temperature used during rollout generation.",
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--top-k",
|
| 133 |
+
type=int,
|
| 134 |
+
default=50,
|
| 135 |
+
help="Top-k sampling parameter forwarded to vLLM.",
|
| 136 |
+
)
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--top-p",
|
| 139 |
+
type=float,
|
| 140 |
+
default=None,
|
| 141 |
+
help="Optional top-p sampling parameter forwarded to vLLM.",
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--learning-rate",
|
| 145 |
+
type=float,
|
| 146 |
+
default=5e-6,
|
| 147 |
+
help="Learning rate for GRPO training.",
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--weight-decay",
|
| 151 |
+
type=float,
|
| 152 |
+
default=0.0,
|
| 153 |
+
help="Weight decay applied during optimization.",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--gradient-accumulation-steps",
|
| 157 |
+
type=int,
|
| 158 |
+
default=32,
|
| 159 |
+
help="Gradient accumulation steps for GRPO training.",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--warmup-steps",
|
| 163 |
+
type=int,
|
| 164 |
+
default=10,
|
| 165 |
+
help="Warmup steps for the scheduler.",
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--per-device-batch-size",
|
| 169 |
+
type=int,
|
| 170 |
+
default=1,
|
| 171 |
+
help="Per-device train batch size.",
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--num-generations",
|
| 175 |
+
type=int,
|
| 176 |
+
default=4,
|
| 177 |
+
help="Number of rollout generations per dataset prompt.",
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--num-epochs",
|
| 181 |
+
type=int,
|
| 182 |
+
default=1,
|
| 183 |
+
help="Number of training epochs.",
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--save-interval",
|
| 187 |
+
type=int,
|
| 188 |
+
default=50,
|
| 189 |
+
help="Interval (in steps) between checkpoint saves.",
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--save-total-limit",
|
| 193 |
+
type=int,
|
| 194 |
+
default=None,
|
| 195 |
+
help="Maximum number of checkpoints to keep.",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--output-dir",
|
| 199 |
+
default=None,
|
| 200 |
+
help="Directory where training outputs and checkpoints are stored.",
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--run-name",
|
| 204 |
+
default=None,
|
| 205 |
+
help="Optional run name for logging systems.",
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"--project",
|
| 209 |
+
default=None,
|
| 210 |
+
help="Optional project identifier for logging systems.",
|
| 211 |
+
)
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--vllm-mode",
|
| 214 |
+
choices=("colocate", "server"),
|
| 215 |
+
default="colocate",
|
| 216 |
+
help="vLLM execution mode: 'colocate' or 'server'.",
|
| 217 |
+
)
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--vllm-server-url",
|
| 220 |
+
type=str,
|
| 221 |
+
default="http://localhost:8001",
|
| 222 |
+
help="URL for the vLLM server (only used when --vllm-mode=server).",
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--logging-steps",
|
| 226 |
+
type=int,
|
| 227 |
+
default=1,
|
| 228 |
+
help="Frequency of logging steps for GRPO training.",
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--debug",
|
| 232 |
+
action="store_true",
|
| 233 |
+
default=False,
|
| 234 |
+
help="Enable verbose debugging output during rollouts.",
|
| 235 |
+
)
|
| 236 |
+
return parser.parse_args()
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def sanitize_name(name: str) -> str:
|
| 240 |
+
return name.replace("/", "-")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# ---------------------------------------------------------------------------
|
| 244 |
+
# System Prompt
|
| 245 |
+
# ---------------------------------------------------------------------------
|
| 246 |
+
|
| 247 |
+
SYSTEM_PROMPT = """You control a web browser through BrowserGym actions.
|
| 248 |
+
You must complete the given web task by interacting with the page.
|
| 249 |
+
|
| 250 |
+
Available actions:
|
| 251 |
+
- noop() - Do nothing
|
| 252 |
+
- click(bid) - Click element with BrowserGym ID (the number in brackets)
|
| 253 |
+
- fill(bid, text) - Fill input field with text
|
| 254 |
+
- send_keys(text) - Send keyboard input
|
| 255 |
+
- scroll(direction) - Scroll up/down
|
| 256 |
+
|
| 257 |
+
The page structure shows elements as: [bid] element_type 'element_text'
|
| 258 |
+
For example: [13] button 'Click Me!' means bid='13'
|
| 259 |
+
|
| 260 |
+
Reply with exactly ONE action on a single line, e.g.:
|
| 261 |
+
click('13')
|
| 262 |
+
fill('42', 'hello world')
|
| 263 |
+
noop()
|
| 264 |
+
|
| 265 |
+
Do not include explanations or multiple actions."""
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# ---------------------------------------------------------------------------
|
| 269 |
+
# Helpers
|
| 270 |
+
# ---------------------------------------------------------------------------
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def make_user_prompt(goal: str, step_num: int, axtree: str, error: str = "") -> str:
|
| 274 |
+
"""Create user prompt from observation."""
|
| 275 |
+
prompt_parts = [f"Step {step_num + 1}"]
|
| 276 |
+
|
| 277 |
+
if goal:
|
| 278 |
+
prompt_parts.append(f"Goal: {goal}")
|
| 279 |
+
|
| 280 |
+
if error:
|
| 281 |
+
prompt_parts.append(f"Previous action error: {error}")
|
| 282 |
+
|
| 283 |
+
# Include accessibility tree (truncated for context)
|
| 284 |
+
if axtree:
|
| 285 |
+
max_len = 2000
|
| 286 |
+
axtree_truncated = axtree[:max_len] + "..." if len(axtree) > max_len else axtree
|
| 287 |
+
prompt_parts.append(f"Page structure:\n{axtree_truncated}")
|
| 288 |
+
|
| 289 |
+
prompt_parts.append("What action do you take?")
|
| 290 |
+
|
| 291 |
+
return "\n\n".join(prompt_parts)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def parse_action(response_text: str) -> str:
|
| 295 |
+
"""Parse BrowserGym action from model response."""
|
| 296 |
+
# Extract first line that looks like an action
|
| 297 |
+
for line in response_text.strip().split("\n"):
|
| 298 |
+
line = line.strip()
|
| 299 |
+
if "(" in line and ")" in line:
|
| 300 |
+
return line
|
| 301 |
+
|
| 302 |
+
# Fallback to noop if no valid action found
|
| 303 |
+
return "noop()"
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def rollout_once(
|
| 307 |
+
trainer: GRPOTrainer,
|
| 308 |
+
env: BrowserGymEnv,
|
| 309 |
+
tokenizer: AutoTokenizer,
|
| 310 |
+
dataset_prompt: str,
|
| 311 |
+
max_steps: int,
|
| 312 |
+
debug: bool = False,
|
| 313 |
+
) -> dict[str, list]:
|
| 314 |
+
"""Run one episode and collect training data (text-only, no screenshots)."""
|
| 315 |
+
result = env.reset()
|
| 316 |
+
observation = result.observation
|
| 317 |
+
|
| 318 |
+
prompt_ids: list[int] = []
|
| 319 |
+
completion_ids: list[int] = []
|
| 320 |
+
logprobs: list[float] = []
|
| 321 |
+
step_rewards: list[float] = []
|
| 322 |
+
completion_rewards: list[float] = []
|
| 323 |
+
|
| 324 |
+
for step_num in range(max_steps):
|
| 325 |
+
if result.done:
|
| 326 |
+
break
|
| 327 |
+
|
| 328 |
+
# Create prompt from observation (text-only using accessibility tree)
|
| 329 |
+
goal = observation.goal or dataset_prompt
|
| 330 |
+
axtree = observation.axtree_txt or ""
|
| 331 |
+
error = observation.error if observation.last_action_error else ""
|
| 332 |
+
|
| 333 |
+
user_prompt = make_user_prompt(goal, step_num, axtree, error)
|
| 334 |
+
messages = [
|
| 335 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 336 |
+
{"role": "user", "content": user_prompt},
|
| 337 |
+
]
|
| 338 |
+
prompt_text = tokenizer.apply_chat_template(
|
| 339 |
+
messages,
|
| 340 |
+
add_generation_prompt=True,
|
| 341 |
+
tokenize=False,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Generate action with vLLM
|
| 345 |
+
rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]
|
| 346 |
+
prompt_ids.extend(rollout_outputs["prompt_ids"])
|
| 347 |
+
completion_ids.extend(rollout_outputs["completion_ids"])
|
| 348 |
+
logprobs.extend(rollout_outputs["logprobs"])
|
| 349 |
+
|
| 350 |
+
completion_text = rollout_outputs.get("text") or tokenizer.decode(
|
| 351 |
+
rollout_outputs["completion_ids"], skip_special_tokens=True
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Parse and execute action
|
| 355 |
+
action_str = parse_action(completion_text)
|
| 356 |
+
|
| 357 |
+
if debug:
|
| 358 |
+
print(f"Step {step_num + 1}: {action_str}")
|
| 359 |
+
|
| 360 |
+
# Take action in environment
|
| 361 |
+
result = env.step(BrowserGymAction(action_str=action_str))
|
| 362 |
+
observation = result.observation
|
| 363 |
+
|
| 364 |
+
# Track rewards
|
| 365 |
+
step_reward = float(result.reward or 0.0)
|
| 366 |
+
step_rewards.append(step_reward)
|
| 367 |
+
|
| 368 |
+
# Reward shaping: success is most important
|
| 369 |
+
if result.done and step_reward > 0:
|
| 370 |
+
completion_rewards.append(1.0) # Task completed successfully
|
| 371 |
+
elif result.done and step_reward == 0:
|
| 372 |
+
completion_rewards.append(0.0) # Task failed
|
| 373 |
+
else:
|
| 374 |
+
completion_rewards.append(step_reward) # Intermediate reward
|
| 375 |
+
|
| 376 |
+
# Final reward is based on task completion
|
| 377 |
+
final_reward = completion_rewards[-1] if completion_rewards else 0.0
|
| 378 |
+
|
| 379 |
+
return {
|
| 380 |
+
"prompt_ids": prompt_ids,
|
| 381 |
+
"completion_ids": completion_ids,
|
| 382 |
+
"logprobs": logprobs,
|
| 383 |
+
"step_rewards": step_rewards,
|
| 384 |
+
"completion_reward": final_reward,
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# ---------------------------------------------------------------------------
|
| 389 |
+
# Rewards
|
| 390 |
+
# ---------------------------------------------------------------------------
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def reward_completion(completions: list[str], **kwargs) -> list[float]:
|
| 394 |
+
"""Reward for task completion."""
|
| 395 |
+
rewards = kwargs.get("completion_reward") if kwargs else None
|
| 396 |
+
if rewards is None:
|
| 397 |
+
return [0.0 for _ in completions]
|
| 398 |
+
return [float(r) for r in rewards]
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
# ---------------------------------------------------------------------------
|
| 402 |
+
# Main entrypoint
|
| 403 |
+
# ---------------------------------------------------------------------------
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def main() -> None:
|
| 407 |
+
args = parse_args()
|
| 408 |
+
|
| 409 |
+
# Connect to BrowserGym environment via Hugging Face Space
|
| 410 |
+
client = BrowserGymEnv(base_url=args.space_url)
|
| 411 |
+
print(f"🌍 Using Hugging Face Space environment at: {args.space_url}")
|
| 412 |
+
|
| 413 |
+
dataset = Dataset.from_dict({"prompt": [args.dataset_prompt] * args.dataset_size})
|
| 414 |
+
|
| 415 |
+
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 416 |
+
default_output_dir = Path("outputs") / f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}"
|
| 417 |
+
output_dir = Path(args.output_dir or default_output_dir)
|
| 418 |
+
|
| 419 |
+
grpo_config = GRPOConfig(
|
| 420 |
+
use_vllm=True,
|
| 421 |
+
vllm_mode=args.vllm_mode,
|
| 422 |
+
vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
|
| 423 |
+
vllm_gpu_memory_utilization=0.4,
|
| 424 |
+
output_dir=str(output_dir),
|
| 425 |
+
num_train_epochs=args.num_epochs,
|
| 426 |
+
learning_rate=args.learning_rate,
|
| 427 |
+
weight_decay=args.weight_decay,
|
| 428 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 429 |
+
per_device_train_batch_size=args.per_device_batch_size,
|
| 430 |
+
warmup_steps=args.warmup_steps,
|
| 431 |
+
num_generations=args.num_generations,
|
| 432 |
+
generation_batch_size=args.num_generations, # Must be divisible by num_generations
|
| 433 |
+
max_completion_length=args.max_new_tokens,
|
| 434 |
+
logging_steps=args.logging_steps,
|
| 435 |
+
report_to="trackio",
|
| 436 |
+
trackio_space_id=f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}",
|
| 437 |
+
save_strategy="steps",
|
| 438 |
+
save_steps=args.save_interval,
|
| 439 |
+
save_total_limit=args.save_total_limit,
|
| 440 |
+
temperature=args.temperature,
|
| 441 |
+
top_k=args.top_k,
|
| 442 |
+
top_p=args.top_p,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
grpo_config.run_name = args.run_name or f"run-{timestamp}"
|
| 446 |
+
grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}"
|
| 447 |
+
|
| 448 |
+
def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
|
| 449 |
+
episode_prompt_ids: list[list[int]] = []
|
| 450 |
+
episode_completion_ids: list[list[int]] = []
|
| 451 |
+
episode_logprobs: list[list[float]] = []
|
| 452 |
+
completion_rewards: list[float] = []
|
| 453 |
+
|
| 454 |
+
if args.debug:
|
| 455 |
+
print(f"\n[DEBUG] rollout_func called with {len(prompts)} prompts (LLM mode, text-only)")
|
| 456 |
+
|
| 457 |
+
for i, prompt_text in enumerate(prompts):
|
| 458 |
+
if args.debug:
|
| 459 |
+
print(f"[DEBUG] Processing prompt {i + 1}/{len(prompts)}")
|
| 460 |
+
episode = rollout_once(
|
| 461 |
+
trainer=trainer,
|
| 462 |
+
env=client,
|
| 463 |
+
tokenizer=trainer.processing_class,
|
| 464 |
+
dataset_prompt=prompt_text,
|
| 465 |
+
max_steps=args.max_steps,
|
| 466 |
+
debug=args.debug,
|
| 467 |
+
)
|
| 468 |
+
episode_prompt_ids.append(episode["prompt_ids"])
|
| 469 |
+
episode_completion_ids.append(episode["completion_ids"])
|
| 470 |
+
episode_logprobs.append(episode["logprobs"])
|
| 471 |
+
completion_rewards.append(episode["completion_reward"])
|
| 472 |
+
|
| 473 |
+
return {
|
| 474 |
+
"prompt_ids": episode_prompt_ids,
|
| 475 |
+
"completion_ids": episode_completion_ids,
|
| 476 |
+
"logprobs": episode_logprobs,
|
| 477 |
+
"completion_reward": completion_rewards,
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
trainer = GRPOTrainer(
|
| 481 |
+
model=args.model_id,
|
| 482 |
+
reward_funcs=[reward_completion],
|
| 483 |
+
train_dataset=dataset,
|
| 484 |
+
args=grpo_config,
|
| 485 |
+
rollout_func=rollout_func,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
print("=" * 80)
|
| 489 |
+
print("Starting GRPO training with BrowserGym environment (LLM mode)")
|
| 490 |
+
print(f"Benchmark: {args.benchmark}")
|
| 491 |
+
print(f"Task: {args.task_name}")
|
| 492 |
+
print(f"Model: {args.model_id}")
|
| 493 |
+
print("Mode: LLM (text-only, using accessibility tree)")
|
| 494 |
+
print(f"Using {args.num_generations} rollouts per dataset prompt")
|
| 495 |
+
print(f"Output directory: {output_dir}")
|
| 496 |
+
print("=" * 80)
|
| 497 |
+
|
| 498 |
+
try:
|
| 499 |
+
trainer.train()
|
| 500 |
+
print("\nTraining completed successfully!")
|
| 501 |
+
finally:
|
| 502 |
+
client.close()
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
if __name__ == "__main__":
|
| 506 |
+
main()
|
ICL/RL/trl_source/examples/scripts/openenv/echo.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# /// script
|
| 16 |
+
# dependencies = [
|
| 17 |
+
# "trl[vllm]",
|
| 18 |
+
# "peft",
|
| 19 |
+
# "trackio",
|
| 20 |
+
# "kernels",
|
| 21 |
+
# "openenv-echo-env @ git+https://huggingface.co/spaces/openenv/echo_env",
|
| 22 |
+
# ]
|
| 23 |
+
# ///
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
Simple script to run GRPO training with OpenEnv's Echo environment and vLLM. The reward function encourages
|
| 28 |
+
longer completions.
|
| 29 |
+
|
| 30 |
+
Setup (Option A - Install from HF Space, recommended):
|
| 31 |
+
|
| 32 |
+
```sh
|
| 33 |
+
uv pip install git+https://huggingface.co/spaces/openenv/echo_env
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Setup (Option B - Clone OpenEnv repo, for development):
|
| 37 |
+
|
| 38 |
+
```sh
|
| 39 |
+
git clone https://github.com/meta-pytorch/OpenEnv.git
|
| 40 |
+
cd OpenEnv/envs/echo_env
|
| 41 |
+
uv pip install -e .
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
# Option 1: HF Spaces + Colocated vLLM (1 GPU required)
|
| 45 |
+
```sh
|
| 46 |
+
python examples/scripts/openenv/echo.py --env-mode space --env-host https://openenv-echo-env.hf.space --vllm-mode colocate
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
# Option 2: HF Spaces + Separate vLLM server (2 GPUs required)
|
| 50 |
+
|
| 51 |
+
# Spin up vLLM server (Terminal 1)
|
| 52 |
+
```sh
|
| 53 |
+
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
# Run training (Terminal 2)
|
| 57 |
+
```sh
|
| 58 |
+
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py --env-mode space --env-host https://openenv-echo-env.hf.space --vllm-mode server --vllm-server-url http://localhost:8000
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
# Option 3: Local + Colocated vLLM (1 GPU required)
|
| 62 |
+
|
| 63 |
+
# Start the environment only if using --env-mode docker-local
|
| 64 |
+
```sh
|
| 65 |
+
docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
```sh
|
| 69 |
+
python examples/scripts/openenv/echo.py --env-mode docker-local --vllm-mode colocate
|
| 70 |
+
```
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
# ruff: noqa: T201
|
| 74 |
+
import argparse
|
| 75 |
+
import os
|
| 76 |
+
import subprocess
|
| 77 |
+
import sys
|
| 78 |
+
import time
|
| 79 |
+
from pathlib import Path
|
| 80 |
+
|
| 81 |
+
import requests
|
| 82 |
+
from datasets import load_dataset
|
| 83 |
+
from echo_env import EchoEnv
|
| 84 |
+
from echo_env.models import EchoAction
|
| 85 |
+
|
| 86 |
+
from trl import GRPOConfig, GRPOTrainer, RichProgressCallback
|
| 87 |
+
from trl.experimental.openenv import generate_rollout_completions
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def parse_args():
|
| 91 |
+
parser = argparse.ArgumentParser(description="Run GRPO training with Echo environment and vLLM.")
|
| 92 |
+
|
| 93 |
+
parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the Echo environment.")
|
| 94 |
+
parser.add_argument("--env-port", type=int, default=8001, help="Port for the Echo environment.")
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--env-mode",
|
| 97 |
+
choices=["local", "docker-local", "docker-image", "docker-hub", "space"],
|
| 98 |
+
default="docker-image",
|
| 99 |
+
help="Where to run the Echo environment: 'local' to launch it, 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.",
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--model",
|
| 103 |
+
type=str,
|
| 104 |
+
default="Qwen/Qwen2.5-0.5B-Instruct",
|
| 105 |
+
help="Model to use for training.",
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--dataset",
|
| 109 |
+
type=str,
|
| 110 |
+
default="trl-lib/ultrafeedback-prompt",
|
| 111 |
+
help="Dataset to use for training.",
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--env-image", type=str, default="echo-env:latest", help="Docker image for the Echo environment."
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--vllm-mode",
|
| 118 |
+
choices=["colocate", "server"],
|
| 119 |
+
default="colocate",
|
| 120 |
+
help="vLLM execution mode: 'colocate' or 'server'.",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--vllm-server-url",
|
| 124 |
+
type=str,
|
| 125 |
+
default="http://localhost:8000",
|
| 126 |
+
help="URL for the vLLM server (only used when --vllm-mode=server).",
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return parser.parse_args()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def start_env_server(env_host: str, env_port: int):
|
| 133 |
+
"""Launch the Echo environment server locally."""
|
| 134 |
+
env_url = f"http://{env_host}:{env_port}"
|
| 135 |
+
print(f"⚡ Starting FastAPI server for Echo Environment on {env_url}...")
|
| 136 |
+
|
| 137 |
+
work_dir = str(Path.cwd().parent.absolute())
|
| 138 |
+
process = subprocess.Popen(
|
| 139 |
+
[sys.executable, "-m", "uvicorn", "echo_env.server.app:app", "--host", env_host, "--port", str(env_port)],
|
| 140 |
+
env={**os.environ, "PYTHONPATH": f"{work_dir}/src"},
|
| 141 |
+
stdout=subprocess.PIPE,
|
| 142 |
+
stderr=subprocess.PIPE,
|
| 143 |
+
text=True,
|
| 144 |
+
cwd=work_dir,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
print("⏳ Waiting for server to start...")
|
| 148 |
+
time.sleep(5)
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
requests.get(f"{env_url}/health", timeout=2)
|
| 152 |
+
print("\n✅ Echo Environment server is running!")
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(f"\n��� Server failed to start: {e}")
|
| 155 |
+
if process.stderr:
|
| 156 |
+
print(process.stderr.read())
|
| 157 |
+
raise
|
| 158 |
+
|
| 159 |
+
return process
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def reward_from_env(completions, **kwargs):
|
| 163 |
+
"""Extract environment rewards for training."""
|
| 164 |
+
env_rewards = kwargs.get("env_reward", [])
|
| 165 |
+
return [float(r) for r in env_rewards] if env_rewards else [0.0] * len(completions)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def main():
|
| 169 |
+
args = parse_args()
|
| 170 |
+
|
| 171 |
+
# Select environment mode
|
| 172 |
+
if args.env_mode == "local":
|
| 173 |
+
env_url = f"http://{args.env_host}:{args.env_port}"
|
| 174 |
+
server_process = start_env_server(args.env_host, args.env_port)
|
| 175 |
+
elif args.env_mode == "docker-local":
|
| 176 |
+
env_url = f"http://{args.env_host}:{args.env_port}"
|
| 177 |
+
server_process = None
|
| 178 |
+
print(f"🌍 Using existing Echo Environment (Docker) at: {env_url}")
|
| 179 |
+
elif args.env_mode == "docker-image":
|
| 180 |
+
client = EchoEnv.from_docker_image(args.env_image)
|
| 181 |
+
server_process = None
|
| 182 |
+
print("🌍 Using Echo Environment (Docker) from local Image")
|
| 183 |
+
elif args.env_mode == "docker-hub":
|
| 184 |
+
client = EchoEnv.from_hub(args.env_image)
|
| 185 |
+
server_process = None
|
| 186 |
+
print("🌍 Using existing Echo Environment (Docker) from Hub Image")
|
| 187 |
+
elif args.env_mode == "space":
|
| 188 |
+
env_url = args.env_host
|
| 189 |
+
server_process = None
|
| 190 |
+
print(f"🌍 Using Hugging Face Space environment at: {env_url}")
|
| 191 |
+
else:
|
| 192 |
+
raise ValueError(f"Unknown environment mode: {args.env_mode}")
|
| 193 |
+
|
| 194 |
+
if args.env_mode != "docker-hub" and args.env_mode != "docker-image":
|
| 195 |
+
client = EchoEnv(base_url=env_url)
|
| 196 |
+
dataset = load_dataset(args.dataset, split="train[:1000]")
|
| 197 |
+
|
| 198 |
+
training_args = GRPOConfig(
|
| 199 |
+
output_dir=f"{args.model.split('/')[-1]}-GRPO-Rollout",
|
| 200 |
+
use_vllm=True,
|
| 201 |
+
vllm_mode=args.vllm_mode,
|
| 202 |
+
vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
|
| 203 |
+
logging_steps=1,
|
| 204 |
+
report_to="trackio",
|
| 205 |
+
trackio_space_id=f"{args.model.split('/')[-1]}-GRPO-Rollout",
|
| 206 |
+
num_train_epochs=1,
|
| 207 |
+
max_completion_length=2048,
|
| 208 |
+
gradient_accumulation_steps=4,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
|
| 212 |
+
outputs = generate_rollout_completions(trainer, prompts)
|
| 213 |
+
tokenizer = trainer.processing_class
|
| 214 |
+
|
| 215 |
+
completions_text = [tokenizer.decode(output["completion_ids"], skip_special_tokens=True) for output in outputs]
|
| 216 |
+
|
| 217 |
+
env_result = client.reset()
|
| 218 |
+
env_rewards: list[float] = []
|
| 219 |
+
for message in completions_text:
|
| 220 |
+
env_result = client.step(EchoAction(message=message))
|
| 221 |
+
env_rewards.append(env_result.reward)
|
| 222 |
+
|
| 223 |
+
return {
|
| 224 |
+
"prompt_ids": [output["prompt_ids"] for output in outputs],
|
| 225 |
+
"completion_ids": [output["completion_ids"] for output in outputs],
|
| 226 |
+
"logprobs": [output["logprobs"] for output in outputs],
|
| 227 |
+
"env_reward": env_rewards,
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
trainer = GRPOTrainer(
|
| 231 |
+
model=args.model,
|
| 232 |
+
reward_funcs=reward_from_env,
|
| 233 |
+
args=training_args,
|
| 234 |
+
train_dataset=dataset,
|
| 235 |
+
rollout_func=rollout_func,
|
| 236 |
+
callbacks=[RichProgressCallback()],
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
trainer.train()
|
| 240 |
+
time.sleep(5)
|
| 241 |
+
|
| 242 |
+
if server_process:
|
| 243 |
+
print("🛑 Terminating Echo Environment server...")
|
| 244 |
+
server_process.terminate()
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
main()
|
ICL/RL/trl_source/examples/scripts/openenv/wordle.py
ADDED
|
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# /// script
|
| 16 |
+
# dependencies = [
|
| 17 |
+
# "trl[vllm]",
|
| 18 |
+
# "peft",
|
| 19 |
+
# "trackio",
|
| 20 |
+
# "kernels",
|
| 21 |
+
# "openenv-textarena @ git+https://huggingface.co/spaces/sergiopaniego/wordle",
|
| 22 |
+
# ]
|
| 23 |
+
# ///
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
Simple script to run GRPO training with OpenEnv's Wordle environment and vLLM.
|
| 28 |
+
|
| 29 |
+
Setup (Option A - Install from HF Space, recommended):
|
| 30 |
+
|
| 31 |
+
```sh
|
| 32 |
+
uv pip install git+https://huggingface.co/spaces/sergiopaniego/wordle
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
# Option 1: HF Spaces + Colocated vLLM (1 GPU required)
|
| 36 |
+
```sh
|
| 37 |
+
python examples/scripts/openenv/wordle.py --vllm-mode colocate
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
# Option 2: HF Spaces + Separate vLLM server (2 GPUs required)
|
| 41 |
+
|
| 42 |
+
# Spin up vLLM server (Terminal 1)
|
| 43 |
+
```sh
|
| 44 |
+
CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
# Run training (Terminal 2)
|
| 48 |
+
```sh
|
| 49 |
+
CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py --vllm-mode server --vllm-server-url http://localhost:8000
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
# Option 3: Local Environment + Colocated vLLM (1 GPU required)
|
| 53 |
+
|
| 54 |
+
To run the Wordle environment locally, you have several options:
|
| 55 |
+
|
| 56 |
+
## Option 3a: Using Docker Image (Recommended)
|
| 57 |
+
|
| 58 |
+
First, build the Docker image from the textarena_env directory:
|
| 59 |
+
```sh
|
| 60 |
+
cd 3rd_party/OpenEnv/envs/textarena_env
|
| 61 |
+
docker build -t textarena-env:latest -f server/Dockerfile .
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
Then run the environment server:
|
| 65 |
+
```sh
|
| 66 |
+
docker run -d -p 8001:8001 textarena-env:latest
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Finally, run training pointing to local server:
|
| 70 |
+
```sh
|
| 71 |
+
python examples/scripts/openenv/wordle.py --vllm-mode colocate --env-url http://localhost:8001
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## Option 3b: Running Server Directly
|
| 75 |
+
|
| 76 |
+
From the textarena_env directory:
|
| 77 |
+
```sh
|
| 78 |
+
cd 3rd_party/OpenEnv/envs/textarena_env
|
| 79 |
+
uv venv && source .venv/bin/activate
|
| 80 |
+
uv pip install -e .
|
| 81 |
+
python -m uvicorn server.app:app --reload --port 8001
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
Then in another terminal, run training:
|
| 85 |
+
```sh
|
| 86 |
+
python examples/scripts/openenv/wordle.py --vllm-mode colocate --env-url http://localhost:8001
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
## Option 3c: Using Pre-built HF Space Image
|
| 90 |
+
|
| 91 |
+
```sh
|
| 92 |
+
docker run -d -p 8001:8001 registry.hf.space/burtenshaw-wordle:latest
|
| 93 |
+
python examples/scripts/openenv/wordle.py --vllm-mode colocate --env-url http://localhost:8001
|
| 94 |
+
```
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
import argparse
|
| 98 |
+
import re
|
| 99 |
+
import sys
|
| 100 |
+
from collections.abc import Iterable
|
| 101 |
+
from datetime import datetime
|
| 102 |
+
from pathlib import Path
|
| 103 |
+
|
| 104 |
+
from datasets import Dataset
|
| 105 |
+
from transformers import AutoTokenizer
|
| 106 |
+
|
| 107 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 108 |
+
from trl.experimental.openenv import generate_rollout_completions
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# Ensure src/ is on the path
|
| 112 |
+
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
| 113 |
+
|
| 114 |
+
from textarena_env import TextArenaAction, TextArenaEnv
|
| 115 |
+
from textarena_env.models import TextArenaMessage
|
| 116 |
+
from textarena_env.rewards import extract_feedback_counts, extract_guess, extract_wordle_feedback
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def parse_args() -> argparse.Namespace:
|
| 120 |
+
parser = argparse.ArgumentParser(
|
| 121 |
+
description="Run GRPO training for Wordle using the TextArena OpenEnv environment."
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--tokenizer-id",
|
| 125 |
+
default="Qwen/Qwen3-1.7B",
|
| 126 |
+
help="Model identifier used to load the tokenizer.",
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--model-id",
|
| 130 |
+
default="Qwen/Qwen3-1.7B",
|
| 131 |
+
help="Model identifier passed to GRPOTrainer for fine-tuning.",
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--env-url", type=str, default="https://sergiopaniego-wordle.hf.space", help="URL for the environment server."
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--system-prompt-path",
|
| 138 |
+
default="wordle_prompt.txt",
|
| 139 |
+
help="Path to the file containing the system prompt.",
|
| 140 |
+
)
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--dataset-prompt",
|
| 143 |
+
default="Play Wordle like an expert.",
|
| 144 |
+
help="Prompt text used to seed the training dataset.",
|
| 145 |
+
)
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--dataset-size",
|
| 148 |
+
type=int,
|
| 149 |
+
default=3000,
|
| 150 |
+
help="Number of entries to include in the synthetic training dataset.",
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--max-turns",
|
| 154 |
+
type=int,
|
| 155 |
+
default=6,
|
| 156 |
+
help="Maximum number of turns to play in the Wordle environment per episode.",
|
| 157 |
+
)
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--max-new-tokens",
|
| 160 |
+
type=int,
|
| 161 |
+
default=8,
|
| 162 |
+
help="Maximum number of new tokens to request from vLLM for each guess.",
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--temperature",
|
| 166 |
+
type=float,
|
| 167 |
+
default=0.8,
|
| 168 |
+
help="Sampling temperature used during rollout generation.",
|
| 169 |
+
)
|
| 170 |
+
parser.add_argument(
|
| 171 |
+
"--top-k",
|
| 172 |
+
type=int,
|
| 173 |
+
default=10,
|
| 174 |
+
help="Top-k sampling parameter forwarded to vLLM.",
|
| 175 |
+
)
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--top-p",
|
| 178 |
+
type=float,
|
| 179 |
+
default=None,
|
| 180 |
+
help="Optional top-p sampling parameter forwarded to vLLM.",
|
| 181 |
+
)
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--learning-rate",
|
| 184 |
+
type=float,
|
| 185 |
+
default=1e-6,
|
| 186 |
+
help="Learning rate for GRPO training.",
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--weight-decay",
|
| 190 |
+
type=float,
|
| 191 |
+
default=0.0,
|
| 192 |
+
help="Weight decay applied during optimization.",
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--gradient-accumulation-steps",
|
| 196 |
+
type=int,
|
| 197 |
+
default=64,
|
| 198 |
+
help="Gradient accumulation steps for GRPO training.",
|
| 199 |
+
)
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--warmup-steps",
|
| 202 |
+
type=int,
|
| 203 |
+
default=10,
|
| 204 |
+
help="Warmup steps for the scheduler.",
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--per-device-batch-size",
|
| 208 |
+
type=int,
|
| 209 |
+
default=1,
|
| 210 |
+
help="Per-device train batch size.",
|
| 211 |
+
)
|
| 212 |
+
parser.add_argument(
|
| 213 |
+
"--num-generations",
|
| 214 |
+
type=int,
|
| 215 |
+
default=4,
|
| 216 |
+
help="Number of rollout generations per dataset prompt.",
|
| 217 |
+
)
|
| 218 |
+
parser.add_argument(
|
| 219 |
+
"--num-epochs",
|
| 220 |
+
type=int,
|
| 221 |
+
default=1,
|
| 222 |
+
help="Number of training epochs.",
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--save-interval",
|
| 226 |
+
type=int,
|
| 227 |
+
default=10,
|
| 228 |
+
help="Interval (in steps) between checkpoint saves.",
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"--save-total-limit",
|
| 232 |
+
type=int,
|
| 233 |
+
default=None,
|
| 234 |
+
help="Maximum number of checkpoints to keep.",
|
| 235 |
+
)
|
| 236 |
+
parser.add_argument(
|
| 237 |
+
"--output-dir",
|
| 238 |
+
default=None,
|
| 239 |
+
help="Directory where training outputs and checkpoints are stored.",
|
| 240 |
+
)
|
| 241 |
+
parser.add_argument(
|
| 242 |
+
"--run-name",
|
| 243 |
+
default=None,
|
| 244 |
+
help="Optional run name for logging systems.",
|
| 245 |
+
)
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
"--project",
|
| 248 |
+
default=None,
|
| 249 |
+
help="Optional project identifier for logging systems.",
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--trackio-space-id",
|
| 253 |
+
default="Wordle-GRPO",
|
| 254 |
+
help="TrackIO space identifier.",
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--vllm-mode",
|
| 258 |
+
choices=("colocate", "server"),
|
| 259 |
+
default="colocate",
|
| 260 |
+
help="vLLM execution mode: 'colocate' or 'server'.",
|
| 261 |
+
)
|
| 262 |
+
parser.add_argument(
|
| 263 |
+
"--vllm-server-url",
|
| 264 |
+
type=str,
|
| 265 |
+
default="http://localhost:8000",
|
| 266 |
+
help="URL for the vLLM server (only used when --vllm-mode=server).",
|
| 267 |
+
)
|
| 268 |
+
parser.add_argument(
|
| 269 |
+
"--logging-steps",
|
| 270 |
+
type=int,
|
| 271 |
+
default=1,
|
| 272 |
+
help="Frequency of logging steps for GRPO training.",
|
| 273 |
+
)
|
| 274 |
+
return parser.parse_args()
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def resolve_system_prompt(path: str) -> str:
|
| 278 |
+
prompt_path = Path(path)
|
| 279 |
+
if not prompt_path.is_file():
|
| 280 |
+
prompt_path = Path(__file__).parent / path
|
| 281 |
+
return prompt_path.read_text()
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def sanitize_name(name: str) -> str:
|
| 285 |
+
return name.replace("/", "-")
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# ---------------------------------------------------------------------------
|
| 289 |
+
# Helpers
|
| 290 |
+
# ---------------------------------------------------------------------------
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def format_history(messages: Iterable[TextArenaMessage]) -> str:
|
| 294 |
+
lines: list[str] = []
|
| 295 |
+
for message in messages:
|
| 296 |
+
tag = message.category or "MESSAGE"
|
| 297 |
+
content = message.content.strip()
|
| 298 |
+
if not content:
|
| 299 |
+
continue
|
| 300 |
+
lines.append(f"[{tag}] {content}")
|
| 301 |
+
return "\n".join(lines)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def make_user_prompt(prompt_text: str, messages: Iterable[TextArenaMessage]) -> str:
|
| 305 |
+
history = format_history(messages)
|
| 306 |
+
# Only use messages for conversation history - the prompt is already included as the first message
|
| 307 |
+
history_section = history if history else "[PROMPT] Awaiting first feedback."
|
| 308 |
+
return f"Conversation so far:\n{history_section}\n\nReply with your next guess enclosed in square brackets."
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def rollout_once(
|
| 312 |
+
trainer: GRPOTrainer,
|
| 313 |
+
env: TextArenaEnv,
|
| 314 |
+
tokenizer: AutoTokenizer,
|
| 315 |
+
dataset_prompt: str,
|
| 316 |
+
system_prompt: str,
|
| 317 |
+
max_turns: int,
|
| 318 |
+
max_new_tokens: int = 16,
|
| 319 |
+
) -> dict[str, list]:
|
| 320 |
+
result = env.reset()
|
| 321 |
+
observation = result.observation
|
| 322 |
+
|
| 323 |
+
prompt_ids: list[int] = []
|
| 324 |
+
completion_ids: list[int] = []
|
| 325 |
+
logprobs: list[float] = []
|
| 326 |
+
env_mask: list[int] = [] # 1 for model-generated tokens, 0 for environment tokens
|
| 327 |
+
model_outputs: list[str] = []
|
| 328 |
+
raw_rewards: list[float] = []
|
| 329 |
+
position_scores: list[float] = []
|
| 330 |
+
correct_scores: list[float] = []
|
| 331 |
+
prev_env_output_len: int = 0 # Track length to only add NEW portion each turn
|
| 332 |
+
|
| 333 |
+
accumulated_messages: list[dict[str, str]] = [{"role": "system", "content": system_prompt}]
|
| 334 |
+
# Build initial prompt (only once, at the start)
|
| 335 |
+
# The initial env messages are included in the prompt, not completion
|
| 336 |
+
base_prompt = observation.prompt or dataset_prompt
|
| 337 |
+
initial_user_prompt = make_user_prompt(base_prompt, observation.messages)
|
| 338 |
+
# Track initial env output length so we don't add it again
|
| 339 |
+
initial_env_output = format_history(observation.messages) if observation.messages else ""
|
| 340 |
+
prev_env_output_len = len(initial_env_output)
|
| 341 |
+
initial_messages = accumulated_messages + [{"role": "user", "content": initial_user_prompt}]
|
| 342 |
+
initial_prompt_text = tokenizer.apply_chat_template(
|
| 343 |
+
initial_messages,
|
| 344 |
+
add_generation_prompt=True,
|
| 345 |
+
tokenize=False,
|
| 346 |
+
enable_thinking=False,
|
| 347 |
+
)
|
| 348 |
+
# Tokenize initial prompt once - this is the base prompt for the entire episode.
|
| 349 |
+
# GRPO expects one prompt-completion pair per episode, where:
|
| 350 |
+
# - prompt_ids = the initial/base prompt (what the model sees at episode start)
|
| 351 |
+
# - completion_ids = all model responses + env feedback from all turns concatenated
|
| 352 |
+
# Note: The actual prompts used for generation in each turn are longer (include conversation history),
|
| 353 |
+
# but we only count the initial prompt tokens here.
|
| 354 |
+
initial_prompt_ids = tokenizer.encode(initial_prompt_text, add_special_tokens=False)
|
| 355 |
+
prompt_ids.extend(initial_prompt_ids)
|
| 356 |
+
|
| 357 |
+
for _turn in range(max_turns):
|
| 358 |
+
if result.done:
|
| 359 |
+
break
|
| 360 |
+
|
| 361 |
+
base_prompt = observation.prompt or dataset_prompt
|
| 362 |
+
user_prompt = make_user_prompt(base_prompt, observation.messages)
|
| 363 |
+
messages = accumulated_messages + [{"role": "user", "content": user_prompt}]
|
| 364 |
+
prompt_text = tokenizer.apply_chat_template(
|
| 365 |
+
messages,
|
| 366 |
+
add_generation_prompt=True,
|
| 367 |
+
tokenize=False,
|
| 368 |
+
enable_thinking=False,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
rollout_outputs = generate_rollout_completions(
|
| 372 |
+
trainer, [prompt_text], generation_overrides={"max_tokens": max_new_tokens}
|
| 373 |
+
)[0]
|
| 374 |
+
# Add model-generated completion tokens and logprobs with newlines for readability
|
| 375 |
+
newline_tokens = tokenizer.encode("\n", add_special_tokens=False)
|
| 376 |
+
completion_ids.extend(newline_tokens) # newline before guess
|
| 377 |
+
logprobs.extend([0.0] * len(newline_tokens))
|
| 378 |
+
env_mask.extend([1] * len(newline_tokens)) # newlines are part of model output format
|
| 379 |
+
|
| 380 |
+
completion_ids.extend(rollout_outputs["completion_ids"])
|
| 381 |
+
logprobs.extend(rollout_outputs["logprobs"])
|
| 382 |
+
env_mask.extend([1] * len(rollout_outputs["completion_ids"])) # model-generated tokens
|
| 383 |
+
|
| 384 |
+
completion_ids.extend(newline_tokens) # newline after guess
|
| 385 |
+
logprobs.extend([0.0] * len(newline_tokens))
|
| 386 |
+
env_mask.extend([1] * len(newline_tokens)) # newlines are part of model output format
|
| 387 |
+
completion_text = rollout_outputs.get("text") or tokenizer.decode(
|
| 388 |
+
rollout_outputs["completion_ids"], skip_special_tokens=True
|
| 389 |
+
)
|
| 390 |
+
guess = extract_guess(completion_text)
|
| 391 |
+
model_outputs.append(completion_text.strip()) # Store raw model output for format reward
|
| 392 |
+
|
| 393 |
+
result = env.step(TextArenaAction(message=guess))
|
| 394 |
+
|
| 395 |
+
raw_rewards.append(float(result.reward or 0.0))
|
| 396 |
+
observation = result.observation
|
| 397 |
+
correct_score = float(result.reward or 0.0)
|
| 398 |
+
feedback = extract_wordle_feedback(observation)
|
| 399 |
+
|
| 400 |
+
full_env_output = format_history(observation.messages) if observation.messages else ""
|
| 401 |
+
new_env_output = full_env_output[prev_env_output_len:].lstrip("\n")
|
| 402 |
+
prev_env_output_len = len(full_env_output)
|
| 403 |
+
|
| 404 |
+
if new_env_output:
|
| 405 |
+
env_output_tokens = tokenizer.encode(new_env_output, add_special_tokens=False)
|
| 406 |
+
completion_ids.extend(env_output_tokens) # Add to completion_ids
|
| 407 |
+
logprobs.extend([0.0] * len(env_output_tokens)) # Placeholder (ignored via env_mask=0)
|
| 408 |
+
env_mask.extend([0] * len(env_output_tokens)) # Environment tokens - mask out from loss
|
| 409 |
+
completion_with_env = completion_text + "\n" + new_env_output
|
| 410 |
+
else:
|
| 411 |
+
completion_with_env = completion_text
|
| 412 |
+
|
| 413 |
+
accumulated_messages.append({"role": "user", "content": user_prompt})
|
| 414 |
+
accumulated_messages.append({"role": "assistant", "content": completion_with_env})
|
| 415 |
+
|
| 416 |
+
if not feedback:
|
| 417 |
+
position_score = 0.0
|
| 418 |
+
else:
|
| 419 |
+
green_count, yellow_count = extract_feedback_counts(feedback)
|
| 420 |
+
position_score = (green_count + 0.5 * yellow_count) / 5.0
|
| 421 |
+
|
| 422 |
+
position_scores.append(position_score)
|
| 423 |
+
correct_scores.append(correct_score)
|
| 424 |
+
|
| 425 |
+
# Use the final correct reward (win/lose is binary at end)
|
| 426 |
+
correct_reward_value = correct_scores[-1] if correct_scores else (raw_rewards[-1] if raw_rewards else 0.0)
|
| 427 |
+
|
| 428 |
+
# Position reward as shaping signal:
|
| 429 |
+
# - If model WINS: position_reward = 1.0 (no penalty for winning fast)
|
| 430 |
+
# - If model LOSES: position_reward = last attempt (where it ended up)
|
| 431 |
+
if correct_reward_value >= 1.0:
|
| 432 |
+
final_position_reward = 1.0
|
| 433 |
+
else:
|
| 434 |
+
final_position_reward = position_scores[-1] if position_scores else 0.0
|
| 435 |
+
|
| 436 |
+
return {
|
| 437 |
+
"prompt_ids": prompt_ids,
|
| 438 |
+
"completion_ids": completion_ids,
|
| 439 |
+
"logprobs": logprobs,
|
| 440 |
+
"env_mask": env_mask,
|
| 441 |
+
"raw_rewards": raw_rewards,
|
| 442 |
+
"correct_reward": correct_reward_value,
|
| 443 |
+
"position_reward": final_position_reward,
|
| 444 |
+
"model_outputs": model_outputs,
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
# ---------------------------------------------------------------------------
|
| 449 |
+
# Rewards
|
| 450 |
+
# ---------------------------------------------------------------------------
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def reward_correct(completions: list[str], **kwargs) -> list[float]:
|
| 454 |
+
"""Reward from environment (correct answer)."""
|
| 455 |
+
rewards = kwargs.get("correct_reward") if kwargs else None
|
| 456 |
+
if rewards is None:
|
| 457 |
+
return [0.0 for _ in completions]
|
| 458 |
+
return [float(r) for r in rewards]
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def reward_position(completions: list[str], **kwargs) -> list[float]:
|
| 462 |
+
"""Position reward: green worth 1.0, yellow worth 0.5, normalized by 5."""
|
| 463 |
+
rewards = kwargs.get("position_reward") if kwargs else None
|
| 464 |
+
if rewards is None:
|
| 465 |
+
return [0.0 for _ in completions]
|
| 466 |
+
return [float(r) for r in rewards]
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def compute_format_reward(model_outputs: list[str]) -> float:
|
| 470 |
+
"""Compute format reward from a list of model outputs (one per turn).
|
| 471 |
+
|
| 472 |
+
Each output should be exactly [5 letters] with optional whitespace.
|
| 473 |
+
Returns proportion of correctly formatted outputs.
|
| 474 |
+
"""
|
| 475 |
+
if not model_outputs:
|
| 476 |
+
return 0.0
|
| 477 |
+
|
| 478 |
+
exact_pattern = re.compile(r"^\s*\[[A-Za-z]{5}\]\s*$")
|
| 479 |
+
correct_count = sum(1 for output in model_outputs if exact_pattern.match(output))
|
| 480 |
+
|
| 481 |
+
return correct_count / len(model_outputs)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def reward_format_strict(completions: list[str], **kwargs) -> list[float]:
|
| 485 |
+
"""Format reward - pre-computed in rollout_func."""
|
| 486 |
+
rewards = kwargs.get("format_reward") if kwargs else None
|
| 487 |
+
if rewards is None:
|
| 488 |
+
return [0.0 for _ in completions]
|
| 489 |
+
return [float(r) for r in rewards]
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
# ---------------------------------------------------------------------------
|
| 493 |
+
# Main entrypoint
|
| 494 |
+
# ---------------------------------------------------------------------------
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def main() -> None:
|
| 498 |
+
args = parse_args()
|
| 499 |
+
|
| 500 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
|
| 501 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 502 |
+
|
| 503 |
+
client = TextArenaEnv(base_url=args.env_url)
|
| 504 |
+
|
| 505 |
+
system_prompt = resolve_system_prompt(args.system_prompt_path)
|
| 506 |
+
|
| 507 |
+
dataset = Dataset.from_dict({"prompt": [args.dataset_prompt] * args.dataset_size})
|
| 508 |
+
|
| 509 |
+
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 510 |
+
default_output_dir = Path("outputs") / f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}"
|
| 511 |
+
output_dir = Path(args.output_dir or default_output_dir)
|
| 512 |
+
|
| 513 |
+
grpo_config = GRPOConfig(
|
| 514 |
+
use_vllm=True,
|
| 515 |
+
vllm_mode=args.vllm_mode,
|
| 516 |
+
vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
|
| 517 |
+
output_dir=str(output_dir),
|
| 518 |
+
num_train_epochs=args.num_epochs,
|
| 519 |
+
learning_rate=args.learning_rate,
|
| 520 |
+
weight_decay=args.weight_decay,
|
| 521 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 522 |
+
per_device_train_batch_size=args.per_device_batch_size,
|
| 523 |
+
warmup_steps=args.warmup_steps,
|
| 524 |
+
num_generations=args.num_generations,
|
| 525 |
+
max_completion_length=1024, # Full episode length, not per-turn
|
| 526 |
+
logging_steps=args.logging_steps,
|
| 527 |
+
log_completions=True,
|
| 528 |
+
report_to="trackio",
|
| 529 |
+
trackio_space_id=f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}",
|
| 530 |
+
save_strategy="steps",
|
| 531 |
+
save_steps=args.save_interval,
|
| 532 |
+
save_total_limit=args.save_total_limit,
|
| 533 |
+
temperature=args.temperature,
|
| 534 |
+
top_k=args.top_k,
|
| 535 |
+
top_p=args.top_p,
|
| 536 |
+
vllm_gpu_memory_utilization=0.25,
|
| 537 |
+
vllm_max_model_length=8192,
|
| 538 |
+
vllm_importance_sampling_mode="token_truncate", # Less aggressive than default sequence_mask
|
| 539 |
+
optim="adamw_torch",
|
| 540 |
+
max_grad_norm=1.0, # Clip gradients to prevent explosion
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
grpo_config.run_name = args.run_name or f"run-{timestamp}"
|
| 544 |
+
grpo_config.project = args.project or f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}"
|
| 545 |
+
grpo_config.trackio_space_id = args.trackio_space_id
|
| 546 |
+
|
| 547 |
+
def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
|
| 548 |
+
episode_prompt_ids: list[list[int]] = []
|
| 549 |
+
episode_completion_ids: list[list[int]] = []
|
| 550 |
+
episode_logprobs: list[list[float]] = []
|
| 551 |
+
episode_env_masks: list[list[int]] = []
|
| 552 |
+
correctness_rewards: list[float] = []
|
| 553 |
+
position_rewards: list[float] = []
|
| 554 |
+
format_rewards: list[float] = []
|
| 555 |
+
|
| 556 |
+
for prompt_text in prompts:
|
| 557 |
+
episode = rollout_once(
|
| 558 |
+
trainer=trainer,
|
| 559 |
+
env=client,
|
| 560 |
+
tokenizer=tokenizer,
|
| 561 |
+
dataset_prompt=prompt_text,
|
| 562 |
+
system_prompt=system_prompt,
|
| 563 |
+
max_turns=args.max_turns,
|
| 564 |
+
max_new_tokens=args.max_new_tokens,
|
| 565 |
+
)
|
| 566 |
+
episode_prompt_ids.append(episode["prompt_ids"])
|
| 567 |
+
episode_completion_ids.append(episode["completion_ids"])
|
| 568 |
+
episode_logprobs.append(episode["logprobs"])
|
| 569 |
+
episode_env_masks.append(episode["env_mask"])
|
| 570 |
+
correctness_rewards.append(episode["correct_reward"])
|
| 571 |
+
position_rewards.append(episode["position_reward"])
|
| 572 |
+
format_rewards.append(compute_format_reward(episode["model_outputs"]))
|
| 573 |
+
|
| 574 |
+
return {
|
| 575 |
+
"prompt_ids": episode_prompt_ids,
|
| 576 |
+
"completion_ids": episode_completion_ids,
|
| 577 |
+
"logprobs": episode_logprobs,
|
| 578 |
+
"env_mask": episode_env_masks,
|
| 579 |
+
"correct_reward": correctness_rewards,
|
| 580 |
+
"position_reward": position_rewards,
|
| 581 |
+
"format_reward": format_rewards,
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
trainer = GRPOTrainer(
|
| 585 |
+
model=args.model_id,
|
| 586 |
+
processing_class=tokenizer,
|
| 587 |
+
reward_funcs=[
|
| 588 |
+
reward_correct,
|
| 589 |
+
reward_position,
|
| 590 |
+
reward_format_strict,
|
| 591 |
+
],
|
| 592 |
+
train_dataset=dataset,
|
| 593 |
+
args=grpo_config,
|
| 594 |
+
rollout_func=rollout_func,
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
print("Starting GRPO training with Wordle environment...")
|
| 598 |
+
print(f"Using {args.num_generations} rollouts per dataset prompt")
|
| 599 |
+
|
| 600 |
+
try:
|
| 601 |
+
trainer.train()
|
| 602 |
+
finally:
|
| 603 |
+
client.close()
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
if __name__ == "__main__":
|
| 607 |
+
main()
|
ICL/RL/trl_source/examples/scripts/openenv/wordle_prompt.txt
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
You are an expert Wordle solver with deep knowledge of English vocabulary, letter frequency patterns, and optimal guessing strategies.
|
| 2 |
+
|
| 3 |
+
## GAME RULES
|
| 4 |
+
|
| 5 |
+
1. The target is a 5-letter English word
|
| 6 |
+
2. You have 6 attempts to guess the correct word
|
| 7 |
+
3. After each guess, you receive color-coded feedback:
|
| 8 |
+
- GREEN: Letter is correct and in the correct position
|
| 9 |
+
- YELLOW: Letter is in the word but in the wrong position
|
| 10 |
+
- GRAY: Letter is not in the word at all
|
| 11 |
+
4. All guesses must be valid 5-letter English words
|
| 12 |
+
5. You cannot reuse a word you've already guessed
|
| 13 |
+
|
| 14 |
+
## RESPONSE FORMAT
|
| 15 |
+
|
| 16 |
+
Only respond with your next guess in square brackets, e.g., [crane].
|
| 17 |
+
|
| 18 |
+
Format:
|
| 19 |
+
```
|
| 20 |
+
[guess]
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
## STRATEGIC APPROACH
|
| 25 |
+
|
| 26 |
+
Do not repeat the same guess twice.
|
| 27 |
+
|
| 28 |
+
### Opening Strategy
|
| 29 |
+
- Start with words rich in common vowels (A, E, I, O, U) and consonants (R, S, T, L, N)
|
| 30 |
+
- Optimal starters: CRANE, SLATE, STARE, AROSE, IRATE
|
| 31 |
+
- Prioritize words that test the most common letters in different positions
|
| 32 |
+
|
| 33 |
+
### Mid-Game Strategy
|
| 34 |
+
- Use confirmed GREEN letters in their correct positions
|
| 35 |
+
- Place YELLOW letters in different positions than where they appeared
|
| 36 |
+
- Eliminate GRAY letters entirely from consideration
|
| 37 |
+
- If multiple letters are unknown, prioritize common letter combinations (TH, CH, ST, ER, etc.)
|
| 38 |
+
- Consider letter frequency: E is most common, followed by A, R, I, O, T, N, S
|
| 39 |
+
|
| 40 |
+
### Vowel Placement
|
| 41 |
+
- Most 5-letter words have 2 vowels
|
| 42 |
+
- Common patterns: vowel-consonant-vowel (like CRANE) or consonant-vowel-vowel-consonant-vowel (like QUEUE)
|
| 43 |
+
- If you have 1-2 vowels confirmed, consider where the others might be
|
| 44 |
+
|
| 45 |
+
### Advanced Tactics
|
| 46 |
+
- Use "sacrificial" guesses to test multiple new letters if you have attempts to spare
|
| 47 |
+
- Avoid repeating letter patterns unless you're certain (e.g., SPEED has two E's)
|
| 48 |
+
- Think about word endings: -ER, -LY, -ED, -ING are common but may not fit the 5-letter constraint
|
| 49 |
+
- Consider less common letters (Q, X, Z, J) only when you've eliminated most common options
|
| 50 |
+
|
| 51 |
+
### Common Pitfalls to Avoid
|
| 52 |
+
- Don't reuse X letters
|
| 53 |
+
- Don't place Y letters in the same position they appeared
|
| 54 |
+
- Don't ignore confirmed G letters
|
| 55 |
+
- Don't guess words that contradict known information
|
| 56 |
+
|
| 57 |
+
## EXAMPLES
|
| 58 |
+
|
| 59 |
+
### Example 1: Opening Guess
|
| 60 |
+
"Starting with a word that tests common vowels and consonants in varied positions."
|
| 61 |
+
[crane]
|
| 62 |
+
|
| 63 |
+
### Example 2: After Receiving Feedback
|
| 64 |
+
Previous guess: CRANE
|
| 65 |
+
Feedback: C=gray, R=yellow, A=green, N=gray, E=yellow
|
| 66 |
+
|
| 67 |
+
"A is confirmed in position 2. R and E are in the word but need different positions. C and N are eliminated. I'll try a word with A in position 2, and test R and E in new positions along with common letters like S and T."
|
| 68 |
+
[spare]
|
| 69 |
+
|
| 70 |
+
### Example 3: Narrowing Down
|
| 71 |
+
Previous guesses: CRANE (C=gray, R=yellow, A=green, N=gray, E=yellow), SPARE (S=gray, P=gray, A=green, R=green, E=green)
|
| 72 |
+
Feedback summary: _ARE_ with R in position 4, A in position 2, E in position 5
|
| 73 |
+
|
| 74 |
+
"I have _AR E_ confirmed. Position 1 and 3 are unknown. Common letters to try: T, L, D, B, F, G. Testing with TARED."
|
| 75 |
+
[tared]
|
| 76 |
+
|
| 77 |
+
### Example 4: Final Deduction
|
| 78 |
+
Previous feedback shows: _ARED with position 1 unknown and all common consonants tested
|
| 79 |
+
|
| 80 |
+
"Only position 1 remains. I've eliminated S, P, C, N. Common starting consonants left are B, F, G, H. BARED is a common word."
|
| 81 |
+
[bared]
|
| 82 |
+
|
| 83 |
+
## LETTER FREQUENCY REFERENCE
|
| 84 |
+
|
| 85 |
+
Most common letters in 5-letter words (in order):
|
| 86 |
+
S, E, A, O, R, I, L, T, N, U, D, Y, C, P, M, H, G, B, K, F
|
| 87 |
+
|
| 88 |
+
Most common starting letters:
|
| 89 |
+
S, C, B, T, P, A, F, G, D, M
|
| 90 |
+
|
| 91 |
+
Most common ending letters:
|
| 92 |
+
E, Y, T, S, R, L, N, D
|
| 93 |
+
|
| 94 |
+
## IMPORTANT CONSTRAINTS
|
| 95 |
+
|
| 96 |
+
- Use lowercase only
|
| 97 |
+
- One guess per response
|
| 98 |
+
- Must be exactly 5 letters
|
| 99 |
+
- Must be a real English word from standard dictionaries
|
| 100 |
+
- Never repeat a previous guess
|
| 101 |
+
- Always include brief reasoning before your guess
|
| 102 |
+
|
| 103 |
+
## YOUR GOAL
|
| 104 |
+
|
| 105 |
+
Solve the Wordle in as few guesses as possible by strategically using feedback to eliminate impossible words and narrow down the solution space efficiently.
|
ICL/RL/trl_source/examples/scripts/ppo/ppo.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# /// script
|
| 16 |
+
# dependencies = [
|
| 17 |
+
# "trl",
|
| 18 |
+
# "peft",
|
| 19 |
+
# "trackio",
|
| 20 |
+
# "kernels",
|
| 21 |
+
# ]
|
| 22 |
+
# ///
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import shutil
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from accelerate import PartialState
|
| 29 |
+
from datasets import load_dataset
|
| 30 |
+
from transformers import (
|
| 31 |
+
AutoModelForCausalLM,
|
| 32 |
+
AutoModelForSequenceClassification,
|
| 33 |
+
AutoTokenizer,
|
| 34 |
+
HfArgumentParser,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
from trl import ModelConfig, ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config
|
| 38 |
+
from trl.experimental.ppo import PPOConfig, PPOTrainer
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Enable logging in a Hugging Face Space
|
| 42 |
+
os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
"""
|
| 46 |
+
python -i examples/scripts/ppo/ppo.py \
|
| 47 |
+
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
|
| 48 |
+
--dataset_train_split descriptiveness \
|
| 49 |
+
--learning_rate 3e-6 \
|
| 50 |
+
--output_dir pythia-1b-deduped-descriptiveness-sentiment-trl-style-ppo \
|
| 51 |
+
--per_device_train_batch_size 64 \
|
| 52 |
+
--gradient_accumulation_steps 1 \
|
| 53 |
+
--total_episodes 10000 \
|
| 54 |
+
--model_name_or_path EleutherAI/pythia-1b-deduped \
|
| 55 |
+
--missing_eos_penalty 1.0
|
| 56 |
+
|
| 57 |
+
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
| 58 |
+
examples/scripts/ppo/ppo.py \
|
| 59 |
+
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
|
| 60 |
+
--dataset_train_split descriptiveness \
|
| 61 |
+
--output_dir pythia-1b-deduped-descriptiveness-sentiment-trl-style-ppo \
|
| 62 |
+
--num_ppo_epochs 1 \
|
| 63 |
+
--num_mini_batches 1 \
|
| 64 |
+
--learning_rate 3e-6 \
|
| 65 |
+
--per_device_train_batch_size 1 \
|
| 66 |
+
--gradient_accumulation_steps 16 \
|
| 67 |
+
--total_episodes 10000 \
|
| 68 |
+
--model_name_or_path EleutherAI/pythia-1b-deduped \
|
| 69 |
+
--sft_model_path EleutherAI/pythia-1b-deduped \
|
| 70 |
+
--reward_model_path EleutherAI/pythia-1b-deduped \
|
| 71 |
+
--local_rollout_forward_batch_size 1 \
|
| 72 |
+
--missing_eos_penalty 1.0
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig))
|
| 78 |
+
script_args, training_args, model_args = parser.parse_args_into_dataclasses()
|
| 79 |
+
# remove output_dir if exists
|
| 80 |
+
shutil.rmtree(training_args.output_dir, ignore_errors=True)
|
| 81 |
+
|
| 82 |
+
################
|
| 83 |
+
# Model & Tokenizer
|
| 84 |
+
################
|
| 85 |
+
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
|
| 86 |
+
model_kwargs = dict(
|
| 87 |
+
revision=model_args.model_revision,
|
| 88 |
+
attn_implementation=model_args.attn_implementation,
|
| 89 |
+
dtype=dtype,
|
| 90 |
+
)
|
| 91 |
+
quantization_config = get_quantization_config(model_args)
|
| 92 |
+
if quantization_config is not None:
|
| 93 |
+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
|
| 94 |
+
model_kwargs["device_map"] = get_kbit_device_map()
|
| 95 |
+
model_kwargs["quantization_config"] = quantization_config
|
| 96 |
+
|
| 97 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 98 |
+
model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code
|
| 99 |
+
)
|
| 100 |
+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 101 |
+
value_model = AutoModelForSequenceClassification.from_pretrained(
|
| 102 |
+
training_args.reward_model_path,
|
| 103 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 104 |
+
num_labels=1,
|
| 105 |
+
**model_kwargs,
|
| 106 |
+
)
|
| 107 |
+
reward_model = AutoModelForSequenceClassification.from_pretrained(
|
| 108 |
+
training_args.reward_model_path,
|
| 109 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 110 |
+
num_labels=1,
|
| 111 |
+
**model_kwargs,
|
| 112 |
+
)
|
| 113 |
+
policy = AutoModelForCausalLM.from_pretrained(
|
| 114 |
+
training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
peft_config = get_peft_config(model_args)
|
| 118 |
+
if peft_config is None:
|
| 119 |
+
ref_policy = AutoModelForCausalLM.from_pretrained(
|
| 120 |
+
training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
ref_policy = None
|
| 124 |
+
|
| 125 |
+
################
|
| 126 |
+
# Dataset
|
| 127 |
+
################
|
| 128 |
+
dataset = load_dataset(
|
| 129 |
+
script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_train_split
|
| 130 |
+
)
|
| 131 |
+
eval_samples = 100
|
| 132 |
+
train_dataset = dataset.select(range(len(dataset) - eval_samples))
|
| 133 |
+
eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset)))
|
| 134 |
+
dataset_text_field = "prompt"
|
| 135 |
+
|
| 136 |
+
def prepare_dataset(dataset, tokenizer):
|
| 137 |
+
"""pre-tokenize the dataset before training; only collate during training"""
|
| 138 |
+
|
| 139 |
+
def tokenize(element):
|
| 140 |
+
outputs = tokenizer(
|
| 141 |
+
element[dataset_text_field],
|
| 142 |
+
padding=False,
|
| 143 |
+
)
|
| 144 |
+
return {"input_ids": outputs["input_ids"]}
|
| 145 |
+
|
| 146 |
+
return dataset.map(
|
| 147 |
+
tokenize,
|
| 148 |
+
batched=True,
|
| 149 |
+
remove_columns=dataset.column_names,
|
| 150 |
+
num_proc=training_args.dataset_num_proc,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Compute that only on the main process for faster data processing.
|
| 154 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
| 155 |
+
with PartialState().local_main_process_first():
|
| 156 |
+
train_dataset = prepare_dataset(train_dataset, tokenizer)
|
| 157 |
+
eval_dataset = prepare_dataset(eval_dataset, tokenizer)
|
| 158 |
+
|
| 159 |
+
################
|
| 160 |
+
# Training
|
| 161 |
+
################
|
| 162 |
+
trainer = PPOTrainer(
|
| 163 |
+
args=training_args,
|
| 164 |
+
processing_class=tokenizer,
|
| 165 |
+
model=policy,
|
| 166 |
+
ref_model=ref_policy,
|
| 167 |
+
reward_model=reward_model,
|
| 168 |
+
value_model=value_model,
|
| 169 |
+
train_dataset=train_dataset,
|
| 170 |
+
eval_dataset=eval_dataset,
|
| 171 |
+
peft_config=peft_config,
|
| 172 |
+
)
|
| 173 |
+
trainer.train()
|
| 174 |
+
|
| 175 |
+
# Save and push to hub
|
| 176 |
+
trainer.save_model(training_args.output_dir)
|
| 177 |
+
if training_args.push_to_hub:
|
| 178 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 179 |
+
|
| 180 |
+
trainer.generate_completions()
|
ICL/RL/trl_source/examples/scripts/reward_modeling.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# /// script
|
| 16 |
+
# dependencies = [
|
| 17 |
+
# "trl",
|
| 18 |
+
# "trackio",
|
| 19 |
+
# "kernels",
|
| 20 |
+
# ]
|
| 21 |
+
# ///
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
Full training:
|
| 25 |
+
python examples/scripts/reward_modeling.py \
|
| 26 |
+
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
| 27 |
+
--dataset_name trl-lib/ultrafeedback_binarized \
|
| 28 |
+
--output_dir Qwen2-0.5B-Reward \
|
| 29 |
+
--per_device_train_batch_size 8 \
|
| 30 |
+
--num_train_epochs 1 \
|
| 31 |
+
--learning_rate 1.0e-5 \
|
| 32 |
+
--eval_strategy steps \
|
| 33 |
+
--eval_steps 50 \
|
| 34 |
+
--max_length 2048
|
| 35 |
+
|
| 36 |
+
LoRA:
|
| 37 |
+
python examples/scripts/reward_modeling.py \
|
| 38 |
+
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
| 39 |
+
--dataset_name trl-lib/ultrafeedback_binarized \
|
| 40 |
+
--output_dir Qwen2-0.5B-Reward-LoRA \
|
| 41 |
+
--per_device_train_batch_size 8 \
|
| 42 |
+
--num_train_epochs 1 \
|
| 43 |
+
--learning_rate 1.0e-4 \
|
| 44 |
+
--eval_strategy steps \
|
| 45 |
+
--eval_steps 50 \
|
| 46 |
+
--max_length 2048 \
|
| 47 |
+
--use_peft \
|
| 48 |
+
--lora_task_type SEQ_CLS \
|
| 49 |
+
--lora_r 32 \
|
| 50 |
+
--lora_alpha 16
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
import os
|
| 54 |
+
|
| 55 |
+
import torch
|
| 56 |
+
from accelerate import logging
|
| 57 |
+
from datasets import load_dataset
|
| 58 |
+
from transformers import AutoModelForSequenceClassification, HfArgumentParser
|
| 59 |
+
|
| 60 |
+
from trl import (
|
| 61 |
+
ModelConfig,
|
| 62 |
+
RewardConfig,
|
| 63 |
+
RewardTrainer,
|
| 64 |
+
ScriptArguments,
|
| 65 |
+
get_kbit_device_map,
|
| 66 |
+
get_peft_config,
|
| 67 |
+
get_quantization_config,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
logger = logging.get_logger(__name__)
|
| 72 |
+
|
| 73 |
+
# Enable logging in a Hugging Face Space
|
| 74 |
+
os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
parser = HfArgumentParser((ScriptArguments, RewardConfig, ModelConfig))
|
| 79 |
+
script_args, training_args, model_args = parser.parse_args_into_dataclasses()
|
| 80 |
+
|
| 81 |
+
################
|
| 82 |
+
# Model & Tokenizer
|
| 83 |
+
################
|
| 84 |
+
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
|
| 85 |
+
model_kwargs = dict(
|
| 86 |
+
revision=model_args.model_revision,
|
| 87 |
+
use_cache=False if training_args.gradient_checkpointing else True,
|
| 88 |
+
dtype=dtype,
|
| 89 |
+
)
|
| 90 |
+
quantization_config = get_quantization_config(model_args)
|
| 91 |
+
if quantization_config is not None:
|
| 92 |
+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
|
| 93 |
+
model_kwargs["device_map"] = get_kbit_device_map()
|
| 94 |
+
model_kwargs["quantization_config"] = quantization_config
|
| 95 |
+
|
| 96 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 97 |
+
model_args.model_name_or_path, num_labels=1, trust_remote_code=model_args.trust_remote_code, **model_kwargs
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if model_args.use_peft and model_args.lora_task_type != "SEQ_CLS":
|
| 101 |
+
logger.warning(
|
| 102 |
+
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
|
| 103 |
+
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT.",
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
##############
|
| 107 |
+
# Load dataset
|
| 108 |
+
##############
|
| 109 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 110 |
+
|
| 111 |
+
##########
|
| 112 |
+
# Training
|
| 113 |
+
##########
|
| 114 |
+
trainer = RewardTrainer(
|
| 115 |
+
model=model,
|
| 116 |
+
args=training_args,
|
| 117 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 118 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 119 |
+
peft_config=get_peft_config(model_args),
|
| 120 |
+
)
|
| 121 |
+
trainer.train()
|
| 122 |
+
|
| 123 |
+
############################
|
| 124 |
+
# Save model and push to Hub
|
| 125 |
+
############################
|
| 126 |
+
trainer.save_model(training_args.output_dir)
|
| 127 |
+
|
| 128 |
+
if training_args.eval_strategy != "no":
|
| 129 |
+
metrics = trainer.evaluate()
|
| 130 |
+
trainer.log_metrics("eval", metrics)
|
| 131 |
+
trainer.save_metrics("eval", metrics)
|
| 132 |
+
|
| 133 |
+
# Save and push to hub
|
| 134 |
+
trainer.save_model(training_args.output_dir)
|
| 135 |
+
if training_args.push_to_hub:
|
| 136 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
ICL/RL/trl_source/examples/scripts/sft_vlm_gemma3.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# /// script
|
| 16 |
+
# dependencies = [
|
| 17 |
+
# "trl",
|
| 18 |
+
# "Pillow>=9.4.0",
|
| 19 |
+
# "peft",
|
| 20 |
+
# "trackio",
|
| 21 |
+
# "kernels",
|
| 22 |
+
# ]
|
| 23 |
+
# ///
|
| 24 |
+
|
| 25 |
+
"""
|
| 26 |
+
Train Gemma 3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image).
|
| 27 |
+
|
| 28 |
+
accelerate launch \
|
| 29 |
+
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
| 30 |
+
examples/scripts/sft_vlm_gemma3.py \
|
| 31 |
+
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
|
| 32 |
+
--model_name_or_path google/gemma-3-4b-it \
|
| 33 |
+
--per_device_train_batch_size 1 \
|
| 34 |
+
--output_dir Gemma-3-4B-SFT-MMIU \
|
| 35 |
+
--dtype bfloat16 \
|
| 36 |
+
--use_peft \
|
| 37 |
+
--lora_target_modules all-linear \
|
| 38 |
+
--attn_implementation eager
|
| 39 |
+
|
| 40 |
+
Train Gemma 3 on the FanqingM/MMIU-Benchmark dataset (multi-image).
|
| 41 |
+
|
| 42 |
+
accelerate launch \
|
| 43 |
+
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
| 44 |
+
examples/scripts/sft_vlm_gemma3.py \
|
| 45 |
+
--dataset_name FanqingM/MMIU-Benchmark \
|
| 46 |
+
--dataset_train_split test \
|
| 47 |
+
--model_name_or_path google/gemma-3-4b-it \
|
| 48 |
+
--per_device_train_batch_size 1 \
|
| 49 |
+
--output_dir Gemma-3-4B-SFT-MMIU \
|
| 50 |
+
--dtype bfloat16 \
|
| 51 |
+
--use_peft \
|
| 52 |
+
--lora_target_modules all-linear \
|
| 53 |
+
--attn_implementation eager
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
import io
|
| 57 |
+
import os
|
| 58 |
+
import zipfile
|
| 59 |
+
|
| 60 |
+
import torch
|
| 61 |
+
from datasets import DatasetDict, load_dataset
|
| 62 |
+
from huggingface_hub import hf_hub_download, list_repo_files
|
| 63 |
+
from PIL import Image
|
| 64 |
+
from transformers import AutoModelForImageTextToText
|
| 65 |
+
|
| 66 |
+
from trl import (
|
| 67 |
+
ModelConfig,
|
| 68 |
+
ScriptArguments,
|
| 69 |
+
SFTConfig,
|
| 70 |
+
SFTTrainer,
|
| 71 |
+
TrlParser,
|
| 72 |
+
get_kbit_device_map,
|
| 73 |
+
get_peft_config,
|
| 74 |
+
get_quantization_config,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Enable logging in a Hugging Face Space
|
| 79 |
+
os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# For multi-image example
|
| 83 |
+
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
| 84 |
+
image_inputs = []
|
| 85 |
+
for msg in messages:
|
| 86 |
+
content = msg.get("content", [])
|
| 87 |
+
if not isinstance(content, list):
|
| 88 |
+
content = [content]
|
| 89 |
+
|
| 90 |
+
for element in content:
|
| 91 |
+
if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
|
| 92 |
+
if "image" in element:
|
| 93 |
+
image = element["image"]
|
| 94 |
+
else:
|
| 95 |
+
image = element
|
| 96 |
+
if image is not None:
|
| 97 |
+
image = Image.open(io.BytesIO(image["bytes"]))
|
| 98 |
+
image_inputs.append(image.convert("RGB"))
|
| 99 |
+
return image_inputs
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def format_data(samples: dict[str, any]) -> dict[str, list]:
|
| 103 |
+
formatted_samples = {"messages": []}
|
| 104 |
+
for cont in range(len(samples["question"])):
|
| 105 |
+
images = []
|
| 106 |
+
for img_path in samples["input_image_path"][cont]:
|
| 107 |
+
try:
|
| 108 |
+
with open(img_path, "rb") as f:
|
| 109 |
+
img_bytes = f.read()
|
| 110 |
+
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
| 111 |
+
images.append({"type": "image", "image": image})
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(f"Error processing image {img_path}: {e}")
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
formatted_samples["messages"].append(
|
| 117 |
+
[
|
| 118 |
+
{"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
|
| 119 |
+
{"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
|
| 120 |
+
{"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
|
| 121 |
+
]
|
| 122 |
+
)
|
| 123 |
+
return formatted_samples
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# For multi-image example
|
| 127 |
+
def prepare_dataset(dataset: DatasetDict, dataset_name: str) -> DatasetDict:
|
| 128 |
+
all_files = list_repo_files(dataset_name, repo_type="dataset")
|
| 129 |
+
zip_files = [f for f in all_files if f.endswith(".zip")]
|
| 130 |
+
|
| 131 |
+
for zip_filename in zip_files:
|
| 132 |
+
zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
|
| 133 |
+
extract_folder = zip_filename.replace(".zip", "")
|
| 134 |
+
os.makedirs(extract_folder, exist_ok=True)
|
| 135 |
+
|
| 136 |
+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
| 137 |
+
zip_ref.extractall(extract_folder)
|
| 138 |
+
|
| 139 |
+
dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
|
| 140 |
+
return dataset
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def main():
|
| 144 |
+
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
|
| 145 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 146 |
+
training_args.max_length = None
|
| 147 |
+
|
| 148 |
+
################
|
| 149 |
+
# Model
|
| 150 |
+
################
|
| 151 |
+
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
|
| 152 |
+
model_kwargs = dict(
|
| 153 |
+
revision=model_args.model_revision,
|
| 154 |
+
attn_implementation=model_args.attn_implementation,
|
| 155 |
+
dtype=dtype,
|
| 156 |
+
)
|
| 157 |
+
quantization_config = get_quantization_config(model_args)
|
| 158 |
+
if quantization_config is not None:
|
| 159 |
+
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
|
| 160 |
+
model_kwargs["device_map"] = get_kbit_device_map()
|
| 161 |
+
model_kwargs["quantization_config"] = quantization_config
|
| 162 |
+
|
| 163 |
+
model = AutoModelForImageTextToText.from_pretrained(
|
| 164 |
+
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
################
|
| 168 |
+
# Dataset
|
| 169 |
+
################
|
| 170 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 171 |
+
if script_args.dataset_name == "FanqingM/MMIU-Benchmark":
|
| 172 |
+
dataset = prepare_dataset(dataset, script_args.dataset_name)
|
| 173 |
+
|
| 174 |
+
################
|
| 175 |
+
# Training
|
| 176 |
+
################
|
| 177 |
+
trainer = SFTTrainer(
|
| 178 |
+
model=model,
|
| 179 |
+
args=training_args,
|
| 180 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 181 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 182 |
+
peft_config=get_peft_config(model_args),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
trainer.train()
|
| 186 |
+
|
| 187 |
+
# Save and push to hub
|
| 188 |
+
trainer.save_model(training_args.output_dir)
|
| 189 |
+
if training_args.push_to_hub:
|
| 190 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
main()
|
ICL/RL/trl_source/trl/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (2.3 kB). View file
|
|
|
ICL/RL/trl_source/trl/__pycache__/_compat.cpython-313.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
ICL/RL/trl_source/trl/__pycache__/chat_template_utils.cpython-313.pyc
ADDED
|
Binary file (22.3 kB). View file
|
|
|
ICL/RL/trl_source/trl/__pycache__/data_utils.cpython-313.pyc
ADDED
|
Binary file (43.1 kB). View file
|
|
|
ICL/RL/trl_source/trl/__pycache__/import_utils.cpython-313.pyc
ADDED
|
Binary file (8.1 kB). View file
|
|
|
ICL/RL/trl_source/trl/accelerate_configs/fsdp1.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: FSDP
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
enable_cpu_affinity: false
|
| 6 |
+
fsdp_config:
|
| 7 |
+
fsdp_activation_checkpointing: false
|
| 8 |
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
| 9 |
+
fsdp_backward_prefetch: BACKWARD_PRE
|
| 10 |
+
fsdp_cpu_ram_efficient_loading: true
|
| 11 |
+
fsdp_forward_prefetch: true
|
| 12 |
+
fsdp_offload_params: false
|
| 13 |
+
fsdp_reshard_after_forward: FULL_SHARD
|
| 14 |
+
fsdp_state_dict_type: FULL_STATE_DICT
|
| 15 |
+
fsdp_sync_module_states: true
|
| 16 |
+
fsdp_use_orig_params: true
|
| 17 |
+
fsdp_version: 1
|
| 18 |
+
machine_rank: 0
|
| 19 |
+
main_training_function: main
|
| 20 |
+
mixed_precision: bf16
|
| 21 |
+
num_machines: 1
|
| 22 |
+
num_processes: 8
|
| 23 |
+
rdzv_backend: static
|
| 24 |
+
same_network: true
|
| 25 |
+
tpu_env: []
|
| 26 |
+
tpu_use_cluster: false
|
| 27 |
+
tpu_use_sudo: false
|
| 28 |
+
use_cpu: false
|
ICL/RL/trl_source/trl/accelerate_configs/fsdp2.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Requires accelerate 1.7.0 or higher
|
| 2 |
+
compute_environment: LOCAL_MACHINE
|
| 3 |
+
debug: false
|
| 4 |
+
distributed_type: FSDP
|
| 5 |
+
downcast_bf16: 'no'
|
| 6 |
+
enable_cpu_affinity: false
|
| 7 |
+
fsdp_config:
|
| 8 |
+
fsdp_activation_checkpointing: false
|
| 9 |
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
| 10 |
+
fsdp_cpu_ram_efficient_loading: true
|
| 11 |
+
fsdp_offload_params: false
|
| 12 |
+
fsdp_reshard_after_forward: true
|
| 13 |
+
fsdp_state_dict_type: FULL_STATE_DICT
|
| 14 |
+
fsdp_version: 2
|
| 15 |
+
machine_rank: 0
|
| 16 |
+
main_training_function: main
|
| 17 |
+
mixed_precision: bf16
|
| 18 |
+
num_machines: 1
|
| 19 |
+
num_processes: 8
|
| 20 |
+
rdzv_backend: static
|
| 21 |
+
same_network: true
|
| 22 |
+
tpu_env: []
|
| 23 |
+
tpu_use_cluster: false
|
| 24 |
+
tpu_use_sudo: false
|
| 25 |
+
use_cpu: false
|
ICL/RL/trl_source/trl/accelerate_configs/multi_gpu.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: MULTI_GPU
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
gpu_ids: all
|
| 6 |
+
machine_rank: 0
|
| 7 |
+
main_training_function: main
|
| 8 |
+
mixed_precision: 'bf16'
|
| 9 |
+
num_machines: 1
|
| 10 |
+
num_processes: 8
|
| 11 |
+
rdzv_backend: static
|
| 12 |
+
same_network: true
|
| 13 |
+
tpu_env: []
|
| 14 |
+
tpu_use_cluster: false
|
| 15 |
+
tpu_use_sudo: false
|
| 16 |
+
use_cpu: false
|
ICL/RL/trl_source/trl/accelerate_configs/single_gpu.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
distributed_type: "NO"
|
| 4 |
+
downcast_bf16: 'no'
|
| 5 |
+
gpu_ids: all
|
| 6 |
+
machine_rank: 0
|
| 7 |
+
main_training_function: main
|
| 8 |
+
mixed_precision: 'bf16'
|
| 9 |
+
num_machines: 1
|
| 10 |
+
num_processes: 8
|
| 11 |
+
rdzv_backend: static
|
| 12 |
+
same_network: true
|
| 13 |
+
tpu_env: []
|
| 14 |
+
tpu_use_cluster: false
|
| 15 |
+
tpu_use_sudo: false
|
| 16 |
+
use_cpu: false
|
ICL/RL/trl_source/trl/accelerate_configs/zero1.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_multinode_launcher: standard
|
| 5 |
+
gradient_accumulation_steps: 1
|
| 6 |
+
zero3_init_flag: false
|
| 7 |
+
zero_stage: 1
|
| 8 |
+
distributed_type: DEEPSPEED
|
| 9 |
+
downcast_bf16: 'no'
|
| 10 |
+
machine_rank: 0
|
| 11 |
+
main_training_function: main
|
| 12 |
+
mixed_precision: 'bf16'
|
| 13 |
+
num_machines: 1
|
| 14 |
+
num_processes: 8
|
| 15 |
+
rdzv_backend: static
|
| 16 |
+
same_network: true
|
| 17 |
+
tpu_env: []
|
| 18 |
+
tpu_use_cluster: false
|
| 19 |
+
tpu_use_sudo: false
|
| 20 |
+
use_cpu: false
|
ICL/RL/trl_source/trl/accelerate_configs/zero2.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_multinode_launcher: standard
|
| 5 |
+
offload_optimizer_device: none
|
| 6 |
+
offload_param_device: none
|
| 7 |
+
zero3_init_flag: false
|
| 8 |
+
zero_stage: 2
|
| 9 |
+
distributed_type: DEEPSPEED
|
| 10 |
+
downcast_bf16: 'no'
|
| 11 |
+
machine_rank: 0
|
| 12 |
+
main_training_function: main
|
| 13 |
+
mixed_precision: 'bf16'
|
| 14 |
+
num_machines: 1
|
| 15 |
+
num_processes: 8
|
| 16 |
+
rdzv_backend: static
|
| 17 |
+
same_network: true
|
| 18 |
+
tpu_env: []
|
| 19 |
+
tpu_use_cluster: false
|
| 20 |
+
tpu_use_sudo: false
|
| 21 |
+
use_cpu: false
|
ICL/RL/trl_source/trl/accelerate_configs/zero3.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compute_environment: LOCAL_MACHINE
|
| 2 |
+
debug: false
|
| 3 |
+
deepspeed_config:
|
| 4 |
+
deepspeed_multinode_launcher: standard
|
| 5 |
+
offload_optimizer_device: none
|
| 6 |
+
offload_param_device: none
|
| 7 |
+
zero3_init_flag: true
|
| 8 |
+
zero3_save_16bit_model: true
|
| 9 |
+
zero_stage: 3
|
| 10 |
+
distributed_type: DEEPSPEED
|
| 11 |
+
downcast_bf16: 'no'
|
| 12 |
+
machine_rank: 0
|
| 13 |
+
main_training_function: main
|
| 14 |
+
mixed_precision: bf16
|
| 15 |
+
num_machines: 1
|
| 16 |
+
num_processes: 8
|
| 17 |
+
rdzv_backend: static
|
| 18 |
+
same_network: true
|
| 19 |
+
tpu_env: []
|
| 20 |
+
tpu_use_cluster: false
|
| 21 |
+
tpu_use_sudo: false
|
| 22 |
+
use_cpu: false
|
ICL/RL/trl_source/trl/experimental/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Experimental submodule for TRL.
|
| 17 |
+
|
| 18 |
+
This submodule contains unstable or incubating features. Anything here may change (or be removed) in any release
|
| 19 |
+
without deprecation. Use at your own risk.
|
| 20 |
+
|
| 21 |
+
To silence this notice set environment variable TRL_EXPERIMENTAL_SILENCE=1.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import warnings
|
| 26 |
+
|
| 27 |
+
from ..import_utils import TRLExperimentalWarning
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
| 31 |
+
warnings.warn(
|
| 32 |
+
"You are importing from 'trl.experimental'. APIs here are unstable and may change or be removed without "
|
| 33 |
+
"notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.",
|
| 34 |
+
TRLExperimentalWarning,
|
| 35 |
+
stacklevel=2,
|
| 36 |
+
)
|
ICL/RL/trl_source/trl/experimental/bco/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .bco_config import BCOConfig
|
| 16 |
+
from .bco_trainer import BCOTrainer
|
ICL/RL/trl_source/trl/experimental/bema_for_ref_model/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .callback import BEMACallback
|
| 16 |
+
from .dpo_trainer import DPOTrainer
|
ICL/RL/trl_source/trl/experimental/bema_for_ref_model/dpo_trainer.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from ...trainer.dpo_trainer import DPOTrainer as _DPOTrainer
|
| 16 |
+
from .callback import CallbackHandlerWithRefModel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DPOTrainer(_DPOTrainer):
|
| 20 |
+
def __init__(self, *args, **kwargs):
|
| 21 |
+
super().__init__(*args, **kwargs)
|
| 22 |
+
# Replace with a new one that calls the events with the reference model
|
| 23 |
+
self.callback_handler = CallbackHandlerWithRefModel(
|
| 24 |
+
self.callback_handler.callbacks,
|
| 25 |
+
self.model,
|
| 26 |
+
self.ref_model,
|
| 27 |
+
self.processing_class,
|
| 28 |
+
self.optimizer,
|
| 29 |
+
self.lr_scheduler,
|
| 30 |
+
)
|
ICL/RL/trl_source/trl/experimental/cpo/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .cpo_config import CPOConfig
|
| 16 |
+
from .cpo_trainer import CPOTrainer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = ["CPOConfig", "CPOTrainer"]
|
ICL/RL/trl_source/trl/experimental/cpo/cpo_config.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from transformers import TrainingArguments
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class CPOConfig(TrainingArguments):
|
| 23 |
+
r"""
|
| 24 |
+
Configuration class for the [`experimental.cpo.CPOTrainer`].
|
| 25 |
+
|
| 26 |
+
This class includes only the parameters that are specific to CPO training. For a full list of training arguments,
|
| 27 |
+
please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
|
| 28 |
+
differ from those in [`~transformers.TrainingArguments`].
|
| 29 |
+
|
| 30 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 31 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 32 |
+
command line.
|
| 33 |
+
|
| 34 |
+
Parameters:
|
| 35 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 36 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
| 37 |
+
to use the default data collator.
|
| 38 |
+
max_completion_length (`int`, *optional*):
|
| 39 |
+
Maximum length of the completion. This argument is required if you want to use the default data collator
|
| 40 |
+
and your model is an encoder-decoder.
|
| 41 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
| 42 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
| 43 |
+
reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
|
| 44 |
+
the [paper](https://huggingface.co/papers/2310.12036).
|
| 45 |
+
label_smoothing (`float`, *optional*, defaults to `0.0`):
|
| 46 |
+
Label smoothing factor. This argument is required if you want to use the default data collator.
|
| 47 |
+
loss_type (`str`, *optional*, defaults to `"sigmoid"`):
|
| 48 |
+
Type of loss to use. Possible values are:
|
| 49 |
+
|
| 50 |
+
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
|
| 51 |
+
- `"hinge"`: hinge loss on the normalized likelihood from the
|
| 52 |
+
[SLiC](https://huggingface.co/papers/2305.10425) paper.
|
| 53 |
+
- `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
|
| 54 |
+
- `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
|
| 55 |
+
- `"alphapo"`: AlphaPO loss from the [AlphaPO](https://huggingface.co/papers/2501.03884) paper. This
|
| 56 |
+
automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`.
|
| 57 |
+
|
| 58 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 59 |
+
Whether to disable dropout in the model.
|
| 60 |
+
cpo_alpha (`float`, *optional*, defaults to `1.0`):
|
| 61 |
+
Weight of the BC regularizer in CPO training.
|
| 62 |
+
simpo_gamma (`float`, *optional*, defaults to `0.5`):
|
| 63 |
+
Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
|
| 64 |
+
alpha (`float`, *optional*, defaults to `0.0`):
|
| 65 |
+
Alpha parameter that controls reward function shape across all loss types. When alpha=0 (default), uses
|
| 66 |
+
standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: `r = (1 - p^(-alpha))
|
| 67 |
+
/ alpha` from the [AlphaPO paper](https://huggingface.co/papers/2501.03884). This parameter works with all
|
| 68 |
+
loss types.
|
| 69 |
+
truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
|
| 70 |
+
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
|
| 71 |
+
This argument is required if you want to use the default data collator.
|
| 72 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
| 73 |
+
If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
|
| 74 |
+
is_encoder_decoder (`bool`, *optional*):
|
| 75 |
+
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
| 76 |
+
you need to specify if the model returned by the callable is an encoder-decoder model.
|
| 77 |
+
model_init_kwargs (`dict[str, Any]`, *optional*):
|
| 78 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
| 79 |
+
string.
|
| 80 |
+
dataset_num_proc (`int`, *optional*):
|
| 81 |
+
Number of processes to use for processing the dataset.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
|
| 85 |
+
|
| 86 |
+
# Parameters whose default values are overridden from TrainingArguments
|
| 87 |
+
learning_rate: float = field(
|
| 88 |
+
default=1e-6,
|
| 89 |
+
metadata={"help": "The initial learning rate for AdamW."},
|
| 90 |
+
)
|
| 91 |
+
logging_steps: float = field(
|
| 92 |
+
default=10,
|
| 93 |
+
metadata={
|
| 94 |
+
"help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, "
|
| 95 |
+
"will be interpreted as ratio of total training steps."
|
| 96 |
+
},
|
| 97 |
+
)
|
| 98 |
+
gradient_checkpointing: bool = field(
|
| 99 |
+
default=True,
|
| 100 |
+
metadata={
|
| 101 |
+
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
| 102 |
+
},
|
| 103 |
+
)
|
| 104 |
+
bf16: bool | None = field(
|
| 105 |
+
default=None,
|
| 106 |
+
metadata={
|
| 107 |
+
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
|
| 108 |
+
"architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if "
|
| 109 |
+
"`fp16` is not set."
|
| 110 |
+
},
|
| 111 |
+
)
|
| 112 |
+
# Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue
|
| 113 |
+
# was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary
|
| 114 |
+
# workaround here, which can be removed once we drop support for versions older than 4.57.5.
|
| 115 |
+
lr_scheduler_kwargs: dict | str | None = field(
|
| 116 |
+
default=None,
|
| 117 |
+
metadata={
|
| 118 |
+
"help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard "
|
| 119 |
+
"restarts."
|
| 120 |
+
},
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
max_length: int | None = field(
|
| 124 |
+
default=1024,
|
| 125 |
+
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
|
| 126 |
+
)
|
| 127 |
+
max_completion_length: int | None = field(
|
| 128 |
+
default=None,
|
| 129 |
+
metadata={
|
| 130 |
+
"help": "Maximum length of the completion. This argument is required if you want to use the default data "
|
| 131 |
+
"collator and your model is an encoder-decoder."
|
| 132 |
+
},
|
| 133 |
+
)
|
| 134 |
+
beta: float = field(
|
| 135 |
+
default=0.1,
|
| 136 |
+
metadata={
|
| 137 |
+
"help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from "
|
| 138 |
+
"the reference model."
|
| 139 |
+
},
|
| 140 |
+
)
|
| 141 |
+
label_smoothing: float = field(
|
| 142 |
+
default=0.0,
|
| 143 |
+
metadata={"help": "Label smoothing factor."},
|
| 144 |
+
)
|
| 145 |
+
loss_type: str = field(
|
| 146 |
+
default="sigmoid",
|
| 147 |
+
metadata={
|
| 148 |
+
"help": "Type of loss to use.",
|
| 149 |
+
"choices": ["sigmoid", "hinge", "ipo", "simpo", "alphapo"],
|
| 150 |
+
},
|
| 151 |
+
)
|
| 152 |
+
disable_dropout: bool = field(
|
| 153 |
+
default=True,
|
| 154 |
+
metadata={"help": "Whether to disable dropout in the model."},
|
| 155 |
+
)
|
| 156 |
+
cpo_alpha: float = field(
|
| 157 |
+
default=1.0,
|
| 158 |
+
metadata={"help": "Weight of the BC regularizer in CPO training."},
|
| 159 |
+
)
|
| 160 |
+
simpo_gamma: float = field(
|
| 161 |
+
default=0.5,
|
| 162 |
+
metadata={"help": "Target reward margin for the SimPO loss, used only when the `loss_type='simpo'`."},
|
| 163 |
+
)
|
| 164 |
+
alpha: float = field(
|
| 165 |
+
default=0.0,
|
| 166 |
+
metadata={
|
| 167 |
+
"help": "Alpha parameter that controls reward function shape across all loss types. When alpha=0 "
|
| 168 |
+
"(default), uses standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: "
|
| 169 |
+
"`r = (1 - p^(-alpha)) / alpha` from the AlphaPO paper. This parameter works with all loss types."
|
| 170 |
+
},
|
| 171 |
+
)
|
| 172 |
+
truncation_mode: str = field(
|
| 173 |
+
default="keep_end",
|
| 174 |
+
metadata={
|
| 175 |
+
"help": "Truncation mode to use when the prompt is too long.",
|
| 176 |
+
"choices": ["keep_end", "keep_start"],
|
| 177 |
+
},
|
| 178 |
+
)
|
| 179 |
+
generate_during_eval: bool = field(
|
| 180 |
+
default=False,
|
| 181 |
+
metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."},
|
| 182 |
+
)
|
| 183 |
+
is_encoder_decoder: bool | None = field(
|
| 184 |
+
default=None,
|
| 185 |
+
metadata={"help": "Whether the model is an encoder-decoder model."},
|
| 186 |
+
)
|
| 187 |
+
model_init_kwargs: dict[str, Any] | None = field(
|
| 188 |
+
default=None,
|
| 189 |
+
metadata={
|
| 190 |
+
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
|
| 191 |
+
"from a string."
|
| 192 |
+
},
|
| 193 |
+
)
|
| 194 |
+
dataset_num_proc: int | None = field(
|
| 195 |
+
default=None,
|
| 196 |
+
metadata={"help": "Number of processes to use for processing the dataset."},
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def __post_init__(self):
|
| 200 |
+
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
|
| 201 |
+
|
| 202 |
+
# Syntactic sugar for AlphaPO: set loss_type to "simpo" and cpo_alpha to 0.0
|
| 203 |
+
if self.loss_type == "alphapo":
|
| 204 |
+
self.loss_type = "simpo"
|
| 205 |
+
self.cpo_alpha = 0.0
|
| 206 |
+
|
| 207 |
+
super().__post_init__()
|
ICL/RL/trl_source/trl/experimental/cpo/cpo_trainer.py
ADDED
|
@@ -0,0 +1,1057 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
import random
|
| 17 |
+
import textwrap
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
from collections.abc import Callable
|
| 20 |
+
from contextlib import nullcontext
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Any, Literal
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import pandas as pd
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
import transformers
|
| 30 |
+
from accelerate import PartialState, logging
|
| 31 |
+
from datasets import Dataset
|
| 32 |
+
from packaging.version import Version
|
| 33 |
+
from torch import autocast
|
| 34 |
+
from torch.utils.data import DataLoader
|
| 35 |
+
from transformers import (
|
| 36 |
+
AutoModelForCausalLM,
|
| 37 |
+
BaseImageProcessor,
|
| 38 |
+
DataCollator,
|
| 39 |
+
FeatureExtractionMixin,
|
| 40 |
+
PreTrainedModel,
|
| 41 |
+
PreTrainedTokenizerBase,
|
| 42 |
+
ProcessorMixin,
|
| 43 |
+
TrainerCallback,
|
| 44 |
+
is_comet_available,
|
| 45 |
+
is_wandb_available,
|
| 46 |
+
)
|
| 47 |
+
from transformers.trainer_utils import EvalLoopOutput
|
| 48 |
+
from transformers.utils import is_peft_available, is_torch_fx_proxy
|
| 49 |
+
|
| 50 |
+
from ...data_utils import maybe_apply_chat_template, maybe_extract_prompt
|
| 51 |
+
from ...models.utils import peft_module_casting_to_bf16
|
| 52 |
+
from ...trainer.base_trainer import BaseTrainer
|
| 53 |
+
from ...trainer.utils import (
|
| 54 |
+
disable_dropout_in_model,
|
| 55 |
+
log_table_to_comet_experiment,
|
| 56 |
+
pad_to_length,
|
| 57 |
+
selective_log_softmax,
|
| 58 |
+
)
|
| 59 |
+
from ..utils import DPODataCollatorWithPadding, add_bos_token_if_needed, add_eos_token_if_needed
|
| 60 |
+
from .cpo_config import CPOConfig
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if is_peft_available():
|
| 64 |
+
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if is_wandb_available():
|
| 68 |
+
import wandb
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
logger = logging.get_logger(__name__)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class CPOTrainer(BaseTrainer):
|
| 75 |
+
r"""
|
| 76 |
+
Initialize CPOTrainer.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
model ([`~transformers.PreTrainedModel`]):
|
| 80 |
+
The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
|
| 81 |
+
args ([`experimental.cpo.CPOConfig`]):
|
| 82 |
+
The CPO config arguments to use for training.
|
| 83 |
+
data_collator ([`~transformers.DataCollator`]):
|
| 84 |
+
The data collator to use for training. If None is specified, the default data collator
|
| 85 |
+
([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the
|
| 86 |
+
maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 87 |
+
train_dataset ([`~datasets.Dataset`]):
|
| 88 |
+
The dataset to use for training.
|
| 89 |
+
eval_dataset ([`~datasets.Dataset`]):
|
| 90 |
+
The dataset to use for evaluation.
|
| 91 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
| 92 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 93 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 94 |
+
reuse the fine-tuned model.
|
| 95 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 96 |
+
The model initializer to use for training. If None is specified, the default model initializer will be
|
| 97 |
+
used.
|
| 98 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 99 |
+
The callbacks to use for training.
|
| 100 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 101 |
+
The optimizer and scheduler to use for training.
|
| 102 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 103 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 104 |
+
peft_config (`dict`, defaults to `None`):
|
| 105 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
|
| 106 |
+
a PEFT model.
|
| 107 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 108 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
|
| 109 |
+
metric values.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
_tag_names = ["trl", "cpo"]
|
| 113 |
+
_name = "CPO"
|
| 114 |
+
_paper = {
|
| 115 |
+
"title": "Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
|
| 116 |
+
"id": "2401.08417",
|
| 117 |
+
# docstyle-ignore
|
| 118 |
+
"citation": textwrap.dedent("""\
|
| 119 |
+
@inproceedings{xu2024contrastive,
|
| 120 |
+
title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
|
| 121 |
+
author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
|
| 122 |
+
year = 2024,
|
| 123 |
+
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
| 124 |
+
publisher = {OpenReview.net},
|
| 125 |
+
url = {https://openreview.net/forum?id=51iwkioZpn}
|
| 126 |
+
}"""),
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
model: PreTrainedModel | nn.Module | str | None = None,
|
| 132 |
+
args: CPOConfig | None = None,
|
| 133 |
+
data_collator: DataCollator | None = None,
|
| 134 |
+
train_dataset: Dataset | None = None,
|
| 135 |
+
eval_dataset: Dataset | dict[str, Dataset] | None = None,
|
| 136 |
+
processing_class: PreTrainedTokenizerBase
|
| 137 |
+
| BaseImageProcessor
|
| 138 |
+
| FeatureExtractionMixin
|
| 139 |
+
| ProcessorMixin
|
| 140 |
+
| None = None,
|
| 141 |
+
model_init: Callable[[], PreTrainedModel] | None = None,
|
| 142 |
+
callbacks: list[TrainerCallback] | None = None,
|
| 143 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 144 |
+
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
| 145 |
+
peft_config: dict | None = None,
|
| 146 |
+
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
|
| 147 |
+
):
|
| 148 |
+
if args.model_init_kwargs is None:
|
| 149 |
+
model_init_kwargs = {}
|
| 150 |
+
elif not isinstance(model, str):
|
| 151 |
+
raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
|
| 152 |
+
else:
|
| 153 |
+
model_init_kwargs = args.model_init_kwargs
|
| 154 |
+
dtype = model_init_kwargs.get("dtype", "auto")
|
| 155 |
+
if dtype is not None:
|
| 156 |
+
# Convert to `torch.dtype` if an str is passed
|
| 157 |
+
if isinstance(dtype, str) and dtype != "auto":
|
| 158 |
+
dtype = getattr(torch, dtype)
|
| 159 |
+
if dtype != "auto" and not isinstance(dtype, torch.dtype):
|
| 160 |
+
raise ValueError(
|
| 161 |
+
f"Invalid `dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
|
| 162 |
+
)
|
| 163 |
+
model_init_kwargs["dtype"] = dtype
|
| 164 |
+
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
|
| 165 |
+
|
| 166 |
+
if isinstance(model, str):
|
| 167 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 168 |
+
|
| 169 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
| 170 |
+
# has been called in order to properly call autocast if needed.
|
| 171 |
+
self._peft_has_been_casted_to_bf16 = False
|
| 172 |
+
|
| 173 |
+
if not is_peft_available() and peft_config is not None:
|
| 174 |
+
raise ValueError(
|
| 175 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
| 176 |
+
)
|
| 177 |
+
elif is_peft_available() and peft_config is not None:
|
| 178 |
+
if isinstance(model, PeftModel):
|
| 179 |
+
raise ValueError(
|
| 180 |
+
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first "
|
| 181 |
+
"merge and unload the existing adapter, save the resulting base model, and then pass that base "
|
| 182 |
+
"model along with the new `peft_config` to the trainer."
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
| 186 |
+
_support_gc_kwargs = hasattr(
|
| 187 |
+
args, "gradient_checkpointing_kwargs"
|
| 188 |
+
) and "gradient_checkpointing_kwargs" in list(
|
| 189 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 193 |
+
|
| 194 |
+
if _support_gc_kwargs:
|
| 195 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 196 |
+
|
| 197 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 198 |
+
elif args.gradient_checkpointing:
|
| 199 |
+
# For backward compatibility with older versions of transformers
|
| 200 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 201 |
+
model.enable_input_require_grads()
|
| 202 |
+
else:
|
| 203 |
+
|
| 204 |
+
def make_inputs_require_grad(module, input, output):
|
| 205 |
+
output.requires_grad_(True)
|
| 206 |
+
|
| 207 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 208 |
+
|
| 209 |
+
# get peft model with the given config
|
| 210 |
+
model = get_peft_model(model, peft_config)
|
| 211 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
| 212 |
+
peft_module_casting_to_bf16(model)
|
| 213 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
| 214 |
+
self._peft_has_been_casted_to_bf16 = True
|
| 215 |
+
|
| 216 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
| 217 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
| 218 |
+
# fail or completely fail.
|
| 219 |
+
elif args.gradient_checkpointing:
|
| 220 |
+
# For backward compatibility with older versions of transformers
|
| 221 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 222 |
+
model.enable_input_require_grads()
|
| 223 |
+
else:
|
| 224 |
+
|
| 225 |
+
def make_inputs_require_grad(module, input, output):
|
| 226 |
+
output.requires_grad_(True)
|
| 227 |
+
|
| 228 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 229 |
+
|
| 230 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
| 231 |
+
raise ValueError(
|
| 232 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
| 233 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if model is not None:
|
| 237 |
+
self.is_encoder_decoder = model.config.is_encoder_decoder
|
| 238 |
+
elif args.is_encoder_decoder is None:
|
| 239 |
+
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
| 240 |
+
else:
|
| 241 |
+
self.is_encoder_decoder = args.is_encoder_decoder
|
| 242 |
+
|
| 243 |
+
if self.is_encoder_decoder:
|
| 244 |
+
self.decoder_start_token_id = model.config.decoder_start_token_id
|
| 245 |
+
self.pad_token_id = model.config.pad_token_id
|
| 246 |
+
|
| 247 |
+
if processing_class is None:
|
| 248 |
+
raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
|
| 249 |
+
if args.max_length is None:
|
| 250 |
+
logger.warning(
|
| 251 |
+
"`max_length` is not set in the CPOConfig's init"
|
| 252 |
+
" it will default to `512` by default, but you should do it yourself in the future.",
|
| 253 |
+
)
|
| 254 |
+
max_length = 512
|
| 255 |
+
else:
|
| 256 |
+
max_length = args.max_length
|
| 257 |
+
|
| 258 |
+
if args.max_completion_length is None and self.is_encoder_decoder:
|
| 259 |
+
logger.warning(
|
| 260 |
+
"When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
|
| 261 |
+
" it will default to `128` by default, but you should do it yourself in the future.",
|
| 262 |
+
)
|
| 263 |
+
max_completion_length = 128
|
| 264 |
+
else:
|
| 265 |
+
max_completion_length = args.max_completion_length
|
| 266 |
+
|
| 267 |
+
if data_collator is None:
|
| 268 |
+
data_collator = DPODataCollatorWithPadding(
|
| 269 |
+
pad_token_id=processing_class.pad_token_id,
|
| 270 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
if args.remove_unused_columns:
|
| 274 |
+
args.remove_unused_columns = False
|
| 275 |
+
# warn users
|
| 276 |
+
logger.warning(
|
| 277 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
| 278 |
+
" we have set it for you, but you should do it yourself in the future.",
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
self.use_dpo_data_collator = True
|
| 282 |
+
else:
|
| 283 |
+
self.use_dpo_data_collator = False
|
| 284 |
+
|
| 285 |
+
# Disable dropout in the model
|
| 286 |
+
if args.disable_dropout:
|
| 287 |
+
disable_dropout_in_model(model)
|
| 288 |
+
|
| 289 |
+
self.max_length = max_length
|
| 290 |
+
self.generate_during_eval = args.generate_during_eval
|
| 291 |
+
self.truncation_mode = args.truncation_mode
|
| 292 |
+
self.max_completion_length = max_completion_length
|
| 293 |
+
self.processing_class = processing_class
|
| 294 |
+
|
| 295 |
+
if processing_class.pad_token is None:
|
| 296 |
+
processing_class.pad_token = processing_class.eos_token
|
| 297 |
+
self.pad_token_id = processing_class.pad_token_id
|
| 298 |
+
|
| 299 |
+
if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
|
| 300 |
+
logger.warning(
|
| 301 |
+
f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
|
| 302 |
+
"`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
|
| 303 |
+
)
|
| 304 |
+
if args.loss_type == "kto_pair":
|
| 305 |
+
raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
|
| 306 |
+
|
| 307 |
+
self.beta = args.beta
|
| 308 |
+
self.label_smoothing = args.label_smoothing
|
| 309 |
+
self.loss_type = args.loss_type
|
| 310 |
+
self.cpo_alpha = args.cpo_alpha
|
| 311 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 312 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
| 313 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
| 314 |
+
logger.warning(
|
| 315 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
| 316 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
| 317 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
| 318 |
+
"loss.",
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
if args.loss_type == "simpo":
|
| 322 |
+
self.simpo_gamma = args.simpo_gamma
|
| 323 |
+
|
| 324 |
+
# AlphaPO parameter for reward shaping
|
| 325 |
+
self.alpha = args.alpha
|
| 326 |
+
|
| 327 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 328 |
+
|
| 329 |
+
# Compute that only on the main process for faster data processing.
|
| 330 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
| 331 |
+
with PartialState().main_process_first():
|
| 332 |
+
# Extract the prompt if needed, and apply the chat template if needed
|
| 333 |
+
train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 334 |
+
train_dataset = train_dataset.map(
|
| 335 |
+
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
|
| 336 |
+
)
|
| 337 |
+
if eval_dataset is not None:
|
| 338 |
+
eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
|
| 339 |
+
eval_dataset = eval_dataset.map(
|
| 340 |
+
maybe_apply_chat_template,
|
| 341 |
+
fn_kwargs={"tokenizer": processing_class},
|
| 342 |
+
num_proc=args.dataset_num_proc,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# tokenize the dataset
|
| 346 |
+
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 347 |
+
if eval_dataset is not None:
|
| 348 |
+
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
| 349 |
+
|
| 350 |
+
# Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
|
| 351 |
+
# never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
|
| 352 |
+
# (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
|
| 353 |
+
# default to the recommended non-reentrant behavior here, while preserving any user-provided value.
|
| 354 |
+
if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
|
| 355 |
+
args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
| 356 |
+
args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)
|
| 357 |
+
|
| 358 |
+
super().__init__(
|
| 359 |
+
model=model,
|
| 360 |
+
args=args,
|
| 361 |
+
data_collator=data_collator,
|
| 362 |
+
train_dataset=train_dataset,
|
| 363 |
+
eval_dataset=eval_dataset,
|
| 364 |
+
processing_class=processing_class,
|
| 365 |
+
model_init=model_init,
|
| 366 |
+
compute_metrics=compute_metrics,
|
| 367 |
+
callbacks=callbacks,
|
| 368 |
+
optimizers=optimizers,
|
| 369 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 373 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 374 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 375 |
+
self.model_accepts_loss_kwargs = False
|
| 376 |
+
|
| 377 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 378 |
+
if hasattr(self.model, "add_model_tags"):
|
| 379 |
+
self.model.add_model_tags(self._tag_names)
|
| 380 |
+
|
| 381 |
+
if not hasattr(self, "accelerator"):
|
| 382 |
+
raise AttributeError(
|
| 383 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
def build_tokenized_answer(self, prompt, answer):
|
| 387 |
+
"""
|
| 388 |
+
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a +
|
| 389 |
+
b)[len(enc(a)):]`. Reference:
|
| 390 |
+
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 391 |
+
"""
|
| 392 |
+
|
| 393 |
+
full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
|
| 394 |
+
prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
|
| 395 |
+
|
| 396 |
+
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
| 397 |
+
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
| 398 |
+
|
| 399 |
+
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
| 400 |
+
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
| 401 |
+
|
| 402 |
+
# Prepare input tokens for token by token comparison
|
| 403 |
+
full_input_ids = np.array(full_tokenized["input_ids"])
|
| 404 |
+
|
| 405 |
+
if len(full_input_ids) != len(full_concat_input_ids):
|
| 406 |
+
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
| 407 |
+
|
| 408 |
+
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
| 409 |
+
# can be merged together when tokenizing prompt+answer. This could result
|
| 410 |
+
# on the last token from the prompt being different when tokenized on its own
|
| 411 |
+
# vs when done as prompt+answer.
|
| 412 |
+
response_token_ids_start_idx = len(prompt_input_ids)
|
| 413 |
+
|
| 414 |
+
# If tokenized prompt is different than both prompt+answer, then it means the
|
| 415 |
+
# last token has changed due to merging.
|
| 416 |
+
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
| 417 |
+
response_token_ids_start_idx -= 1
|
| 418 |
+
|
| 419 |
+
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
| 420 |
+
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
| 421 |
+
|
| 422 |
+
if len(prompt_input_ids) != len(prompt_attention_mask):
|
| 423 |
+
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
| 424 |
+
|
| 425 |
+
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
| 426 |
+
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
| 427 |
+
|
| 428 |
+
return dict(
|
| 429 |
+
prompt_input_ids=prompt_input_ids,
|
| 430 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 431 |
+
input_ids=answer_input_ids,
|
| 432 |
+
attention_mask=answer_attention_mask,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None) -> dict:
|
| 436 |
+
"""Tokenize a single row from a CPO specific dataset.
|
| 437 |
+
|
| 438 |
+
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
|
| 439 |
+
chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long,
|
| 440 |
+
we truncate the chosen/rejected.
|
| 441 |
+
|
| 442 |
+
We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length
|
| 443 |
+
of the prompt and the chosen/rejected response, with `-100` for the prompt tokens.
|
| 444 |
+
"""
|
| 445 |
+
batch = {}
|
| 446 |
+
prompt = feature["prompt"]
|
| 447 |
+
chosen = feature["chosen"]
|
| 448 |
+
rejected = feature["rejected"]
|
| 449 |
+
|
| 450 |
+
if not self.is_encoder_decoder:
|
| 451 |
+
# Check issues below for more details
|
| 452 |
+
# 1. https://github.com/huggingface/trl/issues/907
|
| 453 |
+
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 454 |
+
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
| 455 |
+
|
| 456 |
+
if not isinstance(prompt, str):
|
| 457 |
+
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
| 458 |
+
prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
|
| 459 |
+
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
| 460 |
+
|
| 461 |
+
if not isinstance(chosen, str):
|
| 462 |
+
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
| 463 |
+
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
| 464 |
+
|
| 465 |
+
if not isinstance(rejected, str):
|
| 466 |
+
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
| 467 |
+
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
| 468 |
+
|
| 469 |
+
# Last prompt token might get merged by tokenizer and
|
| 470 |
+
# it should not be included for generation if that happens
|
| 471 |
+
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
| 472 |
+
|
| 473 |
+
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
| 474 |
+
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
| 475 |
+
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
| 476 |
+
|
| 477 |
+
for k, v in prompt_tokens.items():
|
| 478 |
+
prompt_tokens[k] = v[:prompt_len_input_ids]
|
| 479 |
+
|
| 480 |
+
# Make sure prompts only have one different token at most an
|
| 481 |
+
# and length only differs by 1 at most
|
| 482 |
+
num_diff_tokens = sum(
|
| 483 |
+
a != b
|
| 484 |
+
for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=True)
|
| 485 |
+
)
|
| 486 |
+
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
| 487 |
+
if num_diff_tokens > 1 or num_diff_len > 1:
|
| 488 |
+
raise ValueError(
|
| 489 |
+
"Chosen and rejected prompt_input_ids might only differ on the "
|
| 490 |
+
"last token due to tokenizer merge ops."
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# add BOS token to head of prompt. Avoid adding if it's already there
|
| 494 |
+
prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
|
| 495 |
+
self.processing_class.bos_token_id,
|
| 496 |
+
prompt_len_input_ids,
|
| 497 |
+
prompt_tokens,
|
| 498 |
+
chosen_prompt_len_input_ids,
|
| 499 |
+
chosen_tokens,
|
| 500 |
+
rejected_prompt_len_input_ids,
|
| 501 |
+
rejected_tokens,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
# add EOS token to end of answer. Avoid adding if it's already there
|
| 505 |
+
chosen_tokens, rejected_tokens = add_eos_token_if_needed(
|
| 506 |
+
self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
| 510 |
+
|
| 511 |
+
# if combined sequence is too long, truncate the response
|
| 512 |
+
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
| 513 |
+
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
| 514 |
+
for k in ["input_ids", "attention_mask"]:
|
| 515 |
+
answer_tokens[k] = answer_tokens[k][: self.max_length - longer_response_length]
|
| 516 |
+
|
| 517 |
+
# Create labels
|
| 518 |
+
chosen_sequence_tokens = {
|
| 519 |
+
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 520 |
+
}
|
| 521 |
+
rejected_sequence_tokens = {
|
| 522 |
+
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
| 523 |
+
}
|
| 524 |
+
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
| 525 |
+
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [-100] * len(
|
| 526 |
+
chosen_tokens["prompt_input_ids"]
|
| 527 |
+
)
|
| 528 |
+
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
| 529 |
+
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [-100] * len(
|
| 530 |
+
rejected_tokens["prompt_input_ids"]
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
for k, toks in {
|
| 534 |
+
"chosen_": chosen_sequence_tokens,
|
| 535 |
+
"rejected_": rejected_sequence_tokens,
|
| 536 |
+
"": prompt_tokens,
|
| 537 |
+
}.items():
|
| 538 |
+
for type_key, tokens in toks.items():
|
| 539 |
+
if type_key == "token_type_ids":
|
| 540 |
+
continue
|
| 541 |
+
batch[f"{k}{type_key}"] = tokens
|
| 542 |
+
|
| 543 |
+
else:
|
| 544 |
+
chosen_tokens = self.processing_class(
|
| 545 |
+
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 546 |
+
)
|
| 547 |
+
rejected_tokens = self.processing_class(
|
| 548 |
+
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
| 549 |
+
)
|
| 550 |
+
prompt_tokens = self.processing_class(prompt, add_special_tokens=True)
|
| 551 |
+
|
| 552 |
+
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
| 553 |
+
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
| 554 |
+
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
| 555 |
+
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
| 556 |
+
|
| 557 |
+
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
| 558 |
+
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 559 |
+
labels=torch.tensor(batch["rejected_labels"])
|
| 560 |
+
)
|
| 561 |
+
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
| 562 |
+
labels=torch.tensor(batch["chosen_labels"])
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
return batch
|
| 566 |
+
|
| 567 |
+
@staticmethod
|
| 568 |
+
def concatenated_inputs(
|
| 569 |
+
batch: dict[str, list | torch.LongTensor],
|
| 570 |
+
is_encoder_decoder: bool = False,
|
| 571 |
+
padding_value: int = 0,
|
| 572 |
+
device: torch.device | None = None,
|
| 573 |
+
) -> dict[str, torch.LongTensor]:
|
| 574 |
+
"""Concatenate the chosen and rejected inputs into a single tensor.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
batch:
|
| 578 |
+
A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors
|
| 579 |
+
of shape (batch_size, sequence_length).
|
| 580 |
+
is_encoder_decoder:
|
| 581 |
+
Whether the model is an encoder-decoder model.
|
| 582 |
+
padding_value:
|
| 583 |
+
The padding value to use for the concatenated inputs_ids.
|
| 584 |
+
device:
|
| 585 |
+
The device for the concatenated inputs.
|
| 586 |
+
|
| 587 |
+
Returns:
|
| 588 |
+
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
| 589 |
+
"""
|
| 590 |
+
concatenated_batch = {}
|
| 591 |
+
|
| 592 |
+
if is_encoder_decoder:
|
| 593 |
+
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
| 594 |
+
else:
|
| 595 |
+
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
| 596 |
+
|
| 597 |
+
for k in batch:
|
| 598 |
+
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
| 599 |
+
if "labels" in k or is_encoder_decoder:
|
| 600 |
+
pad_value = -100
|
| 601 |
+
elif k.endswith("_input_ids"):
|
| 602 |
+
pad_value = padding_value
|
| 603 |
+
elif k.endswith("_attention_mask"):
|
| 604 |
+
pad_value = 0
|
| 605 |
+
concatenated_key = k.replace("chosen", "concatenated")
|
| 606 |
+
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
| 607 |
+
for k in batch:
|
| 608 |
+
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
| 609 |
+
if "labels" in k or is_encoder_decoder:
|
| 610 |
+
pad_value = -100
|
| 611 |
+
elif k.endswith("_input_ids"):
|
| 612 |
+
pad_value = padding_value
|
| 613 |
+
elif k.endswith("_attention_mask"):
|
| 614 |
+
pad_value = 0
|
| 615 |
+
concatenated_key = k.replace("rejected", "concatenated")
|
| 616 |
+
concatenated_batch[concatenated_key] = torch.cat(
|
| 617 |
+
(
|
| 618 |
+
concatenated_batch[concatenated_key],
|
| 619 |
+
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
| 620 |
+
),
|
| 621 |
+
dim=0,
|
| 622 |
+
).to(device=device)
|
| 623 |
+
|
| 624 |
+
if is_encoder_decoder:
|
| 625 |
+
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
| 626 |
+
concatenated_batch["concatenated_attention_mask"] = (
|
| 627 |
+
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
return concatenated_batch
|
| 631 |
+
|
| 632 |
+
def cpo_loss(
|
| 633 |
+
self,
|
| 634 |
+
policy_chosen_logps: torch.FloatTensor,
|
| 635 |
+
policy_rejected_logps: torch.FloatTensor,
|
| 636 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 637 |
+
"""Compute the CPO loss for a batch of policy and reference model log probabilities.
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
policy_chosen_logps:
|
| 641 |
+
Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
| 642 |
+
policy_rejected_logps:
|
| 643 |
+
Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
| 644 |
+
|
| 645 |
+
Returns:
|
| 646 |
+
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the CPO
|
| 647 |
+
loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
|
| 648 |
+
the chosen and rejected responses, respectively.
|
| 649 |
+
"""
|
| 650 |
+
# Apply AlphaPO reward transformation if alpha != 0
|
| 651 |
+
if self.alpha != 0.0:
|
| 652 |
+
# Compute probabilities
|
| 653 |
+
chosen_probs = torch.exp(policy_chosen_logps)
|
| 654 |
+
rejected_probs = torch.exp(policy_rejected_logps)
|
| 655 |
+
|
| 656 |
+
# Apply AlphaPO transformation: r = (1 - p^(-alpha)) / alpha
|
| 657 |
+
policy_chosen_rewards = (1 - chosen_probs.pow(-self.alpha)) / self.alpha
|
| 658 |
+
policy_rejected_rewards = (1 - rejected_probs.pow(-self.alpha)) / self.alpha
|
| 659 |
+
|
| 660 |
+
logits = (policy_chosen_rewards - policy_rejected_rewards).to(self.accelerator.device)
|
| 661 |
+
else:
|
| 662 |
+
# Standard log probability rewards when alpha = 0
|
| 663 |
+
logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
|
| 664 |
+
|
| 665 |
+
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
|
| 666 |
+
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
|
| 667 |
+
# calculates a conservative CPO loss.
|
| 668 |
+
|
| 669 |
+
if self.loss_type == "simpo":
|
| 670 |
+
gamma_logratios = self.simpo_gamma / self.beta
|
| 671 |
+
logits = logits - gamma_logratios
|
| 672 |
+
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
| 673 |
+
losses = (
|
| 674 |
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
| 675 |
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
| 676 |
+
)
|
| 677 |
+
elif self.loss_type == "sigmoid":
|
| 678 |
+
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
| 679 |
+
losses = (
|
| 680 |
+
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
| 681 |
+
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
| 682 |
+
)
|
| 683 |
+
elif self.loss_type == "hinge":
|
| 684 |
+
losses = torch.relu(1 - self.beta * logits)
|
| 685 |
+
elif self.loss_type == "ipo":
|
| 686 |
+
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
|
| 687 |
+
losses = (logits - 1 / (2 * self.beta)) ** 2
|
| 688 |
+
else:
|
| 689 |
+
raise ValueError(
|
| 690 |
+
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
# Calculate rewards for logging
|
| 694 |
+
if self.alpha != 0.0:
|
| 695 |
+
# When using AlphaPO transformation, use the transformed rewards
|
| 696 |
+
chosen_rewards = self.beta * policy_chosen_rewards.to(self.accelerator.device).detach()
|
| 697 |
+
rejected_rewards = self.beta * policy_rejected_rewards.to(self.accelerator.device).detach()
|
| 698 |
+
else:
|
| 699 |
+
# Standard log probability rewards
|
| 700 |
+
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
| 701 |
+
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
| 702 |
+
|
| 703 |
+
return losses, chosen_rewards, rejected_rewards
|
| 704 |
+
|
| 705 |
+
@staticmethod
|
| 706 |
+
def get_batch_logps(
|
| 707 |
+
logits: torch.FloatTensor,
|
| 708 |
+
labels: torch.LongTensor,
|
| 709 |
+
average_log_prob: bool = False,
|
| 710 |
+
is_encoder_decoder: bool = False,
|
| 711 |
+
) -> torch.FloatTensor:
|
| 712 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
| 713 |
+
|
| 714 |
+
Args:
|
| 715 |
+
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
| 716 |
+
labels:
|
| 717 |
+
Labels for which to compute the log probabilities. Label tokens with a value of `-100` are ignored.
|
| 718 |
+
Shape: (batch_size, sequence_length)
|
| 719 |
+
average_log_prob:
|
| 720 |
+
If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
|
| 721 |
+
log probabilities of the (non-masked) tokens.
|
| 722 |
+
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
| 723 |
+
|
| 724 |
+
Returns:
|
| 725 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
|
| 726 |
+
given logits.
|
| 727 |
+
"""
|
| 728 |
+
if logits.shape[:-1] != labels.shape:
|
| 729 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
| 730 |
+
|
| 731 |
+
if not is_encoder_decoder:
|
| 732 |
+
labels = labels[:, 1:].clone()
|
| 733 |
+
logits = logits[:, :-1, :]
|
| 734 |
+
loss_mask = labels != -100
|
| 735 |
+
|
| 736 |
+
# dummy token; we'll ignore the losses on these tokens later
|
| 737 |
+
labels[labels == -100] = 0
|
| 738 |
+
|
| 739 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
| 740 |
+
|
| 741 |
+
if average_log_prob:
|
| 742 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 743 |
+
else:
|
| 744 |
+
return (per_token_logps * loss_mask).sum(-1)
|
| 745 |
+
|
| 746 |
+
def concatenated_forward(
|
| 747 |
+
self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
|
| 748 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 749 |
+
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
| 750 |
+
|
| 751 |
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
| 752 |
+
"""
|
| 753 |
+
concatenated_batch = self.concatenated_inputs(
|
| 754 |
+
batch,
|
| 755 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 756 |
+
padding_value=self.pad_token_id,
|
| 757 |
+
device=self.accelerator.device,
|
| 758 |
+
)
|
| 759 |
+
len_chosen = batch["chosen_labels"].shape[0]
|
| 760 |
+
|
| 761 |
+
model_kwargs = (
|
| 762 |
+
{
|
| 763 |
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
| 764 |
+
}
|
| 765 |
+
if self.is_encoder_decoder
|
| 766 |
+
else {}
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
if self.aux_loss_enabled:
|
| 770 |
+
model_kwargs["output_router_logits"] = True
|
| 771 |
+
|
| 772 |
+
outputs = model(
|
| 773 |
+
concatenated_batch["concatenated_input_ids"],
|
| 774 |
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
| 775 |
+
use_cache=False,
|
| 776 |
+
**model_kwargs,
|
| 777 |
+
)
|
| 778 |
+
all_logits = outputs.logits
|
| 779 |
+
|
| 780 |
+
def cross_entropy_loss(logits, labels):
|
| 781 |
+
if not self.is_encoder_decoder:
|
| 782 |
+
# Shift so that tokens < n predict n
|
| 783 |
+
logits = logits[..., :-1, :].contiguous()
|
| 784 |
+
labels = labels[..., 1:].contiguous()
|
| 785 |
+
# Flatten the tokens
|
| 786 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 787 |
+
logits = logits.view(-1, logits.shape[-1])
|
| 788 |
+
labels = labels.view(-1)
|
| 789 |
+
# Enable model parallelism
|
| 790 |
+
labels = labels.to(logits.device)
|
| 791 |
+
loss = loss_fct(logits, labels)
|
| 792 |
+
return loss
|
| 793 |
+
|
| 794 |
+
labels = concatenated_batch["concatenated_labels"].clone()
|
| 795 |
+
|
| 796 |
+
if self.cpo_alpha == 0:
|
| 797 |
+
nll_loss = torch.tensor(0.0).to(self.accelerator.device)
|
| 798 |
+
else:
|
| 799 |
+
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
| 800 |
+
|
| 801 |
+
all_logps = self.get_batch_logps(
|
| 802 |
+
all_logits,
|
| 803 |
+
concatenated_batch["concatenated_labels"],
|
| 804 |
+
average_log_prob=self.loss_type in ["ipo", "simpo"],
|
| 805 |
+
is_encoder_decoder=self.is_encoder_decoder,
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
chosen_logps = all_logps[:len_chosen]
|
| 809 |
+
rejected_logps = all_logps[len_chosen:]
|
| 810 |
+
|
| 811 |
+
chosen_logits = all_logits[:len_chosen]
|
| 812 |
+
rejected_logits = all_logits[len_chosen:]
|
| 813 |
+
|
| 814 |
+
if self.aux_loss_enabled:
|
| 815 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
|
| 816 |
+
|
| 817 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
|
| 818 |
+
|
| 819 |
+
def get_batch_loss_metrics(
|
| 820 |
+
self,
|
| 821 |
+
model,
|
| 822 |
+
batch: dict[str, list | torch.LongTensor],
|
| 823 |
+
train_eval: Literal["train", "eval"] = "train",
|
| 824 |
+
):
|
| 825 |
+
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
| 826 |
+
metrics = {}
|
| 827 |
+
|
| 828 |
+
forward_output = self.concatenated_forward(model, batch)
|
| 829 |
+
(
|
| 830 |
+
policy_chosen_logps,
|
| 831 |
+
policy_rejected_logps,
|
| 832 |
+
policy_chosen_logits,
|
| 833 |
+
policy_rejected_logits,
|
| 834 |
+
policy_nll_loss,
|
| 835 |
+
) = forward_output[:5]
|
| 836 |
+
if self.aux_loss_enabled:
|
| 837 |
+
aux_loss = forward_output[5]
|
| 838 |
+
|
| 839 |
+
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
| 840 |
+
policy_chosen_logps,
|
| 841 |
+
policy_rejected_logps,
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
| 845 |
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
| 846 |
+
|
| 847 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
| 848 |
+
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
| 849 |
+
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
| 850 |
+
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
| 851 |
+
metrics[f"{prefix}rewards/margins"] = (
|
| 852 |
+
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
|
| 853 |
+
)
|
| 854 |
+
metrics[f"{prefix}logps/rejected"] = (
|
| 855 |
+
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
|
| 856 |
+
)
|
| 857 |
+
metrics[f"{prefix}logps/chosen"] = (
|
| 858 |
+
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
|
| 859 |
+
)
|
| 860 |
+
metrics[f"{prefix}logits/rejected"] = (
|
| 861 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean()).mean().item()
|
| 862 |
+
)
|
| 863 |
+
metrics[f"{prefix}logits/chosen"] = (
|
| 864 |
+
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean()).mean().item()
|
| 865 |
+
)
|
| 866 |
+
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
| 867 |
+
|
| 868 |
+
if self.aux_loss_enabled:
|
| 869 |
+
loss += self.aux_loss_coef * aux_loss
|
| 870 |
+
|
| 871 |
+
return loss, metrics
|
| 872 |
+
|
| 873 |
+
def compute_loss(
|
| 874 |
+
self,
|
| 875 |
+
model: PreTrainedModel | nn.Module,
|
| 876 |
+
inputs: dict[str, torch.Tensor | Any],
|
| 877 |
+
return_outputs=False,
|
| 878 |
+
num_items_in_batch=None,
|
| 879 |
+
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
| 880 |
+
compute_loss_context_manager = (
|
| 881 |
+
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
with compute_loss_context_manager:
|
| 885 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
| 886 |
+
|
| 887 |
+
# force log the metrics
|
| 888 |
+
self.store_metrics(metrics, train_eval="train")
|
| 889 |
+
|
| 890 |
+
if return_outputs:
|
| 891 |
+
return (loss, metrics)
|
| 892 |
+
return loss
|
| 893 |
+
|
| 894 |
+
def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
|
| 895 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
| 896 |
+
|
| 897 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
| 898 |
+
# the torch amp context manager as some hidden states are silently casted to full precision.
|
| 899 |
+
generate_context_manager = (
|
| 900 |
+
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
with generate_context_manager:
|
| 904 |
+
policy_output = model.generate(
|
| 905 |
+
input_ids=batch["prompt_input_ids"],
|
| 906 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 907 |
+
max_length=self.max_length,
|
| 908 |
+
do_sample=True,
|
| 909 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
| 913 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
| 914 |
+
|
| 915 |
+
return policy_output_decoded
|
| 916 |
+
|
| 917 |
+
def prediction_step(
|
| 918 |
+
self,
|
| 919 |
+
model: PreTrainedModel | nn.Module,
|
| 920 |
+
inputs: dict[str, torch.Tensor | Any],
|
| 921 |
+
prediction_loss_only: bool,
|
| 922 |
+
ignore_keys: list[str] | None = None,
|
| 923 |
+
):
|
| 924 |
+
if ignore_keys is None:
|
| 925 |
+
if hasattr(model, "config"):
|
| 926 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
| 927 |
+
else:
|
| 928 |
+
ignore_keys = []
|
| 929 |
+
|
| 930 |
+
prediction_context_manager = (
|
| 931 |
+
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
with torch.no_grad(), prediction_context_manager:
|
| 935 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
| 936 |
+
|
| 937 |
+
# force log the metrics
|
| 938 |
+
self.store_metrics(metrics, train_eval="eval")
|
| 939 |
+
|
| 940 |
+
if prediction_loss_only:
|
| 941 |
+
return (loss.detach(), None, None)
|
| 942 |
+
|
| 943 |
+
# logits for the chosen and rejected samples from model
|
| 944 |
+
logits_dict = {
|
| 945 |
+
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
| 946 |
+
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
| 947 |
+
}
|
| 948 |
+
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
|
| 949 |
+
logits = torch.tensor(logits, device=self.accelerator.device)
|
| 950 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
| 951 |
+
|
| 952 |
+
return (loss.detach(), logits, labels)
|
| 953 |
+
|
| 954 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
| 955 |
+
for key, value in metrics.items():
|
| 956 |
+
self._stored_metrics[train_eval][key].append(value)
|
| 957 |
+
|
| 958 |
+
def evaluation_loop(
|
| 959 |
+
self,
|
| 960 |
+
dataloader: DataLoader,
|
| 961 |
+
description: str,
|
| 962 |
+
prediction_loss_only: bool | None = None,
|
| 963 |
+
ignore_keys: list[str] | None = None,
|
| 964 |
+
metric_key_prefix: str = "eval",
|
| 965 |
+
) -> EvalLoopOutput:
|
| 966 |
+
"""
|
| 967 |
+
Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
|
| 968 |
+
`Trainer.evaluate()` and `Trainer.predict()`.
|
| 969 |
+
|
| 970 |
+
Works both with or without labels.
|
| 971 |
+
"""
|
| 972 |
+
|
| 973 |
+
# Sample and save to game log if requested (for one batch to save time)
|
| 974 |
+
if self.generate_during_eval:
|
| 975 |
+
# Generate random indices within the range of the total number of samples
|
| 976 |
+
num_samples = len(dataloader.dataset)
|
| 977 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
| 978 |
+
|
| 979 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
| 980 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
| 981 |
+
random_batch = self.data_collator(random_batch_dataset)
|
| 982 |
+
random_batch = self._prepare_inputs(random_batch)
|
| 983 |
+
|
| 984 |
+
policy_output_decoded = self.generate_from_model(self.model, random_batch)
|
| 985 |
+
|
| 986 |
+
table = pd.DataFrame(
|
| 987 |
+
columns=["Prompt", "Policy"],
|
| 988 |
+
data=[
|
| 989 |
+
[prompt, pol[len(prompt) :]]
|
| 990 |
+
for prompt, pol in zip(random_batch["prompt"], policy_output_decoded, strict=True)
|
| 991 |
+
],
|
| 992 |
+
)
|
| 993 |
+
if "wandb" in self.args.report_to:
|
| 994 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
| 995 |
+
|
| 996 |
+
if "comet_ml" in self.args.report_to:
|
| 997 |
+
log_table_to_comet_experiment(
|
| 998 |
+
name="game_log.csv",
|
| 999 |
+
table=table,
|
| 1000 |
+
)
|
| 1001 |
+
|
| 1002 |
+
# Base evaluation
|
| 1003 |
+
initial_output = super().evaluation_loop(
|
| 1004 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
return initial_output
|
| 1008 |
+
|
| 1009 |
+
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
| 1010 |
+
"""
|
| 1011 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
| 1012 |
+
|
| 1013 |
+
Args:
|
| 1014 |
+
logs (`dict[str, float]`):
|
| 1015 |
+
The values to log.
|
| 1016 |
+
start_time (`float`, *optional*):
|
| 1017 |
+
Start time of the training.
|
| 1018 |
+
"""
|
| 1019 |
+
# logs either has 'loss' or 'eval_loss'
|
| 1020 |
+
train_eval = "train" if "loss" in logs else "eval"
|
| 1021 |
+
# Add averaged stored metrics to logs
|
| 1022 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
| 1023 |
+
logs[key] = torch.tensor(metrics).mean().item()
|
| 1024 |
+
del self._stored_metrics[train_eval]
|
| 1025 |
+
return super().log(logs, start_time)
|
| 1026 |
+
|
| 1027 |
+
def _shift_right(self, input_ids):
|
| 1028 |
+
if self.decoder_start_token_id is None:
|
| 1029 |
+
raise ValueError(
|
| 1030 |
+
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
| 1031 |
+
)
|
| 1032 |
+
|
| 1033 |
+
# shift inputs to the right
|
| 1034 |
+
if is_torch_fx_proxy(input_ids):
|
| 1035 |
+
# Item assignment is not supported natively for proxies.
|
| 1036 |
+
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
| 1037 |
+
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
| 1038 |
+
else:
|
| 1039 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 1040 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
| 1041 |
+
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
| 1042 |
+
|
| 1043 |
+
if self.pad_token_id is None:
|
| 1044 |
+
raise ValueError("model.config.pad_token_id has to be defined.")
|
| 1045 |
+
# replace possible -100 values in labels by `pad_token_id`
|
| 1046 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
| 1047 |
+
|
| 1048 |
+
return shifted_input_ids
|
| 1049 |
+
|
| 1050 |
+
# Ensure the model card is saved along with the checkpoint
|
| 1051 |
+
def _save_checkpoint(self, model, trial):
|
| 1052 |
+
if self.args.hub_model_id is None:
|
| 1053 |
+
model_name = Path(self.args.output_dir).name
|
| 1054 |
+
else:
|
| 1055 |
+
model_name = self.args.hub_model_id.split("/")[-1]
|
| 1056 |
+
self.create_model_card(model_name=model_name)
|
| 1057 |
+
super()._save_checkpoint(model, trial)
|
ICL/RL/trl_source/trl/experimental/gfpo/gfpo_config.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
|
| 17 |
+
from ...trainer.grpo_config import GRPOConfig as _GRPOConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class GFPOConfig(_GRPOConfig):
|
| 22 |
+
num_remains_in_group: int | None = field(
|
| 23 |
+
default=None,
|
| 24 |
+
metadata={
|
| 25 |
+
"help": "number inputs remains after group filter function, `'num_remains_in_group'` must be >=2 if given."
|
| 26 |
+
},
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def __post_init__(self):
|
| 30 |
+
super().__post_init__()
|
| 31 |
+
|
| 32 |
+
if self.num_remains_in_group is not None and self.num_remains_in_group >= self.num_generations:
|
| 33 |
+
raise ValueError(
|
| 34 |
+
f"Number remains in Group {self.num_remains_in_group} must be less than num_generations : {self.num_generations}."
|
| 35 |
+
)
|
ICL/RL/trl_source/trl/experimental/gkd/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .gkd_config import GKDConfig
|
| 16 |
+
from .gkd_trainer import GKDTrainer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = ["GKDConfig", "GKDTrainer"]
|
ICL/RL/trl_source/trl/experimental/gkd/gkd_config.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from transformers import TrainingArguments
|
| 19 |
+
|
| 20 |
+
from ...trainer.sft_config import SFTConfig
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class GKDConfig(SFTConfig):
|
| 25 |
+
"""
|
| 26 |
+
Configuration class for [`experimental.gkd.GKDTrainer`].
|
| 27 |
+
|
| 28 |
+
This class includes only the parameters that are specific to GKD training. For a full list of training arguments,
|
| 29 |
+
please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
temperature (`float`, *optional*, defaults to `0.9`):
|
| 33 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
| 34 |
+
lmbda (`float`, *optional*, defaults to `0.5`):
|
| 35 |
+
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
|
| 36 |
+
student-generated outputs).
|
| 37 |
+
beta (`float`, *optional*, defaults to `0.5`):
|
| 38 |
+
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
|
| 39 |
+
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
|
| 40 |
+
max_new_tokens (`int`, *optional*, defaults to `128`):
|
| 41 |
+
Maximum number of tokens to generate per completion.
|
| 42 |
+
teacher_model_name_or_path (`str`, *optional*):
|
| 43 |
+
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
|
| 44 |
+
trained.
|
| 45 |
+
teacher_model_init_kwargs (`dict[str, Any]]`, *optional*):
|
| 46 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
|
| 47 |
+
from a string.
|
| 48 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 49 |
+
Whether to disable dropout in the model.
|
| 50 |
+
seq_kd (`bool`, *optional*, defaults to `False`):
|
| 51 |
+
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
|
| 52 |
+
teacher-generated output).
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
|
| 56 |
+
|
| 57 |
+
temperature: float = field(
|
| 58 |
+
default=0.9,
|
| 59 |
+
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
|
| 60 |
+
)
|
| 61 |
+
lmbda: float = field(
|
| 62 |
+
default=0.5,
|
| 63 |
+
metadata={
|
| 64 |
+
"help": "Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy "
|
| 65 |
+
"student-generated outputs)."
|
| 66 |
+
},
|
| 67 |
+
)
|
| 68 |
+
beta: float = field(
|
| 69 |
+
default=0.5,
|
| 70 |
+
metadata={
|
| 71 |
+
"help": "Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence "
|
| 72 |
+
"loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL "
|
| 73 |
+
"Divergence."
|
| 74 |
+
},
|
| 75 |
+
)
|
| 76 |
+
max_new_tokens: int = field(
|
| 77 |
+
default=128,
|
| 78 |
+
metadata={"help": "Maximum number of tokens to generate per completion."},
|
| 79 |
+
)
|
| 80 |
+
teacher_model_name_or_path: str | None = field(
|
| 81 |
+
default=None,
|
| 82 |
+
metadata={
|
| 83 |
+
"help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the "
|
| 84 |
+
"model being trained."
|
| 85 |
+
},
|
| 86 |
+
)
|
| 87 |
+
teacher_model_init_kwargs: dict[str, Any] | None = field(
|
| 88 |
+
default=None,
|
| 89 |
+
metadata={
|
| 90 |
+
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
|
| 91 |
+
"teacher model from a string."
|
| 92 |
+
},
|
| 93 |
+
)
|
| 94 |
+
disable_dropout: bool = field(
|
| 95 |
+
default=True,
|
| 96 |
+
metadata={"help": "Whether to disable dropouts in `model`."},
|
| 97 |
+
)
|
| 98 |
+
seq_kd: bool = field(
|
| 99 |
+
default=False,
|
| 100 |
+
metadata={
|
| 101 |
+
"help": "Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised "
|
| 102 |
+
"FT on teacher-generated output)."
|
| 103 |
+
},
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def __post_init__(self):
|
| 107 |
+
super().__post_init__()
|
| 108 |
+
# check lmbda and beta are in the range [0, 1]
|
| 109 |
+
if self.lmbda < 0.0 or self.lmbda > 1.0:
|
| 110 |
+
raise ValueError("lmbda must be in the range [0.0, 1.0].")
|
| 111 |
+
if self.beta < 0.0 or self.beta > 1.0:
|
| 112 |
+
raise ValueError("beta must be in the range [0.0, 1.0].")
|
ICL/RL/trl_source/trl/experimental/gold/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .gold_config import GOLDConfig
|
| 16 |
+
from .gold_trainer import GOLDTrainer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = ["GOLDConfig", "GOLDTrainer"]
|
ICL/RL/trl_source/trl/experimental/gold/gold.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# /// script
|
| 16 |
+
# dependencies = [
|
| 17 |
+
# "trl @ git+https://github.com/huggingface/trl.git",
|
| 18 |
+
# "peft",
|
| 19 |
+
# "trackio",
|
| 20 |
+
# ]
|
| 21 |
+
# ///
|
| 22 |
+
|
| 23 |
+
# docstyle-ignore
|
| 24 |
+
"""
|
| 25 |
+
# Full training:
|
| 26 |
+
python trl/experimental/gold/gold.py \
|
| 27 |
+
--model_name_or_path meta-llama/Llama-3.2-1B-Instruct \
|
| 28 |
+
--teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \
|
| 29 |
+
--dataset_name trl-lib/chatbot_arena_completions \
|
| 30 |
+
--learning_rate 2e-5 \
|
| 31 |
+
--per_device_train_batch_size 4 \
|
| 32 |
+
--gradient_accumulation_steps 8 \
|
| 33 |
+
--output_dir gold-model \
|
| 34 |
+
--num_train_epochs 1 \
|
| 35 |
+
--push_to_hub
|
| 36 |
+
|
| 37 |
+
# LoRA:
|
| 38 |
+
python trl/experimental/gold/gold.py \
|
| 39 |
+
--model_name_or_path meta-llama/Llama-3.2-1B-Instruct \
|
| 40 |
+
--teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \
|
| 41 |
+
--dataset_name trl-lib/chatbot_arena_completions \
|
| 42 |
+
--learning_rate 2e-4 \
|
| 43 |
+
--per_device_train_batch_size 4 \
|
| 44 |
+
--gradient_accumulation_steps 8 \
|
| 45 |
+
--output_dir gold-model \
|
| 46 |
+
--num_train_epochs 1 \
|
| 47 |
+
--push_to_hub \
|
| 48 |
+
--use_peft \
|
| 49 |
+
--lora_r 64 \
|
| 50 |
+
--lora_alpha 16
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
import logging
|
| 54 |
+
|
| 55 |
+
from datasets import load_dataset
|
| 56 |
+
from transformers import AutoTokenizer, GenerationConfig
|
| 57 |
+
|
| 58 |
+
from trl import (
|
| 59 |
+
LogCompletionsCallback,
|
| 60 |
+
ModelConfig,
|
| 61 |
+
ScriptArguments,
|
| 62 |
+
TrlParser,
|
| 63 |
+
get_kbit_device_map,
|
| 64 |
+
get_peft_config,
|
| 65 |
+
get_quantization_config,
|
| 66 |
+
)
|
| 67 |
+
from trl.experimental.gold.gold_config import GOLDConfig
|
| 68 |
+
from trl.experimental.gold.gold_trainer import GOLDTrainer
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
logger = logging.getLogger(__name__)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
parser = TrlParser((ScriptArguments, GOLDConfig, ModelConfig))
|
| 76 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 77 |
+
|
| 78 |
+
################
|
| 79 |
+
# Model & Tokenizer
|
| 80 |
+
################
|
| 81 |
+
quantization_config = get_quantization_config(model_args)
|
| 82 |
+
model_kwargs = dict(
|
| 83 |
+
revision=training_args.student_model_revision,
|
| 84 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 85 |
+
attn_implementation=model_args.attn_implementation,
|
| 86 |
+
torch_dtype=model_args.dtype,
|
| 87 |
+
use_cache=False if training_args.gradient_checkpointing else True,
|
| 88 |
+
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
| 89 |
+
quantization_config=quantization_config,
|
| 90 |
+
)
|
| 91 |
+
training_args.model_init_kwargs = model_kwargs
|
| 92 |
+
|
| 93 |
+
if training_args.teacher_tokenizer_name_or_path is None and training_args.use_uld_loss:
|
| 94 |
+
training_args.teacher_tokenizer_name_or_path = training_args.teacher_model_name_or_path
|
| 95 |
+
teacher_model_kwargs = dict(
|
| 96 |
+
revision=model_args.model_revision,
|
| 97 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 98 |
+
attn_implementation=model_args.attn_implementation,
|
| 99 |
+
torch_dtype=model_args.dtype,
|
| 100 |
+
use_cache=True,
|
| 101 |
+
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
| 102 |
+
quantization_config=quantization_config,
|
| 103 |
+
)
|
| 104 |
+
training_args.teacher_model_init_kwargs = teacher_model_kwargs
|
| 105 |
+
|
| 106 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 107 |
+
model_args.model_name_or_path,
|
| 108 |
+
revision=model_args.model_revision,
|
| 109 |
+
trust_remote_code=model_args.trust_remote_code,
|
| 110 |
+
padding_side="left",
|
| 111 |
+
)
|
| 112 |
+
if tokenizer.pad_token is None:
|
| 113 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 114 |
+
|
| 115 |
+
################
|
| 116 |
+
# Dataset
|
| 117 |
+
################
|
| 118 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 119 |
+
|
| 120 |
+
################
|
| 121 |
+
# Training
|
| 122 |
+
################
|
| 123 |
+
# Handle eval dataset - check if test split exists, fallback to validation or None
|
| 124 |
+
eval_dataset = None
|
| 125 |
+
if training_args.eval_strategy != "no":
|
| 126 |
+
if script_args.dataset_test_split in dataset:
|
| 127 |
+
eval_dataset = dataset[script_args.dataset_test_split]
|
| 128 |
+
elif "validation" in dataset:
|
| 129 |
+
eval_dataset = dataset["validation"]
|
| 130 |
+
elif "dev" in dataset:
|
| 131 |
+
eval_dataset = dataset["dev"]
|
| 132 |
+
|
| 133 |
+
trainer = GOLDTrainer(
|
| 134 |
+
model=model_args.model_name_or_path,
|
| 135 |
+
teacher_model=training_args.teacher_model_name_or_path,
|
| 136 |
+
args=training_args,
|
| 137 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 138 |
+
eval_dataset=eval_dataset,
|
| 139 |
+
processing_class=tokenizer,
|
| 140 |
+
peft_config=get_peft_config(model_args),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if training_args.eval_strategy != "no":
|
| 144 |
+
generation_config = GenerationConfig(
|
| 145 |
+
max_new_tokens=training_args.max_completion_length, do_sample=True, temperature=training_args.temperature
|
| 146 |
+
)
|
| 147 |
+
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
|
| 148 |
+
trainer.add_callback(completions_callback)
|
| 149 |
+
|
| 150 |
+
trainer.train()
|
| 151 |
+
|
| 152 |
+
# Save and push to hub
|
| 153 |
+
trainer.save_model(training_args.output_dir)
|
| 154 |
+
if training_args.push_to_hub:
|
| 155 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
ICL/RL/trl_source/trl/experimental/gold/gold_config.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from transformers import TrainingArguments
|
| 19 |
+
|
| 20 |
+
from ...trainer.sft_config import SFTConfig
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class GOLDConfig(SFTConfig):
|
| 25 |
+
r"""
|
| 26 |
+
Configuration class for [`GOLDTrainer`].
|
| 27 |
+
|
| 28 |
+
This class includes only the parameters that are specific to GOLD training. For a full list of training arguments,
|
| 29 |
+
please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
temperature (`float`, *optional*, defaults to `0.9`):
|
| 33 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
| 34 |
+
lmbda (`float`, *optional*, defaults to `0.5`):
|
| 35 |
+
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
|
| 36 |
+
student-generated outputs).
|
| 37 |
+
beta (`float`, *optional*, defaults to `0.5`):
|
| 38 |
+
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
|
| 39 |
+
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
|
| 40 |
+
max_completion_length (`int`, *optional*, defaults to `128`):
|
| 41 |
+
Maximum number of tokens to generate per completion.
|
| 42 |
+
teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
|
| 43 |
+
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
|
| 44 |
+
trained.
|
| 45 |
+
teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
|
| 46 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
|
| 47 |
+
from a string.
|
| 48 |
+
teacher_tokenizer_name_or_path (`str` or `None`, *optional*, defaults to `None`):
|
| 49 |
+
Tokenizer name or path for the teacher model. If None when using ULD loss, will use the same tokenizer as
|
| 50 |
+
the student model (not recommended for cross-tokenizer distillation).
|
| 51 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 52 |
+
Whether to disable dropout in the model.
|
| 53 |
+
seq_kd (`bool`, *optional*, defaults to `False`):
|
| 54 |
+
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
|
| 55 |
+
teacher-generated output).
|
| 56 |
+
use_uld_loss (`bool`, *optional*, defaults to `False`):
|
| 57 |
+
Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence
|
| 58 |
+
loss.
|
| 59 |
+
uld_crossentropy_weight (`float`, *optional*, defaults to `0.0`):
|
| 60 |
+
Weight for the cross-entropy loss component in ULD loss. If 0, only ULD distillation loss is used.
|
| 61 |
+
uld_distillation_weight (`float`, *optional*, defaults to `1.0`):
|
| 62 |
+
Weight for the distillation loss component in ULD loss.
|
| 63 |
+
uld_student_temperature (`float`, *optional*, defaults to `1.0`):
|
| 64 |
+
Temperature for student logits in ULD loss computation.
|
| 65 |
+
uld_teacher_temperature (`float`, *optional*, defaults to `1.0`):
|
| 66 |
+
Temperature for teacher logits in ULD loss computation.
|
| 67 |
+
uld_skip_student_eos (`bool`, *optional*, defaults to `True`):
|
| 68 |
+
Whether to skip EOS token for student in ULD loss computation.
|
| 69 |
+
uld_skip_teacher_eos (`bool`, *optional*, defaults to `True`):
|
| 70 |
+
Whether to skip EOS token for teacher in ULD loss computation.
|
| 71 |
+
use_vllm (`bool`, *optional*, defaults to `False`):
|
| 72 |
+
Whether to use vLLM for generating completions from the student model. Requires `vllm` to be installed.
|
| 73 |
+
vllm_mode (`str`, *optional*, defaults to `"server"`):
|
| 74 |
+
Mode for student vLLM integration. Either `"server"` (connect to a running TRL vLLM server) or `"colocate"`
|
| 75 |
+
(run vLLM in the same process).
|
| 76 |
+
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
|
| 77 |
+
Host of the vLLM server for the student model (if `vllm_mode="server"`).
|
| 78 |
+
vllm_server_port (`int`, *optional*, defaults to `8001`):
|
| 79 |
+
Port of the vLLM server for the student model (if `vllm_mode="server"`).
|
| 80 |
+
vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
|
| 81 |
+
Timeout for connecting to the student vLLM server (if `vllm_mode="server"`).
|
| 82 |
+
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
|
| 83 |
+
GPU memory utilization for the colocated student vLLM engine (if `vllm_mode="colocate"`). It is recommended
|
| 84 |
+
to set this to a low value if the student and teacher models share the same GPU.
|
| 85 |
+
vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
|
| 86 |
+
Tensor parallel size for the colocated student vLLM engine (if `vllm_mode="colocate"`).
|
| 87 |
+
vllm_structured_outputs_regex (`str` or `None`, *optional*, defaults to `None`):
|
| 88 |
+
Regex for vLLM structured outputs for the student model.
|
| 89 |
+
vllm_sync_frequency (`int`, *optional*, defaults to `1`):
|
| 90 |
+
Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after
|
| 91 |
+
every step.
|
| 92 |
+
vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`):
|
| 93 |
+
Enable vLLM sleep mode to offload student weights/cache during the optimizer step. Keeps GPU memory usage
|
| 94 |
+
low, but waking the engine adds host–device transfer latency.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
|
| 98 |
+
|
| 99 |
+
# Parameters whose default values are overridden from TrainingArguments
|
| 100 |
+
learning_rate: float = field(
|
| 101 |
+
default=1e-7,
|
| 102 |
+
metadata={"help": "The initial learning rate for AdamW."},
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# GOLD-specific parameters
|
| 106 |
+
temperature: float = field(
|
| 107 |
+
default=0.9,
|
| 108 |
+
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
|
| 109 |
+
)
|
| 110 |
+
top_p: float = field(
|
| 111 |
+
default=0.95,
|
| 112 |
+
metadata={
|
| 113 |
+
"help": "If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to "
|
| 114 |
+
"`top_p` or higher are kept for generation."
|
| 115 |
+
},
|
| 116 |
+
)
|
| 117 |
+
top_k: int = field(
|
| 118 |
+
default=0,
|
| 119 |
+
metadata={
|
| 120 |
+
"help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, "
|
| 121 |
+
"top-k-filtering is disabled and all tokens are considered."
|
| 122 |
+
},
|
| 123 |
+
)
|
| 124 |
+
lmbda: float = field(
|
| 125 |
+
default=0.5,
|
| 126 |
+
metadata={
|
| 127 |
+
"help": "Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy "
|
| 128 |
+
"student-generated outputs)."
|
| 129 |
+
},
|
| 130 |
+
)
|
| 131 |
+
beta: float = field(
|
| 132 |
+
default=0.5,
|
| 133 |
+
metadata={
|
| 134 |
+
"help": "Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence "
|
| 135 |
+
"loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL "
|
| 136 |
+
"Divergence."
|
| 137 |
+
},
|
| 138 |
+
)
|
| 139 |
+
max_completion_length: int = field(
|
| 140 |
+
default=128,
|
| 141 |
+
metadata={"help": "Maximum number of tokens to generate per completion."},
|
| 142 |
+
)
|
| 143 |
+
student_model_revision: str = field(
|
| 144 |
+
default="main",
|
| 145 |
+
metadata={
|
| 146 |
+
"help": "Revision of the student model to use. If not specified, the default revision of the model will be used."
|
| 147 |
+
},
|
| 148 |
+
)
|
| 149 |
+
teacher_model_name_or_path: str | None = field(
|
| 150 |
+
default=None,
|
| 151 |
+
metadata={
|
| 152 |
+
"help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the "
|
| 153 |
+
"model being trained."
|
| 154 |
+
},
|
| 155 |
+
)
|
| 156 |
+
teacher_model_init_kwargs: dict[str, Any] | None = field(
|
| 157 |
+
default=None,
|
| 158 |
+
metadata={
|
| 159 |
+
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
|
| 160 |
+
"teacher model from a string."
|
| 161 |
+
},
|
| 162 |
+
)
|
| 163 |
+
teacher_tokenizer_name_or_path: str | None = field(
|
| 164 |
+
default=None,
|
| 165 |
+
metadata={
|
| 166 |
+
"help": "Tokenizer name or path for the teacher model. If None when using ULD loss, will use the same "
|
| 167 |
+
"tokenizer as the student model (not recommended for cross-tokenizer distillation)."
|
| 168 |
+
},
|
| 169 |
+
)
|
| 170 |
+
disable_dropout: bool = field(
|
| 171 |
+
default=True,
|
| 172 |
+
metadata={"help": "Whether to disable dropouts in `model`."},
|
| 173 |
+
)
|
| 174 |
+
seq_kd: bool = field(
|
| 175 |
+
default=False,
|
| 176 |
+
metadata={
|
| 177 |
+
"help": "Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised "
|
| 178 |
+
"FT on teacher-generated output)."
|
| 179 |
+
},
|
| 180 |
+
)
|
| 181 |
+
steps_per_generation: int | None = field(
|
| 182 |
+
default=None,
|
| 183 |
+
metadata={
|
| 184 |
+
"help": "Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps."
|
| 185 |
+
},
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# ULD Loss parameters
|
| 189 |
+
use_uld_loss: bool = field(
|
| 190 |
+
default=False,
|
| 191 |
+
metadata={
|
| 192 |
+
"help": "Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence loss."
|
| 193 |
+
},
|
| 194 |
+
)
|
| 195 |
+
use_extended_uld: bool = field(
|
| 196 |
+
default=True,
|
| 197 |
+
metadata={
|
| 198 |
+
"help": (
|
| 199 |
+
"Whether to enable extended ULD alignment that uses tokenizers to align and merge token "
|
| 200 |
+
"probabilities across student and teacher tokenizations. When True, the trainer will compute "
|
| 201 |
+
"token mappings and merge probabilities for split tokens; when False, ULD will use simple "
|
| 202 |
+
"positional truncation like in the original ULD paper."
|
| 203 |
+
)
|
| 204 |
+
},
|
| 205 |
+
)
|
| 206 |
+
uld_use_hybrid_loss: bool = field(
|
| 207 |
+
default=False,
|
| 208 |
+
metadata={
|
| 209 |
+
"help": (
|
| 210 |
+
"Whether to use a hybrid loss that combines ULD loss and JSD loss. When True, the final loss is a "
|
| 211 |
+
"a combination of JSD for known token mappings and ULD for unknown token mappings."
|
| 212 |
+
)
|
| 213 |
+
},
|
| 214 |
+
)
|
| 215 |
+
uld_hybrid_matched_weight: float | None = field(
|
| 216 |
+
default=None,
|
| 217 |
+
metadata={
|
| 218 |
+
"help": (
|
| 219 |
+
"Weight for the matched token loss component when using hybrid ULD + JSD loss. This weight scales "
|
| 220 |
+
"the JSD loss computed over tokens that have a direct mapping between student and teacher "
|
| 221 |
+
"tokenizations. If None, uses adaptive weighting based on vocabulary overlap. Must be set together "
|
| 222 |
+
"with uld_hybrid_unmatched_weight (both None or both float)."
|
| 223 |
+
)
|
| 224 |
+
},
|
| 225 |
+
)
|
| 226 |
+
uld_hybrid_unmatched_weight: float | None = field(
|
| 227 |
+
default=None,
|
| 228 |
+
metadata={
|
| 229 |
+
"help": (
|
| 230 |
+
"Weight for the unmatched token loss component when using hybrid ULD + JSD loss. This weight scales "
|
| 231 |
+
"the ULD loss computed over tokens that do not have a direct mapping between student and teacher "
|
| 232 |
+
"tokenizations. If None, uses adaptive weighting based on vocabulary overlap. Must be set together "
|
| 233 |
+
"with uld_hybrid_matched_weight (both None or both float)."
|
| 234 |
+
)
|
| 235 |
+
},
|
| 236 |
+
)
|
| 237 |
+
uld_crossentropy_weight: float = field(
|
| 238 |
+
default=0.0,
|
| 239 |
+
metadata={"help": "Weight for the cross-entropy loss component in ULD loss."},
|
| 240 |
+
)
|
| 241 |
+
uld_distillation_weight: float = field(
|
| 242 |
+
default=1.0,
|
| 243 |
+
metadata={"help": "Weight for the distillation loss component in ULD loss."},
|
| 244 |
+
)
|
| 245 |
+
uld_student_temperature: float = field(
|
| 246 |
+
default=1.0,
|
| 247 |
+
metadata={"help": "Temperature for student logits in ULD loss computation."},
|
| 248 |
+
)
|
| 249 |
+
uld_teacher_temperature: float = field(
|
| 250 |
+
default=1.0,
|
| 251 |
+
metadata={"help": "Temperature for teacher logits in ULD loss computation."},
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
uld_skip_student_eos: bool = field(
|
| 255 |
+
default=True,
|
| 256 |
+
metadata={"help": "Whether to skip EOS token for student in ULD loss computation."},
|
| 257 |
+
)
|
| 258 |
+
uld_skip_teacher_eos: bool = field(
|
| 259 |
+
default=True,
|
| 260 |
+
metadata={"help": "Whether to skip EOS token for teacher in ULD loss computation."},
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# transformers paged attention
|
| 264 |
+
use_transformers_paged: bool = field(
|
| 265 |
+
default=False,
|
| 266 |
+
metadata={
|
| 267 |
+
"help": "Whether to use the `transformers` paged implementation for generation. If set to `True`, the "
|
| 268 |
+
"`transformers` paged implementation will be used for generation instead of the default padded "
|
| 269 |
+
"implementation."
|
| 270 |
+
},
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# vLLM parameters
|
| 274 |
+
use_vllm: bool = field(
|
| 275 |
+
default=False,
|
| 276 |
+
metadata={"help": "Whether to use vLLM for generating completions. Requires `vllm` to be installed."},
|
| 277 |
+
)
|
| 278 |
+
vllm_mode: str = field(
|
| 279 |
+
default="server",
|
| 280 |
+
metadata={
|
| 281 |
+
"help": 'Mode for vLLM integration. Either "server" (connect to a running TRL vLLM server) or "colocate" (run vLLM in the same process).'
|
| 282 |
+
},
|
| 283 |
+
)
|
| 284 |
+
vllm_server_host: str = field(
|
| 285 |
+
default="0.0.0.0",
|
| 286 |
+
metadata={"help": 'Host of the vLLM server when `vllm_mode="server"`.'},
|
| 287 |
+
)
|
| 288 |
+
vllm_server_port: int = field(
|
| 289 |
+
default=8001,
|
| 290 |
+
metadata={"help": 'Port of the vLLM server when `vllm_mode="server"`.'},
|
| 291 |
+
)
|
| 292 |
+
vllm_server_timeout: float = field(
|
| 293 |
+
default=240.0,
|
| 294 |
+
metadata={"help": 'Timeout (in seconds) for connecting to the vLLM server when `vllm_mode="server"`.'},
|
| 295 |
+
)
|
| 296 |
+
vllm_gpu_memory_utilization: float = field(
|
| 297 |
+
default=0.9,
|
| 298 |
+
metadata={
|
| 299 |
+
"help": 'GPU memory utilization for the colocated vLLM engine when `vllm_mode="colocate"`. Lower values reduce contention when sharing a device with the student/teacher models.'
|
| 300 |
+
},
|
| 301 |
+
)
|
| 302 |
+
vllm_tensor_parallel_size: int = field(
|
| 303 |
+
default=1,
|
| 304 |
+
metadata={"help": 'Tensor parallel size for the colocated vLLM engine when `vllm_mode="colocate"`.'},
|
| 305 |
+
)
|
| 306 |
+
vllm_structured_outputs_regex: str | None = field(
|
| 307 |
+
default=None,
|
| 308 |
+
metadata={"help": "Regex pattern used for vLLM structured outputs (optional)."},
|
| 309 |
+
)
|
| 310 |
+
vllm_sync_frequency: int = field(
|
| 311 |
+
default=1,
|
| 312 |
+
metadata={
|
| 313 |
+
"help": "Frequency (in training steps) to synchronize model weights to the vLLM engine. Set to 1 to sync after every step."
|
| 314 |
+
},
|
| 315 |
+
)
|
| 316 |
+
vllm_enable_sleep_mode: bool = field(
|
| 317 |
+
default=False,
|
| 318 |
+
metadata={
|
| 319 |
+
"help": "Enable vLLM sleep mode to offload student weights/cache during the optimizer step. Keeps GPU "
|
| 320 |
+
"memory usage low, but waking the engine adds host–device transfer latency."
|
| 321 |
+
},
|
| 322 |
+
)
|
| 323 |
+
# Parameters that control the logging
|
| 324 |
+
log_completions: bool = field(
|
| 325 |
+
default=False,
|
| 326 |
+
metadata={
|
| 327 |
+
"help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
|
| 328 |
+
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
|
| 329 |
+
},
|
| 330 |
+
)
|
| 331 |
+
log_completions_steps: int = field(
|
| 332 |
+
default=100,
|
| 333 |
+
metadata={
|
| 334 |
+
"help": "Number of steps between logging (prompt, completion) pairs. Only used if `log_completions` is "
|
| 335 |
+
"set to `True`."
|
| 336 |
+
},
|
| 337 |
+
)
|
| 338 |
+
num_completions_to_print: int | None = field(
|
| 339 |
+
default=None,
|
| 340 |
+
metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."},
|
| 341 |
+
)
|
| 342 |
+
wandb_entity: str | None = field(
|
| 343 |
+
default=None,
|
| 344 |
+
metadata={"help": ("The entity to store runs under.")},
|
| 345 |
+
)
|
| 346 |
+
wandb_project: str | None = field(
|
| 347 |
+
default=None,
|
| 348 |
+
metadata={"help": ("The project to store runs under.")},
|
| 349 |
+
)
|
| 350 |
+
wandb_run_group: str | None = field(
|
| 351 |
+
default=None,
|
| 352 |
+
metadata={"help": ("The group to store runs under.")},
|
| 353 |
+
)
|
| 354 |
+
wandb_log_unique_prompts: bool = field(
|
| 355 |
+
default=True,
|
| 356 |
+
metadata={
|
| 357 |
+
"help": ("Whether to log the unique prompts to wandb. This will create a new run for each unique prompt.")
|
| 358 |
+
},
|
| 359 |
+
)
|
| 360 |
+
callbacks: list[str] = field(
|
| 361 |
+
default_factory=lambda: [],
|
| 362 |
+
metadata={"help": "The callbacks to run during training."},
|
| 363 |
+
)
|
| 364 |
+
hub_model_revision: str | None = field(
|
| 365 |
+
default="main", metadata={"help": "The Hub model branch to push the model to."}
|
| 366 |
+
)
|
| 367 |
+
num_completions_to_print: int = field(default=5, metadata={"help": "Number of completions to print."})
|
| 368 |
+
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
| 369 |
+
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
| 370 |
+
trl_project: str = field(
|
| 371 |
+
default="smollm3",
|
| 372 |
+
metadata={
|
| 373 |
+
"help": "The TRL project to use for evaluation. This is used to determine the path to the evaluation script."
|
| 374 |
+
},
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
def __post_init__(self):
|
| 378 |
+
super().__post_init__()
|
| 379 |
+
# check lmbda and beta are in the range [0, 1]
|
| 380 |
+
if self.lmbda < 0.0 or self.lmbda > 1.0:
|
| 381 |
+
raise ValueError("lmbda must be in the range [0.0, 1.0].")
|
| 382 |
+
if self.beta < 0.0 or self.beta > 1.0:
|
| 383 |
+
raise ValueError("beta must be in the range [0.0, 1.0].")
|
| 384 |
+
|
| 385 |
+
# Validate that max_length is sufficient for max_completion_length
|
| 386 |
+
if self.max_length is not None and self.max_completion_length >= self.max_length:
|
| 387 |
+
raise ValueError(
|
| 388 |
+
f"max_completion_length ({self.max_completion_length}) must be smaller than max_length ({self.max_length}) "
|
| 389 |
+
f"to leave room for the prompt. Consider increasing max_length or reducing max_completion_length."
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
if self.steps_per_generation is None:
|
| 393 |
+
self.steps_per_generation = self.gradient_accumulation_steps
|
| 394 |
+
|
| 395 |
+
# Validate ULD parameters
|
| 396 |
+
if self.use_uld_loss:
|
| 397 |
+
if self.uld_crossentropy_weight < 0.0:
|
| 398 |
+
raise ValueError("uld_crossentropy_weight must be non-negative.")
|
| 399 |
+
if self.uld_distillation_weight < 0.0:
|
| 400 |
+
raise ValueError("uld_distillation_weight must be non-negative.")
|
| 401 |
+
if self.uld_student_temperature <= 0.0:
|
| 402 |
+
raise ValueError("uld_student_temperature must be positive.")
|
| 403 |
+
if self.uld_teacher_temperature <= 0.0:
|
| 404 |
+
raise ValueError("uld_teacher_temperature must be positive.")
|
| 405 |
+
|
| 406 |
+
# Validate hybrid loss weights - both must be None or both must be set
|
| 407 |
+
if self.uld_use_hybrid_loss:
|
| 408 |
+
if (self.uld_hybrid_matched_weight is None) != (self.uld_hybrid_unmatched_weight is None):
|
| 409 |
+
raise ValueError(
|
| 410 |
+
"uld_hybrid_matched_weight and uld_hybrid_unmatched_weight must both be None (for adaptive "
|
| 411 |
+
"weighting) or both be set to numeric values. Got uld_hybrid_matched_weight="
|
| 412 |
+
f"{self.uld_hybrid_matched_weight} and uld_hybrid_unmatched_weight="
|
| 413 |
+
f"{self.uld_hybrid_unmatched_weight}."
|
| 414 |
+
)
|
| 415 |
+
if self.uld_hybrid_matched_weight is not None:
|
| 416 |
+
if self.uld_hybrid_matched_weight < 0.0:
|
| 417 |
+
raise ValueError("uld_hybrid_matched_weight must be non-negative.")
|
| 418 |
+
if self.uld_hybrid_unmatched_weight < 0.0:
|
| 419 |
+
raise ValueError("uld_hybrid_unmatched_weight must be non-negative.")
|
ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig
|
| 16 |
+
from .grpo_with_replay_buffer_trainer import GRPOWithReplayBufferTrainer, ReplayBuffer
|
ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
|
| 17 |
+
from ...trainer.grpo_config import GRPOConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class GRPOWithReplayBufferConfig(GRPOConfig):
|
| 22 |
+
"""
|
| 23 |
+
New Parameters:
|
| 24 |
+
replay_buffer_size (`int`, *optional*, defaults to `0`):
|
| 25 |
+
A cache that stores the rollouts with the highest advantage scores and variance per group. If a new
|
| 26 |
+
group has 0 variance, it is replaced with a group sampled from the replay buffer.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
replay_buffer_size: int = field(
|
| 30 |
+
default=64,
|
| 31 |
+
metadata={
|
| 32 |
+
"help": "A cache that stores the rollouts with the highest advantage scores and variance per group. If a new group has 0 variance, it is replaced with a group sampled from the replay buffer."
|
| 33 |
+
},
|
| 34 |
+
)
|
ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py
ADDED
|
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import heapq
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from accelerate.utils import gather_object
|
| 20 |
+
|
| 21 |
+
from ...data_utils import apply_chat_template, prepare_multimodal_messages
|
| 22 |
+
from ...models.utils import disable_gradient_checkpointing
|
| 23 |
+
from ...trainer.grpo_trainer import GRPOTrainer
|
| 24 |
+
from ...trainer.utils import nanmax, nanmin, nanstd, pad
|
| 25 |
+
from .grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ReplayBuffer:
|
| 29 |
+
"""
|
| 30 |
+
A simple replay buffer to store and sample previously seen rollouts.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, max_size: int):
|
| 34 |
+
self.max_size = max_size
|
| 35 |
+
self.heap = [] # Min-heap of (score, data) tuples
|
| 36 |
+
|
| 37 |
+
def add(self, scores: list[float], data: list[dict]):
|
| 38 |
+
for score, datum in zip(scores, data, strict=True):
|
| 39 |
+
if len(self.heap) < self.max_size:
|
| 40 |
+
heapq.heappush(self.heap, (score, datum))
|
| 41 |
+
else:
|
| 42 |
+
# Only add if score is better than worst (minimum) item
|
| 43 |
+
if score > self.heap[0][0]:
|
| 44 |
+
heapq.heapreplace(self.heap, (score, datum))
|
| 45 |
+
|
| 46 |
+
def sample(self, num_samples: int) -> list[dict[str, torch.Tensor]]:
|
| 47 |
+
if not self.heap:
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
# Sample by normalized scores
|
| 51 |
+
scores = torch.tensor([item[0] for item in self.heap], dtype=torch.float32)
|
| 52 |
+
probabilities = scores / scores.sum()
|
| 53 |
+
replacement = False
|
| 54 |
+
if num_samples > len(self.heap):
|
| 55 |
+
replacement = True
|
| 56 |
+
chosen_indices = torch.multinomial(probabilities, num_samples, replacement=replacement).tolist()
|
| 57 |
+
return [self.heap[i][1] for i in chosen_indices]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class GRPOWithReplayBufferTrainer(GRPOTrainer):
|
| 61 |
+
def __init__(self, args: GRPOWithReplayBufferConfig | None = None, **kwargs):
|
| 62 |
+
super().__init__(args=args, **kwargs)
|
| 63 |
+
self.replay_buffer = ReplayBuffer(args.replay_buffer_size) if args.replay_buffer_size > 0 else None
|
| 64 |
+
|
| 65 |
+
def _generate_and_score_completions(
|
| 66 |
+
self, inputs: list[dict[str, torch.Tensor | Any]]
|
| 67 |
+
) -> dict[str, torch.Tensor | Any]:
|
| 68 |
+
device = self.accelerator.device
|
| 69 |
+
mode = "train" if self.model.training else "eval"
|
| 70 |
+
|
| 71 |
+
prompts = [x["prompt"] for x in inputs]
|
| 72 |
+
|
| 73 |
+
if "images" in inputs[0]:
|
| 74 |
+
images = [example.get("images") for example in inputs]
|
| 75 |
+
elif "image" in inputs[0]:
|
| 76 |
+
images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
|
| 77 |
+
else:
|
| 78 |
+
images = None
|
| 79 |
+
# Transformers requires at least one image in the batch, otherwise it throws an error
|
| 80 |
+
if images is not None and all(img_list == [] for img_list in images):
|
| 81 |
+
images = None
|
| 82 |
+
|
| 83 |
+
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
|
| 84 |
+
# [{"role": "user", "content": "What color is the sky?"}] to
|
| 85 |
+
# [{"role": "user", "content": [{"type": "image", "image": <Image>}, {"type": "text", "text": "What color is the sky?"}]}]
|
| 86 |
+
if images is not None:
|
| 87 |
+
prompts = [
|
| 88 |
+
prepare_multimodal_messages(prompt, image_list)
|
| 89 |
+
for prompt, image_list in zip(prompts, images, strict=True)
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
(
|
| 93 |
+
prompt_ids_list,
|
| 94 |
+
completion_ids_list,
|
| 95 |
+
tool_mask_list,
|
| 96 |
+
completions,
|
| 97 |
+
num_items_in_batch,
|
| 98 |
+
sampling_per_token_logps_list,
|
| 99 |
+
extra_fields,
|
| 100 |
+
) = self._generate(prompts)
|
| 101 |
+
|
| 102 |
+
# Convert lists of token IDs to padded tensors
|
| 103 |
+
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
| 104 |
+
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
|
| 105 |
+
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
|
| 106 |
+
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
|
| 107 |
+
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list]
|
| 108 |
+
completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
|
| 109 |
+
completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
|
| 110 |
+
completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
|
| 111 |
+
if sampling_per_token_logps_list is not None:
|
| 112 |
+
sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list]
|
| 113 |
+
sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right")
|
| 114 |
+
else:
|
| 115 |
+
sampling_per_token_logps = None
|
| 116 |
+
if self.tools:
|
| 117 |
+
tool_mask = [torch.tensor(mask, device=device) for mask in tool_mask_list]
|
| 118 |
+
tool_mask = pad(tool_mask, padding_value=1, padding_side="right") # 0 for tool result tokens, 1 elsewhere
|
| 119 |
+
|
| 120 |
+
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
|
| 121 |
+
if self.mask_truncated_completions:
|
| 122 |
+
eos_and_pad = [self.eos_token_id, self.pad_token_id]
|
| 123 |
+
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
|
| 124 |
+
completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
|
| 125 |
+
|
| 126 |
+
# Concatenate prompt_mask with completion_mask for logit computation
|
| 127 |
+
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
|
| 128 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
| 129 |
+
|
| 130 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
| 131 |
+
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
| 132 |
+
|
| 133 |
+
num_images = [len(img_list) for img_list in images] if images is not None else None
|
| 134 |
+
|
| 135 |
+
# Get forward_kwargs for models with multimodal inputs
|
| 136 |
+
if images is not None:
|
| 137 |
+
prompts_text = [
|
| 138 |
+
apply_chat_template(
|
| 139 |
+
{"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs
|
| 140 |
+
)["prompt"]
|
| 141 |
+
for prompt in prompts
|
| 142 |
+
]
|
| 143 |
+
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
|
| 144 |
+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
| 145 |
+
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
| 146 |
+
else:
|
| 147 |
+
forward_kwargs = {}
|
| 148 |
+
|
| 149 |
+
# If token_type_ids are used, extend them with zeros for the completion part
|
| 150 |
+
if "token_type_ids" in forward_kwargs:
|
| 151 |
+
token_type_ids = forward_kwargs["token_type_ids"]
|
| 152 |
+
forward_kwargs["token_type_ids"] = torch.cat(
|
| 153 |
+
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a
|
| 157 |
+
# torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True").
|
| 158 |
+
# Temporarily disable checkpointing to avoid this warning during inference.
|
| 159 |
+
with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs):
|
| 160 |
+
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
|
| 161 |
+
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
|
| 162 |
+
# samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps
|
| 163 |
+
# for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
|
| 164 |
+
# old_per_token_logps to None.
|
| 165 |
+
# When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the
|
| 166 |
+
# distribution mismatch between vLLM and the training model can be large and harm the training.
|
| 167 |
+
generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency
|
| 168 |
+
if self.args.gradient_accumulation_steps % generate_every != 0 or (
|
| 169 |
+
self.use_vllm and self.vllm_importance_sampling_correction
|
| 170 |
+
):
|
| 171 |
+
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
|
| 172 |
+
self.model,
|
| 173 |
+
prompt_completion_ids,
|
| 174 |
+
attention_mask,
|
| 175 |
+
logits_to_keep,
|
| 176 |
+
batch_size,
|
| 177 |
+
num_images=num_images,
|
| 178 |
+
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
|
| 179 |
+
)
|
| 180 |
+
else:
|
| 181 |
+
old_per_token_logps = None
|
| 182 |
+
|
| 183 |
+
# Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch
|
| 184 |
+
if self.use_vllm and self.vllm_importance_sampling_correction:
|
| 185 |
+
importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps)
|
| 186 |
+
importance_sampling_ratio = torch.clamp(
|
| 187 |
+
importance_sampling_ratio, max=self.vllm_importance_sampling_cap
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Compute the per-token log probabilities for the reference model
|
| 191 |
+
if self.beta != 0.0:
|
| 192 |
+
if self.ref_model is not None:
|
| 193 |
+
ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
|
| 194 |
+
self.ref_model,
|
| 195 |
+
prompt_completion_ids,
|
| 196 |
+
attention_mask,
|
| 197 |
+
logits_to_keep,
|
| 198 |
+
batch_size=batch_size,
|
| 199 |
+
num_images=num_images,
|
| 200 |
+
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
| 204 |
+
ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
|
| 205 |
+
self.model,
|
| 206 |
+
prompt_completion_ids,
|
| 207 |
+
attention_mask,
|
| 208 |
+
logits_to_keep,
|
| 209 |
+
batch_size=batch_size,
|
| 210 |
+
num_images=num_images,
|
| 211 |
+
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
ref_per_token_logps = None
|
| 215 |
+
|
| 216 |
+
# Decode
|
| 217 |
+
prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True)
|
| 218 |
+
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
| 219 |
+
|
| 220 |
+
# Merge extra_fields from rollout_func into inputs for reward functions
|
| 221 |
+
if extra_fields:
|
| 222 |
+
for i, inp in enumerate(inputs):
|
| 223 |
+
for key, values in extra_fields.items():
|
| 224 |
+
if isinstance(values, list) and i < len(values):
|
| 225 |
+
inp[key] = values[i]
|
| 226 |
+
elif not isinstance(values, list):
|
| 227 |
+
inp[key] = values
|
| 228 |
+
|
| 229 |
+
# Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
|
| 230 |
+
# important because rewards will be normalized per group, and completions are distributed. We will later slice
|
| 231 |
+
# rewards_per_func to extract each process's subset.
|
| 232 |
+
rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
|
| 233 |
+
|
| 234 |
+
# Apply weights to each reward function's output and sum
|
| 235 |
+
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
|
| 236 |
+
|
| 237 |
+
# Compute grouped-wise rewards
|
| 238 |
+
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
| 239 |
+
|
| 240 |
+
# Normalize the rewards to compute the advantages
|
| 241 |
+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 242 |
+
advantages = rewards - mean_grouped_rewards
|
| 243 |
+
|
| 244 |
+
grouped_std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
| 245 |
+
grouped_std_rewards = grouped_std_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 246 |
+
|
| 247 |
+
if self.scale_rewards in ["group", "none"]:
|
| 248 |
+
# If self.scale_rewards = "none", we'll still log group level std
|
| 249 |
+
std_rewards = grouped_std_rewards.clone()
|
| 250 |
+
elif self.scale_rewards == "batch":
|
| 251 |
+
# Compute global std
|
| 252 |
+
std_rewards = rewards.std().expand_as(rewards)
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError(
|
| 255 |
+
f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'."
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards))
|
| 259 |
+
if self.scale_rewards != "none":
|
| 260 |
+
advantages = advantages / (std_rewards + 1e-4)
|
| 261 |
+
|
| 262 |
+
# Slice to keep only the local part of the data
|
| 263 |
+
process_slice = slice(
|
| 264 |
+
self.accelerator.process_index * len(prompts),
|
| 265 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
| 266 |
+
)
|
| 267 |
+
all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
|
| 268 |
+
advantages = advantages[process_slice]
|
| 269 |
+
grouped_std_rewards = grouped_std_rewards[process_slice]
|
| 270 |
+
|
| 271 |
+
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
|
| 272 |
+
for i, reward_func_name in enumerate(self.reward_func_names):
|
| 273 |
+
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
|
| 274 |
+
self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
|
| 275 |
+
std_func_rewards = nanstd(rewards_per_func[:, i]).item()
|
| 276 |
+
self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards)
|
| 277 |
+
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
|
| 278 |
+
self._metrics[mode]["reward_std"].append(std_rewards.mean().item())
|
| 279 |
+
self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())
|
| 280 |
+
|
| 281 |
+
# Log prompt and completion texts
|
| 282 |
+
self._logs["prompt"].extend(gather_object(prompts_text))
|
| 283 |
+
self._logs["completion"].extend(gather_object(completions_text))
|
| 284 |
+
for i, name in enumerate(self.reward_func_names):
|
| 285 |
+
self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
|
| 286 |
+
self._logs["advantages"].extend(all_process_advantages.tolist())
|
| 287 |
+
|
| 288 |
+
if images is not None:
|
| 289 |
+
self._logs["images"].extend(gather_object(images))
|
| 290 |
+
|
| 291 |
+
if self.use_vllm and self.vllm_importance_sampling_correction:
|
| 292 |
+
delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
|
| 293 |
+
mask = completion_mask.bool() if not self.tools else (completion_mask * tool_mask).bool()
|
| 294 |
+
delta = delta[mask]
|
| 295 |
+
mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)
|
| 296 |
+
max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)
|
| 297 |
+
self._metrics[mode]["sampling/sampling_logp_difference/mean"].append(
|
| 298 |
+
self.accelerator.gather(mean_delta).mean().item()
|
| 299 |
+
)
|
| 300 |
+
self._metrics[mode]["sampling/sampling_logp_difference/max"].append(
|
| 301 |
+
self.accelerator.gather(max_delta).max().item()
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
flat_is_ratio = importance_sampling_ratio[mask]
|
| 305 |
+
min_importance_sampling_ratio = (
|
| 306 |
+
torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
|
| 307 |
+
)
|
| 308 |
+
mean_importance_sampling_ratio = (
|
| 309 |
+
torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
|
| 310 |
+
)
|
| 311 |
+
max_importance_sampling_ratio = (
|
| 312 |
+
torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
|
| 313 |
+
)
|
| 314 |
+
self._metrics[mode]["sampling/importance_sampling_ratio/min"].append(
|
| 315 |
+
nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item()
|
| 316 |
+
)
|
| 317 |
+
self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append(
|
| 318 |
+
self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item()
|
| 319 |
+
)
|
| 320 |
+
self._metrics[mode]["sampling/importance_sampling_ratio/max"].append(
|
| 321 |
+
nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item()
|
| 322 |
+
)
|
| 323 |
+
outputs_after_sampling_buffer = self.update_with_replay_buffer(
|
| 324 |
+
advantages,
|
| 325 |
+
grouped_std_rewards,
|
| 326 |
+
prompt_ids,
|
| 327 |
+
prompt_mask,
|
| 328 |
+
completion_ids,
|
| 329 |
+
completion_mask,
|
| 330 |
+
forward_kwargs,
|
| 331 |
+
num_items_in_batch,
|
| 332 |
+
old_per_token_logps,
|
| 333 |
+
ref_per_token_logps,
|
| 334 |
+
importance_sampling_ratio if self.use_vllm and self.vllm_importance_sampling_correction else None,
|
| 335 |
+
)
|
| 336 |
+
if outputs_after_sampling_buffer is not None:
|
| 337 |
+
return outputs_after_sampling_buffer
|
| 338 |
+
else:
|
| 339 |
+
output = {
|
| 340 |
+
"prompt_ids": prompt_ids,
|
| 341 |
+
"prompt_mask": prompt_mask,
|
| 342 |
+
"completion_ids": completion_ids,
|
| 343 |
+
"completion_mask": completion_mask,
|
| 344 |
+
"advantages": advantages,
|
| 345 |
+
"num_items_in_batch": num_items_in_batch,
|
| 346 |
+
}
|
| 347 |
+
if old_per_token_logps is not None:
|
| 348 |
+
output["old_per_token_logps"] = old_per_token_logps
|
| 349 |
+
if self.use_vllm and self.vllm_importance_sampling_correction:
|
| 350 |
+
output["importance_sampling_ratio"] = importance_sampling_ratio
|
| 351 |
+
if ref_per_token_logps is not None:
|
| 352 |
+
output["ref_per_token_logps"] = ref_per_token_logps
|
| 353 |
+
if "pixel_values" in forward_kwargs:
|
| 354 |
+
output["pixel_values"] = forward_kwargs["pixel_values"]
|
| 355 |
+
if "image_grid_thw" in forward_kwargs:
|
| 356 |
+
output["image_grid_thw"] = forward_kwargs["image_grid_thw"]
|
| 357 |
+
if "pixel_attention_mask" in forward_kwargs:
|
| 358 |
+
output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"]
|
| 359 |
+
if "image_sizes" in forward_kwargs:
|
| 360 |
+
output["image_sizes"] = forward_kwargs["image_sizes"]
|
| 361 |
+
if "token_type_ids" in forward_kwargs:
|
| 362 |
+
output["token_type_ids"] = forward_kwargs["token_type_ids"]
|
| 363 |
+
if images is not None:
|
| 364 |
+
output["num_images"] = num_images
|
| 365 |
+
if self.tools:
|
| 366 |
+
output["tool_mask"] = tool_mask
|
| 367 |
+
return output
|
| 368 |
+
|
| 369 |
+
def slice_group_data(
|
| 370 |
+
self, data: torch.Tensor, mask: torch.Tensor, group_idx: int
|
| 371 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 372 |
+
"""
|
| 373 |
+
Slices the input data and mask tensors for a specific group index. Also trims the sequence length to the
|
| 374 |
+
maximum length in the group based on the mask.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
data: Tensor of shape (num_groups * num_generations, seq_length)
|
| 378 |
+
mask: Tensor of shape (num_groups * num_generations, seq_length)
|
| 379 |
+
group_idx: Index of the group to slice
|
| 380 |
+
Returns:
|
| 381 |
+
Tuple of (sliced_data, sliced_mask) for the specified group, with sequence length trimmed to the maximum
|
| 382 |
+
length in the group.
|
| 383 |
+
"""
|
| 384 |
+
start_idx = group_idx * self.num_generations
|
| 385 |
+
end_idx = (group_idx + 1) * self.num_generations
|
| 386 |
+
group_data = data[start_idx:end_idx]
|
| 387 |
+
group_mask = mask[start_idx:end_idx]
|
| 388 |
+
group_max_len = group_mask.sum(dim=1).max().item()
|
| 389 |
+
return group_data[:, :group_max_len], group_mask[:, :group_max_len]
|
| 390 |
+
|
| 391 |
+
def update_replay_buffer(
|
| 392 |
+
self,
|
| 393 |
+
groups_with_variance: torch.Tensor,
|
| 394 |
+
group_advantages: torch.Tensor,
|
| 395 |
+
group_std_rewards: torch.Tensor,
|
| 396 |
+
prompt_ids: torch.Tensor,
|
| 397 |
+
prompt_mask: torch.Tensor,
|
| 398 |
+
completion_ids: torch.Tensor,
|
| 399 |
+
completion_mask: torch.Tensor,
|
| 400 |
+
forward_kwargs: dict,
|
| 401 |
+
optional_vision_fields: list[str] = None,
|
| 402 |
+
old_per_token_logps: torch.Tensor | None = None,
|
| 403 |
+
ref_per_token_logps: torch.Tensor | None = None,
|
| 404 |
+
importance_sampling_ratio: float | None = None,
|
| 405 |
+
) -> None:
|
| 406 |
+
"""
|
| 407 |
+
Update the replay buffer with groups that have reward variance (std > 0).
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
groups_with_variance: Boolean tensor indicating which groups have reward variance
|
| 411 |
+
group_advantages: Tensor of shape (num_groups, num_generations) containing advantage values
|
| 412 |
+
std_rewards: Tensor of shape (num_groups, num_generations) containing std of rewards per group
|
| 413 |
+
prompt_ids: Tensor containing prompt token IDs
|
| 414 |
+
prompt_mask: Tensor containing prompt attention masks
|
| 415 |
+
completion_ids: Tensor containing completion token IDs
|
| 416 |
+
completion_mask: Tensor containing completion attention masks
|
| 417 |
+
forward_kwargs: Dictionary containing additional prompt inputs (vision data, etc.)
|
| 418 |
+
optional_vision_fields: List of optional vision-related fields to include if present in forward_kwargs
|
| 419 |
+
old_per_token_logps: Optional tensor of old per-token log probabilities
|
| 420 |
+
ref_per_token_logps: Optional tensor of reference per-token log probabilities
|
| 421 |
+
importance_sampling_ratio: Optional importance sampling correction ratio
|
| 422 |
+
"""
|
| 423 |
+
# Prepare buffered outputs for groups with variance
|
| 424 |
+
buffered_outputs = []
|
| 425 |
+
for _, group_idx in enumerate(groups_with_variance.nonzero(as_tuple=True)[0].unique().tolist()):
|
| 426 |
+
group_prompt_ids, group_prompt_mask = self.slice_group_data(prompt_ids, prompt_mask, group_idx)
|
| 427 |
+
group_completion_ids, group_completion_mask = self.slice_group_data(
|
| 428 |
+
completion_ids, completion_mask, group_idx
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
# Store unpadded data in the buffer
|
| 432 |
+
buffered_output = {
|
| 433 |
+
"prompt_ids": group_prompt_ids,
|
| 434 |
+
"completion_ids": group_completion_ids,
|
| 435 |
+
"advantages": group_advantages[group_idx].tolist(),
|
| 436 |
+
"prompt_mask": group_prompt_mask,
|
| 437 |
+
"completion_mask": group_completion_mask,
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
# Add optional fields if they exist
|
| 441 |
+
optional_fields = {
|
| 442 |
+
"old_per_token_logps": old_per_token_logps if old_per_token_logps is not None else None,
|
| 443 |
+
"ref_per_token_logps": ref_per_token_logps if ref_per_token_logps is not None else None,
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
for field_name, field_data in optional_fields.items():
|
| 447 |
+
if field_data is not None:
|
| 448 |
+
buffered_output[field_name] = self.slice_group_data(field_data, completion_mask, group_idx)[0]
|
| 449 |
+
|
| 450 |
+
# Add importance sampling if needed
|
| 451 |
+
if self.use_vllm and self.vllm_importance_sampling_correction:
|
| 452 |
+
buffered_output["importance_sampling_ratio"] = importance_sampling_ratio
|
| 453 |
+
|
| 454 |
+
if optional_vision_fields:
|
| 455 |
+
# Add vision-related fields if they exist
|
| 456 |
+
for field_name in optional_vision_fields:
|
| 457 |
+
if field_name in forward_kwargs:
|
| 458 |
+
buffered_output[field_name] = self.slice_group_data(
|
| 459 |
+
forward_kwargs[field_name], prompt_mask, group_idx
|
| 460 |
+
)[0]
|
| 461 |
+
|
| 462 |
+
buffered_outputs.append(buffered_output)
|
| 463 |
+
|
| 464 |
+
if groups_with_variance.any():
|
| 465 |
+
# Calculate replay buffer scores for groups with variance
|
| 466 |
+
replay_buffer_scores = (group_advantages.abs() * group_std_rewards).sum(dim=-1)[groups_with_variance]
|
| 467 |
+
# Add all groups to replay buffer at once (batch operation)
|
| 468 |
+
self.replay_buffer.add(replay_buffer_scores.tolist(), buffered_outputs)
|
| 469 |
+
|
| 470 |
+
def sample_from_replay_buffer(
|
| 471 |
+
self, num_samples: int, optional_vision_fields: list[str] = None, optional_tensor_fields: list[str] = None
|
| 472 |
+
) -> list[dict]:
|
| 473 |
+
"""
|
| 474 |
+
Sample groups from the replay buffer.
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
num_samples: Number of samples to draw from the replay buffer
|
| 478 |
+
optional_vision_fields: List of optional vision-related fields to include if present in sampled data
|
| 479 |
+
optional_tensor_fields: List of optional tensor fields to include if present in sampled data
|
| 480 |
+
Returns:
|
| 481 |
+
List of sampled data dictionaries from the replay buffer
|
| 482 |
+
"""
|
| 483 |
+
sampled = self.replay_buffer.sample(num_samples=num_samples)
|
| 484 |
+
|
| 485 |
+
# Extract and concatenate sampled data
|
| 486 |
+
sampled_data = {
|
| 487 |
+
"prompt_ids": [],
|
| 488 |
+
"prompt_mask": [],
|
| 489 |
+
"completion_ids": [],
|
| 490 |
+
"completion_mask": [],
|
| 491 |
+
"advantages": [],
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
all_optional_fields = (optional_tensor_fields or []) + (optional_vision_fields or [])
|
| 495 |
+
# Initialize containers for optional fields if they exist in sampled data
|
| 496 |
+
for field in all_optional_fields:
|
| 497 |
+
if sampled and field in sampled[0]:
|
| 498 |
+
sampled_data[field] = []
|
| 499 |
+
|
| 500 |
+
# Extract data from each sampled item
|
| 501 |
+
for item in sampled:
|
| 502 |
+
# Handle core fields
|
| 503 |
+
for key in ["prompt_ids", "prompt_mask", "completion_ids", "completion_mask"]:
|
| 504 |
+
sampled_data[key].append(item[key])
|
| 505 |
+
|
| 506 |
+
# Handle advantages (list, not tensor)
|
| 507 |
+
sampled_data["advantages"].append(item["advantages"])
|
| 508 |
+
|
| 509 |
+
# Handle optional fields
|
| 510 |
+
for field in all_optional_fields:
|
| 511 |
+
if field in item:
|
| 512 |
+
sampled_data[field].append(item[field])
|
| 513 |
+
|
| 514 |
+
return sampled_data
|
| 515 |
+
|
| 516 |
+
def update_with_replay_buffer(
|
| 517 |
+
self,
|
| 518 |
+
group_advantages: torch.Tensor,
|
| 519 |
+
group_std_rewards: torch.Tensor,
|
| 520 |
+
prompt_ids: torch.Tensor,
|
| 521 |
+
prompt_mask: torch.Tensor,
|
| 522 |
+
completion_ids: torch.Tensor,
|
| 523 |
+
completion_mask: torch.Tensor,
|
| 524 |
+
forward_kwargs: dict,
|
| 525 |
+
num_items_in_batch: int,
|
| 526 |
+
old_per_token_logps: torch.Tensor | None = None,
|
| 527 |
+
ref_per_token_logps: torch.Tensor | None = None,
|
| 528 |
+
importance_sampling_ratio: float | None = None,
|
| 529 |
+
) -> None:
|
| 530 |
+
"""
|
| 531 |
+
Update current batch data with samples from replay buffer.
|
| 532 |
+
|
| 533 |
+
Groups with reward variance (std > 0) are added to the replay buffer and then replaced with samples from the
|
| 534 |
+
buffer to improve training stability.
|
| 535 |
+
|
| 536 |
+
Args:
|
| 537 |
+
group_advantages: Tensor of shape (num_groups, num_generations) containing advantage values
|
| 538 |
+
std_rewards: Tensor of shape (num_groups, num_generations) containing std of rewards per group
|
| 539 |
+
prompt_ids: Tensor containing prompt token IDs
|
| 540 |
+
prompt_mask: Tensor containing prompt attention masks
|
| 541 |
+
completion_ids: Tensor containing completion token IDs
|
| 542 |
+
completion_mask: Tensor containing completion attention masks
|
| 543 |
+
forward_kwargs: Dictionary containing additional prompt inputs (vision data, etc.)
|
| 544 |
+
num_items_in_batch: Number of items in the current batch
|
| 545 |
+
old_per_token_logps: Optional tensor of old per-token log probabilities
|
| 546 |
+
ref_per_token_logps: Optional tensor of reference per-token log probabilities
|
| 547 |
+
importance_sampling_ratio: Optional importance sampling correction ratio
|
| 548 |
+
"""
|
| 549 |
+
if self.replay_buffer.max_size <= 0:
|
| 550 |
+
return
|
| 551 |
+
|
| 552 |
+
# Groups to consider for adding to the replay buffer
|
| 553 |
+
groups_with_variance = group_std_rewards.max(dim=0).values > 0
|
| 554 |
+
# Groups to replace from the replay buffer
|
| 555 |
+
groups_without_variance = ~groups_with_variance
|
| 556 |
+
|
| 557 |
+
# Track which optional fields are present in sampled data
|
| 558 |
+
optional_tensor_fields = ["old_per_token_logps", "ref_per_token_logps"]
|
| 559 |
+
vision_fields = ["pixel_values", "image_grid_thw", "pixel_attention_mask", "image_sizes"]
|
| 560 |
+
|
| 561 |
+
self.update_replay_buffer(
|
| 562 |
+
groups_with_variance,
|
| 563 |
+
group_advantages,
|
| 564 |
+
group_std_rewards,
|
| 565 |
+
prompt_ids,
|
| 566 |
+
prompt_mask,
|
| 567 |
+
completion_ids,
|
| 568 |
+
completion_mask,
|
| 569 |
+
forward_kwargs,
|
| 570 |
+
vision_fields,
|
| 571 |
+
old_per_token_logps,
|
| 572 |
+
ref_per_token_logps,
|
| 573 |
+
importance_sampling_ratio,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# Sample from replay buffer to replace groups with variance
|
| 577 |
+
num_groups_to_replace = groups_without_variance.sum().item()
|
| 578 |
+
if not num_groups_to_replace:
|
| 579 |
+
return
|
| 580 |
+
|
| 581 |
+
sampled_data = self.sample_from_replay_buffer(
|
| 582 |
+
num_samples=num_groups_to_replace,
|
| 583 |
+
optional_vision_fields=vision_fields,
|
| 584 |
+
optional_tensor_fields=optional_tensor_fields,
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
# Pad sampled data if they are shorter than the current batch sequences
|
| 588 |
+
# Or pad the current batch if sampled are longer
|
| 589 |
+
current_batch_prompt_seq_len = prompt_ids.size(1)
|
| 590 |
+
current_batch_completion_seq_len = completion_ids.size(1)
|
| 591 |
+
|
| 592 |
+
groups_to_replace_idxs = groups_with_variance.logical_not().nonzero(as_tuple=True)[0].unique().tolist()
|
| 593 |
+
|
| 594 |
+
# Determine target (max) sequence lengths once
|
| 595 |
+
sampled_prompt_lengths = [t.size(1) for t in sampled_data["prompt_ids"]]
|
| 596 |
+
sampled_completion_lengths = [t.size(1) for t in sampled_data["completion_ids"]]
|
| 597 |
+
target_prompt_len = max([current_batch_prompt_seq_len] + sampled_prompt_lengths)
|
| 598 |
+
target_completion_len = max([current_batch_completion_seq_len] + sampled_completion_lengths)
|
| 599 |
+
|
| 600 |
+
# If any sampled prompt is longer, pad the whole batch prompt tensors once (left padding)
|
| 601 |
+
if target_prompt_len > current_batch_prompt_seq_len:
|
| 602 |
+
prompt_ids = pad(
|
| 603 |
+
list(prompt_ids.unbind(0)),
|
| 604 |
+
padding_value=self.pad_token_id,
|
| 605 |
+
pad_to_multiple_of=target_prompt_len,
|
| 606 |
+
padding_side="left",
|
| 607 |
+
)
|
| 608 |
+
prompt_mask = pad(
|
| 609 |
+
list(prompt_mask.unbind(0)), padding_value=0, pad_to_multiple_of=target_prompt_len, padding_side="left"
|
| 610 |
+
)
|
| 611 |
+
# If any sampled completion is longer, pad the whole batch completion tensors once (right padding)
|
| 612 |
+
if target_completion_len > current_batch_completion_seq_len:
|
| 613 |
+
completion_ids = pad(
|
| 614 |
+
list(completion_ids.unbind(0)),
|
| 615 |
+
padding_value=self.pad_token_id,
|
| 616 |
+
pad_to_multiple_of=target_completion_len,
|
| 617 |
+
padding_side="right",
|
| 618 |
+
)
|
| 619 |
+
completion_mask = pad(
|
| 620 |
+
list(completion_mask.unbind(0)),
|
| 621 |
+
padding_value=0,
|
| 622 |
+
pad_to_multiple_of=target_completion_len,
|
| 623 |
+
padding_side="right",
|
| 624 |
+
)
|
| 625 |
+
if old_per_token_logps is not None:
|
| 626 |
+
old_per_token_logps = pad(
|
| 627 |
+
list(old_per_token_logps.unbind(0)),
|
| 628 |
+
padding_value=0.0,
|
| 629 |
+
pad_to_multiple_of=target_completion_len,
|
| 630 |
+
padding_side="right",
|
| 631 |
+
)
|
| 632 |
+
if ref_per_token_logps is not None:
|
| 633 |
+
ref_per_token_logps = pad(
|
| 634 |
+
list(ref_per_token_logps.unbind(0)),
|
| 635 |
+
padding_value=0.0,
|
| 636 |
+
pad_to_multiple_of=target_completion_len,
|
| 637 |
+
padding_side="right",
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
# Replace per-group data, padding only sampled groups that are shorter than the target
|
| 641 |
+
for i, group_idx in enumerate(groups_to_replace_idxs):
|
| 642 |
+
start_idx = group_idx * self.num_generations
|
| 643 |
+
end_idx = (group_idx + 1) * self.num_generations
|
| 644 |
+
idx_range = slice(start_idx, end_idx)
|
| 645 |
+
|
| 646 |
+
# Pad sampled prompt to target length if needed
|
| 647 |
+
if sampled_data["prompt_ids"][i].size(1) < target_prompt_len:
|
| 648 |
+
sampled_data["prompt_ids"][i] = pad(
|
| 649 |
+
sampled_data["prompt_ids"][i],
|
| 650 |
+
padding_value=self.pad_token_id,
|
| 651 |
+
pad_to_multiple_of=target_prompt_len,
|
| 652 |
+
padding_side="left",
|
| 653 |
+
)
|
| 654 |
+
sampled_data["prompt_mask"][i] = pad(
|
| 655 |
+
sampled_data["prompt_mask"][i],
|
| 656 |
+
padding_value=0,
|
| 657 |
+
pad_to_multiple_of=target_prompt_len,
|
| 658 |
+
padding_side="left",
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
# Pad sampled completion to target length if needed
|
| 662 |
+
if sampled_data["completion_ids"][i].size(1) < target_completion_len:
|
| 663 |
+
sampled_data["completion_ids"][i] = pad(
|
| 664 |
+
sampled_data["completion_ids"][i],
|
| 665 |
+
padding_value=self.pad_token_id,
|
| 666 |
+
pad_to_multiple_of=target_completion_len,
|
| 667 |
+
padding_side="right",
|
| 668 |
+
)
|
| 669 |
+
sampled_data["completion_mask"][i] = pad(
|
| 670 |
+
sampled_data["completion_mask"][i],
|
| 671 |
+
padding_value=0,
|
| 672 |
+
pad_to_multiple_of=target_completion_len,
|
| 673 |
+
padding_side="right",
|
| 674 |
+
)
|
| 675 |
+
if "old_per_token_logps" in sampled_data:
|
| 676 |
+
sampled_data["old_per_token_logps"][i] = pad(
|
| 677 |
+
sampled_data["old_per_token_logps"][i],
|
| 678 |
+
padding_value=0.0,
|
| 679 |
+
pad_to_multiple_of=target_completion_len,
|
| 680 |
+
padding_side="right",
|
| 681 |
+
)
|
| 682 |
+
if "ref_per_token_logps" in sampled_data:
|
| 683 |
+
sampled_data["ref_per_token_logps"][i] = pad(
|
| 684 |
+
sampled_data["ref_per_token_logps"][i],
|
| 685 |
+
padding_value=0.0,
|
| 686 |
+
pad_to_multiple_of=target_completion_len,
|
| 687 |
+
padding_side="right",
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
# Assign (replace) group slice
|
| 691 |
+
prompt_ids[idx_range] = sampled_data["prompt_ids"][i]
|
| 692 |
+
prompt_mask[idx_range] = sampled_data["prompt_mask"][i]
|
| 693 |
+
completion_ids[idx_range] = sampled_data["completion_ids"][i]
|
| 694 |
+
completion_mask[idx_range] = sampled_data["completion_mask"][i]
|
| 695 |
+
group_advantages[group_idx] = sampled_data["advantages"][i]
|
| 696 |
+
|
| 697 |
+
if "old_per_token_logps" in sampled_data:
|
| 698 |
+
old_per_token_logps[idx_range] = sampled_data["old_per_token_logps"][i]
|
| 699 |
+
if "ref_per_token_logps" in sampled_data:
|
| 700 |
+
ref_per_token_logps[idx_range] = sampled_data["ref_per_token_logps"][i]
|
| 701 |
+
|
| 702 |
+
for field in vision_fields:
|
| 703 |
+
if field in sampled_data and field in forward_kwargs:
|
| 704 |
+
forward_kwargs[field][idx_range] = sampled_data[field][i]
|
| 705 |
+
|
| 706 |
+
# Prepare final outputs after sampling and replacement
|
| 707 |
+
outputs_after_sampling_buffer = {
|
| 708 |
+
"prompt_ids": prompt_ids,
|
| 709 |
+
"prompt_mask": prompt_mask,
|
| 710 |
+
"completion_ids": completion_ids,
|
| 711 |
+
"completion_mask": completion_mask,
|
| 712 |
+
"advantages": group_advantages,
|
| 713 |
+
}
|
| 714 |
+
|
| 715 |
+
# Replace optional tensor fields if they exist
|
| 716 |
+
for field in optional_tensor_fields:
|
| 717 |
+
if field in sampled_data:
|
| 718 |
+
outputs_after_sampling_buffer[field] = (
|
| 719 |
+
old_per_token_logps if field == "old_per_token_logps" else ref_per_token_logps
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
# Replace vision fields if they exist
|
| 723 |
+
for field in vision_fields:
|
| 724 |
+
if field in sampled_data and field in forward_kwargs:
|
| 725 |
+
outputs_after_sampling_buffer[field] = forward_kwargs[field]
|
| 726 |
+
|
| 727 |
+
outputs_after_sampling_buffer["num_items_in_batch"] = num_items_in_batch
|
| 728 |
+
if self.use_vllm and self.vllm_importance_sampling_correction:
|
| 729 |
+
outputs_after_sampling_buffer["importance_sampling_ratio"] = importance_sampling_ratio
|
| 730 |
+
|
| 731 |
+
return outputs_after_sampling_buffer
|
ICL/RL/trl_source/trl/experimental/gspo_token/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .grpo_trainer import GRPOTrainer
|
ICL/RL/trl_source/trl/experimental/gspo_token/grpo_trainer.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer
|
| 18 |
+
from ...trainer.utils import nanmax, nanmin
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class GRPOTrainer(_GRPOTrainer):
|
| 22 |
+
def _compute_loss(self, model, inputs):
|
| 23 |
+
# Compute the per-token log probabilities for the model
|
| 24 |
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
| 25 |
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
| 26 |
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 27 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
| 28 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
| 29 |
+
|
| 30 |
+
# Compute the per_token_logps and the entropy at each position in the completion
|
| 31 |
+
per_token_logps, entropies = self._get_per_token_logps_and_entropies(
|
| 32 |
+
model,
|
| 33 |
+
input_ids,
|
| 34 |
+
attention_mask,
|
| 35 |
+
logits_to_keep,
|
| 36 |
+
compute_entropy=True,
|
| 37 |
+
pixel_values=inputs.get("pixel_values"),
|
| 38 |
+
image_grid_thw=inputs.get("image_grid_thw"),
|
| 39 |
+
num_images=inputs.get("num_images"),
|
| 40 |
+
pixel_attention_mask=inputs.get("pixel_attention_mask"),
|
| 41 |
+
image_sizes=inputs.get("image_sizes"),
|
| 42 |
+
token_type_ids=inputs.get("token_type_ids"),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if self.top_entropy_quantile < 1.0:
|
| 46 |
+
entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile)
|
| 47 |
+
else:
|
| 48 |
+
entropy_mask = None
|
| 49 |
+
|
| 50 |
+
# Compute the KL divergence between the model and the reference model
|
| 51 |
+
if self.beta != 0.0:
|
| 52 |
+
ref_per_token_logps = inputs["ref_per_token_logps"]
|
| 53 |
+
per_token_kl = (
|
| 54 |
+
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Compute the loss
|
| 58 |
+
advantages = inputs["advantages"]
|
| 59 |
+
# When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps,
|
| 60 |
+
# old_per_token_logps == per_token_logps. In this case we can skip its computation
|
| 61 |
+
# (see _generate_and_score_completions) and instead use per_token_logps.detach().
|
| 62 |
+
# The exception is when using vLLM, where we always compute old_per_token_logps
|
| 63 |
+
# for importance sampling
|
| 64 |
+
old_per_token_logps = inputs.get("old_per_token_logps")
|
| 65 |
+
old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps
|
| 66 |
+
|
| 67 |
+
log_ratio = per_token_logps - old_per_token_logps
|
| 68 |
+
if self.importance_sampling_level == "token":
|
| 69 |
+
log_importance_weights = log_ratio
|
| 70 |
+
elif self.importance_sampling_level == "sequence":
|
| 71 |
+
log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
|
| 72 |
+
log_importance_weights = log_importance_weights.unsqueeze(-1)
|
| 73 |
+
elif self.importance_sampling_level == "sequence_token":
|
| 74 |
+
# GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)]
|
| 75 |
+
seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
|
| 76 |
+
seq_level_log_weight = seq_level_log_weight.detach().unsqueeze(-1) # Stop gradient
|
| 77 |
+
log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
|
| 81 |
+
"and 'sequence'."
|
| 82 |
+
)
|
| 83 |
+
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
|
| 84 |
+
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
|
| 85 |
+
|
| 86 |
+
coef_1 = torch.exp(log_importance_weights)
|
| 87 |
+
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
|
| 88 |
+
|
| 89 |
+
# Two-sided clipping
|
| 90 |
+
if self.args.delta is not None:
|
| 91 |
+
coef_1 = torch.clamp(coef_1, max=self.args.delta)
|
| 92 |
+
|
| 93 |
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
| 94 |
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
| 95 |
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
| 96 |
+
if entropy_mask is not None:
|
| 97 |
+
per_token_loss = per_token_loss * entropy_mask
|
| 98 |
+
|
| 99 |
+
if self.use_vllm and self.vllm_importance_sampling_correction:
|
| 100 |
+
per_token_loss = per_token_loss * inputs["importance_sampling_ratio"]
|
| 101 |
+
|
| 102 |
+
if self.beta != 0.0:
|
| 103 |
+
per_token_loss = per_token_loss + self.beta * per_token_kl
|
| 104 |
+
|
| 105 |
+
mode = "train" if self.model.training else "eval"
|
| 106 |
+
if self.loss_type == "grpo":
|
| 107 |
+
loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
|
| 108 |
+
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
|
| 109 |
+
loss = loss / normalizer
|
| 110 |
+
elif self.loss_type == "bnpo":
|
| 111 |
+
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
|
| 112 |
+
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
|
| 113 |
+
loss = loss / normalizer
|
| 114 |
+
elif self.loss_type == "dr_grpo":
|
| 115 |
+
loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
|
| 116 |
+
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
|
| 117 |
+
loss = loss / normalizer
|
| 118 |
+
elif self.loss_type == "dapo":
|
| 119 |
+
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
|
| 120 |
+
loss = (per_token_loss * completion_mask).sum() / normalizer
|
| 121 |
+
else:
|
| 122 |
+
raise ValueError(f"Unknown loss type: {self.loss_type}")
|
| 123 |
+
|
| 124 |
+
# Log the metrics
|
| 125 |
+
completion_token_count = completion_mask.sum().clamp(min=1.0)
|
| 126 |
+
|
| 127 |
+
def masked_batch_mean(x):
|
| 128 |
+
if x.shape[1] == 1: # when importance_sampling_level == "sequence"
|
| 129 |
+
return x.mean()
|
| 130 |
+
else:
|
| 131 |
+
return (x * completion_mask).sum() / completion_token_count
|
| 132 |
+
|
| 133 |
+
if self.beta != 0.0:
|
| 134 |
+
mean_kl = masked_batch_mean(per_token_kl)
|
| 135 |
+
self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item())
|
| 136 |
+
|
| 137 |
+
mean_entropy = masked_batch_mean(entropies)
|
| 138 |
+
self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())
|
| 139 |
+
|
| 140 |
+
# Compute the clipped probability ratios
|
| 141 |
+
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
|
| 142 |
+
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
|
| 143 |
+
is_region_clipped = is_low_clipped | is_high_clipped
|
| 144 |
+
|
| 145 |
+
low_clip = masked_batch_mean(is_low_clipped.float())
|
| 146 |
+
high_clip = masked_batch_mean(is_high_clipped.float())
|
| 147 |
+
clip_ratio = masked_batch_mean(is_region_clipped.float())
|
| 148 |
+
|
| 149 |
+
gathered_low_clip = self.accelerator.gather(low_clip)
|
| 150 |
+
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
|
| 151 |
+
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
|
| 152 |
+
gathered_high_clip = self.accelerator.gather(high_clip)
|
| 153 |
+
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
|
| 154 |
+
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
|
| 155 |
+
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
|
| 156 |
+
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
|
| 157 |
+
return loss
|
ICL/RL/trl_source/trl/experimental/judges/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .judges import (
|
| 16 |
+
AllTrueJudge,
|
| 17 |
+
BaseBinaryJudge,
|
| 18 |
+
BaseJudge,
|
| 19 |
+
BasePairwiseJudge,
|
| 20 |
+
BaseRankJudge,
|
| 21 |
+
HfPairwiseJudge,
|
| 22 |
+
OpenAIPairwiseJudge,
|
| 23 |
+
PairRMJudge,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
__all__ = [
|
| 28 |
+
"AllTrueJudge",
|
| 29 |
+
"BaseBinaryJudge",
|
| 30 |
+
"BaseJudge",
|
| 31 |
+
"BasePairwiseJudge",
|
| 32 |
+
"BaseRankJudge",
|
| 33 |
+
"HfPairwiseJudge",
|
| 34 |
+
"OpenAIPairwiseJudge",
|
| 35 |
+
"PairRMJudge",
|
| 36 |
+
]
|
ICL/RL/trl_source/trl/experimental/judges/judges.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import concurrent.futures
|
| 16 |
+
import logging
|
| 17 |
+
from abc import ABC, abstractmethod
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
from accelerate import Accelerator
|
| 21 |
+
from huggingface_hub import InferenceClient
|
| 22 |
+
from packaging.version import Version
|
| 23 |
+
from transformers.utils import is_openai_available
|
| 24 |
+
|
| 25 |
+
from ...import_utils import is_llm_blender_available
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
DEFAULT_PAIRWISE_SYSTEM_PROMPT = '''I require a leaderboard for various large language models. I'll provide you with prompts given to these models and their corresponding outputs. Your task is to assess these responses, and select the model that produces the best output from a human perspective.
|
| 29 |
+
|
| 30 |
+
## Instruction
|
| 31 |
+
|
| 32 |
+
{{
|
| 33 |
+
"instruction": """{prompt}""",
|
| 34 |
+
}}
|
| 35 |
+
|
| 36 |
+
## Model Outputs
|
| 37 |
+
|
| 38 |
+
Here are the unordered outputs from the models. Each output is associated with a specific model, identified by a unique model identifier.
|
| 39 |
+
|
| 40 |
+
{{
|
| 41 |
+
{{
|
| 42 |
+
"model_identifier": "0",
|
| 43 |
+
"output": """{response0}"""
|
| 44 |
+
}},
|
| 45 |
+
{{
|
| 46 |
+
"model_identifier": "1",
|
| 47 |
+
"output": """{response1}"""
|
| 48 |
+
}}
|
| 49 |
+
}}
|
| 50 |
+
|
| 51 |
+
## Task
|
| 52 |
+
|
| 53 |
+
Evaluate the models on the basis of the quality and relevance of their results, and select the model that generated the best result. Reply with the identifier of the best model. Our evaluation will only take into account the first character of your answer, so make sure it contains only one of the identifiers and nothing else (no quotation marks, no spaces, no new lines, ...).
|
| 54 |
+
'''
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _ensure_llm_blender_importable() -> None:
|
| 58 |
+
"""
|
| 59 |
+
Pre-import shim to work around a known `llm-blender` issue.
|
| 60 |
+
|
| 61 |
+
As of `llm-blender` v0.0.2 (see upstream issue: https://github.com/yuchenlin/LLM-Blender/issues/33), importing
|
| 62 |
+
`llm_blender` may fail on `transformers` >= 5.0.0 because it unconditionally accesses
|
| 63 |
+
`transformers.utils.hub.TRANSFORMERS_CACHE`.
|
| 64 |
+
|
| 65 |
+
We set this attribute to a dummy value before importing `llm_blender` so that the import succeeds. This helper is
|
| 66 |
+
intentionally a no-op on older `transformers` versions.
|
| 67 |
+
|
| 68 |
+
This shim can be removed once the upstream issue is fixed and the minimum required `llm-blender` version includes
|
| 69 |
+
that fix.
|
| 70 |
+
"""
|
| 71 |
+
import transformers.utils.hub
|
| 72 |
+
|
| 73 |
+
if Version(transformers.__version__) >= Version("5.0.0"):
|
| 74 |
+
transformers.utils.hub.TRANSFORMERS_CACHE = None # unused; just needs to exist
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class BaseJudge(ABC):
|
| 78 |
+
"""
|
| 79 |
+
Base class for judges. The subclasses of this class should implement the `judge` method.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
@abstractmethod
|
| 83 |
+
def judge(self, prompts: list[str], completions: list[str], shuffle_order: bool = True) -> list:
|
| 84 |
+
raise NotImplementedError("Judge subclasses must implement the `judge` method.")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class BaseRankJudge(ABC):
|
| 88 |
+
"""
|
| 89 |
+
Base class for LLM ranking judges.
|
| 90 |
+
|
| 91 |
+
**Example**:
|
| 92 |
+
```python
|
| 93 |
+
class MyRankJudge(BaseRankJudge):
|
| 94 |
+
def judge(self, prompts, completions, shuffle_order=True):
|
| 95 |
+
return ... # Your ranking logic here
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
judge = MyRankJudge()
|
| 99 |
+
judge.judge(
|
| 100 |
+
prompts=["The capital of France is", "The capital of Germany is"],
|
| 101 |
+
completions=[[" Paris", " Marseille", "Lyon"], [" Munich", " Berlin"]],
|
| 102 |
+
) # [[0, 1, 2], [1, 0]]
|
| 103 |
+
```
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
@abstractmethod
|
| 107 |
+
def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[list[int]]:
|
| 108 |
+
"""
|
| 109 |
+
Judge the completion for the given prompts and return the ranks of each completion.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
prompts (`list[str]`):
|
| 113 |
+
List of prompts.
|
| 114 |
+
completions (`list[list[str]]`):
|
| 115 |
+
List of completions list, where each element is a list of completions for the corresponding prompt.
|
| 116 |
+
shuffle_order (`bool`, *optional*, defaults to `True`):
|
| 117 |
+
Whether to shuffle the order of the completions to avoid positional bias.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
`list[list[int]]`:
|
| 121 |
+
List of lists of idxs, where each list contains the ranks of the completions for the corresponding
|
| 122 |
+
prompt. E.g., `[1, 2, 0]` means that the second completion (`idx=1`) is the best, followed by the
|
| 123 |
+
third, and then the first.
|
| 124 |
+
"""
|
| 125 |
+
raise NotImplementedError("Judge subclasses must implement the `judge` method.")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class BasePairwiseJudge(BaseJudge):
|
| 129 |
+
"""
|
| 130 |
+
Base class for pairwise judges.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
@abstractmethod
|
| 134 |
+
def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[int]:
|
| 135 |
+
"""
|
| 136 |
+
Judge the completion pairs for the given prompts.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
prompts (`list[str]`):
|
| 140 |
+
List of prompts.
|
| 141 |
+
completions (`list[list[str]]`):
|
| 142 |
+
List of completions pairs, where each element is a pair of completions for the corresponding prompt.
|
| 143 |
+
shuffle_order (`bool`, *optional*, defaults to `True`):
|
| 144 |
+
Whether to shuffle the order of the completions to avoid positional bias.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
`list[int]`:
|
| 148 |
+
List of idxs, where each idx is the rank of the best completion for the corresponding prompt. E.g., `1`
|
| 149 |
+
means that the second completion (`idx=1`) is the best.
|
| 150 |
+
|
| 151 |
+
Note:
|
| 152 |
+
If the judge returns `-1` for any prompt, it indicates that the inner process used to compute the
|
| 153 |
+
preference has failed. For instance, this could occur if the underlying language model returned an invalid
|
| 154 |
+
answer. In such cases, the caller should handle these invalid indices appropriately, possibly by
|
| 155 |
+
implementing fallback logic or error handling.
|
| 156 |
+
"""
|
| 157 |
+
raise NotImplementedError("Judge subclasses must implement the `judge` method.")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class BaseBinaryJudge(BaseJudge):
|
| 161 |
+
"""
|
| 162 |
+
Base class for binary judges.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
@abstractmethod
|
| 166 |
+
def judge(
|
| 167 |
+
self,
|
| 168 |
+
prompts: list[str],
|
| 169 |
+
completions: list[str],
|
| 170 |
+
gold_completions: list[str] | None = None,
|
| 171 |
+
shuffle_order: bool = True,
|
| 172 |
+
) -> list[int]:
|
| 173 |
+
"""
|
| 174 |
+
Judge the completion for a given prompt. Used to assess if a completion satisfies a constraint.
|
| 175 |
+
|
| 176 |
+
This base class should be used to implement binary evaluations as done in section 4.1.4 of the [CGPO
|
| 177 |
+
paper](https://huggingface.co/papers/2409.20370). It is relevant for assessing whether a prompt-completion pair
|
| 178 |
+
satisfies a specific constraint.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
prompts (`list[str]`): List of prompts.
|
| 182 |
+
completions (`list[str]`): List of completions.
|
| 183 |
+
gold_completions (`list[str]`, `optional`): List of gold completions if it exists.
|
| 184 |
+
shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
list[int]: A list of binary labels:
|
| 188 |
+
- 1 indicates that the completion satisfies the evaluated constraint.
|
| 189 |
+
- 0 indicates that the completion does not satisfy the evaluated constraint.
|
| 190 |
+
|
| 191 |
+
Note:
|
| 192 |
+
If the judge returns -1 for any prompt, it indicates that the inner process used to compute the preference
|
| 193 |
+
has failed. For instance, this could occur if the underlying language model or rule based constraint
|
| 194 |
+
returned an invalid answer. In such cases, the caller should handle these invalid indices appropriately,
|
| 195 |
+
possibly by implementing fallback logic or error handling.
|
| 196 |
+
"""
|
| 197 |
+
raise NotImplementedError("Judge subclasses must implement the `judge` method.")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class PairRMJudge(BasePairwiseJudge):
|
| 201 |
+
# docstyle-ignore
|
| 202 |
+
"""
|
| 203 |
+
LLM judge based on the PairRM model from AllenAI.
|
| 204 |
+
|
| 205 |
+
This judge uses the PairRM model to rank pairs of completions for given prompts. It's designed for pairwise
|
| 206 |
+
comparison of language model outputs. The PairRM model is loaded using the llm-blender library and runs on the
|
| 207 |
+
default Accelerator device.
|
| 208 |
+
|
| 209 |
+
**Attributes**:
|
| 210 |
+
|
| 211 |
+
blender (`llm_blender.Blender`):
|
| 212 |
+
An instance of the Blender class from llm-blender.
|
| 213 |
+
|
| 214 |
+
**Example**:
|
| 215 |
+
```python
|
| 216 |
+
>>> pairrm_judge = PairRMJudge()
|
| 217 |
+
>>> prompts = ["Translate 'hello' to French", "What's the capital of Japan?"]
|
| 218 |
+
>>> completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]]
|
| 219 |
+
>>> results = pairrm_judge.judge(prompts, completions)
|
| 220 |
+
>>> print(results) # [0, 1] (indicating the first completion is preferred for the first prompt and the second)
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
> [!TIP]
|
| 224 |
+
> This class requires the llm-blender library to be installed. Install it with: `pip install llm-blender`.
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
def __init__(self):
|
| 228 |
+
if not is_llm_blender_available():
|
| 229 |
+
raise ValueError("llm-blender is not installed. Please install it with `pip install llm-blender`.")
|
| 230 |
+
import transformers
|
| 231 |
+
|
| 232 |
+
if Version(transformers.__version__) >= Version("5.0.0"):
|
| 233 |
+
raise RuntimeError(
|
| 234 |
+
"llm-blender currently supports transformers < 5.0.0. Please install a compatible version: `pip install 'transformers<5.0.0'`. Check the issue tracker for updates: https://github.com/huggingface/trl/issues/4918"
|
| 235 |
+
)
|
| 236 |
+
_ensure_llm_blender_importable()
|
| 237 |
+
import llm_blender
|
| 238 |
+
|
| 239 |
+
self.blender = llm_blender.Blender()
|
| 240 |
+
self.blender.loadranker("llm-blender/PairRM", device=Accelerator().device)
|
| 241 |
+
|
| 242 |
+
def judge(
|
| 243 |
+
self,
|
| 244 |
+
prompts: list[str],
|
| 245 |
+
completions: list[list[str]],
|
| 246 |
+
shuffle_order: bool = True,
|
| 247 |
+
return_scores: bool = False,
|
| 248 |
+
temperature: float = 1.0,
|
| 249 |
+
) -> list[int | float]:
|
| 250 |
+
"""
|
| 251 |
+
Judge the completion pairs for the given prompts using the PairRM model.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
prompts (`list[str]`):
|
| 255 |
+
List of prompts to judge.
|
| 256 |
+
completions (`list[list[str]]`):
|
| 257 |
+
List of completion pairs for each prompt.
|
| 258 |
+
shuffle_order (`bool`, *optional*, defaults to `True`):
|
| 259 |
+
Whether to shuffle the order of the completions to avoid positional bias.
|
| 260 |
+
return_scores (`bool`, *optional*, defaults to `False`):
|
| 261 |
+
If `True`, return probability scores of the first completion instead of ranks (i.e. a *soft-judge*).
|
| 262 |
+
temperature (`float`, *optional*, defaults to `1.0`):
|
| 263 |
+
Temperature for scaling logits if `return_scores` is True.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
`list[int | float]`:
|
| 267 |
+
If `return_scores` is `False`, returns a list of ranks (`0` or `1`) for each prompt, indicating which
|
| 268 |
+
completion is preferred. If `return_scores` is `True`, returns softmax probabilities for the first
|
| 269 |
+
completion.
|
| 270 |
+
|
| 271 |
+
Raises:
|
| 272 |
+
`ValueError`:
|
| 273 |
+
If the number of completions per prompt is not exactly 2.
|
| 274 |
+
|
| 275 |
+
Note:
|
| 276 |
+
Unlike llm-blender, ranks are 0-indexed (`0` means the first completion is preferred).
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
if len(completions[0]) != 2:
|
| 280 |
+
raise ValueError("PairRM judge requires exactly 2 completions per prompt.")
|
| 281 |
+
|
| 282 |
+
# Shuffle the order of the completions to avoid positional bias
|
| 283 |
+
if shuffle_order:
|
| 284 |
+
flip_mask = np.random.choice([True, False], size=len(prompts))
|
| 285 |
+
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)]
|
| 286 |
+
|
| 287 |
+
# Rank the completions
|
| 288 |
+
ranks = self.blender.rank(prompts, completions, return_scores=return_scores, disable_tqdm=True)
|
| 289 |
+
if not return_scores:
|
| 290 |
+
ranks -= 1 # PairRM rank is 1-indexed, so we subtract 1 to make it 0-indexed
|
| 291 |
+
else:
|
| 292 |
+
# scale the logits by temperature
|
| 293 |
+
ranks /= temperature
|
| 294 |
+
|
| 295 |
+
# Flip back the ranks or scores to the original order if needed
|
| 296 |
+
if shuffle_order:
|
| 297 |
+
ranks[flip_mask] = ranks[flip_mask][:, ::-1]
|
| 298 |
+
|
| 299 |
+
# Return the ranks or score probability
|
| 300 |
+
if return_scores:
|
| 301 |
+
logit_max = np.amax(ranks, axis=-1, keepdims=True)
|
| 302 |
+
exp_logit_shifted = np.exp(ranks - logit_max)
|
| 303 |
+
probs = exp_logit_shifted / np.sum(exp_logit_shifted, axis=-1, keepdims=True)
|
| 304 |
+
return probs[:, 0].tolist()
|
| 305 |
+
else:
|
| 306 |
+
return ranks[:, 0].tolist()
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class HfPairwiseJudge(BasePairwiseJudge):
|
| 310 |
+
"""
|
| 311 |
+
Pairwise judge based on the Hugging Face API with chat completion.
|
| 312 |
+
|
| 313 |
+
This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3-70B-Instruct"`):
|
| 317 |
+
Model to use for the judge.
|
| 318 |
+
token (`str`, *optional*):
|
| 319 |
+
Hugging Face API token to use for the [`huggingface_hub.InferenceClient`].
|
| 320 |
+
system_prompt (`str`, *optional*):
|
| 321 |
+
The system prompt to be used for the judge. If not provided, a default prompt is used. Note that the system
|
| 322 |
+
prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the
|
| 323 |
+
inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token
|
| 324 |
+
response.
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
def __init__(
|
| 328 |
+
self,
|
| 329 |
+
model="meta-llama/Meta-Llama-3-70B-Instruct",
|
| 330 |
+
token: str | None = None,
|
| 331 |
+
system_prompt: str | None = None,
|
| 332 |
+
):
|
| 333 |
+
self.client = InferenceClient(model=model, token=token)
|
| 334 |
+
self.system_prompt = system_prompt or DEFAULT_PAIRWISE_SYSTEM_PROMPT
|
| 335 |
+
|
| 336 |
+
def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[int]:
|
| 337 |
+
# Shuffle the order of the completions to avoid positional bias
|
| 338 |
+
if shuffle_order:
|
| 339 |
+
flip_mask = np.random.choice([True, False], size=len(prompts))
|
| 340 |
+
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)]
|
| 341 |
+
|
| 342 |
+
# Define a function to get the rank for a single prompt, will be called concurrently
|
| 343 |
+
def get_rank(prompt, candidates):
|
| 344 |
+
content = self.system_prompt.format(prompt=prompt, response0=candidates[0], response1=candidates[1])
|
| 345 |
+
completion = self.client.chat_completion(messages=[{"role": "user", "content": content}], max_tokens=1)
|
| 346 |
+
response = completion.choices[0].message.content
|
| 347 |
+
if response in ["0", "1"]:
|
| 348 |
+
return int(response)
|
| 349 |
+
else:
|
| 350 |
+
logging.debug(f"Invalid response from the judge model: '{response}'. Returning -1.")
|
| 351 |
+
return -1
|
| 352 |
+
|
| 353 |
+
# Call the completions concurrently
|
| 354 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 355 |
+
ranks = list(executor.map(get_rank, prompts, completions))
|
| 356 |
+
|
| 357 |
+
# Flip back the ranks to the original order if needed
|
| 358 |
+
if shuffle_order:
|
| 359 |
+
ranks = [ranks[i] if not flip else 1 - ranks[i] for i, flip in enumerate(flip_mask)]
|
| 360 |
+
|
| 361 |
+
# Return the ranks
|
| 362 |
+
return ranks
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class OpenAIPairwiseJudge(BasePairwiseJudge):
|
| 366 |
+
"""
|
| 367 |
+
Judge based on the OpenAI API.
|
| 368 |
+
|
| 369 |
+
This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
model (`str`, *optional*, defaults to `"gpt-4-turbo-preview"`):
|
| 373 |
+
Model to use for the judge.
|
| 374 |
+
system_prompt (`str`, *optional*):
|
| 375 |
+
System prompt to be used for the judge. If not provided, a default prompt is used. Note that the system
|
| 376 |
+
prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the
|
| 377 |
+
inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token
|
| 378 |
+
response.
|
| 379 |
+
max_requests (`int` or `None`, *optional*, defaults to `1000`):
|
| 380 |
+
Maximum number of requests to make to the OpenAI API. If set to `None`, there is no limit.
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
def __init__(
|
| 384 |
+
self, model="gpt-4-turbo-preview", system_prompt: str | None = None, max_requests: int | None = 1_000
|
| 385 |
+
):
|
| 386 |
+
if not is_openai_available():
|
| 387 |
+
raise ValueError("OpenAI client is not installed. Please install it with 'pip install openai'.")
|
| 388 |
+
from openai import OpenAI
|
| 389 |
+
|
| 390 |
+
self.client = OpenAI()
|
| 391 |
+
self.model = model
|
| 392 |
+
self.system_prompt = system_prompt or DEFAULT_PAIRWISE_SYSTEM_PROMPT
|
| 393 |
+
self.max_requests = max_requests
|
| 394 |
+
self.num_requests = 0
|
| 395 |
+
self._warned = False
|
| 396 |
+
|
| 397 |
+
def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[int]:
|
| 398 |
+
# Check if the limit of requests is reached, if so, use random choice instead
|
| 399 |
+
if self.max_requests is not None and self.num_requests >= self.max_requests:
|
| 400 |
+
if not self._warned: # Print the warning only once
|
| 401 |
+
logging.warning(
|
| 402 |
+
f"Reached the maximum number of requests ({self.max_requests}). From now on, returning -1 instead. "
|
| 403 |
+
" To increase the limit, set `max_requests` to a higher value, or to `None` for no limit."
|
| 404 |
+
)
|
| 405 |
+
self._warned = True
|
| 406 |
+
return [-1] * len(prompts)
|
| 407 |
+
|
| 408 |
+
# Shuffle the order of the completions to avoid positional bias
|
| 409 |
+
if shuffle_order:
|
| 410 |
+
flip_mask = np.random.choice([True, False], size=len(prompts))
|
| 411 |
+
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)]
|
| 412 |
+
|
| 413 |
+
# Define a function to get the rank for a single prompt, will be called concurrently
|
| 414 |
+
def get_rank(prompt, candidates):
|
| 415 |
+
content = self.system_prompt.format(prompt=prompt, response0=candidates[0], response1=candidates[1])
|
| 416 |
+
messages = [{"role": "user", "content": content}]
|
| 417 |
+
completion = self.client.chat.completions.create(model=self.model, messages=messages, max_tokens=1)
|
| 418 |
+
response = completion.choices[0].message.content
|
| 419 |
+
if response in ["0", "1"]:
|
| 420 |
+
return int(response)
|
| 421 |
+
else:
|
| 422 |
+
logging.debug(f"Invalid response from the judge model: '{response}'. Returning -1.")
|
| 423 |
+
return -1
|
| 424 |
+
|
| 425 |
+
# Call the completions concurrently
|
| 426 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 427 |
+
ranks = list(executor.map(get_rank, prompts, completions))
|
| 428 |
+
|
| 429 |
+
# Flip back the ranks to the original order if needed
|
| 430 |
+
if shuffle_order:
|
| 431 |
+
ranks = [ranks[i] if not flip else 1 - ranks[i] for i, flip in enumerate(flip_mask)]
|
| 432 |
+
|
| 433 |
+
# Update the number of requests
|
| 434 |
+
self.num_requests += len(prompts)
|
| 435 |
+
|
| 436 |
+
# Return the ranks
|
| 437 |
+
return ranks
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class AllTrueJudge(BaseBinaryJudge):
|
| 441 |
+
"""
|
| 442 |
+
Unify the decision of multiple [`experimental.judges.BaseBinaryJudge`] instances.
|
| 443 |
+
|
| 444 |
+
Returns `1` only if all inner binary judges return `1`. If any judge returns `0`, it returns `0`. If any judge
|
| 445 |
+
returns `-1`, indicating a failure in its process, this judge will also return `-1`.
|
| 446 |
+
|
| 447 |
+
Implements the Mixture of Judges as described in the [CGPO paper](https://huggingface.co/papers/2409.20370).
|
| 448 |
+
|
| 449 |
+
Args:
|
| 450 |
+
judges (`list` of [`experimental.judges.BaseBinaryJudge`]):
|
| 451 |
+
A list of [`experimental.judges.BaseBinaryJudge`] instances whose decisions will be unified.
|
| 452 |
+
"""
|
| 453 |
+
|
| 454 |
+
def __init__(self, judges: list[BaseBinaryJudge]):
|
| 455 |
+
self.judges = judges
|
| 456 |
+
|
| 457 |
+
def judge(
|
| 458 |
+
self,
|
| 459 |
+
prompts: list[str],
|
| 460 |
+
completions: list[str],
|
| 461 |
+
gold_completions: list[str] | None = None,
|
| 462 |
+
shuffle_order: bool = True,
|
| 463 |
+
) -> list[int]:
|
| 464 |
+
all_binary_judgments = [
|
| 465 |
+
judge.judge(prompts, completions, gold_completions, shuffle_order) for judge in self.judges
|
| 466 |
+
]
|
| 467 |
+
output = []
|
| 468 |
+
for binary_judgments in zip(*all_binary_judgments, strict=True):
|
| 469 |
+
# Check that all values are in {0, 1, -1}
|
| 470 |
+
if any(binary_judgment not in {0, 1, -1} for binary_judgment in binary_judgments):
|
| 471 |
+
raise ValueError(
|
| 472 |
+
f"Invalid binary judgment: {binary_judgments}, expected list of values in {{0, 1, -1}}."
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# Unify the decision
|
| 476 |
+
if -1 in binary_judgments:
|
| 477 |
+
output.append(-1)
|
| 478 |
+
elif all(binary_judgment == 1 for binary_judgment in binary_judgments):
|
| 479 |
+
output.append(1)
|
| 480 |
+
else:
|
| 481 |
+
output.append(0)
|
| 482 |
+
return output
|
ICL/RL/trl_source/trl/experimental/kto/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .kto_config import KTOConfig
|
| 16 |
+
from .kto_trainer import KTOTrainer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = ["KTOConfig", "KTOTrainer"]
|
ICL/RL/trl_source/trl/experimental/kto/kto_config.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from transformers import TrainingArguments
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class KTOConfig(TrainingArguments):
|
| 23 |
+
r"""
|
| 24 |
+
Configuration class for the [`experimental.kto.KTOTrainer`].
|
| 25 |
+
|
| 26 |
+
This class includes only the parameters that are specific to KTO training. For a full list of training arguments,
|
| 27 |
+
please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
|
| 28 |
+
differ from those in [`~transformers.TrainingArguments`].
|
| 29 |
+
|
| 30 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 31 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 32 |
+
command line.
|
| 33 |
+
|
| 34 |
+
Parameters:
|
| 35 |
+
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
| 36 |
+
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
|
| 37 |
+
to use the default data collator.
|
| 38 |
+
beta (`float`, *optional*, defaults to `0.1`):
|
| 39 |
+
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
| 40 |
+
reference model.
|
| 41 |
+
loss_type (`str`, *optional*, defaults to `"kto"`):
|
| 42 |
+
Type of loss to use. Possible values are:
|
| 43 |
+
|
| 44 |
+
- `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
|
| 45 |
+
- `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the
|
| 46 |
+
[APO](https://huggingface.co/papers/2408.06266) paper.
|
| 47 |
+
|
| 48 |
+
desirable_weight (`float`, *optional*, defaults to `1.0`):
|
| 49 |
+
Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
|
| 50 |
+
undesirable_weight (`float`, *optional*, defaults to `1.0`):
|
| 51 |
+
Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
|
| 52 |
+
generate_during_eval (`bool`, *optional*, defaults to `False`):
|
| 53 |
+
If `True`, generates and logs completions from both the model and the reference model to W&B or Comet
|
| 54 |
+
during evaluation.
|
| 55 |
+
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
|
| 56 |
+
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
|
| 57 |
+
useful when training without the reference model to reduce the total GPU memory needed.
|
| 58 |
+
model_init_kwargs (`dict[str, Any]`, *optional*):
|
| 59 |
+
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
|
| 60 |
+
string.
|
| 61 |
+
dataset_num_proc: (`int`, *optional*):
|
| 62 |
+
Number of processes to use for processing the dataset.
|
| 63 |
+
disable_dropout (`bool`, *optional*, defaults to `True`):
|
| 64 |
+
Whether to disable dropout in the model and reference model.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
|
| 68 |
+
|
| 69 |
+
# Parameters whose default values are overridden from TrainingArguments
|
| 70 |
+
learning_rate: float = field(
|
| 71 |
+
default=1e-6,
|
| 72 |
+
metadata={"help": "The initial learning rate for AdamW."},
|
| 73 |
+
)
|
| 74 |
+
logging_steps: float = field(
|
| 75 |
+
default=10,
|
| 76 |
+
metadata={
|
| 77 |
+
"help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, "
|
| 78 |
+
"will be interpreted as ratio of total training steps."
|
| 79 |
+
},
|
| 80 |
+
)
|
| 81 |
+
gradient_checkpointing: bool = field(
|
| 82 |
+
default=True,
|
| 83 |
+
metadata={
|
| 84 |
+
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
| 85 |
+
},
|
| 86 |
+
)
|
| 87 |
+
bf16: bool | None = field(
|
| 88 |
+
default=None,
|
| 89 |
+
metadata={
|
| 90 |
+
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
|
| 91 |
+
"architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if "
|
| 92 |
+
"`fp16` is not set."
|
| 93 |
+
},
|
| 94 |
+
)
|
| 95 |
+
# Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue
|
| 96 |
+
# was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary
|
| 97 |
+
# workaround here, which can be removed once we drop support for versions older than 4.57.5.
|
| 98 |
+
lr_scheduler_kwargs: dict | str | None = field(
|
| 99 |
+
default=None,
|
| 100 |
+
metadata={
|
| 101 |
+
"help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard "
|
| 102 |
+
"restarts."
|
| 103 |
+
},
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
max_length: int | None = field(
|
| 107 |
+
default=1024,
|
| 108 |
+
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
|
| 109 |
+
)
|
| 110 |
+
beta: float = field(
|
| 111 |
+
default=0.1,
|
| 112 |
+
metadata={
|
| 113 |
+
"help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from "
|
| 114 |
+
"the reference model."
|
| 115 |
+
},
|
| 116 |
+
)
|
| 117 |
+
loss_type: str = field(
|
| 118 |
+
default="kto",
|
| 119 |
+
metadata={
|
| 120 |
+
"help": "Type of loss to use.",
|
| 121 |
+
"choices": ["kto", "apo_zero_unpaired"],
|
| 122 |
+
},
|
| 123 |
+
)
|
| 124 |
+
desirable_weight: float = field(
|
| 125 |
+
default=1.0,
|
| 126 |
+
metadata={
|
| 127 |
+
"help": "Desirable losses are weighed by this factor to counter unequal number of desirable and "
|
| 128 |
+
"undesirable pairs.",
|
| 129 |
+
},
|
| 130 |
+
)
|
| 131 |
+
undesirable_weight: float = field(
|
| 132 |
+
default=1.0,
|
| 133 |
+
metadata={
|
| 134 |
+
"help": "Undesirable losses are weighed by this factor to counter unequal number of desirable and "
|
| 135 |
+
"undesirable pairs.",
|
| 136 |
+
},
|
| 137 |
+
)
|
| 138 |
+
generate_during_eval: bool = field(
|
| 139 |
+
default=False,
|
| 140 |
+
metadata={
|
| 141 |
+
"help": "If `True`, generates and logs completions from both the model and the reference model to W&B "
|
| 142 |
+
"during evaluation."
|
| 143 |
+
},
|
| 144 |
+
)
|
| 145 |
+
disable_dropout: bool = field(
|
| 146 |
+
default=True,
|
| 147 |
+
metadata={"help": "Whether to disable dropout in the model."},
|
| 148 |
+
)
|
| 149 |
+
precompute_ref_log_probs: bool = field(
|
| 150 |
+
default=False,
|
| 151 |
+
metadata={
|
| 152 |
+
"help": "Whether to precompute reference model log probabilities for training and evaluation datasets. "
|
| 153 |
+
"This is useful when training without the reference model to reduce the total GPU memory needed."
|
| 154 |
+
},
|
| 155 |
+
)
|
| 156 |
+
model_init_kwargs: dict[str, Any] | None = field(
|
| 157 |
+
default=None,
|
| 158 |
+
metadata={
|
| 159 |
+
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
|
| 160 |
+
"from a string."
|
| 161 |
+
},
|
| 162 |
+
)
|
| 163 |
+
dataset_num_proc: int | None = field(
|
| 164 |
+
default=None,
|
| 165 |
+
metadata={"help": "Number of processes to use for processing the dataset."},
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def __post_init__(self):
|
| 169 |
+
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
|
| 170 |
+
|
| 171 |
+
super().__post_init__()
|
ICL/RL/trl_source/trl/experimental/kto/kto_trainer.py
ADDED
|
@@ -0,0 +1,1511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
import random
|
| 17 |
+
import textwrap
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
from collections.abc import Callable
|
| 20 |
+
from contextlib import contextmanager, nullcontext
|
| 21 |
+
from operator import itemgetter
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import TYPE_CHECKING, Any, Literal
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import pandas as pd
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
import transformers
|
| 31 |
+
from accelerate import PartialState, logging
|
| 32 |
+
from accelerate.utils import tqdm
|
| 33 |
+
from datasets import Dataset, concatenate_datasets
|
| 34 |
+
from packaging.version import Version
|
| 35 |
+
from torch import autocast
|
| 36 |
+
from torch.utils.data import DataLoader, SequentialSampler
|
| 37 |
+
from transformers import (
|
| 38 |
+
BaseImageProcessor,
|
| 39 |
+
DataCollator,
|
| 40 |
+
FeatureExtractionMixin,
|
| 41 |
+
PreTrainedModel,
|
| 42 |
+
PreTrainedTokenizerBase,
|
| 43 |
+
ProcessorMixin,
|
| 44 |
+
TrainerCallback,
|
| 45 |
+
TrainingArguments,
|
| 46 |
+
is_comet_available,
|
| 47 |
+
is_wandb_available,
|
| 48 |
+
)
|
| 49 |
+
from transformers.trainer_utils import EvalLoopOutput, has_length
|
| 50 |
+
from transformers.utils import is_peft_available
|
| 51 |
+
|
| 52 |
+
from ...data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset
|
| 53 |
+
from ...import_utils import is_liger_kernel_available
|
| 54 |
+
from ...models.utils import create_reference_model, peft_module_casting_to_bf16, prepare_deepspeed
|
| 55 |
+
from ...trainer.base_trainer import BaseTrainer
|
| 56 |
+
from ...trainer.utils import (
|
| 57 |
+
create_model_from_path,
|
| 58 |
+
disable_dropout_in_model,
|
| 59 |
+
log_table_to_comet_experiment,
|
| 60 |
+
pad_to_length,
|
| 61 |
+
selective_log_softmax,
|
| 62 |
+
)
|
| 63 |
+
from ..utils import DPODataCollatorWithPadding
|
| 64 |
+
from .kto_config import KTOConfig
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if is_liger_kernel_available():
|
| 68 |
+
from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss
|
| 69 |
+
|
| 70 |
+
if is_peft_available():
|
| 71 |
+
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
|
| 72 |
+
|
| 73 |
+
if is_wandb_available():
|
| 74 |
+
import wandb
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if TYPE_CHECKING:
|
| 78 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
logger = logging.get_logger(__name__)
|
| 82 |
+
|
| 83 |
+
RUNNING_NAME = "running.pt"
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _get_kl_dataset(batch: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
| 87 |
+
"""
|
| 88 |
+
Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of
|
| 89 |
+
completions. For best results, the mismatched outputs y' used to estimate the KL term for a batch should be the
|
| 90 |
+
same set as the matched outputs y used to estimate the rewards in that batch, just paired with different x.
|
| 91 |
+
"""
|
| 92 |
+
batch["answer_input_ids"] = [batch["answer_input_ids"][-1]] + batch["answer_input_ids"][:-1]
|
| 93 |
+
batch["answer_attention_mask"] = [batch["answer_attention_mask"][-1]] + batch["answer_attention_mask"][:-1]
|
| 94 |
+
return batch
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _tokenize(
|
| 98 |
+
batch: dict[str, list[Any]],
|
| 99 |
+
tokenizer: "PreTrainedTokenizer",
|
| 100 |
+
) -> dict[str, list[Any]]:
|
| 101 |
+
"""Tokenize a batch from a KTO specific dataset."""
|
| 102 |
+
prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False)
|
| 103 |
+
prompt_input_ids = prompt_tokenized["input_ids"]
|
| 104 |
+
prompt_attention_mask = prompt_tokenized["attention_mask"]
|
| 105 |
+
prompt_and_completion = [
|
| 106 |
+
prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"], strict=True)
|
| 107 |
+
]
|
| 108 |
+
full_tokenized = tokenizer(prompt_and_completion, add_special_tokens=False)
|
| 109 |
+
full_input_ids = full_tokenized["input_ids"]
|
| 110 |
+
full_attention_mask = full_tokenized["attention_mask"]
|
| 111 |
+
|
| 112 |
+
answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids, strict=True)]
|
| 113 |
+
answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask, strict=True)]
|
| 114 |
+
|
| 115 |
+
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
| 116 |
+
full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids, strict=True)]
|
| 117 |
+
# Prepare input tokens for token by token comparison
|
| 118 |
+
full_input_ids = [np.array(f) for f in full_input_ids]
|
| 119 |
+
for full, concat in zip(full_input_ids, full_concat_input_ids, strict=True):
|
| 120 |
+
if len(full) != len(concat):
|
| 121 |
+
raise ValueError(
|
| 122 |
+
"The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
| 126 |
+
# can be merged together when tokenizing prompt+answer. This could result
|
| 127 |
+
# on the last token from the prompt being different when tokenized on its own
|
| 128 |
+
# vs when done as prompt+answer.
|
| 129 |
+
response_token_ids_start_idx = [len(p) for p in prompt_input_ids]
|
| 130 |
+
|
| 131 |
+
# If tokenized prompt is different than both prompt+answer, then it means the
|
| 132 |
+
# last token has changed due to merging.
|
| 133 |
+
for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx, strict=True)):
|
| 134 |
+
if not np.array_equal(p, f[:r]):
|
| 135 |
+
response_token_ids_start_idx[idx] -= 1
|
| 136 |
+
|
| 137 |
+
prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)]
|
| 138 |
+
prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)]
|
| 139 |
+
|
| 140 |
+
for p, m in zip(prompt_input_ids, prompt_attention_mask, strict=True):
|
| 141 |
+
if len(p) != len(m):
|
| 142 |
+
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
| 143 |
+
|
| 144 |
+
answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)]
|
| 145 |
+
answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)]
|
| 146 |
+
|
| 147 |
+
output = dict(
|
| 148 |
+
prompt_input_ids=prompt_input_ids,
|
| 149 |
+
prompt_attention_mask=prompt_attention_mask,
|
| 150 |
+
answer_input_ids=answer_input_ids,
|
| 151 |
+
answer_attention_mask=answer_attention_mask,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
return output
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **kwargs) -> dict:
|
| 158 |
+
"""Process tokens of a KTO specific dataset.
|
| 159 |
+
|
| 160 |
+
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
|
| 161 |
+
completion responses is/are too long. We truncate from the end (completion) to fit within max_length.
|
| 162 |
+
|
| 163 |
+
We also create the labels for the completion responses, which are of length equal to the sum of the length of the
|
| 164 |
+
prompt and the completion response, with `-100` for the prompt tokens.
|
| 165 |
+
"""
|
| 166 |
+
prompt = example["prompt"]
|
| 167 |
+
completion = example["completion"]
|
| 168 |
+
|
| 169 |
+
batch = {
|
| 170 |
+
f"{kwargs['prefix']}prompt": prompt,
|
| 171 |
+
f"{kwargs['prefix']}completion": completion,
|
| 172 |
+
f"{kwargs['prefix']}label": example["label"],
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
# Check issues below for more details
|
| 176 |
+
# 1. https://github.com/huggingface/trl/issues/907
|
| 177 |
+
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
| 178 |
+
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
| 179 |
+
|
| 180 |
+
if not isinstance(prompt, str):
|
| 181 |
+
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
| 182 |
+
|
| 183 |
+
if not isinstance(completion, str):
|
| 184 |
+
raise ValueError(f"completion should be an str but got {type(completion)}")
|
| 185 |
+
|
| 186 |
+
# keys of format prompt_* refers to just the prompt and answer_* refers to just the answer
|
| 187 |
+
all_tokens = {
|
| 188 |
+
"prompt_input_ids": example["prompt_input_ids"],
|
| 189 |
+
"prompt_attention_mask": example["prompt_attention_mask"],
|
| 190 |
+
"answer_input_ids": example["answer_input_ids"],
|
| 191 |
+
"answer_attention_mask": example["answer_attention_mask"],
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
# calculate max length by checking if BOS/EOS is already there
|
| 195 |
+
max_length = kwargs["max_length"]
|
| 196 |
+
bos_token_id = kwargs["tokenizer"].bos_token_id
|
| 197 |
+
eos_token_id = kwargs["tokenizer"].eos_token_id
|
| 198 |
+
if len(all_tokens["prompt_input_ids"]) > 0 and bos_token_id != all_tokens["prompt_input_ids"][0]:
|
| 199 |
+
max_length -= 1
|
| 200 |
+
if len(all_tokens["answer_input_ids"]) > 0 and eos_token_id != all_tokens["answer_input_ids"][-1]:
|
| 201 |
+
max_length -= 1
|
| 202 |
+
|
| 203 |
+
# if combined sequence is too long, truncate the completion (answer) from the end
|
| 204 |
+
prompt_length = len(all_tokens["prompt_input_ids"])
|
| 205 |
+
completion_length = len(all_tokens["answer_input_ids"])
|
| 206 |
+
if prompt_length + completion_length > max_length:
|
| 207 |
+
max_completion_length = max_length - prompt_length
|
| 208 |
+
for k in ["answer_input_ids", "answer_attention_mask"]:
|
| 209 |
+
all_tokens[k] = all_tokens[k][:max_completion_length]
|
| 210 |
+
|
| 211 |
+
# all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens
|
| 212 |
+
batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"]
|
| 213 |
+
batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"]
|
| 214 |
+
batch[f"{kwargs['prefix']}completion_input_ids"] = all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"]
|
| 215 |
+
batch[f"{kwargs['prefix']}completion_attention_mask"] = (
|
| 216 |
+
all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"]
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# add BOS, which affects both prompt and the full completion
|
| 220 |
+
if bos_token_id is not None:
|
| 221 |
+
if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]:
|
| 222 |
+
batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[
|
| 223 |
+
f"{kwargs['prefix']}prompt_input_ids"
|
| 224 |
+
]
|
| 225 |
+
batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"]
|
| 226 |
+
batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[
|
| 227 |
+
f"{kwargs['prefix']}completion_input_ids"
|
| 228 |
+
]
|
| 229 |
+
batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[
|
| 230 |
+
f"{kwargs['prefix']}completion_attention_mask"
|
| 231 |
+
]
|
| 232 |
+
# add EOS, which affects only the full completion
|
| 233 |
+
if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]:
|
| 234 |
+
batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [
|
| 235 |
+
eos_token_id
|
| 236 |
+
]
|
| 237 |
+
batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[
|
| 238 |
+
f"{kwargs['prefix']}completion_attention_mask"
|
| 239 |
+
] + [1]
|
| 240 |
+
|
| 241 |
+
batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:]
|
| 242 |
+
batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [-100] * len(
|
| 243 |
+
batch[f"{kwargs['prefix']}prompt_input_ids"]
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
return batch
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class KTOTrainer(BaseTrainer):
|
| 250 |
+
r"""
|
| 251 |
+
Initialize KTOTrainer.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
model ([`~transformers.PreTrainedModel`]):
|
| 255 |
+
The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
|
| 256 |
+
ref_model ([`~transformers.PreTrainedModel`]):
|
| 257 |
+
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
|
| 258 |
+
and loss. If no reference model is provided, the trainer will create a reference model with the same
|
| 259 |
+
architecture as the model to be optimized.
|
| 260 |
+
args ([`experimental.kto.KTOConfig`]):
|
| 261 |
+
The arguments to use for training.
|
| 262 |
+
train_dataset ([`~datasets.Dataset`]):
|
| 263 |
+
The dataset to use for training.
|
| 264 |
+
eval_dataset ([`~datasets.Dataset`]):
|
| 265 |
+
The dataset to use for evaluation.
|
| 266 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
| 267 |
+
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
| 268 |
+
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
| 269 |
+
reuse the fine-tuned model.
|
| 270 |
+
data_collator ([`~transformers.DataCollator`], *optional*):
|
| 271 |
+
The data collator to use for training. If None is specified, the default data collator
|
| 272 |
+
([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the
|
| 273 |
+
maximum length of the sequences in the batch, given a dataset of paired sequences.
|
| 274 |
+
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
| 275 |
+
The model initializer to use for training. If None is specified, the default model initializer will be
|
| 276 |
+
used.
|
| 277 |
+
callbacks (`list[transformers.TrainerCallback]`):
|
| 278 |
+
The callbacks to use for training.
|
| 279 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
| 280 |
+
The optimizer and scheduler to use for training.
|
| 281 |
+
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
| 282 |
+
The function to use to preprocess the logits before computing the metrics.
|
| 283 |
+
peft_config (`dict`, defaults to `None`):
|
| 284 |
+
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
|
| 285 |
+
a PEFT model.
|
| 286 |
+
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
| 287 |
+
The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
|
| 288 |
+
metric values.
|
| 289 |
+
model_adapter_name (`str`, defaults to `None`):
|
| 290 |
+
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
| 291 |
+
ref_adapter_name (`str`, defaults to `None`):
|
| 292 |
+
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
_tag_names = ["trl", "kto"]
|
| 296 |
+
_name = "KTO"
|
| 297 |
+
_paper = {
|
| 298 |
+
"title": "KTO: Model Alignment as Prospect Theoretic Optimization",
|
| 299 |
+
"id": "2402.01306",
|
| 300 |
+
# docstyle-ignore
|
| 301 |
+
"citation": textwrap.dedent("""\
|
| 302 |
+
@article{ethayarajh2024kto,
|
| 303 |
+
title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
|
| 304 |
+
author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
|
| 305 |
+
year = 2024,
|
| 306 |
+
eprint = {arXiv:2402.01306},
|
| 307 |
+
}"""),
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
def __init__(
|
| 311 |
+
self,
|
| 312 |
+
model: PreTrainedModel | nn.Module | str = None,
|
| 313 |
+
ref_model: PreTrainedModel | nn.Module | str | None = None,
|
| 314 |
+
args: KTOConfig = None,
|
| 315 |
+
train_dataset: Dataset | None = None,
|
| 316 |
+
eval_dataset: Dataset | dict[str, Dataset] | None = None,
|
| 317 |
+
processing_class: PreTrainedTokenizerBase
|
| 318 |
+
| BaseImageProcessor
|
| 319 |
+
| FeatureExtractionMixin
|
| 320 |
+
| ProcessorMixin
|
| 321 |
+
| None = None,
|
| 322 |
+
data_collator: DataCollator | None = None,
|
| 323 |
+
model_init: Callable[[], PreTrainedModel] | None = None,
|
| 324 |
+
callbacks: list[TrainerCallback] | None = None,
|
| 325 |
+
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 326 |
+
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
| 327 |
+
peft_config: dict | None = None,
|
| 328 |
+
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
|
| 329 |
+
model_adapter_name: str | None = None,
|
| 330 |
+
ref_adapter_name: str | None = None,
|
| 331 |
+
):
|
| 332 |
+
if type(args) is TrainingArguments:
|
| 333 |
+
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
| 334 |
+
|
| 335 |
+
if not isinstance(model, str) and ref_model is model:
|
| 336 |
+
raise ValueError(
|
| 337 |
+
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
| 338 |
+
"same as `model`, you must mass a copy of it, or `None` if you use peft."
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Model initialization
|
| 342 |
+
if isinstance(model, str):
|
| 343 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
| 344 |
+
# Distributed training requires device_map=None ("auto" fails)
|
| 345 |
+
if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
|
| 346 |
+
model_init_kwargs["device_map"] = None
|
| 347 |
+
model = create_model_from_path(model, **model_init_kwargs)
|
| 348 |
+
else:
|
| 349 |
+
if args.model_init_kwargs is not None:
|
| 350 |
+
logger.warning(
|
| 351 |
+
"You passed `model_init_kwargs` to the KTOConfig, but your model is already instantiated. "
|
| 352 |
+
"The `model_init_kwargs` will be ignored."
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Reference model initialization
|
| 356 |
+
if isinstance(ref_model, str):
|
| 357 |
+
ref_model_init_kwargs = args.model_init_kwargs or {}
|
| 358 |
+
# Distributed training requires device_map=None ("auto" fails)
|
| 359 |
+
if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
|
| 360 |
+
ref_model_init_kwargs["device_map"] = None
|
| 361 |
+
ref_model = create_model_from_path(ref_model, **ref_model_init_kwargs)
|
| 362 |
+
|
| 363 |
+
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
| 364 |
+
# has been called in order to properly call autocast if needed.
|
| 365 |
+
self._peft_has_been_casted_to_bf16 = False
|
| 366 |
+
|
| 367 |
+
if not is_peft_available() and peft_config is not None:
|
| 368 |
+
raise ValueError(
|
| 369 |
+
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
| 370 |
+
)
|
| 371 |
+
elif is_peft_available() and peft_config is not None:
|
| 372 |
+
if isinstance(model, PeftModel):
|
| 373 |
+
raise ValueError(
|
| 374 |
+
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first "
|
| 375 |
+
"merge and unload the existing adapter, save the resulting base model, and then pass that base "
|
| 376 |
+
"model along with the new `peft_config` to the trainer."
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
| 380 |
+
_support_gc_kwargs = hasattr(
|
| 381 |
+
args, "gradient_checkpointing_kwargs"
|
| 382 |
+
) and "gradient_checkpointing_kwargs" in list(
|
| 383 |
+
inspect.signature(prepare_model_for_kbit_training).parameters
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
| 387 |
+
|
| 388 |
+
if _support_gc_kwargs:
|
| 389 |
+
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
| 390 |
+
|
| 391 |
+
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
| 392 |
+
elif args.gradient_checkpointing:
|
| 393 |
+
# For backward compatibility with older versions of transformers
|
| 394 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 395 |
+
model.enable_input_require_grads()
|
| 396 |
+
else:
|
| 397 |
+
|
| 398 |
+
def make_inputs_require_grad(module, input, output):
|
| 399 |
+
output.requires_grad_(True)
|
| 400 |
+
|
| 401 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 402 |
+
|
| 403 |
+
# get peft model with the given config
|
| 404 |
+
model = get_peft_model(model, peft_config)
|
| 405 |
+
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
| 406 |
+
peft_module_casting_to_bf16(model)
|
| 407 |
+
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
| 408 |
+
self._peft_has_been_casted_to_bf16 = True
|
| 409 |
+
|
| 410 |
+
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
| 411 |
+
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
| 412 |
+
# fail or completely fail.
|
| 413 |
+
elif args.gradient_checkpointing:
|
| 414 |
+
# For backward compatibility with older versions of transformers
|
| 415 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 416 |
+
model.enable_input_require_grads()
|
| 417 |
+
else:
|
| 418 |
+
|
| 419 |
+
def make_inputs_require_grad(module, input, output):
|
| 420 |
+
output.requires_grad_(True)
|
| 421 |
+
|
| 422 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
| 423 |
+
|
| 424 |
+
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
| 425 |
+
raise ValueError(
|
| 426 |
+
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
| 427 |
+
" Please install `wandb` or `comet-ml` to resolve."
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# KTO only supports causal language models, not encoder-decoder models
|
| 431 |
+
if model is not None and hasattr(model.config, "is_encoder_decoder") and model.config.is_encoder_decoder:
|
| 432 |
+
raise ValueError(
|
| 433 |
+
"KTO only supports causal language models. Encoder-decoder models are not supported. "
|
| 434 |
+
"Please use a causal LM (e.g., GPT, Llama, Mistral) instead of an encoder-decoder model (e.g., T5, BART)."
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
| 438 |
+
self.model_adapter_name = model_adapter_name
|
| 439 |
+
self.ref_adapter_name = ref_adapter_name
|
| 440 |
+
|
| 441 |
+
if ref_model:
|
| 442 |
+
self.ref_model = ref_model
|
| 443 |
+
elif self.is_peft_model or args.precompute_ref_log_probs:
|
| 444 |
+
# The `model` with adapters turned off will be used as the reference model
|
| 445 |
+
self.ref_model = None
|
| 446 |
+
else:
|
| 447 |
+
self.ref_model = create_reference_model(model)
|
| 448 |
+
|
| 449 |
+
if processing_class is None:
|
| 450 |
+
raise ValueError(
|
| 451 |
+
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
| 452 |
+
)
|
| 453 |
+
if args.max_length is None:
|
| 454 |
+
logger.warning(
|
| 455 |
+
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
|
| 456 |
+
" it will be set to `512` by default, but you should do it yourself in the future.",
|
| 457 |
+
)
|
| 458 |
+
max_length = 512
|
| 459 |
+
if args.max_length is not None:
|
| 460 |
+
max_length = args.max_length
|
| 461 |
+
|
| 462 |
+
if data_collator is None:
|
| 463 |
+
data_collator = DPODataCollatorWithPadding(
|
| 464 |
+
pad_token_id=processing_class.pad_token_id,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
if args.remove_unused_columns:
|
| 468 |
+
args.remove_unused_columns = False
|
| 469 |
+
# warn users
|
| 470 |
+
logger.warning(
|
| 471 |
+
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
|
| 472 |
+
" we have set it for you, but you should do it yourself in the future.",
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
self.use_dpo_data_collator = True
|
| 476 |
+
else:
|
| 477 |
+
self.use_dpo_data_collator = False
|
| 478 |
+
|
| 479 |
+
# Disable dropout in the model and reference model
|
| 480 |
+
if args.disable_dropout:
|
| 481 |
+
disable_dropout_in_model(model)
|
| 482 |
+
if self.ref_model is not None:
|
| 483 |
+
disable_dropout_in_model(self.ref_model)
|
| 484 |
+
|
| 485 |
+
self.loss_type = args.loss_type
|
| 486 |
+
self.max_length = max_length
|
| 487 |
+
self.generate_during_eval = args.generate_during_eval
|
| 488 |
+
self.processing_class = processing_class
|
| 489 |
+
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
| 490 |
+
|
| 491 |
+
# Not all losses require a KL calculation
|
| 492 |
+
self.calculate_KL = True
|
| 493 |
+
if self.loss_type in ["apo_zero_unpaired"]:
|
| 494 |
+
self.calculate_KL = False
|
| 495 |
+
|
| 496 |
+
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
| 497 |
+
# keep track of first called to avoid computation of future calls
|
| 498 |
+
self._precomputed_train_ref_log_probs = False
|
| 499 |
+
self._precomputed_eval_ref_log_probs = False
|
| 500 |
+
|
| 501 |
+
# metric
|
| 502 |
+
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
| 503 |
+
|
| 504 |
+
# KTO parameter
|
| 505 |
+
self.beta = args.beta
|
| 506 |
+
self.desirable_weight = args.desirable_weight
|
| 507 |
+
self.undesirable_weight = args.undesirable_weight
|
| 508 |
+
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
| 509 |
+
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
| 510 |
+
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
| 511 |
+
logger.warning(
|
| 512 |
+
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
| 513 |
+
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
| 514 |
+
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
| 515 |
+
"loss.",
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# Compute that only on the main process for faster data processing.
|
| 519 |
+
# see: https://github.com/huggingface/trl/pull/1255
|
| 520 |
+
with PartialState().main_process_first():
|
| 521 |
+
# Extract the prompt if needed
|
| 522 |
+
train_dataset = train_dataset.map(
|
| 523 |
+
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
|
| 524 |
+
)
|
| 525 |
+
# Unpair the dataset if needed
|
| 526 |
+
train_dataset = maybe_unpair_preference_dataset(
|
| 527 |
+
train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
|
| 528 |
+
)
|
| 529 |
+
# Apply the chat template if needed
|
| 530 |
+
train_dataset = train_dataset.map(
|
| 531 |
+
maybe_apply_chat_template,
|
| 532 |
+
fn_kwargs={"tokenizer": processing_class},
|
| 533 |
+
num_proc=args.dataset_num_proc,
|
| 534 |
+
desc="Applying chat template to train dataset",
|
| 535 |
+
)
|
| 536 |
+
if eval_dataset is not None:
|
| 537 |
+
eval_dataset = eval_dataset.map(
|
| 538 |
+
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
|
| 539 |
+
)
|
| 540 |
+
eval_dataset = maybe_unpair_preference_dataset(
|
| 541 |
+
eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
|
| 542 |
+
)
|
| 543 |
+
eval_dataset = eval_dataset.map(
|
| 544 |
+
maybe_apply_chat_template,
|
| 545 |
+
fn_kwargs={"tokenizer": processing_class},
|
| 546 |
+
num_proc=args.dataset_num_proc,
|
| 547 |
+
desc="Applying chat template to eval dataset",
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
# Tokenize and prepare the training datasets
|
| 551 |
+
train_dataset = train_dataset.map(
|
| 552 |
+
_tokenize,
|
| 553 |
+
batched=True,
|
| 554 |
+
fn_kwargs={"tokenizer": self.processing_class},
|
| 555 |
+
num_proc=args.dataset_num_proc,
|
| 556 |
+
desc="Tokenizing train dataset",
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
fn_kwargs = {
|
| 560 |
+
"prefix": "",
|
| 561 |
+
"tokenizer": self.processing_class,
|
| 562 |
+
"max_length": self.max_length,
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
train_dataset = train_dataset.map(
|
| 566 |
+
_process_tokens,
|
| 567 |
+
fn_kwargs=fn_kwargs,
|
| 568 |
+
num_proc=args.dataset_num_proc,
|
| 569 |
+
desc="Processing tokenized train dataset",
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
# Tokenize and prepare the eval datasets
|
| 573 |
+
if eval_dataset is not None:
|
| 574 |
+
eval_dataset = eval_dataset.map(
|
| 575 |
+
_tokenize,
|
| 576 |
+
fn_kwargs={"tokenizer": self.processing_class},
|
| 577 |
+
batched=True,
|
| 578 |
+
num_proc=args.dataset_num_proc,
|
| 579 |
+
desc="Tokenizing eval dataset",
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
eval_dataset = eval_dataset.map(
|
| 583 |
+
_process_tokens,
|
| 584 |
+
fn_kwargs=fn_kwargs,
|
| 585 |
+
num_proc=args.dataset_num_proc,
|
| 586 |
+
desc="Processing tokenized eval dataset",
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
# Get KL datasets if needed
|
| 590 |
+
if self.calculate_KL:
|
| 591 |
+
if args.per_device_train_batch_size <= 1:
|
| 592 |
+
raise ValueError(
|
| 593 |
+
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
|
| 597 |
+
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
|
| 598 |
+
train_kl_dataset = train_dataset.map(
|
| 599 |
+
_get_kl_dataset,
|
| 600 |
+
batched=True,
|
| 601 |
+
batch_size=args.per_device_train_batch_size,
|
| 602 |
+
num_proc=args.dataset_num_proc,
|
| 603 |
+
desc="Extracting KL train dataset",
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
fn_kwargs["prefix"] = "KL_"
|
| 607 |
+
train_kl_dataset = train_kl_dataset.map(
|
| 608 |
+
_process_tokens,
|
| 609 |
+
fn_kwargs=fn_kwargs,
|
| 610 |
+
num_proc=args.dataset_num_proc,
|
| 611 |
+
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
|
| 612 |
+
desc="Processing tokenized train KL dataset",
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
# merge the datasets
|
| 616 |
+
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
|
| 617 |
+
|
| 618 |
+
if eval_dataset is not None:
|
| 619 |
+
# Get KL dataset
|
| 620 |
+
eval_kl_dataset = eval_dataset.map(
|
| 621 |
+
_get_kl_dataset,
|
| 622 |
+
batched=True,
|
| 623 |
+
batch_size=args.per_device_train_batch_size,
|
| 624 |
+
num_proc=args.dataset_num_proc,
|
| 625 |
+
desc="Extracting eval KL dataset",
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
eval_kl_dataset = eval_kl_dataset.map(
|
| 629 |
+
_process_tokens,
|
| 630 |
+
fn_kwargs=fn_kwargs,
|
| 631 |
+
num_proc=args.dataset_num_proc,
|
| 632 |
+
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
|
| 633 |
+
desc="Processing tokenized eval KL dataset",
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
# merge the datasets
|
| 637 |
+
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
|
| 638 |
+
|
| 639 |
+
# calculate dataset desirability balance
|
| 640 |
+
num_desirable = max(sum(train_dataset["label"]), 1)
|
| 641 |
+
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
|
| 642 |
+
|
| 643 |
+
if num_desirable != num_undesirable:
|
| 644 |
+
# The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
|
| 645 |
+
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
|
| 646 |
+
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
|
| 647 |
+
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
|
| 648 |
+
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
|
| 649 |
+
|
| 650 |
+
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
|
| 651 |
+
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
|
| 652 |
+
|
| 653 |
+
if not (des_weight_in_range or und_weight_in_range):
|
| 654 |
+
logger.warning(
|
| 655 |
+
"You have different amounts of desirable/positive and undesirable/negative examples but the "
|
| 656 |
+
"weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
|
| 657 |
+
f"on your data, we recommend EITHER "
|
| 658 |
+
f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
|
| 659 |
+
f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
|
| 660 |
+
"See the documentation on how to optimally set these weights.",
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
# Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
|
| 664 |
+
# never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
|
| 665 |
+
# (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
|
| 666 |
+
# default to the recommended non-reentrant behavior here, while preserving any user-provided value.
|
| 667 |
+
if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
|
| 668 |
+
args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
| 669 |
+
args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)
|
| 670 |
+
|
| 671 |
+
super().__init__(
|
| 672 |
+
model=model,
|
| 673 |
+
args=args,
|
| 674 |
+
data_collator=data_collator,
|
| 675 |
+
train_dataset=train_dataset,
|
| 676 |
+
eval_dataset=eval_dataset,
|
| 677 |
+
processing_class=processing_class,
|
| 678 |
+
model_init=model_init,
|
| 679 |
+
compute_metrics=compute_metrics,
|
| 680 |
+
callbacks=callbacks,
|
| 681 |
+
optimizers=optimizers,
|
| 682 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 686 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 687 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 688 |
+
self.model_accepts_loss_kwargs = False
|
| 689 |
+
|
| 690 |
+
# Add tags for models that have been loaded with the correct transformers version
|
| 691 |
+
if hasattr(self.model, "add_model_tags"):
|
| 692 |
+
self.model.add_model_tags(self._tag_names)
|
| 693 |
+
|
| 694 |
+
if not hasattr(self, "accelerator"):
|
| 695 |
+
raise AttributeError(
|
| 696 |
+
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
| 700 |
+
if self.is_deepspeed_enabled:
|
| 701 |
+
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
| 702 |
+
raise ValueError(
|
| 703 |
+
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
if self.ref_model is None:
|
| 707 |
+
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
| 708 |
+
raise ValueError(
|
| 709 |
+
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
| 710 |
+
)
|
| 711 |
+
else:
|
| 712 |
+
if self.is_deepspeed_enabled:
|
| 713 |
+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
| 714 |
+
else:
|
| 715 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 716 |
+
|
| 717 |
+
# Import Liger kernel if enabled
|
| 718 |
+
if self.args.use_liger_kernel:
|
| 719 |
+
if not is_liger_kernel_available():
|
| 720 |
+
raise ImportError(
|
| 721 |
+
"You set `use_liger_kernel=True` but the liger kernel is not available. "
|
| 722 |
+
"Please install liger-kernel first: `pip install liger-kernel`"
|
| 723 |
+
)
|
| 724 |
+
if self.loss_type in ["apo_zero_unpaired"]:
|
| 725 |
+
raise ValueError(
|
| 726 |
+
"You cannot set `loss_type='apo_zero_unpaired'` with liger-kernel."
|
| 727 |
+
"Only KTO loss is supported with liger-kernel."
|
| 728 |
+
)
|
| 729 |
+
if self.precompute_ref_log_probs:
|
| 730 |
+
raise ValueError(
|
| 731 |
+
"You cannot use `precompute_ref_log_probs=True` with liger kernel. Please set "
|
| 732 |
+
"`precompute_ref_log_probs=False`."
|
| 733 |
+
)
|
| 734 |
+
if self.is_peft_model or self.ref_adapter_name is not None:
|
| 735 |
+
raise ValueError(
|
| 736 |
+
"You cannot use `use_liger_kernel=True` with Peft models. Please set `use_liger_kernel=False`."
|
| 737 |
+
)
|
| 738 |
+
self.kto_loss_fn = LigerFusedLinearKTOLoss(beta=self.beta, use_ref_model=(self.ref_model is not None))
|
| 739 |
+
|
| 740 |
+
@contextmanager
|
| 741 |
+
def null_ref_context(self):
|
| 742 |
+
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
| 743 |
+
with (
|
| 744 |
+
self.accelerator.unwrap_model(self.model).disable_adapter()
|
| 745 |
+
if self.is_peft_model and not self.ref_adapter_name
|
| 746 |
+
else nullcontext()
|
| 747 |
+
):
|
| 748 |
+
if self.ref_adapter_name:
|
| 749 |
+
self.model.set_adapter(self.ref_adapter_name)
|
| 750 |
+
yield
|
| 751 |
+
if self.ref_adapter_name:
|
| 752 |
+
self.model.set_adapter(self.model_adapter_name or "default")
|
| 753 |
+
|
| 754 |
+
def get_train_dataloader(self) -> DataLoader:
|
| 755 |
+
"""
|
| 756 |
+
Returns the training [`~torch.utils.data.DataLoader`].
|
| 757 |
+
|
| 758 |
+
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
|
| 759 |
+
"""
|
| 760 |
+
|
| 761 |
+
if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
|
| 762 |
+
dataloader_params = {
|
| 763 |
+
"batch_size": self.args.per_device_train_batch_size,
|
| 764 |
+
"collate_fn": self.data_collator,
|
| 765 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 766 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 767 |
+
"shuffle": False,
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
# prepare dataloader
|
| 771 |
+
data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
|
| 772 |
+
reference_completion_logps = []
|
| 773 |
+
reference_KL_logps = []
|
| 774 |
+
|
| 775 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
|
| 776 |
+
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
| 777 |
+
|
| 778 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
| 779 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
| 780 |
+
|
| 781 |
+
if self.calculate_KL:
|
| 782 |
+
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
| 783 |
+
reference_KL_logps.append(reference_KL_logp.cpu())
|
| 784 |
+
|
| 785 |
+
self.train_dataset = self.train_dataset.add_column(
|
| 786 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
if self.calculate_KL:
|
| 790 |
+
self.train_dataset = self.train_dataset.add_column(
|
| 791 |
+
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
self._precomputed_train_ref_log_probs = True
|
| 795 |
+
|
| 796 |
+
return super().get_train_dataloader()
|
| 797 |
+
|
| 798 |
+
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
| 799 |
+
"""
|
| 800 |
+
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
| 801 |
+
|
| 802 |
+
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
|
| 803 |
+
|
| 804 |
+
Args:
|
| 805 |
+
eval_dataset (`torch.utils.data.Dataset`, *optional*):
|
| 806 |
+
If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
|
| 807 |
+
by the `model.forward()` method are automatically removed. It must implement `__len__`.
|
| 808 |
+
"""
|
| 809 |
+
if eval_dataset is None and self.eval_dataset is None:
|
| 810 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
| 811 |
+
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
| 812 |
+
|
| 813 |
+
if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
|
| 814 |
+
dataloader_params = {
|
| 815 |
+
"batch_size": self.args.per_device_eval_batch_size,
|
| 816 |
+
"collate_fn": self.data_collator,
|
| 817 |
+
"num_workers": self.args.dataloader_num_workers,
|
| 818 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
| 819 |
+
"shuffle": False,
|
| 820 |
+
}
|
| 821 |
+
|
| 822 |
+
# prepare dataloader
|
| 823 |
+
data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
|
| 824 |
+
|
| 825 |
+
reference_completion_logps = []
|
| 826 |
+
reference_KL_logps = []
|
| 827 |
+
|
| 828 |
+
for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
|
| 829 |
+
reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
|
| 830 |
+
|
| 831 |
+
reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
|
| 832 |
+
reference_completion_logps.append(reference_completion_logp.cpu())
|
| 833 |
+
|
| 834 |
+
if self.calculate_KL:
|
| 835 |
+
reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
|
| 836 |
+
reference_KL_logps.append(reference_KL_logp.cpu())
|
| 837 |
+
|
| 838 |
+
eval_dataset = eval_dataset.add_column(
|
| 839 |
+
name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
|
| 840 |
+
)
|
| 841 |
+
if self.calculate_KL:
|
| 842 |
+
eval_dataset = eval_dataset.add_column(
|
| 843 |
+
name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
# Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
|
| 847 |
+
if self.eval_dataset is not None:
|
| 848 |
+
self.eval_dataset = eval_dataset
|
| 849 |
+
self._precomputed_eval_ref_log_probs = True
|
| 850 |
+
|
| 851 |
+
return super().get_eval_dataloader(eval_dataset=eval_dataset)
|
| 852 |
+
|
| 853 |
+
def compute_reference_log_probs(self, padded_batch: dict) -> dict:
|
| 854 |
+
"""Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
|
| 855 |
+
with torch.no_grad():
|
| 856 |
+
if self.ref_model is None:
|
| 857 |
+
with self.null_ref_context():
|
| 858 |
+
completion_logits = self.model(
|
| 859 |
+
padded_batch["completion_input_ids"],
|
| 860 |
+
attention_mask=padded_batch["completion_attention_mask"],
|
| 861 |
+
).logits
|
| 862 |
+
|
| 863 |
+
if self.calculate_KL:
|
| 864 |
+
KL_logits = self.model(
|
| 865 |
+
padded_batch["KL_completion_input_ids"],
|
| 866 |
+
attention_mask=padded_batch["KL_completion_attention_mask"],
|
| 867 |
+
).logits
|
| 868 |
+
else:
|
| 869 |
+
completion_logits = self.ref_model(
|
| 870 |
+
padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
|
| 871 |
+
).logits
|
| 872 |
+
|
| 873 |
+
if self.calculate_KL:
|
| 874 |
+
KL_logits = self.ref_model(
|
| 875 |
+
padded_batch["KL_completion_input_ids"],
|
| 876 |
+
attention_mask=padded_batch["KL_completion_attention_mask"],
|
| 877 |
+
).logits
|
| 878 |
+
|
| 879 |
+
completion_logps = self.get_batch_logps(
|
| 880 |
+
completion_logits,
|
| 881 |
+
padded_batch["completion_labels"],
|
| 882 |
+
average_log_prob=False,
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
if self.calculate_KL:
|
| 886 |
+
KL_logps = self.get_batch_logps(
|
| 887 |
+
KL_logits,
|
| 888 |
+
padded_batch["KL_completion_labels"],
|
| 889 |
+
average_log_prob=False,
|
| 890 |
+
)
|
| 891 |
+
else:
|
| 892 |
+
KL_logps = None
|
| 893 |
+
|
| 894 |
+
return completion_logps, KL_logps
|
| 895 |
+
|
| 896 |
+
@staticmethod
|
| 897 |
+
def get_batch_logps(
|
| 898 |
+
logits: torch.FloatTensor,
|
| 899 |
+
labels: torch.LongTensor,
|
| 900 |
+
average_log_prob: bool = False,
|
| 901 |
+
) -> torch.FloatTensor:
|
| 902 |
+
"""Compute the log probabilities of the given labels under the given logits.
|
| 903 |
+
|
| 904 |
+
Args:
|
| 905 |
+
logits:
|
| 906 |
+
Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
| 907 |
+
labels:
|
| 908 |
+
Labels for which to compute the log probabilities. Label tokens with a value of `-100` are ignored.
|
| 909 |
+
Shape: (batch_size, sequence_length)
|
| 910 |
+
average_log_prob:
|
| 911 |
+
If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
|
| 912 |
+
log probabilities of the (non-masked) tokens.
|
| 913 |
+
|
| 914 |
+
Returns:
|
| 915 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
|
| 916 |
+
given logits.
|
| 917 |
+
"""
|
| 918 |
+
if logits.shape[:-1] != labels.shape:
|
| 919 |
+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
| 920 |
+
|
| 921 |
+
# For causal LM, shift labels and logits by one position
|
| 922 |
+
labels = labels[:, 1:].clone()
|
| 923 |
+
logits = logits[:, :-1, :]
|
| 924 |
+
|
| 925 |
+
loss_mask = labels != -100
|
| 926 |
+
|
| 927 |
+
# dummy token; we'll ignore the losses on these tokens later
|
| 928 |
+
labels[labels == -100] = 0
|
| 929 |
+
|
| 930 |
+
per_token_logps = selective_log_softmax(logits, labels)
|
| 931 |
+
|
| 932 |
+
if average_log_prob:
|
| 933 |
+
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
| 934 |
+
else:
|
| 935 |
+
return (per_token_logps * loss_mask).sum(-1)
|
| 936 |
+
|
| 937 |
+
def forward(
|
| 938 |
+
self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
|
| 939 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 940 |
+
KL_logps = self._compute_kl_logps(model, batch)
|
| 941 |
+
|
| 942 |
+
model_kwargs = {}
|
| 943 |
+
if self.aux_loss_enabled:
|
| 944 |
+
model_kwargs["output_router_logits"] = True
|
| 945 |
+
|
| 946 |
+
outputs = model(
|
| 947 |
+
batch["completion_input_ids"],
|
| 948 |
+
attention_mask=batch["completion_attention_mask"],
|
| 949 |
+
**model_kwargs,
|
| 950 |
+
)
|
| 951 |
+
completion_logits = outputs.logits
|
| 952 |
+
|
| 953 |
+
completion_logps = self.get_batch_logps(
|
| 954 |
+
completion_logits,
|
| 955 |
+
batch["completion_labels"],
|
| 956 |
+
average_log_prob=False,
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
if completion_logps.shape[0] != len(batch["label"]):
|
| 960 |
+
raise ValueError(
|
| 961 |
+
"There is a mismatch between the number of examples in this batch and the number of "
|
| 962 |
+
"examples for which an output sequence was predicted."
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
# Use torch.nonzero for efficient tensor index selection
|
| 966 |
+
device = completion_logits.device
|
| 967 |
+
labels = torch.as_tensor(batch["label"], dtype=torch.bool, device=device)
|
| 968 |
+
chosen_idx = torch.nonzero(labels, as_tuple=False).view(-1)
|
| 969 |
+
rejected_idx = torch.nonzero(~labels, as_tuple=False).view(-1)
|
| 970 |
+
|
| 971 |
+
# Use index_select for efficient CUDA operations
|
| 972 |
+
chosen_logps = completion_logps.index_select(0, chosen_idx)
|
| 973 |
+
rejected_logps = completion_logps.index_select(0, rejected_idx)
|
| 974 |
+
|
| 975 |
+
chosen_logits = completion_logits.index_select(0, chosen_idx)
|
| 976 |
+
rejected_logits = completion_logits.index_select(0, rejected_idx)
|
| 977 |
+
|
| 978 |
+
if self.aux_loss_enabled:
|
| 979 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
|
| 980 |
+
else:
|
| 981 |
+
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
|
| 982 |
+
|
| 983 |
+
def kto_loss(
|
| 984 |
+
self,
|
| 985 |
+
policy_chosen_logps: torch.FloatTensor,
|
| 986 |
+
policy_rejected_logps: torch.FloatTensor,
|
| 987 |
+
policy_KL_logps: torch.FloatTensor,
|
| 988 |
+
reference_chosen_logps: torch.FloatTensor,
|
| 989 |
+
reference_rejected_logps: torch.FloatTensor,
|
| 990 |
+
reference_KL_logps: torch.FloatTensor,
|
| 991 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
| 992 |
+
"""Compute the KTO loss for a batch of policy and reference model log probabilities.
|
| 993 |
+
|
| 994 |
+
Args:
|
| 995 |
+
policy_chosen_logps:
|
| 996 |
+
Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
| 997 |
+
policy_rejected_logps:
|
| 998 |
+
Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
|
| 999 |
+
policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
|
| 1000 |
+
reference_chosen_logps:
|
| 1001 |
+
Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
|
| 1002 |
+
reference_rejected_logps:
|
| 1003 |
+
Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in
|
| 1004 |
+
batch_size,)
|
| 1005 |
+
reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
|
| 1006 |
+
|
| 1007 |
+
Returns:
|
| 1008 |
+
A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL). The losses tensor contains the KTO
|
| 1009 |
+
loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
|
| 1010 |
+
the chosen and rejected responses, respectively. The KL tensor contains the detached KL divergence estimate
|
| 1011 |
+
between the policy and reference models.
|
| 1012 |
+
"""
|
| 1013 |
+
if self.calculate_KL:
|
| 1014 |
+
kl = (policy_KL_logps - reference_KL_logps).mean().detach()
|
| 1015 |
+
kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
|
| 1016 |
+
else:
|
| 1017 |
+
kl = torch.zeros(1).to(policy_chosen_logps.device)
|
| 1018 |
+
|
| 1019 |
+
# Chosen losses
|
| 1020 |
+
if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
|
| 1021 |
+
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
| 1022 |
+
|
| 1023 |
+
if self.loss_type == "kto":
|
| 1024 |
+
# Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
|
| 1025 |
+
chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
|
| 1026 |
+
elif self.loss_type == "apo_zero_unpaired":
|
| 1027 |
+
# Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
|
| 1028 |
+
# Use this loss when you believe the chosen outputs are better than your model's default output
|
| 1029 |
+
chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
|
| 1030 |
+
|
| 1031 |
+
chosen_rewards = self.beta * chosen_logratios.detach()
|
| 1032 |
+
|
| 1033 |
+
else:
|
| 1034 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
| 1035 |
+
chosen_losses = torch.Tensor([]).to(self.accelerator.device)
|
| 1036 |
+
chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
|
| 1037 |
+
|
| 1038 |
+
# Rejected losses
|
| 1039 |
+
if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
|
| 1040 |
+
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
| 1041 |
+
|
| 1042 |
+
if self.loss_type == "kto":
|
| 1043 |
+
rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
|
| 1044 |
+
elif self.loss_type == "apo_zero_unpaired":
|
| 1045 |
+
rejected_losses = F.sigmoid(self.beta * rejected_logratios)
|
| 1046 |
+
|
| 1047 |
+
rejected_rewards = self.beta * rejected_logratios.detach()
|
| 1048 |
+
else:
|
| 1049 |
+
# lists can't be empty -- if they are, then accelerate.gather will hang
|
| 1050 |
+
rejected_losses = torch.Tensor([]).to(self.accelerator.device)
|
| 1051 |
+
rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
|
| 1052 |
+
|
| 1053 |
+
losses = torch.cat(
|
| 1054 |
+
(self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
|
| 1055 |
+
0,
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
+
return losses, chosen_rewards, rejected_rewards, kl
|
| 1059 |
+
|
| 1060 |
+
def _compute_kl_logps(self, model, batch):
|
| 1061 |
+
"""Compute KL log probabilities for a given batch."""
|
| 1062 |
+
KL_logps = None
|
| 1063 |
+
if self.calculate_KL:
|
| 1064 |
+
KL_model_kwargs = {
|
| 1065 |
+
"input_ids": batch["KL_completion_input_ids"],
|
| 1066 |
+
"attention_mask": batch["KL_completion_attention_mask"],
|
| 1067 |
+
}
|
| 1068 |
+
|
| 1069 |
+
with torch.no_grad():
|
| 1070 |
+
KL_logits = model(**KL_model_kwargs).logits
|
| 1071 |
+
|
| 1072 |
+
KL_logps = self.get_batch_logps(
|
| 1073 |
+
KL_logits,
|
| 1074 |
+
batch["KL_completion_labels"],
|
| 1075 |
+
average_log_prob=False,
|
| 1076 |
+
)
|
| 1077 |
+
return KL_logps
|
| 1078 |
+
|
| 1079 |
+
def _compute_loss_liger(self, model, batch):
|
| 1080 |
+
"""
|
| 1081 |
+
Compute the KTO loss using the Liger-Kernel's LigerFusedLinearKTOLoss.
|
| 1082 |
+
|
| 1083 |
+
Args:
|
| 1084 |
+
model:
|
| 1085 |
+
The policy model used for generating log probabilities and outputs. It could be an encoder-decoder
|
| 1086 |
+
model or a regular language model.
|
| 1087 |
+
batch: A dictionary containing the input data and labels for the batch.
|
| 1088 |
+
|
| 1089 |
+
Returns:
|
| 1090 |
+
A dictionary containing the following keys:
|
| 1091 |
+
- "loss": The computed KTO loss for the batch.
|
| 1092 |
+
- "chosen_logits_sum": Sum of the logits for the chosen responses from the policy model.
|
| 1093 |
+
- "rejected_logits_sum": Sum of the logits for the rejected responses from the policy model.
|
| 1094 |
+
- "chosen_logps": Log probabilities of the chosen responses from the policy model.
|
| 1095 |
+
- "rejected_logps": Log probabilities of the rejected responses from the policy model.
|
| 1096 |
+
- "chosen_rewards": Rewards for the chosen responses.
|
| 1097 |
+
- "rejected_rewards": Rewards for the rejected responses.
|
| 1098 |
+
- "kl": The KL divergence between the policy and reference models (detached).
|
| 1099 |
+
|
| 1100 |
+
If auxiliary loss is enabled, the dictionary will also include:
|
| 1101 |
+
- "aux_loss": The auxiliary loss from the model outputs.
|
| 1102 |
+
"""
|
| 1103 |
+
policy_KL_logps = self._compute_kl_logps(model, batch)
|
| 1104 |
+
reference_KL_logps = self._compute_kl_logps(self.ref_model, batch)
|
| 1105 |
+
if self.calculate_KL:
|
| 1106 |
+
kl = (policy_KL_logps - reference_KL_logps).mean().detach()
|
| 1107 |
+
kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
|
| 1108 |
+
else:
|
| 1109 |
+
kl = torch.zeros(1).to(self.accelerator.device)
|
| 1110 |
+
|
| 1111 |
+
model_kwargs = {}
|
| 1112 |
+
if self.aux_loss_enabled:
|
| 1113 |
+
model_kwargs["output_router_logits"] = True
|
| 1114 |
+
|
| 1115 |
+
# skip the lm head and get the last hidden state
|
| 1116 |
+
base_model = model.get_decoder()
|
| 1117 |
+
outputs = base_model(
|
| 1118 |
+
batch["completion_input_ids"],
|
| 1119 |
+
attention_mask=batch["completion_attention_mask"],
|
| 1120 |
+
use_cache=False,
|
| 1121 |
+
**model_kwargs,
|
| 1122 |
+
)
|
| 1123 |
+
|
| 1124 |
+
# reference model
|
| 1125 |
+
ref_base_model = self.ref_model.get_decoder()
|
| 1126 |
+
ref_outputs = ref_base_model(
|
| 1127 |
+
batch["completion_input_ids"],
|
| 1128 |
+
attention_mask=batch["completion_attention_mask"],
|
| 1129 |
+
use_cache=False,
|
| 1130 |
+
**model_kwargs,
|
| 1131 |
+
)
|
| 1132 |
+
lm_head = model.get_output_embeddings()
|
| 1133 |
+
ref_lm_head = self.ref_model.get_output_embeddings()
|
| 1134 |
+
|
| 1135 |
+
(
|
| 1136 |
+
loss,
|
| 1137 |
+
(
|
| 1138 |
+
chosen_logps_sum,
|
| 1139 |
+
rejected_logps_sum,
|
| 1140 |
+
chosen_logits_sum,
|
| 1141 |
+
rejected_logits_sum,
|
| 1142 |
+
chosen_rewards_sum,
|
| 1143 |
+
rejected_rewards_sum,
|
| 1144 |
+
),
|
| 1145 |
+
) = self.kto_loss_fn(
|
| 1146 |
+
_input=outputs.last_hidden_state[:, :-1],
|
| 1147 |
+
lin_weight=lm_head.weight,
|
| 1148 |
+
target=batch["completion_labels"][:, 1:],
|
| 1149 |
+
bias=lm_head.bias if hasattr(lm_head, "bias") else None,
|
| 1150 |
+
preference_labels=torch.tensor(batch["label"], dtype=torch.bool).to(self.accelerator.device),
|
| 1151 |
+
ref_input=ref_outputs.last_hidden_state[:, :-1],
|
| 1152 |
+
ref_weight=ref_lm_head.weight,
|
| 1153 |
+
ref_bias=ref_lm_head.bias if hasattr(lm_head, "bias") else None,
|
| 1154 |
+
kl=kl,
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
output = {
|
| 1158 |
+
"loss": loss,
|
| 1159 |
+
"chosen_logits_sum": chosen_logits_sum,
|
| 1160 |
+
"rejected_logits_sum": rejected_logits_sum,
|
| 1161 |
+
"chosen_logps_sum": chosen_logps_sum,
|
| 1162 |
+
"rejected_logps_sum": rejected_logps_sum,
|
| 1163 |
+
"chosen_rewards_sum": chosen_rewards_sum,
|
| 1164 |
+
"rejected_rewards_sum": rejected_rewards_sum,
|
| 1165 |
+
"kl": kl,
|
| 1166 |
+
}
|
| 1167 |
+
if self.aux_loss_enabled:
|
| 1168 |
+
output["aux_loss"] = outputs.aux_loss
|
| 1169 |
+
|
| 1170 |
+
return output
|
| 1171 |
+
|
| 1172 |
+
def get_batch_loss_metrics(
|
| 1173 |
+
self,
|
| 1174 |
+
model,
|
| 1175 |
+
batch: dict[str, list | torch.LongTensor],
|
| 1176 |
+
):
|
| 1177 |
+
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
|
| 1178 |
+
metrics = {}
|
| 1179 |
+
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
| 1180 |
+
|
| 1181 |
+
labels = torch.tensor(batch["label"])
|
| 1182 |
+
num_chosen = labels.sum().to(self.accelerator.device)
|
| 1183 |
+
num_rejected = (len(labels) - num_chosen).to(self.accelerator.device)
|
| 1184 |
+
|
| 1185 |
+
if self.args.use_liger_kernel:
|
| 1186 |
+
model_output = self._compute_loss_liger(model, batch)
|
| 1187 |
+
losses = model_output["loss"]
|
| 1188 |
+
policy_chosen_logits = model_output["chosen_logits_sum"]
|
| 1189 |
+
policy_rejected_logits = model_output["rejected_logits_sum"]
|
| 1190 |
+
policy_chosen_logps = model_output["chosen_logps_sum"]
|
| 1191 |
+
policy_rejected_logps = model_output["rejected_logps_sum"]
|
| 1192 |
+
chosen_rewards = model_output["chosen_rewards_sum"]
|
| 1193 |
+
rejected_rewards = model_output["rejected_rewards_sum"]
|
| 1194 |
+
kl = model_output["kl"]
|
| 1195 |
+
if self.aux_loss_enabled:
|
| 1196 |
+
aux_loss = model_output["aux_loss"]
|
| 1197 |
+
else:
|
| 1198 |
+
forward_output = self.forward(model, batch)
|
| 1199 |
+
(
|
| 1200 |
+
policy_chosen_logps,
|
| 1201 |
+
policy_rejected_logps,
|
| 1202 |
+
policy_chosen_logits,
|
| 1203 |
+
policy_rejected_logits,
|
| 1204 |
+
policy_KL_logps,
|
| 1205 |
+
) = forward_output[:5]
|
| 1206 |
+
if self.aux_loss_enabled:
|
| 1207 |
+
aux_loss = forward_output[5]
|
| 1208 |
+
|
| 1209 |
+
# if reference_logps in batch use them, otherwise use the reference model
|
| 1210 |
+
if "reference_logps" in batch:
|
| 1211 |
+
# Convert Python lists to tensor indices for efficient CUDA operations
|
| 1212 |
+
device = batch["reference_logps"].device
|
| 1213 |
+
labels = torch.as_tensor(batch["label"], dtype=torch.bool, device=device)
|
| 1214 |
+
chosen_idx = torch.nonzero(labels, as_tuple=False).view(-1)
|
| 1215 |
+
rejected_idx = torch.nonzero(~labels, as_tuple=False).view(-1)
|
| 1216 |
+
|
| 1217 |
+
# Use index_select for efficient CUDA operations
|
| 1218 |
+
reference_chosen_logps = batch["reference_logps"].index_select(0, chosen_idx)
|
| 1219 |
+
reference_rejected_logps = batch["reference_logps"].index_select(0, rejected_idx)
|
| 1220 |
+
if self.calculate_KL:
|
| 1221 |
+
reference_KL_logps = batch["reference_KL_logps"]
|
| 1222 |
+
else:
|
| 1223 |
+
reference_KL_logps = None
|
| 1224 |
+
else:
|
| 1225 |
+
with torch.no_grad():
|
| 1226 |
+
if self.ref_model is None:
|
| 1227 |
+
with self.null_ref_context():
|
| 1228 |
+
(
|
| 1229 |
+
reference_chosen_logps,
|
| 1230 |
+
reference_rejected_logps,
|
| 1231 |
+
_,
|
| 1232 |
+
_,
|
| 1233 |
+
reference_KL_logps,
|
| 1234 |
+
) = self.forward(self.model, batch)[:5]
|
| 1235 |
+
else:
|
| 1236 |
+
(
|
| 1237 |
+
reference_chosen_logps,
|
| 1238 |
+
reference_rejected_logps,
|
| 1239 |
+
_,
|
| 1240 |
+
_,
|
| 1241 |
+
reference_KL_logps,
|
| 1242 |
+
) = self.forward(self.ref_model, batch)[:5]
|
| 1243 |
+
|
| 1244 |
+
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
| 1245 |
+
policy_chosen_logps,
|
| 1246 |
+
policy_rejected_logps,
|
| 1247 |
+
policy_KL_logps,
|
| 1248 |
+
reference_chosen_logps,
|
| 1249 |
+
reference_rejected_logps,
|
| 1250 |
+
reference_KL_logps,
|
| 1251 |
+
)
|
| 1252 |
+
|
| 1253 |
+
metrics["kl"] = kl.item()
|
| 1254 |
+
|
| 1255 |
+
all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
|
| 1256 |
+
all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
|
| 1257 |
+
|
| 1258 |
+
if all_num_chosen > 0:
|
| 1259 |
+
metrics["rewards/chosen_sum"] = (
|
| 1260 |
+
self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
|
| 1261 |
+
)
|
| 1262 |
+
metrics["logps/chosen_sum"] = (
|
| 1263 |
+
self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
|
| 1264 |
+
)
|
| 1265 |
+
metrics["logits/chosen_sum"] = (
|
| 1266 |
+
self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
|
| 1267 |
+
)
|
| 1268 |
+
metrics["count/chosen"] = all_num_chosen
|
| 1269 |
+
|
| 1270 |
+
if all_num_rejected > 0:
|
| 1271 |
+
metrics["rewards/rejected_sum"] = (
|
| 1272 |
+
self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
|
| 1273 |
+
)
|
| 1274 |
+
metrics["logps/rejected_sum"] = (
|
| 1275 |
+
self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
|
| 1276 |
+
)
|
| 1277 |
+
metrics["logits/rejected_sum"] = (
|
| 1278 |
+
self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
|
| 1279 |
+
)
|
| 1280 |
+
metrics["count/rejected"] = all_num_rejected
|
| 1281 |
+
|
| 1282 |
+
loss = losses.nanmean()
|
| 1283 |
+
if self.aux_loss_enabled:
|
| 1284 |
+
loss += self.aux_loss_coef * aux_loss
|
| 1285 |
+
|
| 1286 |
+
return loss, metrics
|
| 1287 |
+
|
| 1288 |
+
def compute_loss(
|
| 1289 |
+
self,
|
| 1290 |
+
model: PreTrainedModel | nn.Module,
|
| 1291 |
+
inputs: dict[str, torch.Tensor | Any],
|
| 1292 |
+
return_outputs=False,
|
| 1293 |
+
num_items_in_batch=None,
|
| 1294 |
+
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
| 1295 |
+
compute_loss_context_manager = (
|
| 1296 |
+
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1297 |
+
)
|
| 1298 |
+
|
| 1299 |
+
with compute_loss_context_manager:
|
| 1300 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
| 1301 |
+
|
| 1302 |
+
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
| 1303 |
+
loss = loss.to(self.args.device)
|
| 1304 |
+
# force log the metrics
|
| 1305 |
+
if self.accelerator.is_main_process:
|
| 1306 |
+
self.store_metrics(metrics, train_eval="train")
|
| 1307 |
+
|
| 1308 |
+
if return_outputs:
|
| 1309 |
+
return (loss, metrics)
|
| 1310 |
+
return loss
|
| 1311 |
+
|
| 1312 |
+
def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
| 1313 |
+
for key, value in metrics.items():
|
| 1314 |
+
self._stored_metrics[train_eval][key].append(value)
|
| 1315 |
+
|
| 1316 |
+
def _get_train_sampler(self, dataset: Dataset | None = None) -> torch.utils.data.Sampler | None:
|
| 1317 |
+
if dataset is None:
|
| 1318 |
+
dataset = self.train_dataset
|
| 1319 |
+
if dataset is None or not has_length(dataset):
|
| 1320 |
+
return None
|
| 1321 |
+
return SequentialSampler(dataset)
|
| 1322 |
+
|
| 1323 |
+
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
| 1324 |
+
"""Generate samples from the model and reference model for the given batch of inputs."""
|
| 1325 |
+
|
| 1326 |
+
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
| 1327 |
+
# the torch amp context manager as some hidden states are silently casted to full precision.
|
| 1328 |
+
generate_context_manager = (
|
| 1329 |
+
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1330 |
+
)
|
| 1331 |
+
|
| 1332 |
+
with generate_context_manager:
|
| 1333 |
+
policy_output = model.generate(
|
| 1334 |
+
input_ids=batch["prompt_input_ids"],
|
| 1335 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1336 |
+
max_length=self.max_length,
|
| 1337 |
+
do_sample=True,
|
| 1338 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1339 |
+
)
|
| 1340 |
+
|
| 1341 |
+
# if reference_output in batch use that otherwise use the reference model
|
| 1342 |
+
if "reference_output" in batch:
|
| 1343 |
+
reference_output = batch["reference_output"]
|
| 1344 |
+
else:
|
| 1345 |
+
if self.ref_model is None:
|
| 1346 |
+
with self.null_ref_context():
|
| 1347 |
+
reference_output = self.model.generate(
|
| 1348 |
+
input_ids=batch["prompt_input_ids"],
|
| 1349 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1350 |
+
max_length=self.max_length,
|
| 1351 |
+
do_sample=True,
|
| 1352 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1353 |
+
)
|
| 1354 |
+
else:
|
| 1355 |
+
reference_output = self.ref_model.generate(
|
| 1356 |
+
input_ids=batch["prompt_input_ids"],
|
| 1357 |
+
attention_mask=batch["prompt_attention_mask"],
|
| 1358 |
+
max_length=self.max_length,
|
| 1359 |
+
do_sample=True,
|
| 1360 |
+
pad_token_id=self.processing_class.pad_token_id,
|
| 1361 |
+
)
|
| 1362 |
+
|
| 1363 |
+
policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
|
| 1364 |
+
policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
|
| 1365 |
+
|
| 1366 |
+
reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
|
| 1367 |
+
reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
|
| 1368 |
+
|
| 1369 |
+
return policy_output_decoded, reference_output_decoded
|
| 1370 |
+
|
| 1371 |
+
def prediction_step(
|
| 1372 |
+
self,
|
| 1373 |
+
model: PreTrainedModel | nn.Module,
|
| 1374 |
+
inputs: dict[str, torch.Tensor | Any],
|
| 1375 |
+
prediction_loss_only: bool,
|
| 1376 |
+
ignore_keys: list[str] | None = None,
|
| 1377 |
+
):
|
| 1378 |
+
if ignore_keys is None:
|
| 1379 |
+
if hasattr(model, "config"):
|
| 1380 |
+
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
| 1381 |
+
else:
|
| 1382 |
+
ignore_keys = []
|
| 1383 |
+
|
| 1384 |
+
prediction_context_manager = (
|
| 1385 |
+
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
| 1386 |
+
)
|
| 1387 |
+
with torch.no_grad(), prediction_context_manager:
|
| 1388 |
+
loss, metrics = self.get_batch_loss_metrics(model, inputs)
|
| 1389 |
+
|
| 1390 |
+
# force log the metrics
|
| 1391 |
+
if self.accelerator.is_main_process:
|
| 1392 |
+
self.store_metrics(metrics, train_eval="eval")
|
| 1393 |
+
|
| 1394 |
+
if prediction_loss_only:
|
| 1395 |
+
return (loss.detach(), None, None)
|
| 1396 |
+
|
| 1397 |
+
# logits for the chosen and rejected samples from model
|
| 1398 |
+
logits_dict = {}
|
| 1399 |
+
if "logits/chosen_sum" in metrics:
|
| 1400 |
+
logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"]
|
| 1401 |
+
if "logits/rejected_sum" in metrics:
|
| 1402 |
+
logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"]
|
| 1403 |
+
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
|
| 1404 |
+
logits = torch.tensor(logits, device=self.accelerator.device)
|
| 1405 |
+
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
| 1406 |
+
|
| 1407 |
+
return (loss.detach(), logits, labels)
|
| 1408 |
+
|
| 1409 |
+
def evaluation_loop(
|
| 1410 |
+
self,
|
| 1411 |
+
dataloader: DataLoader,
|
| 1412 |
+
description: str,
|
| 1413 |
+
prediction_loss_only: bool | None = None,
|
| 1414 |
+
ignore_keys: list[str] | None = None,
|
| 1415 |
+
metric_key_prefix: str = "eval",
|
| 1416 |
+
) -> EvalLoopOutput:
|
| 1417 |
+
"""
|
| 1418 |
+
Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
|
| 1419 |
+
`Trainer.evaluate()` and `Trainer.predict()`.
|
| 1420 |
+
|
| 1421 |
+
Works both with or without labels.
|
| 1422 |
+
"""
|
| 1423 |
+
|
| 1424 |
+
# Sample and save to game log if requested (for one batch to save time)
|
| 1425 |
+
if self.generate_during_eval:
|
| 1426 |
+
# Generate random indices within the range of the total number of samples
|
| 1427 |
+
num_samples = len(dataloader.dataset)
|
| 1428 |
+
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
| 1429 |
+
|
| 1430 |
+
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
| 1431 |
+
random_batch_dataset = dataloader.dataset.select(random_indices)
|
| 1432 |
+
random_batch = self.data_collator(random_batch_dataset)
|
| 1433 |
+
random_batch = self._prepare_inputs(random_batch)
|
| 1434 |
+
|
| 1435 |
+
target_labels = torch.tensor(random_batch["label"], dtype=torch.bool, device=self.accelerator.device)
|
| 1436 |
+
target_indices = torch.where(~target_labels)[0]
|
| 1437 |
+
target_batch = {
|
| 1438 |
+
"prompt_input_ids": random_batch["prompt_input_ids"][target_indices],
|
| 1439 |
+
"prompt_attention_mask": random_batch["prompt_attention_mask"][target_indices],
|
| 1440 |
+
"prompt": itemgetter(*target_indices)(random_batch["prompt"]),
|
| 1441 |
+
}
|
| 1442 |
+
policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
|
| 1443 |
+
|
| 1444 |
+
table = pd.DataFrame(
|
| 1445 |
+
columns=["Prompt", "Policy", "Ref Model"],
|
| 1446 |
+
data=[
|
| 1447 |
+
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
| 1448 |
+
for prompt, pol, ref in zip(
|
| 1449 |
+
target_batch["prompt"], policy_output_decoded, ref_output_decoded, strict=True
|
| 1450 |
+
)
|
| 1451 |
+
],
|
| 1452 |
+
)
|
| 1453 |
+
if "wandb" in self.args.report_to:
|
| 1454 |
+
wandb.log({"game_log": wandb.Table(data=table)})
|
| 1455 |
+
|
| 1456 |
+
if "comet_ml" in self.args.report_to:
|
| 1457 |
+
log_table_to_comet_experiment(
|
| 1458 |
+
name="game_log.csv",
|
| 1459 |
+
table=table,
|
| 1460 |
+
)
|
| 1461 |
+
|
| 1462 |
+
# Base evaluation
|
| 1463 |
+
initial_output = super().evaluation_loop(
|
| 1464 |
+
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
| 1465 |
+
)
|
| 1466 |
+
|
| 1467 |
+
return initial_output
|
| 1468 |
+
|
| 1469 |
+
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
| 1470 |
+
"""
|
| 1471 |
+
Log `logs` on the various objects watching training, including stored metrics.
|
| 1472 |
+
|
| 1473 |
+
Args:
|
| 1474 |
+
logs (`dict[str, float]`):
|
| 1475 |
+
The values to log.
|
| 1476 |
+
start_time (`float`, *optional*):
|
| 1477 |
+
Start time of the training.
|
| 1478 |
+
"""
|
| 1479 |
+
# logs either has 'loss' or 'eval_loss'
|
| 1480 |
+
train_eval = "train" if "loss" in logs else "eval"
|
| 1481 |
+
# train metrics should have no prefix, eval should have 'eval_'
|
| 1482 |
+
prefix = "eval_" if train_eval == "eval" else ""
|
| 1483 |
+
# accumulate average metrics from sums and lengths
|
| 1484 |
+
for split in ["chosen", "rejected"]:
|
| 1485 |
+
if f"count/{split}" in self._stored_metrics[train_eval]:
|
| 1486 |
+
count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
|
| 1487 |
+
for metric in ["rewards", "logps", "logits"]:
|
| 1488 |
+
logs[f"{prefix}{metric}/{split}"] = (
|
| 1489 |
+
torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
|
| 1490 |
+
/ count_sum
|
| 1491 |
+
)
|
| 1492 |
+
# delete obsolete metric
|
| 1493 |
+
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
| 1494 |
+
del self._stored_metrics[train_eval][f"count/{split}"]
|
| 1495 |
+
# calculate reward margin
|
| 1496 |
+
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
| 1497 |
+
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
| 1498 |
+
# Add averaged stored metrics to logs
|
| 1499 |
+
for key, metrics in self._stored_metrics[train_eval].items():
|
| 1500 |
+
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
| 1501 |
+
del self._stored_metrics[train_eval]
|
| 1502 |
+
return super().log(logs, start_time)
|
| 1503 |
+
|
| 1504 |
+
# Ensure the model card is saved along with the checkpoint
|
| 1505 |
+
def _save_checkpoint(self, model, trial):
|
| 1506 |
+
if self.args.hub_model_id is None:
|
| 1507 |
+
model_name = Path(self.args.output_dir).name
|
| 1508 |
+
else:
|
| 1509 |
+
model_name = self.args.hub_model_id.split("/")[-1]
|
| 1510 |
+
self.create_model_card(model_name=model_name)
|
| 1511 |
+
super()._save_checkpoint(model, trial)
|
ICL/RL/trl_source/trl/experimental/merge_model_callback.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from huggingface_hub import HfApi
|
| 20 |
+
from transformers import TrainerCallback
|
| 21 |
+
|
| 22 |
+
from ..import_utils import is_mergekit_available
|
| 23 |
+
from ..trainer.utils import get_config_model_id
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_mergekit_available():
|
| 27 |
+
from mergekit.config import MergeConfiguration
|
| 28 |
+
from mergekit.merge import MergeOptions, run_merge
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Logger for module-level logging
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def upload_model_to_hf(folder_path: str, repo_id: str):
|
| 36 |
+
api = HfApi()
|
| 37 |
+
# Create the repository if it doesn't exist
|
| 38 |
+
repo = api.create_repo(repo_id, repo_type="model")
|
| 39 |
+
|
| 40 |
+
# Upload the folder to the specified repository
|
| 41 |
+
api.upload_folder(
|
| 42 |
+
folder_path=folder_path,
|
| 43 |
+
repo_id=repo.repo_id,
|
| 44 |
+
repo_type=repo.repo_type,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class MergeConfig:
|
| 49 |
+
r"""
|
| 50 |
+
Configuration class for merging two models using `mergekit`.
|
| 51 |
+
|
| 52 |
+
This class provides a structured way to configure and generate merge configurations for various merge methods, such
|
| 53 |
+
as `linear`, `ties`, `dare_ties`, and `slerp`.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
method (`str`, *optional*, defaults to `"linear"`):
|
| 57 |
+
Merge method to use. Supported methods include:
|
| 58 |
+
|
| 59 |
+
- `"linear"`: Linearly combines two models with specified weights.
|
| 60 |
+
- `"ties"`: Combines two models using the TIES method with density parameters.
|
| 61 |
+
- `"dare_ties"`: A variant of TIES for domain adaptation.
|
| 62 |
+
- `"slerp"`: Combines models using spherical linear interpolation.
|
| 63 |
+
|
| 64 |
+
Note:
|
| 65 |
+
|
| 66 |
+
For more details about the merge methods and how they are implemented, see the [MergeKit GitHub
|
| 67 |
+
repository](https://github.com/arcee-ai/mergekit?tab=readme-ov-file#merge-methods).
|
| 68 |
+
|
| 69 |
+
Attributes:
|
| 70 |
+
method (`str`): The merge method to use.
|
| 71 |
+
policy_model_path (`str` or `None`): Path to the policy model.
|
| 72 |
+
target_model_path (`str` or `None`): Path to the target model.
|
| 73 |
+
policy_model_weight (`float`): Weight for the policy model (for `linear` and `ties` methods).
|
| 74 |
+
target_model_weight (`float`): Weight for the target model (for `linear` and `ties` methods).
|
| 75 |
+
policy_model_density (`list[float]`): Density parameters for the policy model (for `ties` and `dare_ties`).
|
| 76 |
+
target_model_density (`list[float]`): Density parameters for the target model (for `ties` and `dare_ties`).
|
| 77 |
+
normalize (`float` or `None`): Normalization factor for the TIES method.
|
| 78 |
+
t_values (`float` or `None`): Interpolation factor for the SLERP method.
|
| 79 |
+
dtype (`str`): Data type to use for merging, e.g., `"float16"`.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(self, method: str = "linear"):
|
| 83 |
+
if not is_mergekit_available():
|
| 84 |
+
raise ImportError("MergeConfig requires the `mergekit` extra. To install, run `pip install mergekit`.")
|
| 85 |
+
self.method = method
|
| 86 |
+
self.policy_model_path = None
|
| 87 |
+
self.target_model_path = None
|
| 88 |
+
|
| 89 |
+
# Initialize relevant parameters based on the method
|
| 90 |
+
if method == "linear":
|
| 91 |
+
self.policy_model_weight = 0.5
|
| 92 |
+
self.target_model_weight = 0.5
|
| 93 |
+
self.dtype = "float16"
|
| 94 |
+
elif method == "ties":
|
| 95 |
+
self.policy_model_weight = 1.0
|
| 96 |
+
self.policy_model_density = [1.0, 0.7, 0.1]
|
| 97 |
+
self.target_model_weight = 1.0
|
| 98 |
+
self.target_model_density = [1.0]
|
| 99 |
+
self.normalize = 1.0
|
| 100 |
+
self.dtype = "float16"
|
| 101 |
+
elif method == "dare_ties":
|
| 102 |
+
self.policy_model_weight = 1.0
|
| 103 |
+
self.policy_model_density = [1.0, 0.7, 0.1]
|
| 104 |
+
self.target_model_weight = 1.0
|
| 105 |
+
self.target_model_density = [1.0]
|
| 106 |
+
self.normalize = 1.0
|
| 107 |
+
self.dtype = "float16"
|
| 108 |
+
elif method == "slerp":
|
| 109 |
+
self.t_values = 0.5
|
| 110 |
+
self.dtype = "float16"
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Unsupported merge method: {method}")
|
| 113 |
+
|
| 114 |
+
def create_merge_config_linear(self) -> "MergeConfiguration":
|
| 115 |
+
"""
|
| 116 |
+
Creates a merge configuration for a linear merge of two models with specified weights.
|
| 117 |
+
"""
|
| 118 |
+
# Create the merge configuration dictionary
|
| 119 |
+
merge_config_dict = {
|
| 120 |
+
"dtype": self.dtype,
|
| 121 |
+
"merge_method": "linear",
|
| 122 |
+
"models": [
|
| 123 |
+
{"model": self.policy_model_path, "parameters": {"weight": self.policy_model_weight}},
|
| 124 |
+
{"model": self.target_model_path, "parameters": {"weight": self.target_model_weight}},
|
| 125 |
+
],
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
# Create the MergeConfiguration from the dictionary
|
| 129 |
+
merge_config = MergeConfiguration.model_validate(merge_config_dict)
|
| 130 |
+
|
| 131 |
+
return merge_config
|
| 132 |
+
|
| 133 |
+
def create_merge_config_ties(self) -> "MergeConfiguration":
|
| 134 |
+
"""
|
| 135 |
+
Creates a merge configuration for a TIES merge of two models, with specified weights and densities.
|
| 136 |
+
"""
|
| 137 |
+
# Create the TIES merge configuration dictionary
|
| 138 |
+
merge_config_dict = {
|
| 139 |
+
"merge_method": "ties",
|
| 140 |
+
"slices": None, # Optional slices if needed
|
| 141 |
+
"models": [
|
| 142 |
+
{
|
| 143 |
+
"model": {
|
| 144 |
+
"model": {"path": self.target_model_path, "revision": None},
|
| 145 |
+
"lora": None,
|
| 146 |
+
"override_architecture": None,
|
| 147 |
+
},
|
| 148 |
+
"parameters": {"density": self.target_model_density, "weight": self.target_model_weight},
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"model": {
|
| 152 |
+
"model": {"path": self.policy_model_path, "revision": None},
|
| 153 |
+
"lora": None,
|
| 154 |
+
"override_architecture": None,
|
| 155 |
+
},
|
| 156 |
+
"parameters": {"density": self.policy_model_density, "weight": self.policy_model_weight},
|
| 157 |
+
},
|
| 158 |
+
],
|
| 159 |
+
"parameters": {"normalize": self.normalize},
|
| 160 |
+
"base_model": {
|
| 161 |
+
"model": {"path": self.policy_model_path, "revision": None},
|
| 162 |
+
"lora": None,
|
| 163 |
+
"override_architecture": None,
|
| 164 |
+
},
|
| 165 |
+
"dtype": self.dtype,
|
| 166 |
+
"tokenizer_source": None,
|
| 167 |
+
"tokenizer": None,
|
| 168 |
+
"chat_template": None,
|
| 169 |
+
"out_dtype": None,
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
# Create the MergeConfiguration from the dictionary
|
| 173 |
+
merge_config = MergeConfiguration.model_validate(merge_config_dict)
|
| 174 |
+
|
| 175 |
+
return merge_config
|
| 176 |
+
|
| 177 |
+
def create_merge_config_dare_ties(self) -> "MergeConfiguration":
|
| 178 |
+
"""
|
| 179 |
+
Creates a merge configuration for a DARE TIES merge of two models, with specified weights and densities.
|
| 180 |
+
"""
|
| 181 |
+
# Create the DARE TIES merge configuration dictionary
|
| 182 |
+
merge_config_dict = {
|
| 183 |
+
"merge_method": "dare_ties",
|
| 184 |
+
"slices": None, # Optional slices if needed
|
| 185 |
+
"models": [
|
| 186 |
+
{
|
| 187 |
+
"model": {
|
| 188 |
+
"model": {"path": self.target_model_path, "revision": None},
|
| 189 |
+
"lora": None,
|
| 190 |
+
"override_architecture": None,
|
| 191 |
+
},
|
| 192 |
+
"parameters": {"density": self.target_model_density, "weight": self.target_model_weight},
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"model": {
|
| 196 |
+
"model": {"path": self.policy_model_path, "revision": None},
|
| 197 |
+
"lora": None,
|
| 198 |
+
"override_architecture": None,
|
| 199 |
+
},
|
| 200 |
+
"parameters": {"density": self.policy_model_density, "weight": self.policy_model_weight},
|
| 201 |
+
},
|
| 202 |
+
],
|
| 203 |
+
"parameters": {"normalize": self.normalize},
|
| 204 |
+
"base_model": {
|
| 205 |
+
"model": {"path": self.policy_model_path, "revision": None},
|
| 206 |
+
"lora": None,
|
| 207 |
+
"override_architecture": None,
|
| 208 |
+
},
|
| 209 |
+
"dtype": self.dtype,
|
| 210 |
+
"tokenizer_source": None,
|
| 211 |
+
"tokenizer": None,
|
| 212 |
+
"chat_template": None,
|
| 213 |
+
"out_dtype": None,
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
# Create the MergeConfiguration from the dictionary
|
| 217 |
+
merge_config = MergeConfiguration.model_validate(merge_config_dict)
|
| 218 |
+
|
| 219 |
+
return merge_config
|
| 220 |
+
|
| 221 |
+
def create_merge_config_slerp(self) -> "MergeConfiguration":
|
| 222 |
+
"""
|
| 223 |
+
Creates a merge configuration for a SLERP merge of a model with a base model.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
# Create the SLERP merge configuration dictionary
|
| 227 |
+
merge_config_dict = {
|
| 228 |
+
"merge_method": "slerp",
|
| 229 |
+
"slices": None, # Optional slices if needed
|
| 230 |
+
"models": [
|
| 231 |
+
{
|
| 232 |
+
"model": {
|
| 233 |
+
"model": {"path": self.target_model_path, "revision": None},
|
| 234 |
+
"lora": None,
|
| 235 |
+
"override_architecture": None,
|
| 236 |
+
},
|
| 237 |
+
"parameters": None, # No specific parameters for SLERP model
|
| 238 |
+
}
|
| 239 |
+
],
|
| 240 |
+
"parameters": {
|
| 241 |
+
"t": self.t_values # Set the t values for SLERP
|
| 242 |
+
},
|
| 243 |
+
"base_model": {
|
| 244 |
+
"model": {"path": self.policy_model_path, "revision": None},
|
| 245 |
+
"lora": None,
|
| 246 |
+
"override_architecture": None,
|
| 247 |
+
},
|
| 248 |
+
"dtype": self.dtype,
|
| 249 |
+
"tokenizer_source": None,
|
| 250 |
+
"tokenizer": None,
|
| 251 |
+
"chat_template": None,
|
| 252 |
+
"out_dtype": None,
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
# Create the MergeConfiguration from the dictionary
|
| 256 |
+
merge_config = MergeConfiguration.model_validate(merge_config_dict)
|
| 257 |
+
|
| 258 |
+
return merge_config
|
| 259 |
+
|
| 260 |
+
def create(self) -> "MergeConfiguration":
|
| 261 |
+
if self.method == "linear":
|
| 262 |
+
return self.create_merge_config_linear()
|
| 263 |
+
elif self.method == "ties":
|
| 264 |
+
return self.create_merge_config_ties()
|
| 265 |
+
elif self.method == "dare_ties":
|
| 266 |
+
return self.create_merge_config_dare_ties()
|
| 267 |
+
elif self.method == "slerp":
|
| 268 |
+
return self.create_merge_config_slerp()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def merge_models(config: "MergeConfiguration", out_path: str):
|
| 272 |
+
"""
|
| 273 |
+
Merge two models using mergekit
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
config (`MergeConfiguration`): The merge configuration.
|
| 277 |
+
out_path (`str`): The output path for the merged model.
|
| 278 |
+
"""
|
| 279 |
+
if not is_mergekit_available():
|
| 280 |
+
raise ImportError("merge_models requires the `mergekit` extra. To install, run `pip install mergekit`.")
|
| 281 |
+
run_merge(
|
| 282 |
+
config,
|
| 283 |
+
out_path=out_path,
|
| 284 |
+
options=MergeOptions(
|
| 285 |
+
device="auto",
|
| 286 |
+
cuda=torch.cuda.is_available(),
|
| 287 |
+
copy_tokenizer=True,
|
| 288 |
+
lazy_unpickle=False,
|
| 289 |
+
low_cpu_memory=False,
|
| 290 |
+
),
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class MergeModelCallback(TrainerCallback):
|
| 295 |
+
r"""
|
| 296 |
+
A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based
|
| 297 |
+
on a merge configuration.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
merge_config ([`experimental.merge_model_callback.MergeConfig`], *optional*):
|
| 301 |
+
Configuration used for the merging process. If not provided, the default
|
| 302 |
+
[`~experimental.merge_model_callback.MergeConfig`] is used.
|
| 303 |
+
merge_at_every_checkpoint (`bool`, *optional*, defaults to `False`):
|
| 304 |
+
Whether to merge the model at every checkpoint.
|
| 305 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 306 |
+
Whether to push the merged model to the Hub after merging.
|
| 307 |
+
|
| 308 |
+
Example:
|
| 309 |
+
|
| 310 |
+
```python
|
| 311 |
+
from trl.experimental.merge_model_callback import MergeConfig, MergeModelCallback
|
| 312 |
+
|
| 313 |
+
config = MergeConfig()
|
| 314 |
+
merge_callback = MergeModelCallback(config)
|
| 315 |
+
trainer = DPOTrainer(..., callbacks=[merge_callback])
|
| 316 |
+
```
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
def __init__(
|
| 320 |
+
self,
|
| 321 |
+
merge_config: "MergeConfig | None" = None,
|
| 322 |
+
merge_at_every_checkpoint: bool = False,
|
| 323 |
+
push_to_hub: bool = False,
|
| 324 |
+
):
|
| 325 |
+
if not is_mergekit_available():
|
| 326 |
+
raise ImportError(
|
| 327 |
+
"MergeModelCallback requires the `mergekit` extra. To install, run `pip install mergekit`."
|
| 328 |
+
)
|
| 329 |
+
self.merge_config = merge_config or MergeConfig()
|
| 330 |
+
self.merge_at_every_checkpoint = merge_at_every_checkpoint
|
| 331 |
+
self.push_to_hub = push_to_hub
|
| 332 |
+
|
| 333 |
+
def _merge_and_maybe_push(self, output_dir, global_step, model):
|
| 334 |
+
checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}")
|
| 335 |
+
self.merge_config.policy_model_path = checkpoint_path
|
| 336 |
+
if self.merge_config.target_model_path is None:
|
| 337 |
+
self.merge_config.target_model_path = get_config_model_id(model.config)
|
| 338 |
+
merge_path = os.path.join(checkpoint_path, "merged")
|
| 339 |
+
|
| 340 |
+
merge_models(self.merge_config.create(), merge_path)
|
| 341 |
+
|
| 342 |
+
if self.push_to_hub:
|
| 343 |
+
repo_name = f"{output_dir}_checkpoint-{global_step}_merged"
|
| 344 |
+
upload_model_to_hf(merge_path, repo_name)
|
| 345 |
+
|
| 346 |
+
def on_save(self, args, state, control, model=None, **kwargs):
|
| 347 |
+
if self.merge_at_every_checkpoint:
|
| 348 |
+
self._merge_and_maybe_push(args.output_dir, state.global_step, model)
|
| 349 |
+
|
| 350 |
+
def on_train_end(self, args, state, control, model=None, **kwargs):
|
| 351 |
+
if not self.merge_at_every_checkpoint:
|
| 352 |
+
self._merge_and_maybe_push(args.output_dir, state.global_step, model)
|
ICL/RL/trl_source/trl/experimental/minillm/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .minillm_config import MiniLLMConfig
|
| 16 |
+
from .minillm_trainer import MiniLLMTrainer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = ["MiniLLMConfig", "MiniLLMTrainer"]
|