Lekr0 commited on
Commit
5de3d77
·
verified ·
1 Parent(s): e9585fc

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ICL/RL/trl_source/.github/workflows/tests-experimental.yml +70 -0
  2. ICL/RL/trl_source/.github/workflows/tests_transformers_branch.yml +121 -0
  3. ICL/RL/trl_source/examples/scripts/evals/judge_tldr.py +108 -0
  4. ICL/RL/trl_source/examples/scripts/nemo_gym/deepspeed_zero3.yaml +22 -0
  5. ICL/RL/trl_source/examples/scripts/nemo_gym/submit.sh +112 -0
  6. ICL/RL/trl_source/examples/scripts/online_dpo.py +159 -0
  7. ICL/RL/trl_source/examples/scripts/openenv/browsergym_llm.py +506 -0
  8. ICL/RL/trl_source/examples/scripts/openenv/echo.py +248 -0
  9. ICL/RL/trl_source/examples/scripts/openenv/wordle.py +607 -0
  10. ICL/RL/trl_source/examples/scripts/openenv/wordle_prompt.txt +105 -0
  11. ICL/RL/trl_source/examples/scripts/ppo/ppo.py +180 -0
  12. ICL/RL/trl_source/examples/scripts/reward_modeling.py +136 -0
  13. ICL/RL/trl_source/examples/scripts/sft_vlm_gemma3.py +194 -0
  14. ICL/RL/trl_source/trl/__pycache__/__init__.cpython-313.pyc +0 -0
  15. ICL/RL/trl_source/trl/__pycache__/_compat.cpython-313.pyc +0 -0
  16. ICL/RL/trl_source/trl/__pycache__/chat_template_utils.cpython-313.pyc +0 -0
  17. ICL/RL/trl_source/trl/__pycache__/data_utils.cpython-313.pyc +0 -0
  18. ICL/RL/trl_source/trl/__pycache__/import_utils.cpython-313.pyc +0 -0
  19. ICL/RL/trl_source/trl/accelerate_configs/fsdp1.yaml +28 -0
  20. ICL/RL/trl_source/trl/accelerate_configs/fsdp2.yaml +25 -0
  21. ICL/RL/trl_source/trl/accelerate_configs/multi_gpu.yaml +16 -0
  22. ICL/RL/trl_source/trl/accelerate_configs/single_gpu.yaml +16 -0
  23. ICL/RL/trl_source/trl/accelerate_configs/zero1.yaml +20 -0
  24. ICL/RL/trl_source/trl/accelerate_configs/zero2.yaml +21 -0
  25. ICL/RL/trl_source/trl/accelerate_configs/zero3.yaml +22 -0
  26. ICL/RL/trl_source/trl/experimental/__init__.py +36 -0
  27. ICL/RL/trl_source/trl/experimental/bco/__init__.py +16 -0
  28. ICL/RL/trl_source/trl/experimental/bema_for_ref_model/__init__.py +16 -0
  29. ICL/RL/trl_source/trl/experimental/bema_for_ref_model/dpo_trainer.py +30 -0
  30. ICL/RL/trl_source/trl/experimental/cpo/__init__.py +19 -0
  31. ICL/RL/trl_source/trl/experimental/cpo/cpo_config.py +207 -0
  32. ICL/RL/trl_source/trl/experimental/cpo/cpo_trainer.py +1057 -0
  33. ICL/RL/trl_source/trl/experimental/gfpo/gfpo_config.py +35 -0
  34. ICL/RL/trl_source/trl/experimental/gkd/__init__.py +19 -0
  35. ICL/RL/trl_source/trl/experimental/gkd/gkd_config.py +112 -0
  36. ICL/RL/trl_source/trl/experimental/gold/__init__.py +19 -0
  37. ICL/RL/trl_source/trl/experimental/gold/gold.py +155 -0
  38. ICL/RL/trl_source/trl/experimental/gold/gold_config.py +419 -0
  39. ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/__init__.py +16 -0
  40. ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py +34 -0
  41. ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +731 -0
  42. ICL/RL/trl_source/trl/experimental/gspo_token/__init__.py +15 -0
  43. ICL/RL/trl_source/trl/experimental/gspo_token/grpo_trainer.py +157 -0
  44. ICL/RL/trl_source/trl/experimental/judges/__init__.py +36 -0
  45. ICL/RL/trl_source/trl/experimental/judges/judges.py +482 -0
  46. ICL/RL/trl_source/trl/experimental/kto/__init__.py +19 -0
  47. ICL/RL/trl_source/trl/experimental/kto/kto_config.py +171 -0
  48. ICL/RL/trl_source/trl/experimental/kto/kto_trainer.py +1511 -0
  49. ICL/RL/trl_source/trl/experimental/merge_model_callback.py +352 -0
  50. ICL/RL/trl_source/trl/experimental/minillm/__init__.py +19 -0
ICL/RL/trl_source/.github/workflows/tests-experimental.yml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Tests (experimental)
2
+
3
+ on:
4
+ pull_request:
5
+ paths:
6
+ # Run only when relevant files are modified
7
+ - "trl/experimental/**"
8
+ - "tests/experimental/**"
9
+
10
+ env:
11
+ TQDM_DISABLE: 1
12
+ PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
13
+ TRL_EXPERIMENTAL_SILENCE: 1
14
+
15
+ jobs:
16
+ check_code_quality:
17
+ name: Check code quality
18
+ runs-on: ubuntu-latest
19
+ if: github.event.pull_request.draft == false
20
+ steps:
21
+ - uses: actions/checkout@v6
22
+ - name: Set up Python 3.13
23
+ uses: actions/setup-python@v6
24
+ with:
25
+ python-version: 3.13
26
+ - uses: pre-commit/action@v3.0.1
27
+ with:
28
+ extra_args: --all-files
29
+
30
+ tests:
31
+ name: Tests (experimental)
32
+ runs-on:
33
+ group: aws-g4dn-2xlarge
34
+ container:
35
+ image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
36
+ options: --gpus all
37
+ defaults:
38
+ run:
39
+ shell: bash
40
+ steps:
41
+ - name: Git checkout
42
+ uses: actions/checkout@v6
43
+
44
+ - name: Set up Python 3.13
45
+ uses: actions/setup-python@v6
46
+ with:
47
+ python-version: 3.13
48
+
49
+ - name: Install Make and Git
50
+ run: |
51
+ apt-get update && apt-get install -y make git curl
52
+
53
+ - name: Install uv
54
+ run: |
55
+ curl -LsSf https://astral.sh/uv/install.sh | sh
56
+
57
+ - name: Create Python virtual environment
58
+ run: |
59
+ uv venv
60
+ uv pip install --upgrade setuptools wheel
61
+
62
+ - name: Install dependencies
63
+ run: |
64
+ source .venv/bin/activate
65
+ uv pip install ".[dev]"
66
+
67
+ - name: Test with pytest
68
+ run: |
69
+ source .venv/bin/activate
70
+ make test_experimental
ICL/RL/trl_source/.github/workflows/tests_transformers_branch.yml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Tests against Transformers branch
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ inputs:
6
+ transformers_ref:
7
+ description: "Transformers git ref (branch, tag, or commit SHA)"
8
+ required: true
9
+ default: "main"
10
+
11
+ env:
12
+ TQDM_DISABLE: 1
13
+ CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
14
+ PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
15
+
16
+ jobs:
17
+ tests_transformers_branch:
18
+ name: Tests with Transformers ${{ inputs.transformers_ref }}
19
+ runs-on:
20
+ group: aws-g4dn-2xlarge
21
+ container:
22
+ image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
23
+ options: --gpus all
24
+ defaults:
25
+ run:
26
+ shell: bash
27
+ steps:
28
+ - name: Git checkout
29
+ uses: actions/checkout@v6
30
+
31
+ - name: Set up Python 3.12
32
+ uses: actions/setup-python@v6
33
+ with:
34
+ python-version: '3.12'
35
+
36
+ - name: Install Make and Git
37
+ run: |
38
+ apt-get update && apt-get install -y make git curl
39
+
40
+ - name: Install uv
41
+ run: |
42
+ curl -LsSf https://astral.sh/uv/install.sh | sh
43
+
44
+ - name: Create Python virtual environment
45
+ run: |
46
+ uv venv
47
+ uv pip install --upgrade setuptools wheel
48
+
49
+ - name: Install dependencies
50
+ run: |
51
+ source .venv/bin/activate
52
+ uv pip install ".[dev]"
53
+ uv pip install -U git+https://github.com/huggingface/transformers.git@${{ inputs.transformers_ref }}
54
+
55
+ - name: Test with pytest
56
+ run: |
57
+ source .venv/bin/activate
58
+ make test
59
+
60
+ - name: Post to Slack
61
+ if: github.ref == 'refs/heads/main' && always()
62
+ uses: huggingface/hf-workflows/.github/actions/post-slack@main
63
+ with:
64
+ slack_channel: ${{ env.CI_SLACK_CHANNEL }}
65
+ title: Results with Transformers ${{ inputs.transformers_ref }}
66
+ status: ${{ job.status }}
67
+ slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
68
+
69
+ distributed_smoke:
70
+ name: Distributed smoke tests with Transformers ${{ inputs.transformers_ref }}
71
+ runs-on:
72
+ group: aws-g5-12xlarge-cache
73
+ container:
74
+ image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
75
+ options: --gpus all
76
+ defaults:
77
+ run:
78
+ shell: bash
79
+ env:
80
+ CUDA_VISIBLE_DEVICES: "0,1"
81
+ steps:
82
+ - name: Git checkout
83
+ uses: actions/checkout@v6
84
+
85
+ - name: Set up Python 3.12
86
+ uses: actions/setup-python@v6
87
+ with:
88
+ python-version: '3.12'
89
+
90
+ - name: Install Make and Git
91
+ run: |
92
+ apt-get update && apt-get install -y make git curl
93
+
94
+ - name: Install uv
95
+ run: |
96
+ curl -LsSf https://astral.sh/uv/install.sh | sh
97
+
98
+ - name: Create Python virtual environment
99
+ run: |
100
+ uv venv
101
+ uv pip install --upgrade setuptools wheel
102
+
103
+ - name: Install dependencies
104
+ run: |
105
+ source .venv/bin/activate
106
+ uv pip install ".[dev]"
107
+ uv pip install -U git+https://github.com/huggingface/transformers.git@${{ inputs.transformers_ref }}
108
+
109
+ - name: Run distributed smoke tests
110
+ run: |
111
+ source .venv/bin/activate
112
+ pytest -v tests/distributed/test_distributed.py
113
+
114
+ - name: Post to Slack
115
+ if: github.ref == 'refs/heads/main' && always()
116
+ uses: huggingface/hf-workflows/.github/actions/post-slack@main
117
+ with:
118
+ slack_channel: ${{ env.CI_SLACK_CHANNEL }}
119
+ title: Results of distributed smoke tests with Transformers ${{ inputs.transformers_ref }}
120
+ status: ${{ job.status }}
121
+ slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
ICL/RL/trl_source/examples/scripts/evals/judge_tldr.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl[vllm]",
18
+ # ]
19
+ # ///
20
+
21
+ from dataclasses import dataclass, field
22
+
23
+ from datasets import load_dataset
24
+ from transformers import HfArgumentParser
25
+ from vllm import LLM, SamplingParams
26
+
27
+ from trl.experimental.judges import HfPairwiseJudge, OpenAIPairwiseJudge
28
+
29
+
30
+ """
31
+ Examples:
32
+
33
+ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --num_examples 1000
34
+ Model win rate: 31.40%
35
+
36
+ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-3.5-turbo-0125 --num_examples 1000
37
+ Model win rate: 51.60%
38
+
39
+ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-4o-mini --num_examples 1000
40
+ Model win rate: 51.20%
41
+
42
+ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --num_examples 1000
43
+ Model win rate: 46.30%
44
+
45
+ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-3.5-turbo-0125 --num_examples 1000
46
+ Model win rate: 52.50%
47
+
48
+ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-4o-mini --num_examples 1000
49
+ Model win rate: 63.00%
50
+ """
51
+
52
+
53
+ @dataclass
54
+ class ScriptArguments:
55
+ r"""
56
+ Arguments for the script.
57
+
58
+ Args:
59
+ model_name_or_path (`str`):
60
+ Model name or path to the model to evaluate.
61
+ judge_model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3-70B-Instruct"`):
62
+ Model name or path to the model to use as a judge. E.g., 'gpt-3.5-turbo-0125' or
63
+ 'meta-llama/Meta-Llama-3-70B-Instruct'.
64
+ num_examples (`int`, *optional*):
65
+ Number of examples to evaluate.
66
+ """
67
+
68
+ model_name_or_path: str = field(metadata={"help": "Model name or path to the model to evaluate."})
69
+ judge_model: str = field(
70
+ default="meta-llama/Meta-Llama-3-70B-Instruct",
71
+ metadata={
72
+ "help": "Model name or path to the model to use as a judge. E.g., 'gpt-3.5-turbo-0125' or "
73
+ "'meta-llama/Meta-Llama-3-70B-Instruct'."
74
+ },
75
+ )
76
+ num_examples: int | None = field(default=None, metadata={"help": "Number of examples to evaluate."})
77
+
78
+
79
+ if __name__ == "__main__":
80
+ # Parse the arguments
81
+ parser = HfArgumentParser(ScriptArguments)
82
+ script_args = parser.parse_args_into_dataclasses()[0]
83
+
84
+ # Load the dataset
85
+ dataset = load_dataset("trl-lib/tldr", split="validation")
86
+ if script_args.num_examples is not None:
87
+ dataset = dataset.select(range(script_args.num_examples))
88
+
89
+ # Extract the prompts and reference completions
90
+ prompts = dataset["prompt"]
91
+ reference_completions = dataset["completion"]
92
+
93
+ # Generate the model completions
94
+ sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=200) # very generous max token length
95
+ llm = LLM(model=script_args.model_name_or_path, tensor_parallel_size=1)
96
+ outputs = llm.generate(prompts, sampling_params)
97
+ model_completions = [output.outputs[0].text.strip() for output in outputs]
98
+
99
+ # Judge the outputs
100
+ if "gpt" in script_args.judge_model:
101
+ judge = OpenAIPairwiseJudge(script_args.judge_model)
102
+ else:
103
+ judge = HfPairwiseJudge(script_args.judge_model)
104
+
105
+ completions = [[c0, c1] for c0, c1 in zip(reference_completions, model_completions, strict=True)]
106
+ best_idxs = judge.judge(prompts, completions)
107
+ model_win_rate = best_idxs.count(1) / len(best_idxs)
108
+ print(f"Model win rate: {model_win_rate * 100:.2f}%")
ICL/RL/trl_source/examples/scripts/nemo_gym/deepspeed_zero3.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: true
8
+ zero3_save_16bit_model: true
9
+ zero_stage: 3
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: bf16
15
+ num_machines: 4
16
+ num_processes: 32
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
ICL/RL/trl_source/examples/scripts/nemo_gym/submit.sh ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -A account
3
+ #SBATCH -p partition
4
+ #SBATCH -N 5
5
+ #SBATCH --gres gpu:8
6
+ #SBATCH --ntasks-per-node=1
7
+ #SBATCH --cpus-per-task=16
8
+ #SBATCH --time=4:00:00
9
+ #SBATCH --job-name=trl_nemo_gym
10
+ #SBATCH --output=logs/%j/slurm.out
11
+ #SBATCH --error=logs/%j/slurm.err
12
+
13
+ CONTAINER_IMAGE="nvcr.io/nvidia/pytorch:25.12-py3"
14
+ MOUNTS="/path/to/mounts:/path/to/mounts"
15
+
16
+ NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
17
+
18
+ TRAIN_NODE_0="${NODELIST[0]}"
19
+ TRAIN_NODE_1="${NODELIST[1]}"
20
+ TRAIN_NODE_2="${NODELIST[2]}"
21
+ TRAIN_NODE_3="${NODELIST[3]}"
22
+ VLLM_NODE="${NODELIST[4]}"
23
+
24
+ echo "Training Nodes: $TRAIN_NODE_0, $TRAIN_NODE_1, $TRAIN_NODE_2, $TRAIN_NODE_3"
25
+ echo "vLLM Node: $VLLM_NODE"
26
+ echo "Main process IP: $TRAIN_NODE_0"
27
+
28
+ LOG_DIR="logs/${SLURM_JOB_ID}"
29
+ mkdir -p ${LOG_DIR}
30
+
31
+ echo "Starting ng_run and vLLM on ${VLLM_NODE}..."
32
+ echo "Logs will be saved to: ${LOG_DIR}"
33
+
34
+ # NOTE: If you have already set up your TRL venv, you can remove all of the pip installs and uv venv related commands below!
35
+
36
+ srun --nodes=1 --ntasks=1 --nodelist="${VLLM_NODE}" \
37
+ --container-image="${CONTAINER_IMAGE}" \
38
+ --container-mounts="${MOUNTS}" \
39
+ --container-mount-home \
40
+ bash -c "
41
+ LOG_DIR=/path/to/logs
42
+ mkdir -p \${LOG_DIR}
43
+
44
+ # Install uv if not already installed
45
+ curl -LsSf https://astral.sh/uv/install.sh | sh
46
+ source \$HOME/.local/bin/env
47
+
48
+ # Start nemo gym servers
49
+ (set -x && \
50
+ export HOME=/path/to/user && \
51
+ export PATH=\$HOME/.local/bin:\$PATH && \
52
+ cd /path/to/user/Gym && \
53
+ uv venv --python 3.12 && \
54
+ source .venv/bin/activate && \
55
+ uv sync && \
56
+ ray stop --force && \
57
+ ng_run +config_paths=[responses_api_models/vllm_model/configs/vllm_model.yaml,resources_servers/workplace_assistant/configs/workplace_assistant.yaml] +head_server.host=0.0.0.0 +head_server.port=11000) > \${LOG_DIR}/ng_run.log 2>&1 &
58
+
59
+ sleep 10
60
+
61
+ # Start trl vllm server
62
+ (set -x && \
63
+ export HOME=/path/to/user && \
64
+ export HF_HOME=/path/to/user/hf_home && \
65
+ cd /path/to/user/trl && \
66
+ rm -rf .venv && uv venv && source .venv/bin/activate && uv sync && uv pip install -e .[vllm] && uv pip install fastapi uvicorn && \
67
+ python -m trl.scripts.vllm_serve \
68
+ --model Qwen/Qwen3-4B-Instruct-2507 \
69
+ --host 0.0.0.0 \
70
+ --tensor-parallel-size 8 \
71
+ --data-parallel-size 1 \
72
+ --max-model-len 16384 \
73
+ --gpu-memory-utilization 0.7 \
74
+ --port 8000) > \${LOG_DIR}/vllm_serve.log 2>&1 &
75
+
76
+ wait
77
+ " &
78
+
79
+ echo "Waiting for nemo gym and vllm to start..."
80
+ sleep 120
81
+
82
+ echo "Launching training on 4 nodes..."
83
+
84
+ TRAIN_NODES_LIST="${TRAIN_NODE_0},${TRAIN_NODE_1},${TRAIN_NODE_2},${TRAIN_NODE_3}"
85
+
86
+ srun --nodes=4 --ntasks=4 --nodelist="${TRAIN_NODES_LIST}" \
87
+ --container-image="${CONTAINER_IMAGE}" \
88
+ --container-mounts="${MOUNTS}" \
89
+ --container-mount-home \
90
+ bash -c "
91
+ set -x && \
92
+ export HOME=/path/to/user && \
93
+ export HF_HOME=/path/to/user/hf_home && \
94
+ cd /path/to/user/trl && \
95
+ source .venv/bin/activate && uv pip install accelerate deepspeed wandb omegaconf && \
96
+ cd examples/scripts/nemo_gym && \
97
+ export WANDB_API_KEY=<your wandb api key> && \
98
+ accelerate launch \
99
+ --config_file deepspeed_zero3.yaml \
100
+ --num_processes 32 \
101
+ --num_machines 4 \
102
+ --machine_rank \$SLURM_PROCID \
103
+ --main_process_ip ${TRAIN_NODE_0} \
104
+ --main_process_port 29500 \
105
+ --rdzv_backend c10d \
106
+ train_multi_environment.py \
107
+ --config config.yaml \
108
+ --vllm_server_host ${VLLM_NODE} \
109
+ --head_server_host ${VLLM_NODE}" &
110
+
111
+ wait
112
+
ICL/RL/trl_source/examples/scripts/online_dpo.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "peft",
19
+ # "trackio",
20
+ # "kernels",
21
+ # ]
22
+ # ///
23
+
24
+ """
25
+ Usage:
26
+
27
+ python examples/scripts/online_dpo.py \
28
+ --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
29
+ --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
30
+ --dataset_name trl-lib/tldr \
31
+ --learning_rate 5.0e-7 \
32
+ --output_dir pythia-1b-tldr-online-dpo \
33
+ --per_device_train_batch_size 8 \
34
+ --gradient_accumulation_steps 16 \
35
+ --warmup_steps 0.1 \
36
+ --missing_eos_penalty 1.0
37
+
38
+ With LoRA:
39
+ python examples/scripts/online_dpo.py \
40
+ --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
41
+ --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
42
+ --dataset_name trl-lib/tldr \
43
+ --learning_rate 5.0e-6 \
44
+ --output_dir pythia-1b-tldr-online-dpo \
45
+ --per_device_train_batch_size 16 \
46
+ --gradient_accumulation_steps 8 \
47
+ --warmup_steps 0.1 \
48
+ --missing_eos_penalty 1.0 \
49
+ --use_peft
50
+ """
51
+
52
+ import os
53
+
54
+ import torch
55
+ from datasets import load_dataset
56
+ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig
57
+
58
+ from trl import (
59
+ LogCompletionsCallback,
60
+ ModelConfig,
61
+ ScriptArguments,
62
+ TrlParser,
63
+ get_kbit_device_map,
64
+ get_peft_config,
65
+ get_quantization_config,
66
+ )
67
+ from trl.experimental.judges import HfPairwiseJudge, OpenAIPairwiseJudge, PairRMJudge
68
+ from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer
69
+
70
+
71
+ # Enable logging in a Hugging Face Space
72
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
73
+
74
+
75
+ JUDGES = {"pair_rm": PairRMJudge, "openai": OpenAIPairwiseJudge, "hf": HfPairwiseJudge}
76
+
77
+ if __name__ == "__main__":
78
+ parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig))
79
+ script_args, training_args, model_args = parser.parse_args_and_config()
80
+ training_args.gradient_checkpointing_kwargs = {"use_reentrant": True}
81
+
82
+ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
83
+ model_kwargs = dict(
84
+ revision=model_args.model_revision,
85
+ attn_implementation=model_args.attn_implementation,
86
+ dtype=dtype,
87
+ use_cache=False if training_args.gradient_checkpointing else True,
88
+ )
89
+ quantization_config = get_quantization_config(model_args)
90
+ if quantization_config is not None:
91
+ # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
92
+ model_kwargs["device_map"] = get_kbit_device_map()
93
+ model_kwargs["quantization_config"] = quantization_config
94
+
95
+ model = AutoModelForCausalLM.from_pretrained(
96
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
97
+ )
98
+
99
+ if training_args.reward_model_path is not None:
100
+ reward_model = AutoModelForSequenceClassification.from_pretrained(
101
+ training_args.reward_model_path,
102
+ num_labels=1,
103
+ trust_remote_code=model_args.trust_remote_code,
104
+ **model_kwargs,
105
+ )
106
+ reward_tokenizer = AutoTokenizer.from_pretrained(
107
+ training_args.reward_model_path,
108
+ trust_remote_code=model_args.trust_remote_code,
109
+ truncation=True,
110
+ truncation_side="left", # since we judge the completion, truncating left is more appropriate
111
+ )
112
+ if reward_tokenizer.pad_token_id is None:
113
+ reward_tokenizer.pad_token = reward_tokenizer.eos_token
114
+ else:
115
+ reward_model = None
116
+ reward_tokenizer = None
117
+
118
+ if training_args.judge is not None:
119
+ judge_cls = JUDGES[training_args.judge]
120
+ judge = judge_cls()
121
+ else:
122
+ judge = None
123
+
124
+ tokenizer = AutoTokenizer.from_pretrained(
125
+ model_args.model_name_or_path,
126
+ padding_side="left",
127
+ trust_remote_code=model_args.trust_remote_code,
128
+ **model_kwargs,
129
+ )
130
+ if tokenizer.pad_token_id is None:
131
+ tokenizer.pad_token = tokenizer.eos_token
132
+
133
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
134
+
135
+ trainer = OnlineDPOTrainer(
136
+ model=model,
137
+ reward_funcs=reward_model,
138
+ judge=judge,
139
+ args=training_args,
140
+ train_dataset=dataset[script_args.dataset_train_split],
141
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
142
+ processing_class=tokenizer,
143
+ reward_processing_classes=reward_tokenizer,
144
+ peft_config=get_peft_config(model_args),
145
+ )
146
+
147
+ if training_args.eval_strategy != "no":
148
+ generation_config = GenerationConfig(
149
+ max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
150
+ )
151
+ completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
152
+ trainer.add_callback(completions_callback)
153
+
154
+ trainer.train()
155
+
156
+ # Save and push to hub
157
+ trainer.save_model(training_args.output_dir)
158
+ if training_args.push_to_hub:
159
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/examples/scripts/openenv/browsergym_llm.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl[vllm]",
18
+ # "peft",
19
+ # "trackio",
20
+ # "kernels",
21
+ # "openenv-browsergym @ git+https://huggingface.co/spaces/openenv/browsergym_env",
22
+ # ]
23
+ # ///
24
+
25
+ """
26
+ Simple script to run GRPO training with OpenEnv's BrowserGym environment and vLLM for LLMs.
27
+
28
+ This script is optimized for text-only Language Models (LLMs). It uses the accessibility
29
+ tree text from BrowserGym, making it memory-efficient.
30
+
31
+ The environment runs on a Hugging Face Space by default.
32
+
33
+ Setup (Option A - Install from HF Space, recommended):
34
+
35
+ ```sh
36
+ uv pip install git+https://huggingface.co/spaces/openenv/browsergym_env
37
+ ```
38
+
39
+ Setup (Option B - Clone OpenEnv repo, for development):
40
+
41
+ ```sh
42
+ git clone https://github.com/meta-pytorch/OpenEnv.git
43
+ cd OpenEnv/envs/browsergym_env
44
+ uv pip install -e .
45
+ ```
46
+
47
+ # Option 1: HF Spaces + Colocated vLLM (1 GPU required)
48
+ ```sh
49
+ python examples/scripts/openenv/browsergym_llm.py --vllm-mode colocate
50
+ ```
51
+
52
+ # Option 2: HF Spaces + Separate vLLM server (2 GPUs required)
53
+
54
+ # Spin up vLLM server (Terminal 1)
55
+ ```sh
56
+ CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-0.6B --host 0.0.0.0 --port 8001
57
+ ```
58
+
59
+ # Run training (Terminal 2)
60
+ ```sh
61
+ CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/browsergym_llm.py --vllm-mode server --vllm-server-url http://localhost:8001
62
+ ```
63
+ """
64
+
65
+ from __future__ import annotations
66
+
67
+ import argparse
68
+ from datetime import datetime
69
+ from pathlib import Path
70
+
71
+ from browsergym_env import BrowserGymAction, BrowserGymEnv
72
+ from datasets import Dataset
73
+ from transformers import AutoTokenizer
74
+
75
+ from trl import GRPOConfig, GRPOTrainer
76
+ from trl.experimental.openenv import generate_rollout_completions
77
+
78
+
79
+ def parse_args() -> argparse.Namespace:
80
+ parser = argparse.ArgumentParser(description="Run GRPO training for BrowserGym MiniWoB using OpenEnv environment.")
81
+ parser.add_argument(
82
+ "--model-id",
83
+ default="Qwen/Qwen3-0.6B",
84
+ help="Model identifier passed to GRPOTrainer for fine-tuning.",
85
+ )
86
+ parser.add_argument(
87
+ "--space-url",
88
+ type=str,
89
+ default="https://openenv-browsergym-env.hf.space",
90
+ help="URL for the Hugging Face Space running the BrowserGym environment.",
91
+ )
92
+ parser.add_argument(
93
+ "--benchmark",
94
+ default="miniwob",
95
+ help="BrowserGym benchmark to use (miniwob, webarena, etc.).",
96
+ )
97
+ parser.add_argument(
98
+ "--task-name",
99
+ default="click-test",
100
+ help="Specific task within the benchmark (e.g., click-test, click-button).",
101
+ )
102
+ parser.add_argument(
103
+ "--dataset-prompt",
104
+ default="Complete the web task successfully.",
105
+ help="Prompt text used to seed the training dataset.",
106
+ )
107
+ parser.add_argument(
108
+ "--dataset-size",
109
+ type=int,
110
+ default=1000,
111
+ help="Number of entries to include in the synthetic training dataset.",
112
+ )
113
+ parser.add_argument(
114
+ "--max-steps",
115
+ type=int,
116
+ default=10,
117
+ help="Maximum number of steps per episode.",
118
+ )
119
+ parser.add_argument(
120
+ "--max-new-tokens",
121
+ type=int,
122
+ default=32,
123
+ help="Maximum number of new tokens to request from vLLM for each action.",
124
+ )
125
+ parser.add_argument(
126
+ "--temperature",
127
+ type=float,
128
+ default=0.7,
129
+ help="Sampling temperature used during rollout generation.",
130
+ )
131
+ parser.add_argument(
132
+ "--top-k",
133
+ type=int,
134
+ default=50,
135
+ help="Top-k sampling parameter forwarded to vLLM.",
136
+ )
137
+ parser.add_argument(
138
+ "--top-p",
139
+ type=float,
140
+ default=None,
141
+ help="Optional top-p sampling parameter forwarded to vLLM.",
142
+ )
143
+ parser.add_argument(
144
+ "--learning-rate",
145
+ type=float,
146
+ default=5e-6,
147
+ help="Learning rate for GRPO training.",
148
+ )
149
+ parser.add_argument(
150
+ "--weight-decay",
151
+ type=float,
152
+ default=0.0,
153
+ help="Weight decay applied during optimization.",
154
+ )
155
+ parser.add_argument(
156
+ "--gradient-accumulation-steps",
157
+ type=int,
158
+ default=32,
159
+ help="Gradient accumulation steps for GRPO training.",
160
+ )
161
+ parser.add_argument(
162
+ "--warmup-steps",
163
+ type=int,
164
+ default=10,
165
+ help="Warmup steps for the scheduler.",
166
+ )
167
+ parser.add_argument(
168
+ "--per-device-batch-size",
169
+ type=int,
170
+ default=1,
171
+ help="Per-device train batch size.",
172
+ )
173
+ parser.add_argument(
174
+ "--num-generations",
175
+ type=int,
176
+ default=4,
177
+ help="Number of rollout generations per dataset prompt.",
178
+ )
179
+ parser.add_argument(
180
+ "--num-epochs",
181
+ type=int,
182
+ default=1,
183
+ help="Number of training epochs.",
184
+ )
185
+ parser.add_argument(
186
+ "--save-interval",
187
+ type=int,
188
+ default=50,
189
+ help="Interval (in steps) between checkpoint saves.",
190
+ )
191
+ parser.add_argument(
192
+ "--save-total-limit",
193
+ type=int,
194
+ default=None,
195
+ help="Maximum number of checkpoints to keep.",
196
+ )
197
+ parser.add_argument(
198
+ "--output-dir",
199
+ default=None,
200
+ help="Directory where training outputs and checkpoints are stored.",
201
+ )
202
+ parser.add_argument(
203
+ "--run-name",
204
+ default=None,
205
+ help="Optional run name for logging systems.",
206
+ )
207
+ parser.add_argument(
208
+ "--project",
209
+ default=None,
210
+ help="Optional project identifier for logging systems.",
211
+ )
212
+ parser.add_argument(
213
+ "--vllm-mode",
214
+ choices=("colocate", "server"),
215
+ default="colocate",
216
+ help="vLLM execution mode: 'colocate' or 'server'.",
217
+ )
218
+ parser.add_argument(
219
+ "--vllm-server-url",
220
+ type=str,
221
+ default="http://localhost:8001",
222
+ help="URL for the vLLM server (only used when --vllm-mode=server).",
223
+ )
224
+ parser.add_argument(
225
+ "--logging-steps",
226
+ type=int,
227
+ default=1,
228
+ help="Frequency of logging steps for GRPO training.",
229
+ )
230
+ parser.add_argument(
231
+ "--debug",
232
+ action="store_true",
233
+ default=False,
234
+ help="Enable verbose debugging output during rollouts.",
235
+ )
236
+ return parser.parse_args()
237
+
238
+
239
+ def sanitize_name(name: str) -> str:
240
+ return name.replace("/", "-")
241
+
242
+
243
+ # ---------------------------------------------------------------------------
244
+ # System Prompt
245
+ # ---------------------------------------------------------------------------
246
+
247
+ SYSTEM_PROMPT = """You control a web browser through BrowserGym actions.
248
+ You must complete the given web task by interacting with the page.
249
+
250
+ Available actions:
251
+ - noop() - Do nothing
252
+ - click(bid) - Click element with BrowserGym ID (the number in brackets)
253
+ - fill(bid, text) - Fill input field with text
254
+ - send_keys(text) - Send keyboard input
255
+ - scroll(direction) - Scroll up/down
256
+
257
+ The page structure shows elements as: [bid] element_type 'element_text'
258
+ For example: [13] button 'Click Me!' means bid='13'
259
+
260
+ Reply with exactly ONE action on a single line, e.g.:
261
+ click('13')
262
+ fill('42', 'hello world')
263
+ noop()
264
+
265
+ Do not include explanations or multiple actions."""
266
+
267
+
268
+ # ---------------------------------------------------------------------------
269
+ # Helpers
270
+ # ---------------------------------------------------------------------------
271
+
272
+
273
+ def make_user_prompt(goal: str, step_num: int, axtree: str, error: str = "") -> str:
274
+ """Create user prompt from observation."""
275
+ prompt_parts = [f"Step {step_num + 1}"]
276
+
277
+ if goal:
278
+ prompt_parts.append(f"Goal: {goal}")
279
+
280
+ if error:
281
+ prompt_parts.append(f"Previous action error: {error}")
282
+
283
+ # Include accessibility tree (truncated for context)
284
+ if axtree:
285
+ max_len = 2000
286
+ axtree_truncated = axtree[:max_len] + "..." if len(axtree) > max_len else axtree
287
+ prompt_parts.append(f"Page structure:\n{axtree_truncated}")
288
+
289
+ prompt_parts.append("What action do you take?")
290
+
291
+ return "\n\n".join(prompt_parts)
292
+
293
+
294
+ def parse_action(response_text: str) -> str:
295
+ """Parse BrowserGym action from model response."""
296
+ # Extract first line that looks like an action
297
+ for line in response_text.strip().split("\n"):
298
+ line = line.strip()
299
+ if "(" in line and ")" in line:
300
+ return line
301
+
302
+ # Fallback to noop if no valid action found
303
+ return "noop()"
304
+
305
+
306
+ def rollout_once(
307
+ trainer: GRPOTrainer,
308
+ env: BrowserGymEnv,
309
+ tokenizer: AutoTokenizer,
310
+ dataset_prompt: str,
311
+ max_steps: int,
312
+ debug: bool = False,
313
+ ) -> dict[str, list]:
314
+ """Run one episode and collect training data (text-only, no screenshots)."""
315
+ result = env.reset()
316
+ observation = result.observation
317
+
318
+ prompt_ids: list[int] = []
319
+ completion_ids: list[int] = []
320
+ logprobs: list[float] = []
321
+ step_rewards: list[float] = []
322
+ completion_rewards: list[float] = []
323
+
324
+ for step_num in range(max_steps):
325
+ if result.done:
326
+ break
327
+
328
+ # Create prompt from observation (text-only using accessibility tree)
329
+ goal = observation.goal or dataset_prompt
330
+ axtree = observation.axtree_txt or ""
331
+ error = observation.error if observation.last_action_error else ""
332
+
333
+ user_prompt = make_user_prompt(goal, step_num, axtree, error)
334
+ messages = [
335
+ {"role": "system", "content": SYSTEM_PROMPT},
336
+ {"role": "user", "content": user_prompt},
337
+ ]
338
+ prompt_text = tokenizer.apply_chat_template(
339
+ messages,
340
+ add_generation_prompt=True,
341
+ tokenize=False,
342
+ )
343
+
344
+ # Generate action with vLLM
345
+ rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]
346
+ prompt_ids.extend(rollout_outputs["prompt_ids"])
347
+ completion_ids.extend(rollout_outputs["completion_ids"])
348
+ logprobs.extend(rollout_outputs["logprobs"])
349
+
350
+ completion_text = rollout_outputs.get("text") or tokenizer.decode(
351
+ rollout_outputs["completion_ids"], skip_special_tokens=True
352
+ )
353
+
354
+ # Parse and execute action
355
+ action_str = parse_action(completion_text)
356
+
357
+ if debug:
358
+ print(f"Step {step_num + 1}: {action_str}")
359
+
360
+ # Take action in environment
361
+ result = env.step(BrowserGymAction(action_str=action_str))
362
+ observation = result.observation
363
+
364
+ # Track rewards
365
+ step_reward = float(result.reward or 0.0)
366
+ step_rewards.append(step_reward)
367
+
368
+ # Reward shaping: success is most important
369
+ if result.done and step_reward > 0:
370
+ completion_rewards.append(1.0) # Task completed successfully
371
+ elif result.done and step_reward == 0:
372
+ completion_rewards.append(0.0) # Task failed
373
+ else:
374
+ completion_rewards.append(step_reward) # Intermediate reward
375
+
376
+ # Final reward is based on task completion
377
+ final_reward = completion_rewards[-1] if completion_rewards else 0.0
378
+
379
+ return {
380
+ "prompt_ids": prompt_ids,
381
+ "completion_ids": completion_ids,
382
+ "logprobs": logprobs,
383
+ "step_rewards": step_rewards,
384
+ "completion_reward": final_reward,
385
+ }
386
+
387
+
388
+ # ---------------------------------------------------------------------------
389
+ # Rewards
390
+ # ---------------------------------------------------------------------------
391
+
392
+
393
+ def reward_completion(completions: list[str], **kwargs) -> list[float]:
394
+ """Reward for task completion."""
395
+ rewards = kwargs.get("completion_reward") if kwargs else None
396
+ if rewards is None:
397
+ return [0.0 for _ in completions]
398
+ return [float(r) for r in rewards]
399
+
400
+
401
+ # ---------------------------------------------------------------------------
402
+ # Main entrypoint
403
+ # ---------------------------------------------------------------------------
404
+
405
+
406
+ def main() -> None:
407
+ args = parse_args()
408
+
409
+ # Connect to BrowserGym environment via Hugging Face Space
410
+ client = BrowserGymEnv(base_url=args.space_url)
411
+ print(f"🌍 Using Hugging Face Space environment at: {args.space_url}")
412
+
413
+ dataset = Dataset.from_dict({"prompt": [args.dataset_prompt] * args.dataset_size})
414
+
415
+ timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
416
+ default_output_dir = Path("outputs") / f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}"
417
+ output_dir = Path(args.output_dir or default_output_dir)
418
+
419
+ grpo_config = GRPOConfig(
420
+ use_vllm=True,
421
+ vllm_mode=args.vllm_mode,
422
+ vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
423
+ vllm_gpu_memory_utilization=0.4,
424
+ output_dir=str(output_dir),
425
+ num_train_epochs=args.num_epochs,
426
+ learning_rate=args.learning_rate,
427
+ weight_decay=args.weight_decay,
428
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
429
+ per_device_train_batch_size=args.per_device_batch_size,
430
+ warmup_steps=args.warmup_steps,
431
+ num_generations=args.num_generations,
432
+ generation_batch_size=args.num_generations, # Must be divisible by num_generations
433
+ max_completion_length=args.max_new_tokens,
434
+ logging_steps=args.logging_steps,
435
+ report_to="trackio",
436
+ trackio_space_id=f"browsergym-grpo-{sanitize_name(args.model_id)}-{timestamp}",
437
+ save_strategy="steps",
438
+ save_steps=args.save_interval,
439
+ save_total_limit=args.save_total_limit,
440
+ temperature=args.temperature,
441
+ top_k=args.top_k,
442
+ top_p=args.top_p,
443
+ )
444
+
445
+ grpo_config.run_name = args.run_name or f"run-{timestamp}"
446
+ grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}"
447
+
448
+ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
449
+ episode_prompt_ids: list[list[int]] = []
450
+ episode_completion_ids: list[list[int]] = []
451
+ episode_logprobs: list[list[float]] = []
452
+ completion_rewards: list[float] = []
453
+
454
+ if args.debug:
455
+ print(f"\n[DEBUG] rollout_func called with {len(prompts)} prompts (LLM mode, text-only)")
456
+
457
+ for i, prompt_text in enumerate(prompts):
458
+ if args.debug:
459
+ print(f"[DEBUG] Processing prompt {i + 1}/{len(prompts)}")
460
+ episode = rollout_once(
461
+ trainer=trainer,
462
+ env=client,
463
+ tokenizer=trainer.processing_class,
464
+ dataset_prompt=prompt_text,
465
+ max_steps=args.max_steps,
466
+ debug=args.debug,
467
+ )
468
+ episode_prompt_ids.append(episode["prompt_ids"])
469
+ episode_completion_ids.append(episode["completion_ids"])
470
+ episode_logprobs.append(episode["logprobs"])
471
+ completion_rewards.append(episode["completion_reward"])
472
+
473
+ return {
474
+ "prompt_ids": episode_prompt_ids,
475
+ "completion_ids": episode_completion_ids,
476
+ "logprobs": episode_logprobs,
477
+ "completion_reward": completion_rewards,
478
+ }
479
+
480
+ trainer = GRPOTrainer(
481
+ model=args.model_id,
482
+ reward_funcs=[reward_completion],
483
+ train_dataset=dataset,
484
+ args=grpo_config,
485
+ rollout_func=rollout_func,
486
+ )
487
+
488
+ print("=" * 80)
489
+ print("Starting GRPO training with BrowserGym environment (LLM mode)")
490
+ print(f"Benchmark: {args.benchmark}")
491
+ print(f"Task: {args.task_name}")
492
+ print(f"Model: {args.model_id}")
493
+ print("Mode: LLM (text-only, using accessibility tree)")
494
+ print(f"Using {args.num_generations} rollouts per dataset prompt")
495
+ print(f"Output directory: {output_dir}")
496
+ print("=" * 80)
497
+
498
+ try:
499
+ trainer.train()
500
+ print("\nTraining completed successfully!")
501
+ finally:
502
+ client.close()
503
+
504
+
505
+ if __name__ == "__main__":
506
+ main()
ICL/RL/trl_source/examples/scripts/openenv/echo.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl[vllm]",
18
+ # "peft",
19
+ # "trackio",
20
+ # "kernels",
21
+ # "openenv-echo-env @ git+https://huggingface.co/spaces/openenv/echo_env",
22
+ # ]
23
+ # ///
24
+
25
+
26
+ """
27
+ Simple script to run GRPO training with OpenEnv's Echo environment and vLLM. The reward function encourages
28
+ longer completions.
29
+
30
+ Setup (Option A - Install from HF Space, recommended):
31
+
32
+ ```sh
33
+ uv pip install git+https://huggingface.co/spaces/openenv/echo_env
34
+ ```
35
+
36
+ Setup (Option B - Clone OpenEnv repo, for development):
37
+
38
+ ```sh
39
+ git clone https://github.com/meta-pytorch/OpenEnv.git
40
+ cd OpenEnv/envs/echo_env
41
+ uv pip install -e .
42
+ ```
43
+
44
+ # Option 1: HF Spaces + Colocated vLLM (1 GPU required)
45
+ ```sh
46
+ python examples/scripts/openenv/echo.py --env-mode space --env-host https://openenv-echo-env.hf.space --vllm-mode colocate
47
+ ```
48
+
49
+ # Option 2: HF Spaces + Separate vLLM server (2 GPUs required)
50
+
51
+ # Spin up vLLM server (Terminal 1)
52
+ ```sh
53
+ CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000
54
+ ```
55
+
56
+ # Run training (Terminal 2)
57
+ ```sh
58
+ CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py --env-mode space --env-host https://openenv-echo-env.hf.space --vllm-mode server --vllm-server-url http://localhost:8000
59
+ ```
60
+
61
+ # Option 3: Local + Colocated vLLM (1 GPU required)
62
+
63
+ # Start the environment only if using --env-mode docker-local
64
+ ```sh
65
+ docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest
66
+ ```
67
+
68
+ ```sh
69
+ python examples/scripts/openenv/echo.py --env-mode docker-local --vllm-mode colocate
70
+ ```
71
+ """
72
+
73
+ # ruff: noqa: T201
74
+ import argparse
75
+ import os
76
+ import subprocess
77
+ import sys
78
+ import time
79
+ from pathlib import Path
80
+
81
+ import requests
82
+ from datasets import load_dataset
83
+ from echo_env import EchoEnv
84
+ from echo_env.models import EchoAction
85
+
86
+ from trl import GRPOConfig, GRPOTrainer, RichProgressCallback
87
+ from trl.experimental.openenv import generate_rollout_completions
88
+
89
+
90
+ def parse_args():
91
+ parser = argparse.ArgumentParser(description="Run GRPO training with Echo environment and vLLM.")
92
+
93
+ parser.add_argument("--env-host", type=str, default="0.0.0.0", help="Host for the Echo environment.")
94
+ parser.add_argument("--env-port", type=int, default=8001, help="Port for the Echo environment.")
95
+ parser.add_argument(
96
+ "--env-mode",
97
+ choices=["local", "docker-local", "docker-image", "docker-hub", "space"],
98
+ default="docker-image",
99
+ help="Where to run the Echo environment: 'local' to launch it, 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.",
100
+ )
101
+ parser.add_argument(
102
+ "--model",
103
+ type=str,
104
+ default="Qwen/Qwen2.5-0.5B-Instruct",
105
+ help="Model to use for training.",
106
+ )
107
+ parser.add_argument(
108
+ "--dataset",
109
+ type=str,
110
+ default="trl-lib/ultrafeedback-prompt",
111
+ help="Dataset to use for training.",
112
+ )
113
+ parser.add_argument(
114
+ "--env-image", type=str, default="echo-env:latest", help="Docker image for the Echo environment."
115
+ )
116
+ parser.add_argument(
117
+ "--vllm-mode",
118
+ choices=["colocate", "server"],
119
+ default="colocate",
120
+ help="vLLM execution mode: 'colocate' or 'server'.",
121
+ )
122
+ parser.add_argument(
123
+ "--vllm-server-url",
124
+ type=str,
125
+ default="http://localhost:8000",
126
+ help="URL for the vLLM server (only used when --vllm-mode=server).",
127
+ )
128
+
129
+ return parser.parse_args()
130
+
131
+
132
+ def start_env_server(env_host: str, env_port: int):
133
+ """Launch the Echo environment server locally."""
134
+ env_url = f"http://{env_host}:{env_port}"
135
+ print(f"⚡ Starting FastAPI server for Echo Environment on {env_url}...")
136
+
137
+ work_dir = str(Path.cwd().parent.absolute())
138
+ process = subprocess.Popen(
139
+ [sys.executable, "-m", "uvicorn", "echo_env.server.app:app", "--host", env_host, "--port", str(env_port)],
140
+ env={**os.environ, "PYTHONPATH": f"{work_dir}/src"},
141
+ stdout=subprocess.PIPE,
142
+ stderr=subprocess.PIPE,
143
+ text=True,
144
+ cwd=work_dir,
145
+ )
146
+
147
+ print("⏳ Waiting for server to start...")
148
+ time.sleep(5)
149
+
150
+ try:
151
+ requests.get(f"{env_url}/health", timeout=2)
152
+ print("\n✅ Echo Environment server is running!")
153
+ except Exception as e:
154
+ print(f"\n��� Server failed to start: {e}")
155
+ if process.stderr:
156
+ print(process.stderr.read())
157
+ raise
158
+
159
+ return process
160
+
161
+
162
+ def reward_from_env(completions, **kwargs):
163
+ """Extract environment rewards for training."""
164
+ env_rewards = kwargs.get("env_reward", [])
165
+ return [float(r) for r in env_rewards] if env_rewards else [0.0] * len(completions)
166
+
167
+
168
+ def main():
169
+ args = parse_args()
170
+
171
+ # Select environment mode
172
+ if args.env_mode == "local":
173
+ env_url = f"http://{args.env_host}:{args.env_port}"
174
+ server_process = start_env_server(args.env_host, args.env_port)
175
+ elif args.env_mode == "docker-local":
176
+ env_url = f"http://{args.env_host}:{args.env_port}"
177
+ server_process = None
178
+ print(f"🌍 Using existing Echo Environment (Docker) at: {env_url}")
179
+ elif args.env_mode == "docker-image":
180
+ client = EchoEnv.from_docker_image(args.env_image)
181
+ server_process = None
182
+ print("🌍 Using Echo Environment (Docker) from local Image")
183
+ elif args.env_mode == "docker-hub":
184
+ client = EchoEnv.from_hub(args.env_image)
185
+ server_process = None
186
+ print("🌍 Using existing Echo Environment (Docker) from Hub Image")
187
+ elif args.env_mode == "space":
188
+ env_url = args.env_host
189
+ server_process = None
190
+ print(f"🌍 Using Hugging Face Space environment at: {env_url}")
191
+ else:
192
+ raise ValueError(f"Unknown environment mode: {args.env_mode}")
193
+
194
+ if args.env_mode != "docker-hub" and args.env_mode != "docker-image":
195
+ client = EchoEnv(base_url=env_url)
196
+ dataset = load_dataset(args.dataset, split="train[:1000]")
197
+
198
+ training_args = GRPOConfig(
199
+ output_dir=f"{args.model.split('/')[-1]}-GRPO-Rollout",
200
+ use_vllm=True,
201
+ vllm_mode=args.vllm_mode,
202
+ vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
203
+ logging_steps=1,
204
+ report_to="trackio",
205
+ trackio_space_id=f"{args.model.split('/')[-1]}-GRPO-Rollout",
206
+ num_train_epochs=1,
207
+ max_completion_length=2048,
208
+ gradient_accumulation_steps=4,
209
+ )
210
+
211
+ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
212
+ outputs = generate_rollout_completions(trainer, prompts)
213
+ tokenizer = trainer.processing_class
214
+
215
+ completions_text = [tokenizer.decode(output["completion_ids"], skip_special_tokens=True) for output in outputs]
216
+
217
+ env_result = client.reset()
218
+ env_rewards: list[float] = []
219
+ for message in completions_text:
220
+ env_result = client.step(EchoAction(message=message))
221
+ env_rewards.append(env_result.reward)
222
+
223
+ return {
224
+ "prompt_ids": [output["prompt_ids"] for output in outputs],
225
+ "completion_ids": [output["completion_ids"] for output in outputs],
226
+ "logprobs": [output["logprobs"] for output in outputs],
227
+ "env_reward": env_rewards,
228
+ }
229
+
230
+ trainer = GRPOTrainer(
231
+ model=args.model,
232
+ reward_funcs=reward_from_env,
233
+ args=training_args,
234
+ train_dataset=dataset,
235
+ rollout_func=rollout_func,
236
+ callbacks=[RichProgressCallback()],
237
+ )
238
+
239
+ trainer.train()
240
+ time.sleep(5)
241
+
242
+ if server_process:
243
+ print("🛑 Terminating Echo Environment server...")
244
+ server_process.terminate()
245
+
246
+
247
+ if __name__ == "__main__":
248
+ main()
ICL/RL/trl_source/examples/scripts/openenv/wordle.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl[vllm]",
18
+ # "peft",
19
+ # "trackio",
20
+ # "kernels",
21
+ # "openenv-textarena @ git+https://huggingface.co/spaces/sergiopaniego/wordle",
22
+ # ]
23
+ # ///
24
+
25
+
26
+ """
27
+ Simple script to run GRPO training with OpenEnv's Wordle environment and vLLM.
28
+
29
+ Setup (Option A - Install from HF Space, recommended):
30
+
31
+ ```sh
32
+ uv pip install git+https://huggingface.co/spaces/sergiopaniego/wordle
33
+ ```
34
+
35
+ # Option 1: HF Spaces + Colocated vLLM (1 GPU required)
36
+ ```sh
37
+ python examples/scripts/openenv/wordle.py --vllm-mode colocate
38
+ ```
39
+
40
+ # Option 2: HF Spaces + Separate vLLM server (2 GPUs required)
41
+
42
+ # Spin up vLLM server (Terminal 1)
43
+ ```sh
44
+ CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000
45
+ ```
46
+
47
+ # Run training (Terminal 2)
48
+ ```sh
49
+ CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py --vllm-mode server --vllm-server-url http://localhost:8000
50
+ ```
51
+
52
+ # Option 3: Local Environment + Colocated vLLM (1 GPU required)
53
+
54
+ To run the Wordle environment locally, you have several options:
55
+
56
+ ## Option 3a: Using Docker Image (Recommended)
57
+
58
+ First, build the Docker image from the textarena_env directory:
59
+ ```sh
60
+ cd 3rd_party/OpenEnv/envs/textarena_env
61
+ docker build -t textarena-env:latest -f server/Dockerfile .
62
+ ```
63
+
64
+ Then run the environment server:
65
+ ```sh
66
+ docker run -d -p 8001:8001 textarena-env:latest
67
+ ```
68
+
69
+ Finally, run training pointing to local server:
70
+ ```sh
71
+ python examples/scripts/openenv/wordle.py --vllm-mode colocate --env-url http://localhost:8001
72
+ ```
73
+
74
+ ## Option 3b: Running Server Directly
75
+
76
+ From the textarena_env directory:
77
+ ```sh
78
+ cd 3rd_party/OpenEnv/envs/textarena_env
79
+ uv venv && source .venv/bin/activate
80
+ uv pip install -e .
81
+ python -m uvicorn server.app:app --reload --port 8001
82
+ ```
83
+
84
+ Then in another terminal, run training:
85
+ ```sh
86
+ python examples/scripts/openenv/wordle.py --vllm-mode colocate --env-url http://localhost:8001
87
+ ```
88
+
89
+ ## Option 3c: Using Pre-built HF Space Image
90
+
91
+ ```sh
92
+ docker run -d -p 8001:8001 registry.hf.space/burtenshaw-wordle:latest
93
+ python examples/scripts/openenv/wordle.py --vllm-mode colocate --env-url http://localhost:8001
94
+ ```
95
+ """
96
+
97
+ import argparse
98
+ import re
99
+ import sys
100
+ from collections.abc import Iterable
101
+ from datetime import datetime
102
+ from pathlib import Path
103
+
104
+ from datasets import Dataset
105
+ from transformers import AutoTokenizer
106
+
107
+ from trl import GRPOConfig, GRPOTrainer
108
+ from trl.experimental.openenv import generate_rollout_completions
109
+
110
+
111
+ # Ensure src/ is on the path
112
+ sys.path.insert(0, str(Path(__file__).parent / "src"))
113
+
114
+ from textarena_env import TextArenaAction, TextArenaEnv
115
+ from textarena_env.models import TextArenaMessage
116
+ from textarena_env.rewards import extract_feedback_counts, extract_guess, extract_wordle_feedback
117
+
118
+
119
+ def parse_args() -> argparse.Namespace:
120
+ parser = argparse.ArgumentParser(
121
+ description="Run GRPO training for Wordle using the TextArena OpenEnv environment."
122
+ )
123
+ parser.add_argument(
124
+ "--tokenizer-id",
125
+ default="Qwen/Qwen3-1.7B",
126
+ help="Model identifier used to load the tokenizer.",
127
+ )
128
+ parser.add_argument(
129
+ "--model-id",
130
+ default="Qwen/Qwen3-1.7B",
131
+ help="Model identifier passed to GRPOTrainer for fine-tuning.",
132
+ )
133
+ parser.add_argument(
134
+ "--env-url", type=str, default="https://sergiopaniego-wordle.hf.space", help="URL for the environment server."
135
+ )
136
+ parser.add_argument(
137
+ "--system-prompt-path",
138
+ default="wordle_prompt.txt",
139
+ help="Path to the file containing the system prompt.",
140
+ )
141
+ parser.add_argument(
142
+ "--dataset-prompt",
143
+ default="Play Wordle like an expert.",
144
+ help="Prompt text used to seed the training dataset.",
145
+ )
146
+ parser.add_argument(
147
+ "--dataset-size",
148
+ type=int,
149
+ default=3000,
150
+ help="Number of entries to include in the synthetic training dataset.",
151
+ )
152
+ parser.add_argument(
153
+ "--max-turns",
154
+ type=int,
155
+ default=6,
156
+ help="Maximum number of turns to play in the Wordle environment per episode.",
157
+ )
158
+ parser.add_argument(
159
+ "--max-new-tokens",
160
+ type=int,
161
+ default=8,
162
+ help="Maximum number of new tokens to request from vLLM for each guess.",
163
+ )
164
+ parser.add_argument(
165
+ "--temperature",
166
+ type=float,
167
+ default=0.8,
168
+ help="Sampling temperature used during rollout generation.",
169
+ )
170
+ parser.add_argument(
171
+ "--top-k",
172
+ type=int,
173
+ default=10,
174
+ help="Top-k sampling parameter forwarded to vLLM.",
175
+ )
176
+ parser.add_argument(
177
+ "--top-p",
178
+ type=float,
179
+ default=None,
180
+ help="Optional top-p sampling parameter forwarded to vLLM.",
181
+ )
182
+ parser.add_argument(
183
+ "--learning-rate",
184
+ type=float,
185
+ default=1e-6,
186
+ help="Learning rate for GRPO training.",
187
+ )
188
+ parser.add_argument(
189
+ "--weight-decay",
190
+ type=float,
191
+ default=0.0,
192
+ help="Weight decay applied during optimization.",
193
+ )
194
+ parser.add_argument(
195
+ "--gradient-accumulation-steps",
196
+ type=int,
197
+ default=64,
198
+ help="Gradient accumulation steps for GRPO training.",
199
+ )
200
+ parser.add_argument(
201
+ "--warmup-steps",
202
+ type=int,
203
+ default=10,
204
+ help="Warmup steps for the scheduler.",
205
+ )
206
+ parser.add_argument(
207
+ "--per-device-batch-size",
208
+ type=int,
209
+ default=1,
210
+ help="Per-device train batch size.",
211
+ )
212
+ parser.add_argument(
213
+ "--num-generations",
214
+ type=int,
215
+ default=4,
216
+ help="Number of rollout generations per dataset prompt.",
217
+ )
218
+ parser.add_argument(
219
+ "--num-epochs",
220
+ type=int,
221
+ default=1,
222
+ help="Number of training epochs.",
223
+ )
224
+ parser.add_argument(
225
+ "--save-interval",
226
+ type=int,
227
+ default=10,
228
+ help="Interval (in steps) between checkpoint saves.",
229
+ )
230
+ parser.add_argument(
231
+ "--save-total-limit",
232
+ type=int,
233
+ default=None,
234
+ help="Maximum number of checkpoints to keep.",
235
+ )
236
+ parser.add_argument(
237
+ "--output-dir",
238
+ default=None,
239
+ help="Directory where training outputs and checkpoints are stored.",
240
+ )
241
+ parser.add_argument(
242
+ "--run-name",
243
+ default=None,
244
+ help="Optional run name for logging systems.",
245
+ )
246
+ parser.add_argument(
247
+ "--project",
248
+ default=None,
249
+ help="Optional project identifier for logging systems.",
250
+ )
251
+ parser.add_argument(
252
+ "--trackio-space-id",
253
+ default="Wordle-GRPO",
254
+ help="TrackIO space identifier.",
255
+ )
256
+ parser.add_argument(
257
+ "--vllm-mode",
258
+ choices=("colocate", "server"),
259
+ default="colocate",
260
+ help="vLLM execution mode: 'colocate' or 'server'.",
261
+ )
262
+ parser.add_argument(
263
+ "--vllm-server-url",
264
+ type=str,
265
+ default="http://localhost:8000",
266
+ help="URL for the vLLM server (only used when --vllm-mode=server).",
267
+ )
268
+ parser.add_argument(
269
+ "--logging-steps",
270
+ type=int,
271
+ default=1,
272
+ help="Frequency of logging steps for GRPO training.",
273
+ )
274
+ return parser.parse_args()
275
+
276
+
277
+ def resolve_system_prompt(path: str) -> str:
278
+ prompt_path = Path(path)
279
+ if not prompt_path.is_file():
280
+ prompt_path = Path(__file__).parent / path
281
+ return prompt_path.read_text()
282
+
283
+
284
+ def sanitize_name(name: str) -> str:
285
+ return name.replace("/", "-")
286
+
287
+
288
+ # ---------------------------------------------------------------------------
289
+ # Helpers
290
+ # ---------------------------------------------------------------------------
291
+
292
+
293
+ def format_history(messages: Iterable[TextArenaMessage]) -> str:
294
+ lines: list[str] = []
295
+ for message in messages:
296
+ tag = message.category or "MESSAGE"
297
+ content = message.content.strip()
298
+ if not content:
299
+ continue
300
+ lines.append(f"[{tag}] {content}")
301
+ return "\n".join(lines)
302
+
303
+
304
+ def make_user_prompt(prompt_text: str, messages: Iterable[TextArenaMessage]) -> str:
305
+ history = format_history(messages)
306
+ # Only use messages for conversation history - the prompt is already included as the first message
307
+ history_section = history if history else "[PROMPT] Awaiting first feedback."
308
+ return f"Conversation so far:\n{history_section}\n\nReply with your next guess enclosed in square brackets."
309
+
310
+
311
+ def rollout_once(
312
+ trainer: GRPOTrainer,
313
+ env: TextArenaEnv,
314
+ tokenizer: AutoTokenizer,
315
+ dataset_prompt: str,
316
+ system_prompt: str,
317
+ max_turns: int,
318
+ max_new_tokens: int = 16,
319
+ ) -> dict[str, list]:
320
+ result = env.reset()
321
+ observation = result.observation
322
+
323
+ prompt_ids: list[int] = []
324
+ completion_ids: list[int] = []
325
+ logprobs: list[float] = []
326
+ env_mask: list[int] = [] # 1 for model-generated tokens, 0 for environment tokens
327
+ model_outputs: list[str] = []
328
+ raw_rewards: list[float] = []
329
+ position_scores: list[float] = []
330
+ correct_scores: list[float] = []
331
+ prev_env_output_len: int = 0 # Track length to only add NEW portion each turn
332
+
333
+ accumulated_messages: list[dict[str, str]] = [{"role": "system", "content": system_prompt}]
334
+ # Build initial prompt (only once, at the start)
335
+ # The initial env messages are included in the prompt, not completion
336
+ base_prompt = observation.prompt or dataset_prompt
337
+ initial_user_prompt = make_user_prompt(base_prompt, observation.messages)
338
+ # Track initial env output length so we don't add it again
339
+ initial_env_output = format_history(observation.messages) if observation.messages else ""
340
+ prev_env_output_len = len(initial_env_output)
341
+ initial_messages = accumulated_messages + [{"role": "user", "content": initial_user_prompt}]
342
+ initial_prompt_text = tokenizer.apply_chat_template(
343
+ initial_messages,
344
+ add_generation_prompt=True,
345
+ tokenize=False,
346
+ enable_thinking=False,
347
+ )
348
+ # Tokenize initial prompt once - this is the base prompt for the entire episode.
349
+ # GRPO expects one prompt-completion pair per episode, where:
350
+ # - prompt_ids = the initial/base prompt (what the model sees at episode start)
351
+ # - completion_ids = all model responses + env feedback from all turns concatenated
352
+ # Note: The actual prompts used for generation in each turn are longer (include conversation history),
353
+ # but we only count the initial prompt tokens here.
354
+ initial_prompt_ids = tokenizer.encode(initial_prompt_text, add_special_tokens=False)
355
+ prompt_ids.extend(initial_prompt_ids)
356
+
357
+ for _turn in range(max_turns):
358
+ if result.done:
359
+ break
360
+
361
+ base_prompt = observation.prompt or dataset_prompt
362
+ user_prompt = make_user_prompt(base_prompt, observation.messages)
363
+ messages = accumulated_messages + [{"role": "user", "content": user_prompt}]
364
+ prompt_text = tokenizer.apply_chat_template(
365
+ messages,
366
+ add_generation_prompt=True,
367
+ tokenize=False,
368
+ enable_thinking=False,
369
+ )
370
+
371
+ rollout_outputs = generate_rollout_completions(
372
+ trainer, [prompt_text], generation_overrides={"max_tokens": max_new_tokens}
373
+ )[0]
374
+ # Add model-generated completion tokens and logprobs with newlines for readability
375
+ newline_tokens = tokenizer.encode("\n", add_special_tokens=False)
376
+ completion_ids.extend(newline_tokens) # newline before guess
377
+ logprobs.extend([0.0] * len(newline_tokens))
378
+ env_mask.extend([1] * len(newline_tokens)) # newlines are part of model output format
379
+
380
+ completion_ids.extend(rollout_outputs["completion_ids"])
381
+ logprobs.extend(rollout_outputs["logprobs"])
382
+ env_mask.extend([1] * len(rollout_outputs["completion_ids"])) # model-generated tokens
383
+
384
+ completion_ids.extend(newline_tokens) # newline after guess
385
+ logprobs.extend([0.0] * len(newline_tokens))
386
+ env_mask.extend([1] * len(newline_tokens)) # newlines are part of model output format
387
+ completion_text = rollout_outputs.get("text") or tokenizer.decode(
388
+ rollout_outputs["completion_ids"], skip_special_tokens=True
389
+ )
390
+ guess = extract_guess(completion_text)
391
+ model_outputs.append(completion_text.strip()) # Store raw model output for format reward
392
+
393
+ result = env.step(TextArenaAction(message=guess))
394
+
395
+ raw_rewards.append(float(result.reward or 0.0))
396
+ observation = result.observation
397
+ correct_score = float(result.reward or 0.0)
398
+ feedback = extract_wordle_feedback(observation)
399
+
400
+ full_env_output = format_history(observation.messages) if observation.messages else ""
401
+ new_env_output = full_env_output[prev_env_output_len:].lstrip("\n")
402
+ prev_env_output_len = len(full_env_output)
403
+
404
+ if new_env_output:
405
+ env_output_tokens = tokenizer.encode(new_env_output, add_special_tokens=False)
406
+ completion_ids.extend(env_output_tokens) # Add to completion_ids
407
+ logprobs.extend([0.0] * len(env_output_tokens)) # Placeholder (ignored via env_mask=0)
408
+ env_mask.extend([0] * len(env_output_tokens)) # Environment tokens - mask out from loss
409
+ completion_with_env = completion_text + "\n" + new_env_output
410
+ else:
411
+ completion_with_env = completion_text
412
+
413
+ accumulated_messages.append({"role": "user", "content": user_prompt})
414
+ accumulated_messages.append({"role": "assistant", "content": completion_with_env})
415
+
416
+ if not feedback:
417
+ position_score = 0.0
418
+ else:
419
+ green_count, yellow_count = extract_feedback_counts(feedback)
420
+ position_score = (green_count + 0.5 * yellow_count) / 5.0
421
+
422
+ position_scores.append(position_score)
423
+ correct_scores.append(correct_score)
424
+
425
+ # Use the final correct reward (win/lose is binary at end)
426
+ correct_reward_value = correct_scores[-1] if correct_scores else (raw_rewards[-1] if raw_rewards else 0.0)
427
+
428
+ # Position reward as shaping signal:
429
+ # - If model WINS: position_reward = 1.0 (no penalty for winning fast)
430
+ # - If model LOSES: position_reward = last attempt (where it ended up)
431
+ if correct_reward_value >= 1.0:
432
+ final_position_reward = 1.0
433
+ else:
434
+ final_position_reward = position_scores[-1] if position_scores else 0.0
435
+
436
+ return {
437
+ "prompt_ids": prompt_ids,
438
+ "completion_ids": completion_ids,
439
+ "logprobs": logprobs,
440
+ "env_mask": env_mask,
441
+ "raw_rewards": raw_rewards,
442
+ "correct_reward": correct_reward_value,
443
+ "position_reward": final_position_reward,
444
+ "model_outputs": model_outputs,
445
+ }
446
+
447
+
448
+ # ---------------------------------------------------------------------------
449
+ # Rewards
450
+ # ---------------------------------------------------------------------------
451
+
452
+
453
+ def reward_correct(completions: list[str], **kwargs) -> list[float]:
454
+ """Reward from environment (correct answer)."""
455
+ rewards = kwargs.get("correct_reward") if kwargs else None
456
+ if rewards is None:
457
+ return [0.0 for _ in completions]
458
+ return [float(r) for r in rewards]
459
+
460
+
461
+ def reward_position(completions: list[str], **kwargs) -> list[float]:
462
+ """Position reward: green worth 1.0, yellow worth 0.5, normalized by 5."""
463
+ rewards = kwargs.get("position_reward") if kwargs else None
464
+ if rewards is None:
465
+ return [0.0 for _ in completions]
466
+ return [float(r) for r in rewards]
467
+
468
+
469
+ def compute_format_reward(model_outputs: list[str]) -> float:
470
+ """Compute format reward from a list of model outputs (one per turn).
471
+
472
+ Each output should be exactly [5 letters] with optional whitespace.
473
+ Returns proportion of correctly formatted outputs.
474
+ """
475
+ if not model_outputs:
476
+ return 0.0
477
+
478
+ exact_pattern = re.compile(r"^\s*\[[A-Za-z]{5}\]\s*$")
479
+ correct_count = sum(1 for output in model_outputs if exact_pattern.match(output))
480
+
481
+ return correct_count / len(model_outputs)
482
+
483
+
484
+ def reward_format_strict(completions: list[str], **kwargs) -> list[float]:
485
+ """Format reward - pre-computed in rollout_func."""
486
+ rewards = kwargs.get("format_reward") if kwargs else None
487
+ if rewards is None:
488
+ return [0.0 for _ in completions]
489
+ return [float(r) for r in rewards]
490
+
491
+
492
+ # ---------------------------------------------------------------------------
493
+ # Main entrypoint
494
+ # ---------------------------------------------------------------------------
495
+
496
+
497
+ def main() -> None:
498
+ args = parse_args()
499
+
500
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
501
+ tokenizer.pad_token = tokenizer.eos_token
502
+
503
+ client = TextArenaEnv(base_url=args.env_url)
504
+
505
+ system_prompt = resolve_system_prompt(args.system_prompt_path)
506
+
507
+ dataset = Dataset.from_dict({"prompt": [args.dataset_prompt] * args.dataset_size})
508
+
509
+ timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
510
+ default_output_dir = Path("outputs") / f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}"
511
+ output_dir = Path(args.output_dir or default_output_dir)
512
+
513
+ grpo_config = GRPOConfig(
514
+ use_vllm=True,
515
+ vllm_mode=args.vllm_mode,
516
+ vllm_server_base_url=args.vllm_server_url if args.vllm_mode == "server" else None,
517
+ output_dir=str(output_dir),
518
+ num_train_epochs=args.num_epochs,
519
+ learning_rate=args.learning_rate,
520
+ weight_decay=args.weight_decay,
521
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
522
+ per_device_train_batch_size=args.per_device_batch_size,
523
+ warmup_steps=args.warmup_steps,
524
+ num_generations=args.num_generations,
525
+ max_completion_length=1024, # Full episode length, not per-turn
526
+ logging_steps=args.logging_steps,
527
+ log_completions=True,
528
+ report_to="trackio",
529
+ trackio_space_id=f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}",
530
+ save_strategy="steps",
531
+ save_steps=args.save_interval,
532
+ save_total_limit=args.save_total_limit,
533
+ temperature=args.temperature,
534
+ top_k=args.top_k,
535
+ top_p=args.top_p,
536
+ vllm_gpu_memory_utilization=0.25,
537
+ vllm_max_model_length=8192,
538
+ vllm_importance_sampling_mode="token_truncate", # Less aggressive than default sequence_mask
539
+ optim="adamw_torch",
540
+ max_grad_norm=1.0, # Clip gradients to prevent explosion
541
+ )
542
+
543
+ grpo_config.run_name = args.run_name or f"run-{timestamp}"
544
+ grpo_config.project = args.project or f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}"
545
+ grpo_config.trackio_space_id = args.trackio_space_id
546
+
547
+ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
548
+ episode_prompt_ids: list[list[int]] = []
549
+ episode_completion_ids: list[list[int]] = []
550
+ episode_logprobs: list[list[float]] = []
551
+ episode_env_masks: list[list[int]] = []
552
+ correctness_rewards: list[float] = []
553
+ position_rewards: list[float] = []
554
+ format_rewards: list[float] = []
555
+
556
+ for prompt_text in prompts:
557
+ episode = rollout_once(
558
+ trainer=trainer,
559
+ env=client,
560
+ tokenizer=tokenizer,
561
+ dataset_prompt=prompt_text,
562
+ system_prompt=system_prompt,
563
+ max_turns=args.max_turns,
564
+ max_new_tokens=args.max_new_tokens,
565
+ )
566
+ episode_prompt_ids.append(episode["prompt_ids"])
567
+ episode_completion_ids.append(episode["completion_ids"])
568
+ episode_logprobs.append(episode["logprobs"])
569
+ episode_env_masks.append(episode["env_mask"])
570
+ correctness_rewards.append(episode["correct_reward"])
571
+ position_rewards.append(episode["position_reward"])
572
+ format_rewards.append(compute_format_reward(episode["model_outputs"]))
573
+
574
+ return {
575
+ "prompt_ids": episode_prompt_ids,
576
+ "completion_ids": episode_completion_ids,
577
+ "logprobs": episode_logprobs,
578
+ "env_mask": episode_env_masks,
579
+ "correct_reward": correctness_rewards,
580
+ "position_reward": position_rewards,
581
+ "format_reward": format_rewards,
582
+ }
583
+
584
+ trainer = GRPOTrainer(
585
+ model=args.model_id,
586
+ processing_class=tokenizer,
587
+ reward_funcs=[
588
+ reward_correct,
589
+ reward_position,
590
+ reward_format_strict,
591
+ ],
592
+ train_dataset=dataset,
593
+ args=grpo_config,
594
+ rollout_func=rollout_func,
595
+ )
596
+
597
+ print("Starting GRPO training with Wordle environment...")
598
+ print(f"Using {args.num_generations} rollouts per dataset prompt")
599
+
600
+ try:
601
+ trainer.train()
602
+ finally:
603
+ client.close()
604
+
605
+
606
+ if __name__ == "__main__":
607
+ main()
ICL/RL/trl_source/examples/scripts/openenv/wordle_prompt.txt ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are an expert Wordle solver with deep knowledge of English vocabulary, letter frequency patterns, and optimal guessing strategies.
2
+
3
+ ## GAME RULES
4
+
5
+ 1. The target is a 5-letter English word
6
+ 2. You have 6 attempts to guess the correct word
7
+ 3. After each guess, you receive color-coded feedback:
8
+ - GREEN: Letter is correct and in the correct position
9
+ - YELLOW: Letter is in the word but in the wrong position
10
+ - GRAY: Letter is not in the word at all
11
+ 4. All guesses must be valid 5-letter English words
12
+ 5. You cannot reuse a word you've already guessed
13
+
14
+ ## RESPONSE FORMAT
15
+
16
+ Only respond with your next guess in square brackets, e.g., [crane].
17
+
18
+ Format:
19
+ ```
20
+ [guess]
21
+ ```
22
+
23
+
24
+ ## STRATEGIC APPROACH
25
+
26
+ Do not repeat the same guess twice.
27
+
28
+ ### Opening Strategy
29
+ - Start with words rich in common vowels (A, E, I, O, U) and consonants (R, S, T, L, N)
30
+ - Optimal starters: CRANE, SLATE, STARE, AROSE, IRATE
31
+ - Prioritize words that test the most common letters in different positions
32
+
33
+ ### Mid-Game Strategy
34
+ - Use confirmed GREEN letters in their correct positions
35
+ - Place YELLOW letters in different positions than where they appeared
36
+ - Eliminate GRAY letters entirely from consideration
37
+ - If multiple letters are unknown, prioritize common letter combinations (TH, CH, ST, ER, etc.)
38
+ - Consider letter frequency: E is most common, followed by A, R, I, O, T, N, S
39
+
40
+ ### Vowel Placement
41
+ - Most 5-letter words have 2 vowels
42
+ - Common patterns: vowel-consonant-vowel (like CRANE) or consonant-vowel-vowel-consonant-vowel (like QUEUE)
43
+ - If you have 1-2 vowels confirmed, consider where the others might be
44
+
45
+ ### Advanced Tactics
46
+ - Use "sacrificial" guesses to test multiple new letters if you have attempts to spare
47
+ - Avoid repeating letter patterns unless you're certain (e.g., SPEED has two E's)
48
+ - Think about word endings: -ER, -LY, -ED, -ING are common but may not fit the 5-letter constraint
49
+ - Consider less common letters (Q, X, Z, J) only when you've eliminated most common options
50
+
51
+ ### Common Pitfalls to Avoid
52
+ - Don't reuse X letters
53
+ - Don't place Y letters in the same position they appeared
54
+ - Don't ignore confirmed G letters
55
+ - Don't guess words that contradict known information
56
+
57
+ ## EXAMPLES
58
+
59
+ ### Example 1: Opening Guess
60
+ "Starting with a word that tests common vowels and consonants in varied positions."
61
+ [crane]
62
+
63
+ ### Example 2: After Receiving Feedback
64
+ Previous guess: CRANE
65
+ Feedback: C=gray, R=yellow, A=green, N=gray, E=yellow
66
+
67
+ "A is confirmed in position 2. R and E are in the word but need different positions. C and N are eliminated. I'll try a word with A in position 2, and test R and E in new positions along with common letters like S and T."
68
+ [spare]
69
+
70
+ ### Example 3: Narrowing Down
71
+ Previous guesses: CRANE (C=gray, R=yellow, A=green, N=gray, E=yellow), SPARE (S=gray, P=gray, A=green, R=green, E=green)
72
+ Feedback summary: _ARE_ with R in position 4, A in position 2, E in position 5
73
+
74
+ "I have _AR E_ confirmed. Position 1 and 3 are unknown. Common letters to try: T, L, D, B, F, G. Testing with TARED."
75
+ [tared]
76
+
77
+ ### Example 4: Final Deduction
78
+ Previous feedback shows: _ARED with position 1 unknown and all common consonants tested
79
+
80
+ "Only position 1 remains. I've eliminated S, P, C, N. Common starting consonants left are B, F, G, H. BARED is a common word."
81
+ [bared]
82
+
83
+ ## LETTER FREQUENCY REFERENCE
84
+
85
+ Most common letters in 5-letter words (in order):
86
+ S, E, A, O, R, I, L, T, N, U, D, Y, C, P, M, H, G, B, K, F
87
+
88
+ Most common starting letters:
89
+ S, C, B, T, P, A, F, G, D, M
90
+
91
+ Most common ending letters:
92
+ E, Y, T, S, R, L, N, D
93
+
94
+ ## IMPORTANT CONSTRAINTS
95
+
96
+ - Use lowercase only
97
+ - One guess per response
98
+ - Must be exactly 5 letters
99
+ - Must be a real English word from standard dictionaries
100
+ - Never repeat a previous guess
101
+ - Always include brief reasoning before your guess
102
+
103
+ ## YOUR GOAL
104
+
105
+ Solve the Wordle in as few guesses as possible by strategically using feedback to eliminate impossible words and narrow down the solution space efficiently.
ICL/RL/trl_source/examples/scripts/ppo/ppo.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "peft",
19
+ # "trackio",
20
+ # "kernels",
21
+ # ]
22
+ # ///
23
+
24
+ import os
25
+ import shutil
26
+
27
+ import torch
28
+ from accelerate import PartialState
29
+ from datasets import load_dataset
30
+ from transformers import (
31
+ AutoModelForCausalLM,
32
+ AutoModelForSequenceClassification,
33
+ AutoTokenizer,
34
+ HfArgumentParser,
35
+ )
36
+
37
+ from trl import ModelConfig, ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config
38
+ from trl.experimental.ppo import PPOConfig, PPOTrainer
39
+
40
+
41
+ # Enable logging in a Hugging Face Space
42
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
43
+
44
+
45
+ """
46
+ python -i examples/scripts/ppo/ppo.py \
47
+ --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
48
+ --dataset_train_split descriptiveness \
49
+ --learning_rate 3e-6 \
50
+ --output_dir pythia-1b-deduped-descriptiveness-sentiment-trl-style-ppo \
51
+ --per_device_train_batch_size 64 \
52
+ --gradient_accumulation_steps 1 \
53
+ --total_episodes 10000 \
54
+ --model_name_or_path EleutherAI/pythia-1b-deduped \
55
+ --missing_eos_penalty 1.0
56
+
57
+ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
58
+ examples/scripts/ppo/ppo.py \
59
+ --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
60
+ --dataset_train_split descriptiveness \
61
+ --output_dir pythia-1b-deduped-descriptiveness-sentiment-trl-style-ppo \
62
+ --num_ppo_epochs 1 \
63
+ --num_mini_batches 1 \
64
+ --learning_rate 3e-6 \
65
+ --per_device_train_batch_size 1 \
66
+ --gradient_accumulation_steps 16 \
67
+ --total_episodes 10000 \
68
+ --model_name_or_path EleutherAI/pythia-1b-deduped \
69
+ --sft_model_path EleutherAI/pythia-1b-deduped \
70
+ --reward_model_path EleutherAI/pythia-1b-deduped \
71
+ --local_rollout_forward_batch_size 1 \
72
+ --missing_eos_penalty 1.0
73
+ """
74
+
75
+
76
+ if __name__ == "__main__":
77
+ parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig))
78
+ script_args, training_args, model_args = parser.parse_args_into_dataclasses()
79
+ # remove output_dir if exists
80
+ shutil.rmtree(training_args.output_dir, ignore_errors=True)
81
+
82
+ ################
83
+ # Model & Tokenizer
84
+ ################
85
+ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
86
+ model_kwargs = dict(
87
+ revision=model_args.model_revision,
88
+ attn_implementation=model_args.attn_implementation,
89
+ dtype=dtype,
90
+ )
91
+ quantization_config = get_quantization_config(model_args)
92
+ if quantization_config is not None:
93
+ # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
94
+ model_kwargs["device_map"] = get_kbit_device_map()
95
+ model_kwargs["quantization_config"] = quantization_config
96
+
97
+ tokenizer = AutoTokenizer.from_pretrained(
98
+ model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code
99
+ )
100
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
101
+ value_model = AutoModelForSequenceClassification.from_pretrained(
102
+ training_args.reward_model_path,
103
+ trust_remote_code=model_args.trust_remote_code,
104
+ num_labels=1,
105
+ **model_kwargs,
106
+ )
107
+ reward_model = AutoModelForSequenceClassification.from_pretrained(
108
+ training_args.reward_model_path,
109
+ trust_remote_code=model_args.trust_remote_code,
110
+ num_labels=1,
111
+ **model_kwargs,
112
+ )
113
+ policy = AutoModelForCausalLM.from_pretrained(
114
+ training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
115
+ )
116
+
117
+ peft_config = get_peft_config(model_args)
118
+ if peft_config is None:
119
+ ref_policy = AutoModelForCausalLM.from_pretrained(
120
+ training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
121
+ )
122
+ else:
123
+ ref_policy = None
124
+
125
+ ################
126
+ # Dataset
127
+ ################
128
+ dataset = load_dataset(
129
+ script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_train_split
130
+ )
131
+ eval_samples = 100
132
+ train_dataset = dataset.select(range(len(dataset) - eval_samples))
133
+ eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset)))
134
+ dataset_text_field = "prompt"
135
+
136
+ def prepare_dataset(dataset, tokenizer):
137
+ """pre-tokenize the dataset before training; only collate during training"""
138
+
139
+ def tokenize(element):
140
+ outputs = tokenizer(
141
+ element[dataset_text_field],
142
+ padding=False,
143
+ )
144
+ return {"input_ids": outputs["input_ids"]}
145
+
146
+ return dataset.map(
147
+ tokenize,
148
+ batched=True,
149
+ remove_columns=dataset.column_names,
150
+ num_proc=training_args.dataset_num_proc,
151
+ )
152
+
153
+ # Compute that only on the main process for faster data processing.
154
+ # see: https://github.com/huggingface/trl/pull/1255
155
+ with PartialState().local_main_process_first():
156
+ train_dataset = prepare_dataset(train_dataset, tokenizer)
157
+ eval_dataset = prepare_dataset(eval_dataset, tokenizer)
158
+
159
+ ################
160
+ # Training
161
+ ################
162
+ trainer = PPOTrainer(
163
+ args=training_args,
164
+ processing_class=tokenizer,
165
+ model=policy,
166
+ ref_model=ref_policy,
167
+ reward_model=reward_model,
168
+ value_model=value_model,
169
+ train_dataset=train_dataset,
170
+ eval_dataset=eval_dataset,
171
+ peft_config=peft_config,
172
+ )
173
+ trainer.train()
174
+
175
+ # Save and push to hub
176
+ trainer.save_model(training_args.output_dir)
177
+ if training_args.push_to_hub:
178
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
179
+
180
+ trainer.generate_completions()
ICL/RL/trl_source/examples/scripts/reward_modeling.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "trackio",
19
+ # "kernels",
20
+ # ]
21
+ # ///
22
+
23
+ """
24
+ Full training:
25
+ python examples/scripts/reward_modeling.py \
26
+ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
27
+ --dataset_name trl-lib/ultrafeedback_binarized \
28
+ --output_dir Qwen2-0.5B-Reward \
29
+ --per_device_train_batch_size 8 \
30
+ --num_train_epochs 1 \
31
+ --learning_rate 1.0e-5 \
32
+ --eval_strategy steps \
33
+ --eval_steps 50 \
34
+ --max_length 2048
35
+
36
+ LoRA:
37
+ python examples/scripts/reward_modeling.py \
38
+ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
39
+ --dataset_name trl-lib/ultrafeedback_binarized \
40
+ --output_dir Qwen2-0.5B-Reward-LoRA \
41
+ --per_device_train_batch_size 8 \
42
+ --num_train_epochs 1 \
43
+ --learning_rate 1.0e-4 \
44
+ --eval_strategy steps \
45
+ --eval_steps 50 \
46
+ --max_length 2048 \
47
+ --use_peft \
48
+ --lora_task_type SEQ_CLS \
49
+ --lora_r 32 \
50
+ --lora_alpha 16
51
+ """
52
+
53
+ import os
54
+
55
+ import torch
56
+ from accelerate import logging
57
+ from datasets import load_dataset
58
+ from transformers import AutoModelForSequenceClassification, HfArgumentParser
59
+
60
+ from trl import (
61
+ ModelConfig,
62
+ RewardConfig,
63
+ RewardTrainer,
64
+ ScriptArguments,
65
+ get_kbit_device_map,
66
+ get_peft_config,
67
+ get_quantization_config,
68
+ )
69
+
70
+
71
+ logger = logging.get_logger(__name__)
72
+
73
+ # Enable logging in a Hugging Face Space
74
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
75
+
76
+
77
+ if __name__ == "__main__":
78
+ parser = HfArgumentParser((ScriptArguments, RewardConfig, ModelConfig))
79
+ script_args, training_args, model_args = parser.parse_args_into_dataclasses()
80
+
81
+ ################
82
+ # Model & Tokenizer
83
+ ################
84
+ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
85
+ model_kwargs = dict(
86
+ revision=model_args.model_revision,
87
+ use_cache=False if training_args.gradient_checkpointing else True,
88
+ dtype=dtype,
89
+ )
90
+ quantization_config = get_quantization_config(model_args)
91
+ if quantization_config is not None:
92
+ # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
93
+ model_kwargs["device_map"] = get_kbit_device_map()
94
+ model_kwargs["quantization_config"] = quantization_config
95
+
96
+ model = AutoModelForSequenceClassification.from_pretrained(
97
+ model_args.model_name_or_path, num_labels=1, trust_remote_code=model_args.trust_remote_code, **model_kwargs
98
+ )
99
+
100
+ if model_args.use_peft and model_args.lora_task_type != "SEQ_CLS":
101
+ logger.warning(
102
+ "You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
103
+ " Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT.",
104
+ )
105
+
106
+ ##############
107
+ # Load dataset
108
+ ##############
109
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
110
+
111
+ ##########
112
+ # Training
113
+ ##########
114
+ trainer = RewardTrainer(
115
+ model=model,
116
+ args=training_args,
117
+ train_dataset=dataset[script_args.dataset_train_split],
118
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
119
+ peft_config=get_peft_config(model_args),
120
+ )
121
+ trainer.train()
122
+
123
+ ############################
124
+ # Save model and push to Hub
125
+ ############################
126
+ trainer.save_model(training_args.output_dir)
127
+
128
+ if training_args.eval_strategy != "no":
129
+ metrics = trainer.evaluate()
130
+ trainer.log_metrics("eval", metrics)
131
+ trainer.save_metrics("eval", metrics)
132
+
133
+ # Save and push to hub
134
+ trainer.save_model(training_args.output_dir)
135
+ if training_args.push_to_hub:
136
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/examples/scripts/sft_vlm_gemma3.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl",
18
+ # "Pillow>=9.4.0",
19
+ # "peft",
20
+ # "trackio",
21
+ # "kernels",
22
+ # ]
23
+ # ///
24
+
25
+ """
26
+ Train Gemma 3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image).
27
+
28
+ accelerate launch \
29
+ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
30
+ examples/scripts/sft_vlm_gemma3.py \
31
+ --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
32
+ --model_name_or_path google/gemma-3-4b-it \
33
+ --per_device_train_batch_size 1 \
34
+ --output_dir Gemma-3-4B-SFT-MMIU \
35
+ --dtype bfloat16 \
36
+ --use_peft \
37
+ --lora_target_modules all-linear \
38
+ --attn_implementation eager
39
+
40
+ Train Gemma 3 on the FanqingM/MMIU-Benchmark dataset (multi-image).
41
+
42
+ accelerate launch \
43
+ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
44
+ examples/scripts/sft_vlm_gemma3.py \
45
+ --dataset_name FanqingM/MMIU-Benchmark \
46
+ --dataset_train_split test \
47
+ --model_name_or_path google/gemma-3-4b-it \
48
+ --per_device_train_batch_size 1 \
49
+ --output_dir Gemma-3-4B-SFT-MMIU \
50
+ --dtype bfloat16 \
51
+ --use_peft \
52
+ --lora_target_modules all-linear \
53
+ --attn_implementation eager
54
+ """
55
+
56
+ import io
57
+ import os
58
+ import zipfile
59
+
60
+ import torch
61
+ from datasets import DatasetDict, load_dataset
62
+ from huggingface_hub import hf_hub_download, list_repo_files
63
+ from PIL import Image
64
+ from transformers import AutoModelForImageTextToText
65
+
66
+ from trl import (
67
+ ModelConfig,
68
+ ScriptArguments,
69
+ SFTConfig,
70
+ SFTTrainer,
71
+ TrlParser,
72
+ get_kbit_device_map,
73
+ get_peft_config,
74
+ get_quantization_config,
75
+ )
76
+
77
+
78
+ # Enable logging in a Hugging Face Space
79
+ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
80
+
81
+
82
+ # For multi-image example
83
+ def process_vision_info(messages: list[dict]) -> list[Image.Image]:
84
+ image_inputs = []
85
+ for msg in messages:
86
+ content = msg.get("content", [])
87
+ if not isinstance(content, list):
88
+ content = [content]
89
+
90
+ for element in content:
91
+ if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
92
+ if "image" in element:
93
+ image = element["image"]
94
+ else:
95
+ image = element
96
+ if image is not None:
97
+ image = Image.open(io.BytesIO(image["bytes"]))
98
+ image_inputs.append(image.convert("RGB"))
99
+ return image_inputs
100
+
101
+
102
+ def format_data(samples: dict[str, any]) -> dict[str, list]:
103
+ formatted_samples = {"messages": []}
104
+ for cont in range(len(samples["question"])):
105
+ images = []
106
+ for img_path in samples["input_image_path"][cont]:
107
+ try:
108
+ with open(img_path, "rb") as f:
109
+ img_bytes = f.read()
110
+ image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
111
+ images.append({"type": "image", "image": image})
112
+ except Exception as e:
113
+ print(f"Error processing image {img_path}: {e}")
114
+ continue
115
+
116
+ formatted_samples["messages"].append(
117
+ [
118
+ {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
119
+ {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
120
+ {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
121
+ ]
122
+ )
123
+ return formatted_samples
124
+
125
+
126
+ # For multi-image example
127
+ def prepare_dataset(dataset: DatasetDict, dataset_name: str) -> DatasetDict:
128
+ all_files = list_repo_files(dataset_name, repo_type="dataset")
129
+ zip_files = [f for f in all_files if f.endswith(".zip")]
130
+
131
+ for zip_filename in zip_files:
132
+ zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
133
+ extract_folder = zip_filename.replace(".zip", "")
134
+ os.makedirs(extract_folder, exist_ok=True)
135
+
136
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
137
+ zip_ref.extractall(extract_folder)
138
+
139
+ dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
140
+ return dataset
141
+
142
+
143
+ def main():
144
+ parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
145
+ script_args, training_args, model_args = parser.parse_args_and_config()
146
+ training_args.max_length = None
147
+
148
+ ################
149
+ # Model
150
+ ################
151
+ dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
152
+ model_kwargs = dict(
153
+ revision=model_args.model_revision,
154
+ attn_implementation=model_args.attn_implementation,
155
+ dtype=dtype,
156
+ )
157
+ quantization_config = get_quantization_config(model_args)
158
+ if quantization_config is not None:
159
+ # Passing None would not be treated the same as omitting the argument, so we include it only when valid.
160
+ model_kwargs["device_map"] = get_kbit_device_map()
161
+ model_kwargs["quantization_config"] = quantization_config
162
+
163
+ model = AutoModelForImageTextToText.from_pretrained(
164
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
165
+ )
166
+
167
+ ################
168
+ # Dataset
169
+ ################
170
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
171
+ if script_args.dataset_name == "FanqingM/MMIU-Benchmark":
172
+ dataset = prepare_dataset(dataset, script_args.dataset_name)
173
+
174
+ ################
175
+ # Training
176
+ ################
177
+ trainer = SFTTrainer(
178
+ model=model,
179
+ args=training_args,
180
+ train_dataset=dataset[script_args.dataset_train_split],
181
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
182
+ peft_config=get_peft_config(model_args),
183
+ )
184
+
185
+ trainer.train()
186
+
187
+ # Save and push to hub
188
+ trainer.save_model(training_args.output_dir)
189
+ if training_args.push_to_hub:
190
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
191
+
192
+
193
+ if __name__ == "__main__":
194
+ main()
ICL/RL/trl_source/trl/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (2.3 kB). View file
 
ICL/RL/trl_source/trl/__pycache__/_compat.cpython-313.pyc ADDED
Binary file (11.1 kB). View file
 
ICL/RL/trl_source/trl/__pycache__/chat_template_utils.cpython-313.pyc ADDED
Binary file (22.3 kB). View file
 
ICL/RL/trl_source/trl/__pycache__/data_utils.cpython-313.pyc ADDED
Binary file (43.1 kB). View file
 
ICL/RL/trl_source/trl/__pycache__/import_utils.cpython-313.pyc ADDED
Binary file (8.1 kB). View file
 
ICL/RL/trl_source/trl/accelerate_configs/fsdp1.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: FSDP
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ fsdp_config:
7
+ fsdp_activation_checkpointing: false
8
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
9
+ fsdp_backward_prefetch: BACKWARD_PRE
10
+ fsdp_cpu_ram_efficient_loading: true
11
+ fsdp_forward_prefetch: true
12
+ fsdp_offload_params: false
13
+ fsdp_reshard_after_forward: FULL_SHARD
14
+ fsdp_state_dict_type: FULL_STATE_DICT
15
+ fsdp_sync_module_states: true
16
+ fsdp_use_orig_params: true
17
+ fsdp_version: 1
18
+ machine_rank: 0
19
+ main_training_function: main
20
+ mixed_precision: bf16
21
+ num_machines: 1
22
+ num_processes: 8
23
+ rdzv_backend: static
24
+ same_network: true
25
+ tpu_env: []
26
+ tpu_use_cluster: false
27
+ tpu_use_sudo: false
28
+ use_cpu: false
ICL/RL/trl_source/trl/accelerate_configs/fsdp2.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requires accelerate 1.7.0 or higher
2
+ compute_environment: LOCAL_MACHINE
3
+ debug: false
4
+ distributed_type: FSDP
5
+ downcast_bf16: 'no'
6
+ enable_cpu_affinity: false
7
+ fsdp_config:
8
+ fsdp_activation_checkpointing: false
9
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
10
+ fsdp_cpu_ram_efficient_loading: true
11
+ fsdp_offload_params: false
12
+ fsdp_reshard_after_forward: true
13
+ fsdp_state_dict_type: FULL_STATE_DICT
14
+ fsdp_version: 2
15
+ machine_rank: 0
16
+ main_training_function: main
17
+ mixed_precision: bf16
18
+ num_machines: 1
19
+ num_processes: 8
20
+ rdzv_backend: static
21
+ same_network: true
22
+ tpu_env: []
23
+ tpu_use_cluster: false
24
+ tpu_use_sudo: false
25
+ use_cpu: false
ICL/RL/trl_source/trl/accelerate_configs/multi_gpu.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ gpu_ids: all
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: 'bf16'
9
+ num_machines: 1
10
+ num_processes: 8
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
ICL/RL/trl_source/trl/accelerate_configs/single_gpu.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: "NO"
4
+ downcast_bf16: 'no'
5
+ gpu_ids: all
6
+ machine_rank: 0
7
+ main_training_function: main
8
+ mixed_precision: 'bf16'
9
+ num_machines: 1
10
+ num_processes: 8
11
+ rdzv_backend: static
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
ICL/RL/trl_source/trl/accelerate_configs/zero1.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ gradient_accumulation_steps: 1
6
+ zero3_init_flag: false
7
+ zero_stage: 1
8
+ distributed_type: DEEPSPEED
9
+ downcast_bf16: 'no'
10
+ machine_rank: 0
11
+ main_training_function: main
12
+ mixed_precision: 'bf16'
13
+ num_machines: 1
14
+ num_processes: 8
15
+ rdzv_backend: static
16
+ same_network: true
17
+ tpu_env: []
18
+ tpu_use_cluster: false
19
+ tpu_use_sudo: false
20
+ use_cpu: false
ICL/RL/trl_source/trl/accelerate_configs/zero2.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: false
8
+ zero_stage: 2
9
+ distributed_type: DEEPSPEED
10
+ downcast_bf16: 'no'
11
+ machine_rank: 0
12
+ main_training_function: main
13
+ mixed_precision: 'bf16'
14
+ num_machines: 1
15
+ num_processes: 8
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
ICL/RL/trl_source/trl/accelerate_configs/zero3.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ deepspeed_multinode_launcher: standard
5
+ offload_optimizer_device: none
6
+ offload_param_device: none
7
+ zero3_init_flag: true
8
+ zero3_save_16bit_model: true
9
+ zero_stage: 3
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ machine_rank: 0
13
+ main_training_function: main
14
+ mixed_precision: bf16
15
+ num_machines: 1
16
+ num_processes: 8
17
+ rdzv_backend: static
18
+ same_network: true
19
+ tpu_env: []
20
+ tpu_use_cluster: false
21
+ tpu_use_sudo: false
22
+ use_cpu: false
ICL/RL/trl_source/trl/experimental/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Experimental submodule for TRL.
17
+
18
+ This submodule contains unstable or incubating features. Anything here may change (or be removed) in any release
19
+ without deprecation. Use at your own risk.
20
+
21
+ To silence this notice set environment variable TRL_EXPERIMENTAL_SILENCE=1.
22
+ """
23
+
24
+ import os
25
+ import warnings
26
+
27
+ from ..import_utils import TRLExperimentalWarning
28
+
29
+
30
+ if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
31
+ warnings.warn(
32
+ "You are importing from 'trl.experimental'. APIs here are unstable and may change or be removed without "
33
+ "notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.",
34
+ TRLExperimentalWarning,
35
+ stacklevel=2,
36
+ )
ICL/RL/trl_source/trl/experimental/bco/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .bco_config import BCOConfig
16
+ from .bco_trainer import BCOTrainer
ICL/RL/trl_source/trl/experimental/bema_for_ref_model/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .callback import BEMACallback
16
+ from .dpo_trainer import DPOTrainer
ICL/RL/trl_source/trl/experimental/bema_for_ref_model/dpo_trainer.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...trainer.dpo_trainer import DPOTrainer as _DPOTrainer
16
+ from .callback import CallbackHandlerWithRefModel
17
+
18
+
19
+ class DPOTrainer(_DPOTrainer):
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+ # Replace with a new one that calls the events with the reference model
23
+ self.callback_handler = CallbackHandlerWithRefModel(
24
+ self.callback_handler.callbacks,
25
+ self.model,
26
+ self.ref_model,
27
+ self.processing_class,
28
+ self.optimizer,
29
+ self.lr_scheduler,
30
+ )
ICL/RL/trl_source/trl/experimental/cpo/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .cpo_config import CPOConfig
16
+ from .cpo_trainer import CPOTrainer
17
+
18
+
19
+ __all__ = ["CPOConfig", "CPOTrainer"]
ICL/RL/trl_source/trl/experimental/cpo/cpo_config.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import Any
17
+
18
+ from transformers import TrainingArguments
19
+
20
+
21
+ @dataclass
22
+ class CPOConfig(TrainingArguments):
23
+ r"""
24
+ Configuration class for the [`experimental.cpo.CPOTrainer`].
25
+
26
+ This class includes only the parameters that are specific to CPO training. For a full list of training arguments,
27
+ please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
28
+ differ from those in [`~transformers.TrainingArguments`].
29
+
30
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
31
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
32
+ command line.
33
+
34
+ Parameters:
35
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
36
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
37
+ to use the default data collator.
38
+ max_completion_length (`int`, *optional*):
39
+ Maximum length of the completion. This argument is required if you want to use the default data collator
40
+ and your model is an encoder-decoder.
41
+ beta (`float`, *optional*, defaults to `0.1`):
42
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
43
+ reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
44
+ the [paper](https://huggingface.co/papers/2310.12036).
45
+ label_smoothing (`float`, *optional*, defaults to `0.0`):
46
+ Label smoothing factor. This argument is required if you want to use the default data collator.
47
+ loss_type (`str`, *optional*, defaults to `"sigmoid"`):
48
+ Type of loss to use. Possible values are:
49
+
50
+ - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
51
+ - `"hinge"`: hinge loss on the normalized likelihood from the
52
+ [SLiC](https://huggingface.co/papers/2305.10425) paper.
53
+ - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
54
+ - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
55
+ - `"alphapo"`: AlphaPO loss from the [AlphaPO](https://huggingface.co/papers/2501.03884) paper. This
56
+ automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`.
57
+
58
+ disable_dropout (`bool`, *optional*, defaults to `True`):
59
+ Whether to disable dropout in the model.
60
+ cpo_alpha (`float`, *optional*, defaults to `1.0`):
61
+ Weight of the BC regularizer in CPO training.
62
+ simpo_gamma (`float`, *optional*, defaults to `0.5`):
63
+ Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
64
+ alpha (`float`, *optional*, defaults to `0.0`):
65
+ Alpha parameter that controls reward function shape across all loss types. When alpha=0 (default), uses
66
+ standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: `r = (1 - p^(-alpha))
67
+ / alpha` from the [AlphaPO paper](https://huggingface.co/papers/2501.03884). This parameter works with all
68
+ loss types.
69
+ truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
70
+ Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
71
+ This argument is required if you want to use the default data collator.
72
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
73
+ If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
74
+ is_encoder_decoder (`bool`, *optional*):
75
+ When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
76
+ you need to specify if the model returned by the callable is an encoder-decoder model.
77
+ model_init_kwargs (`dict[str, Any]`, *optional*):
78
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
79
+ string.
80
+ dataset_num_proc (`int`, *optional*):
81
+ Number of processes to use for processing the dataset.
82
+ """
83
+
84
+ _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
85
+
86
+ # Parameters whose default values are overridden from TrainingArguments
87
+ learning_rate: float = field(
88
+ default=1e-6,
89
+ metadata={"help": "The initial learning rate for AdamW."},
90
+ )
91
+ logging_steps: float = field(
92
+ default=10,
93
+ metadata={
94
+ "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, "
95
+ "will be interpreted as ratio of total training steps."
96
+ },
97
+ )
98
+ gradient_checkpointing: bool = field(
99
+ default=True,
100
+ metadata={
101
+ "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
102
+ },
103
+ )
104
+ bf16: bool | None = field(
105
+ default=None,
106
+ metadata={
107
+ "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
108
+ "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if "
109
+ "`fp16` is not set."
110
+ },
111
+ )
112
+ # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue
113
+ # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary
114
+ # workaround here, which can be removed once we drop support for versions older than 4.57.5.
115
+ lr_scheduler_kwargs: dict | str | None = field(
116
+ default=None,
117
+ metadata={
118
+ "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard "
119
+ "restarts."
120
+ },
121
+ )
122
+
123
+ max_length: int | None = field(
124
+ default=1024,
125
+ metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
126
+ )
127
+ max_completion_length: int | None = field(
128
+ default=None,
129
+ metadata={
130
+ "help": "Maximum length of the completion. This argument is required if you want to use the default data "
131
+ "collator and your model is an encoder-decoder."
132
+ },
133
+ )
134
+ beta: float = field(
135
+ default=0.1,
136
+ metadata={
137
+ "help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from "
138
+ "the reference model."
139
+ },
140
+ )
141
+ label_smoothing: float = field(
142
+ default=0.0,
143
+ metadata={"help": "Label smoothing factor."},
144
+ )
145
+ loss_type: str = field(
146
+ default="sigmoid",
147
+ metadata={
148
+ "help": "Type of loss to use.",
149
+ "choices": ["sigmoid", "hinge", "ipo", "simpo", "alphapo"],
150
+ },
151
+ )
152
+ disable_dropout: bool = field(
153
+ default=True,
154
+ metadata={"help": "Whether to disable dropout in the model."},
155
+ )
156
+ cpo_alpha: float = field(
157
+ default=1.0,
158
+ metadata={"help": "Weight of the BC regularizer in CPO training."},
159
+ )
160
+ simpo_gamma: float = field(
161
+ default=0.5,
162
+ metadata={"help": "Target reward margin for the SimPO loss, used only when the `loss_type='simpo'`."},
163
+ )
164
+ alpha: float = field(
165
+ default=0.0,
166
+ metadata={
167
+ "help": "Alpha parameter that controls reward function shape across all loss types. When alpha=0 "
168
+ "(default), uses standard log probability rewards. When `alpha != 0`, applies AlphaPO transformation: "
169
+ "`r = (1 - p^(-alpha)) / alpha` from the AlphaPO paper. This parameter works with all loss types."
170
+ },
171
+ )
172
+ truncation_mode: str = field(
173
+ default="keep_end",
174
+ metadata={
175
+ "help": "Truncation mode to use when the prompt is too long.",
176
+ "choices": ["keep_end", "keep_start"],
177
+ },
178
+ )
179
+ generate_during_eval: bool = field(
180
+ default=False,
181
+ metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."},
182
+ )
183
+ is_encoder_decoder: bool | None = field(
184
+ default=None,
185
+ metadata={"help": "Whether the model is an encoder-decoder model."},
186
+ )
187
+ model_init_kwargs: dict[str, Any] | None = field(
188
+ default=None,
189
+ metadata={
190
+ "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
191
+ "from a string."
192
+ },
193
+ )
194
+ dataset_num_proc: int | None = field(
195
+ default=None,
196
+ metadata={"help": "Number of processes to use for processing the dataset."},
197
+ )
198
+
199
+ def __post_init__(self):
200
+ self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
201
+
202
+ # Syntactic sugar for AlphaPO: set loss_type to "simpo" and cpo_alpha to 0.0
203
+ if self.loss_type == "alphapo":
204
+ self.loss_type = "simpo"
205
+ self.cpo_alpha = 0.0
206
+
207
+ super().__post_init__()
ICL/RL/trl_source/trl/experimental/cpo/cpo_trainer.py ADDED
@@ -0,0 +1,1057 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import random
17
+ import textwrap
18
+ from collections import defaultdict
19
+ from collections.abc import Callable
20
+ from contextlib import nullcontext
21
+ from pathlib import Path
22
+ from typing import Any, Literal
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ import transformers
30
+ from accelerate import PartialState, logging
31
+ from datasets import Dataset
32
+ from packaging.version import Version
33
+ from torch import autocast
34
+ from torch.utils.data import DataLoader
35
+ from transformers import (
36
+ AutoModelForCausalLM,
37
+ BaseImageProcessor,
38
+ DataCollator,
39
+ FeatureExtractionMixin,
40
+ PreTrainedModel,
41
+ PreTrainedTokenizerBase,
42
+ ProcessorMixin,
43
+ TrainerCallback,
44
+ is_comet_available,
45
+ is_wandb_available,
46
+ )
47
+ from transformers.trainer_utils import EvalLoopOutput
48
+ from transformers.utils import is_peft_available, is_torch_fx_proxy
49
+
50
+ from ...data_utils import maybe_apply_chat_template, maybe_extract_prompt
51
+ from ...models.utils import peft_module_casting_to_bf16
52
+ from ...trainer.base_trainer import BaseTrainer
53
+ from ...trainer.utils import (
54
+ disable_dropout_in_model,
55
+ log_table_to_comet_experiment,
56
+ pad_to_length,
57
+ selective_log_softmax,
58
+ )
59
+ from ..utils import DPODataCollatorWithPadding, add_bos_token_if_needed, add_eos_token_if_needed
60
+ from .cpo_config import CPOConfig
61
+
62
+
63
+ if is_peft_available():
64
+ from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
65
+
66
+
67
+ if is_wandb_available():
68
+ import wandb
69
+
70
+
71
+ logger = logging.get_logger(__name__)
72
+
73
+
74
+ class CPOTrainer(BaseTrainer):
75
+ r"""
76
+ Initialize CPOTrainer.
77
+
78
+ Args:
79
+ model ([`~transformers.PreTrainedModel`]):
80
+ The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
81
+ args ([`experimental.cpo.CPOConfig`]):
82
+ The CPO config arguments to use for training.
83
+ data_collator ([`~transformers.DataCollator`]):
84
+ The data collator to use for training. If None is specified, the default data collator
85
+ ([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the
86
+ maximum length of the sequences in the batch, given a dataset of paired sequences.
87
+ train_dataset ([`~datasets.Dataset`]):
88
+ The dataset to use for training.
89
+ eval_dataset ([`~datasets.Dataset`]):
90
+ The dataset to use for evaluation.
91
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
92
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
93
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
94
+ reuse the fine-tuned model.
95
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
96
+ The model initializer to use for training. If None is specified, the default model initializer will be
97
+ used.
98
+ callbacks (`list[transformers.TrainerCallback]`):
99
+ The callbacks to use for training.
100
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
101
+ The optimizer and scheduler to use for training.
102
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
103
+ The function to use to preprocess the logits before computing the metrics.
104
+ peft_config (`dict`, defaults to `None`):
105
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
106
+ a PEFT model.
107
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
108
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
109
+ metric values.
110
+ """
111
+
112
+ _tag_names = ["trl", "cpo"]
113
+ _name = "CPO"
114
+ _paper = {
115
+ "title": "Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
116
+ "id": "2401.08417",
117
+ # docstyle-ignore
118
+ "citation": textwrap.dedent("""\
119
+ @inproceedings{xu2024contrastive,
120
+ title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
121
+ author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
122
+ year = 2024,
123
+ booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
124
+ publisher = {OpenReview.net},
125
+ url = {https://openreview.net/forum?id=51iwkioZpn}
126
+ }"""),
127
+ }
128
+
129
+ def __init__(
130
+ self,
131
+ model: PreTrainedModel | nn.Module | str | None = None,
132
+ args: CPOConfig | None = None,
133
+ data_collator: DataCollator | None = None,
134
+ train_dataset: Dataset | None = None,
135
+ eval_dataset: Dataset | dict[str, Dataset] | None = None,
136
+ processing_class: PreTrainedTokenizerBase
137
+ | BaseImageProcessor
138
+ | FeatureExtractionMixin
139
+ | ProcessorMixin
140
+ | None = None,
141
+ model_init: Callable[[], PreTrainedModel] | None = None,
142
+ callbacks: list[TrainerCallback] | None = None,
143
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
144
+ preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
145
+ peft_config: dict | None = None,
146
+ compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
147
+ ):
148
+ if args.model_init_kwargs is None:
149
+ model_init_kwargs = {}
150
+ elif not isinstance(model, str):
151
+ raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
152
+ else:
153
+ model_init_kwargs = args.model_init_kwargs
154
+ dtype = model_init_kwargs.get("dtype", "auto")
155
+ if dtype is not None:
156
+ # Convert to `torch.dtype` if an str is passed
157
+ if isinstance(dtype, str) and dtype != "auto":
158
+ dtype = getattr(torch, dtype)
159
+ if dtype != "auto" and not isinstance(dtype, torch.dtype):
160
+ raise ValueError(
161
+ f"Invalid `dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}."
162
+ )
163
+ model_init_kwargs["dtype"] = dtype
164
+ model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
165
+
166
+ if isinstance(model, str):
167
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
168
+
169
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
170
+ # has been called in order to properly call autocast if needed.
171
+ self._peft_has_been_casted_to_bf16 = False
172
+
173
+ if not is_peft_available() and peft_config is not None:
174
+ raise ValueError(
175
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
176
+ )
177
+ elif is_peft_available() and peft_config is not None:
178
+ if isinstance(model, PeftModel):
179
+ raise ValueError(
180
+ "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first "
181
+ "merge and unload the existing adapter, save the resulting base model, and then pass that base "
182
+ "model along with the new `peft_config` to the trainer."
183
+ )
184
+
185
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
186
+ _support_gc_kwargs = hasattr(
187
+ args, "gradient_checkpointing_kwargs"
188
+ ) and "gradient_checkpointing_kwargs" in list(
189
+ inspect.signature(prepare_model_for_kbit_training).parameters
190
+ )
191
+
192
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
193
+
194
+ if _support_gc_kwargs:
195
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
196
+
197
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
198
+ elif args.gradient_checkpointing:
199
+ # For backward compatibility with older versions of transformers
200
+ if hasattr(model, "enable_input_require_grads"):
201
+ model.enable_input_require_grads()
202
+ else:
203
+
204
+ def make_inputs_require_grad(module, input, output):
205
+ output.requires_grad_(True)
206
+
207
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
208
+
209
+ # get peft model with the given config
210
+ model = get_peft_model(model, peft_config)
211
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
212
+ peft_module_casting_to_bf16(model)
213
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
214
+ self._peft_has_been_casted_to_bf16 = True
215
+
216
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
217
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
218
+ # fail or completely fail.
219
+ elif args.gradient_checkpointing:
220
+ # For backward compatibility with older versions of transformers
221
+ if hasattr(model, "enable_input_require_grads"):
222
+ model.enable_input_require_grads()
223
+ else:
224
+
225
+ def make_inputs_require_grad(module, input, output):
226
+ output.requires_grad_(True)
227
+
228
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
229
+
230
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
231
+ raise ValueError(
232
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
233
+ " Please install `wandb` or `comet-ml` to resolve."
234
+ )
235
+
236
+ if model is not None:
237
+ self.is_encoder_decoder = model.config.is_encoder_decoder
238
+ elif args.is_encoder_decoder is None:
239
+ raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
240
+ else:
241
+ self.is_encoder_decoder = args.is_encoder_decoder
242
+
243
+ if self.is_encoder_decoder:
244
+ self.decoder_start_token_id = model.config.decoder_start_token_id
245
+ self.pad_token_id = model.config.pad_token_id
246
+
247
+ if processing_class is None:
248
+ raise ValueError("processing_class must be specified to tokenize a CPO dataset.")
249
+ if args.max_length is None:
250
+ logger.warning(
251
+ "`max_length` is not set in the CPOConfig's init"
252
+ " it will default to `512` by default, but you should do it yourself in the future.",
253
+ )
254
+ max_length = 512
255
+ else:
256
+ max_length = args.max_length
257
+
258
+ if args.max_completion_length is None and self.is_encoder_decoder:
259
+ logger.warning(
260
+ "When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init"
261
+ " it will default to `128` by default, but you should do it yourself in the future.",
262
+ )
263
+ max_completion_length = 128
264
+ else:
265
+ max_completion_length = args.max_completion_length
266
+
267
+ if data_collator is None:
268
+ data_collator = DPODataCollatorWithPadding(
269
+ pad_token_id=processing_class.pad_token_id,
270
+ is_encoder_decoder=self.is_encoder_decoder,
271
+ )
272
+
273
+ if args.remove_unused_columns:
274
+ args.remove_unused_columns = False
275
+ # warn users
276
+ logger.warning(
277
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
278
+ " we have set it for you, but you should do it yourself in the future.",
279
+ )
280
+
281
+ self.use_dpo_data_collator = True
282
+ else:
283
+ self.use_dpo_data_collator = False
284
+
285
+ # Disable dropout in the model
286
+ if args.disable_dropout:
287
+ disable_dropout_in_model(model)
288
+
289
+ self.max_length = max_length
290
+ self.generate_during_eval = args.generate_during_eval
291
+ self.truncation_mode = args.truncation_mode
292
+ self.max_completion_length = max_completion_length
293
+ self.processing_class = processing_class
294
+
295
+ if processing_class.pad_token is None:
296
+ processing_class.pad_token = processing_class.eos_token
297
+ self.pad_token_id = processing_class.pad_token_id
298
+
299
+ if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0:
300
+ logger.warning(
301
+ f"You are using the {args.loss_type} loss type that does not support label smoothing. The "
302
+ "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.",
303
+ )
304
+ if args.loss_type == "kto_pair":
305
+ raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.")
306
+
307
+ self.beta = args.beta
308
+ self.label_smoothing = args.label_smoothing
309
+ self.loss_type = args.loss_type
310
+ self.cpo_alpha = args.cpo_alpha
311
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
312
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
313
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
314
+ logger.warning(
315
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
316
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
317
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
318
+ "loss.",
319
+ )
320
+
321
+ if args.loss_type == "simpo":
322
+ self.simpo_gamma = args.simpo_gamma
323
+
324
+ # AlphaPO parameter for reward shaping
325
+ self.alpha = args.alpha
326
+
327
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
328
+
329
+ # Compute that only on the main process for faster data processing.
330
+ # see: https://github.com/huggingface/trl/pull/1255
331
+ with PartialState().main_process_first():
332
+ # Extract the prompt if needed, and apply the chat template if needed
333
+ train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
334
+ train_dataset = train_dataset.map(
335
+ maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc
336
+ )
337
+ if eval_dataset is not None:
338
+ eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc)
339
+ eval_dataset = eval_dataset.map(
340
+ maybe_apply_chat_template,
341
+ fn_kwargs={"tokenizer": processing_class},
342
+ num_proc=args.dataset_num_proc,
343
+ )
344
+
345
+ # tokenize the dataset
346
+ train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
347
+ if eval_dataset is not None:
348
+ eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
349
+
350
+ # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
351
+ # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
352
+ # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
353
+ # default to the recommended non-reentrant behavior here, while preserving any user-provided value.
354
+ if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
355
+ args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
356
+ args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)
357
+
358
+ super().__init__(
359
+ model=model,
360
+ args=args,
361
+ data_collator=data_collator,
362
+ train_dataset=train_dataset,
363
+ eval_dataset=eval_dataset,
364
+ processing_class=processing_class,
365
+ model_init=model_init,
366
+ compute_metrics=compute_metrics,
367
+ callbacks=callbacks,
368
+ optimizers=optimizers,
369
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
370
+ )
371
+
372
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
373
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
374
+ # self.model_accepts_loss_kwargs to False to enable scaling.
375
+ self.model_accepts_loss_kwargs = False
376
+
377
+ # Add tags for models that have been loaded with the correct transformers version
378
+ if hasattr(self.model, "add_model_tags"):
379
+ self.model.add_model_tags(self._tag_names)
380
+
381
+ if not hasattr(self, "accelerator"):
382
+ raise AttributeError(
383
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
384
+ )
385
+
386
+ def build_tokenized_answer(self, prompt, answer):
387
+ """
388
+ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a +
389
+ b)[len(enc(a)):]`. Reference:
390
+ https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
391
+ """
392
+
393
+ full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False)
394
+ prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"]
395
+
396
+ answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
397
+ answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
398
+
399
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
400
+ full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
401
+
402
+ # Prepare input tokens for token by token comparison
403
+ full_input_ids = np.array(full_tokenized["input_ids"])
404
+
405
+ if len(full_input_ids) != len(full_concat_input_ids):
406
+ raise ValueError("Prompt input ids and answer input ids should have the same length.")
407
+
408
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
409
+ # can be merged together when tokenizing prompt+answer. This could result
410
+ # on the last token from the prompt being different when tokenized on its own
411
+ # vs when done as prompt+answer.
412
+ response_token_ids_start_idx = len(prompt_input_ids)
413
+
414
+ # If tokenized prompt is different than both prompt+answer, then it means the
415
+ # last token has changed due to merging.
416
+ if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
417
+ response_token_ids_start_idx -= 1
418
+
419
+ prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
420
+ prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
421
+
422
+ if len(prompt_input_ids) != len(prompt_attention_mask):
423
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
424
+
425
+ answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
426
+ answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
427
+
428
+ return dict(
429
+ prompt_input_ids=prompt_input_ids,
430
+ prompt_attention_mask=prompt_attention_mask,
431
+ input_ids=answer_input_ids,
432
+ attention_mask=answer_attention_mask,
433
+ )
434
+
435
+ def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None) -> dict:
436
+ """Tokenize a single row from a CPO specific dataset.
437
+
438
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
439
+ chosen or prompt + rejected responses is/are too long. First we truncate the prompt; if we're still too long,
440
+ we truncate the chosen/rejected.
441
+
442
+ We also create the labels for the chosen/rejected responses, which are of length equal to the sum of the length
443
+ of the prompt and the chosen/rejected response, with `-100` for the prompt tokens.
444
+ """
445
+ batch = {}
446
+ prompt = feature["prompt"]
447
+ chosen = feature["chosen"]
448
+ rejected = feature["rejected"]
449
+
450
+ if not self.is_encoder_decoder:
451
+ # Check issues below for more details
452
+ # 1. https://github.com/huggingface/trl/issues/907
453
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
454
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
455
+
456
+ if not isinstance(prompt, str):
457
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
458
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=False)
459
+ prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
460
+
461
+ if not isinstance(chosen, str):
462
+ raise ValueError(f"chosen should be an str but got {type(chosen)}")
463
+ chosen_tokens = self.build_tokenized_answer(prompt, chosen)
464
+
465
+ if not isinstance(rejected, str):
466
+ raise ValueError(f"rejected should be an str but got {type(rejected)}")
467
+ rejected_tokens = self.build_tokenized_answer(prompt, rejected)
468
+
469
+ # Last prompt token might get merged by tokenizer and
470
+ # it should not be included for generation if that happens
471
+ prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
472
+
473
+ chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
474
+ rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
475
+ prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
476
+
477
+ for k, v in prompt_tokens.items():
478
+ prompt_tokens[k] = v[:prompt_len_input_ids]
479
+
480
+ # Make sure prompts only have one different token at most an
481
+ # and length only differs by 1 at most
482
+ num_diff_tokens = sum(
483
+ a != b
484
+ for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=True)
485
+ )
486
+ num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
487
+ if num_diff_tokens > 1 or num_diff_len > 1:
488
+ raise ValueError(
489
+ "Chosen and rejected prompt_input_ids might only differ on the "
490
+ "last token due to tokenizer merge ops."
491
+ )
492
+
493
+ # add BOS token to head of prompt. Avoid adding if it's already there
494
+ prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed(
495
+ self.processing_class.bos_token_id,
496
+ prompt_len_input_ids,
497
+ prompt_tokens,
498
+ chosen_prompt_len_input_ids,
499
+ chosen_tokens,
500
+ rejected_prompt_len_input_ids,
501
+ rejected_tokens,
502
+ )
503
+
504
+ # add EOS token to end of answer. Avoid adding if it's already there
505
+ chosen_tokens, rejected_tokens = add_eos_token_if_needed(
506
+ self.processing_class.eos_token_id, chosen_tokens, rejected_tokens
507
+ )
508
+
509
+ longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
510
+
511
+ # if combined sequence is too long, truncate the response
512
+ for answer_tokens in [chosen_tokens, rejected_tokens]:
513
+ if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
514
+ for k in ["input_ids", "attention_mask"]:
515
+ answer_tokens[k] = answer_tokens[k][: self.max_length - longer_response_length]
516
+
517
+ # Create labels
518
+ chosen_sequence_tokens = {
519
+ k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
520
+ }
521
+ rejected_sequence_tokens = {
522
+ k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
523
+ }
524
+ chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
525
+ chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [-100] * len(
526
+ chosen_tokens["prompt_input_ids"]
527
+ )
528
+ rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
529
+ rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [-100] * len(
530
+ rejected_tokens["prompt_input_ids"]
531
+ )
532
+
533
+ for k, toks in {
534
+ "chosen_": chosen_sequence_tokens,
535
+ "rejected_": rejected_sequence_tokens,
536
+ "": prompt_tokens,
537
+ }.items():
538
+ for type_key, tokens in toks.items():
539
+ if type_key == "token_type_ids":
540
+ continue
541
+ batch[f"{k}{type_key}"] = tokens
542
+
543
+ else:
544
+ chosen_tokens = self.processing_class(
545
+ chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
546
+ )
547
+ rejected_tokens = self.processing_class(
548
+ rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
549
+ )
550
+ prompt_tokens = self.processing_class(prompt, add_special_tokens=True)
551
+
552
+ batch["chosen_labels"] = chosen_tokens["input_ids"]
553
+ batch["rejected_labels"] = rejected_tokens["input_ids"]
554
+ batch["prompt_input_ids"] = prompt_tokens["input_ids"]
555
+ batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
556
+
557
+ if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
558
+ batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
559
+ labels=torch.tensor(batch["rejected_labels"])
560
+ )
561
+ batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
562
+ labels=torch.tensor(batch["chosen_labels"])
563
+ )
564
+
565
+ return batch
566
+
567
+ @staticmethod
568
+ def concatenated_inputs(
569
+ batch: dict[str, list | torch.LongTensor],
570
+ is_encoder_decoder: bool = False,
571
+ padding_value: int = 0,
572
+ device: torch.device | None = None,
573
+ ) -> dict[str, torch.LongTensor]:
574
+ """Concatenate the chosen and rejected inputs into a single tensor.
575
+
576
+ Args:
577
+ batch:
578
+ A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors
579
+ of shape (batch_size, sequence_length).
580
+ is_encoder_decoder:
581
+ Whether the model is an encoder-decoder model.
582
+ padding_value:
583
+ The padding value to use for the concatenated inputs_ids.
584
+ device:
585
+ The device for the concatenated inputs.
586
+
587
+ Returns:
588
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
589
+ """
590
+ concatenated_batch = {}
591
+
592
+ if is_encoder_decoder:
593
+ max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
594
+ else:
595
+ max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
596
+
597
+ for k in batch:
598
+ if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
599
+ if "labels" in k or is_encoder_decoder:
600
+ pad_value = -100
601
+ elif k.endswith("_input_ids"):
602
+ pad_value = padding_value
603
+ elif k.endswith("_attention_mask"):
604
+ pad_value = 0
605
+ concatenated_key = k.replace("chosen", "concatenated")
606
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
607
+ for k in batch:
608
+ if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
609
+ if "labels" in k or is_encoder_decoder:
610
+ pad_value = -100
611
+ elif k.endswith("_input_ids"):
612
+ pad_value = padding_value
613
+ elif k.endswith("_attention_mask"):
614
+ pad_value = 0
615
+ concatenated_key = k.replace("rejected", "concatenated")
616
+ concatenated_batch[concatenated_key] = torch.cat(
617
+ (
618
+ concatenated_batch[concatenated_key],
619
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
620
+ ),
621
+ dim=0,
622
+ ).to(device=device)
623
+
624
+ if is_encoder_decoder:
625
+ concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
626
+ concatenated_batch["concatenated_attention_mask"] = (
627
+ batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
628
+ )
629
+
630
+ return concatenated_batch
631
+
632
+ def cpo_loss(
633
+ self,
634
+ policy_chosen_logps: torch.FloatTensor,
635
+ policy_rejected_logps: torch.FloatTensor,
636
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
637
+ """Compute the CPO loss for a batch of policy and reference model log probabilities.
638
+
639
+ Args:
640
+ policy_chosen_logps:
641
+ Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
642
+ policy_rejected_logps:
643
+ Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
644
+
645
+ Returns:
646
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). The losses tensor contains the CPO
647
+ loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
648
+ the chosen and rejected responses, respectively.
649
+ """
650
+ # Apply AlphaPO reward transformation if alpha != 0
651
+ if self.alpha != 0.0:
652
+ # Compute probabilities
653
+ chosen_probs = torch.exp(policy_chosen_logps)
654
+ rejected_probs = torch.exp(policy_rejected_logps)
655
+
656
+ # Apply AlphaPO transformation: r = (1 - p^(-alpha)) / alpha
657
+ policy_chosen_rewards = (1 - chosen_probs.pow(-self.alpha)) / self.alpha
658
+ policy_rejected_rewards = (1 - rejected_probs.pow(-self.alpha)) / self.alpha
659
+
660
+ logits = (policy_chosen_rewards - policy_rejected_rewards).to(self.accelerator.device)
661
+ else:
662
+ # Standard log probability rewards when alpha = 0
663
+ logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
664
+
665
+ # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
666
+ # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
667
+ # calculates a conservative CPO loss.
668
+
669
+ if self.loss_type == "simpo":
670
+ gamma_logratios = self.simpo_gamma / self.beta
671
+ logits = logits - gamma_logratios
672
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
673
+ losses = (
674
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
675
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
676
+ )
677
+ elif self.loss_type == "sigmoid":
678
+ # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
679
+ losses = (
680
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
681
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
682
+ )
683
+ elif self.loss_type == "hinge":
684
+ losses = torch.relu(1 - self.beta * logits)
685
+ elif self.loss_type == "ipo":
686
+ # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
687
+ losses = (logits - 1 / (2 * self.beta)) ** 2
688
+ else:
689
+ raise ValueError(
690
+ f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']"
691
+ )
692
+
693
+ # Calculate rewards for logging
694
+ if self.alpha != 0.0:
695
+ # When using AlphaPO transformation, use the transformed rewards
696
+ chosen_rewards = self.beta * policy_chosen_rewards.to(self.accelerator.device).detach()
697
+ rejected_rewards = self.beta * policy_rejected_rewards.to(self.accelerator.device).detach()
698
+ else:
699
+ # Standard log probability rewards
700
+ chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
701
+ rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
702
+
703
+ return losses, chosen_rewards, rejected_rewards
704
+
705
+ @staticmethod
706
+ def get_batch_logps(
707
+ logits: torch.FloatTensor,
708
+ labels: torch.LongTensor,
709
+ average_log_prob: bool = False,
710
+ is_encoder_decoder: bool = False,
711
+ ) -> torch.FloatTensor:
712
+ """Compute the log probabilities of the given labels under the given logits.
713
+
714
+ Args:
715
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
716
+ labels:
717
+ Labels for which to compute the log probabilities. Label tokens with a value of `-100` are ignored.
718
+ Shape: (batch_size, sequence_length)
719
+ average_log_prob:
720
+ If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
721
+ log probabilities of the (non-masked) tokens.
722
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
723
+
724
+ Returns:
725
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
726
+ given logits.
727
+ """
728
+ if logits.shape[:-1] != labels.shape:
729
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
730
+
731
+ if not is_encoder_decoder:
732
+ labels = labels[:, 1:].clone()
733
+ logits = logits[:, :-1, :]
734
+ loss_mask = labels != -100
735
+
736
+ # dummy token; we'll ignore the losses on these tokens later
737
+ labels[labels == -100] = 0
738
+
739
+ per_token_logps = selective_log_softmax(logits, labels)
740
+
741
+ if average_log_prob:
742
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
743
+ else:
744
+ return (per_token_logps * loss_mask).sum(-1)
745
+
746
+ def concatenated_forward(
747
+ self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
748
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
749
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
750
+
751
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
752
+ """
753
+ concatenated_batch = self.concatenated_inputs(
754
+ batch,
755
+ is_encoder_decoder=self.is_encoder_decoder,
756
+ padding_value=self.pad_token_id,
757
+ device=self.accelerator.device,
758
+ )
759
+ len_chosen = batch["chosen_labels"].shape[0]
760
+
761
+ model_kwargs = (
762
+ {
763
+ "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
764
+ }
765
+ if self.is_encoder_decoder
766
+ else {}
767
+ )
768
+
769
+ if self.aux_loss_enabled:
770
+ model_kwargs["output_router_logits"] = True
771
+
772
+ outputs = model(
773
+ concatenated_batch["concatenated_input_ids"],
774
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
775
+ use_cache=False,
776
+ **model_kwargs,
777
+ )
778
+ all_logits = outputs.logits
779
+
780
+ def cross_entropy_loss(logits, labels):
781
+ if not self.is_encoder_decoder:
782
+ # Shift so that tokens < n predict n
783
+ logits = logits[..., :-1, :].contiguous()
784
+ labels = labels[..., 1:].contiguous()
785
+ # Flatten the tokens
786
+ loss_fct = nn.CrossEntropyLoss()
787
+ logits = logits.view(-1, logits.shape[-1])
788
+ labels = labels.view(-1)
789
+ # Enable model parallelism
790
+ labels = labels.to(logits.device)
791
+ loss = loss_fct(logits, labels)
792
+ return loss
793
+
794
+ labels = concatenated_batch["concatenated_labels"].clone()
795
+
796
+ if self.cpo_alpha == 0:
797
+ nll_loss = torch.tensor(0.0).to(self.accelerator.device)
798
+ else:
799
+ nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
800
+
801
+ all_logps = self.get_batch_logps(
802
+ all_logits,
803
+ concatenated_batch["concatenated_labels"],
804
+ average_log_prob=self.loss_type in ["ipo", "simpo"],
805
+ is_encoder_decoder=self.is_encoder_decoder,
806
+ )
807
+
808
+ chosen_logps = all_logps[:len_chosen]
809
+ rejected_logps = all_logps[len_chosen:]
810
+
811
+ chosen_logits = all_logits[:len_chosen]
812
+ rejected_logits = all_logits[len_chosen:]
813
+
814
+ if self.aux_loss_enabled:
815
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
816
+
817
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
818
+
819
+ def get_batch_loss_metrics(
820
+ self,
821
+ model,
822
+ batch: dict[str, list | torch.LongTensor],
823
+ train_eval: Literal["train", "eval"] = "train",
824
+ ):
825
+ """Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
826
+ metrics = {}
827
+
828
+ forward_output = self.concatenated_forward(model, batch)
829
+ (
830
+ policy_chosen_logps,
831
+ policy_rejected_logps,
832
+ policy_chosen_logits,
833
+ policy_rejected_logits,
834
+ policy_nll_loss,
835
+ ) = forward_output[:5]
836
+ if self.aux_loss_enabled:
837
+ aux_loss = forward_output[5]
838
+
839
+ losses, chosen_rewards, rejected_rewards = self.cpo_loss(
840
+ policy_chosen_logps,
841
+ policy_rejected_logps,
842
+ )
843
+
844
+ loss = losses.mean() + self.cpo_alpha * policy_nll_loss
845
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
846
+
847
+ prefix = "eval_" if train_eval == "eval" else ""
848
+ metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
849
+ metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
850
+ metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
851
+ metrics[f"{prefix}rewards/margins"] = (
852
+ self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
853
+ )
854
+ metrics[f"{prefix}logps/rejected"] = (
855
+ self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item()
856
+ )
857
+ metrics[f"{prefix}logps/chosen"] = (
858
+ self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item()
859
+ )
860
+ metrics[f"{prefix}logits/rejected"] = (
861
+ self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean()).mean().item()
862
+ )
863
+ metrics[f"{prefix}logits/chosen"] = (
864
+ self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean()).mean().item()
865
+ )
866
+ metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
867
+
868
+ if self.aux_loss_enabled:
869
+ loss += self.aux_loss_coef * aux_loss
870
+
871
+ return loss, metrics
872
+
873
+ def compute_loss(
874
+ self,
875
+ model: PreTrainedModel | nn.Module,
876
+ inputs: dict[str, torch.Tensor | Any],
877
+ return_outputs=False,
878
+ num_items_in_batch=None,
879
+ ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
880
+ compute_loss_context_manager = (
881
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
882
+ )
883
+
884
+ with compute_loss_context_manager:
885
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
886
+
887
+ # force log the metrics
888
+ self.store_metrics(metrics, train_eval="train")
889
+
890
+ if return_outputs:
891
+ return (loss, metrics)
892
+ return loss
893
+
894
+ def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str:
895
+ """Generate samples from the model and reference model for the given batch of inputs."""
896
+
897
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
898
+ # the torch amp context manager as some hidden states are silently casted to full precision.
899
+ generate_context_manager = (
900
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
901
+ )
902
+
903
+ with generate_context_manager:
904
+ policy_output = model.generate(
905
+ input_ids=batch["prompt_input_ids"],
906
+ attention_mask=batch["prompt_attention_mask"],
907
+ max_length=self.max_length,
908
+ do_sample=True,
909
+ pad_token_id=self.processing_class.pad_token_id,
910
+ )
911
+
912
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
913
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
914
+
915
+ return policy_output_decoded
916
+
917
+ def prediction_step(
918
+ self,
919
+ model: PreTrainedModel | nn.Module,
920
+ inputs: dict[str, torch.Tensor | Any],
921
+ prediction_loss_only: bool,
922
+ ignore_keys: list[str] | None = None,
923
+ ):
924
+ if ignore_keys is None:
925
+ if hasattr(model, "config"):
926
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
927
+ else:
928
+ ignore_keys = []
929
+
930
+ prediction_context_manager = (
931
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
932
+ )
933
+
934
+ with torch.no_grad(), prediction_context_manager:
935
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
936
+
937
+ # force log the metrics
938
+ self.store_metrics(metrics, train_eval="eval")
939
+
940
+ if prediction_loss_only:
941
+ return (loss.detach(), None, None)
942
+
943
+ # logits for the chosen and rejected samples from model
944
+ logits_dict = {
945
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
946
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
947
+ }
948
+ logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
949
+ logits = torch.tensor(logits, device=self.accelerator.device)
950
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
951
+
952
+ return (loss.detach(), logits, labels)
953
+
954
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
955
+ for key, value in metrics.items():
956
+ self._stored_metrics[train_eval][key].append(value)
957
+
958
+ def evaluation_loop(
959
+ self,
960
+ dataloader: DataLoader,
961
+ description: str,
962
+ prediction_loss_only: bool | None = None,
963
+ ignore_keys: list[str] | None = None,
964
+ metric_key_prefix: str = "eval",
965
+ ) -> EvalLoopOutput:
966
+ """
967
+ Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
968
+ `Trainer.evaluate()` and `Trainer.predict()`.
969
+
970
+ Works both with or without labels.
971
+ """
972
+
973
+ # Sample and save to game log if requested (for one batch to save time)
974
+ if self.generate_during_eval:
975
+ # Generate random indices within the range of the total number of samples
976
+ num_samples = len(dataloader.dataset)
977
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
978
+
979
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
980
+ random_batch_dataset = dataloader.dataset.select(random_indices)
981
+ random_batch = self.data_collator(random_batch_dataset)
982
+ random_batch = self._prepare_inputs(random_batch)
983
+
984
+ policy_output_decoded = self.generate_from_model(self.model, random_batch)
985
+
986
+ table = pd.DataFrame(
987
+ columns=["Prompt", "Policy"],
988
+ data=[
989
+ [prompt, pol[len(prompt) :]]
990
+ for prompt, pol in zip(random_batch["prompt"], policy_output_decoded, strict=True)
991
+ ],
992
+ )
993
+ if "wandb" in self.args.report_to:
994
+ wandb.log({"game_log": wandb.Table(data=table)})
995
+
996
+ if "comet_ml" in self.args.report_to:
997
+ log_table_to_comet_experiment(
998
+ name="game_log.csv",
999
+ table=table,
1000
+ )
1001
+
1002
+ # Base evaluation
1003
+ initial_output = super().evaluation_loop(
1004
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1005
+ )
1006
+
1007
+ return initial_output
1008
+
1009
+ def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
1010
+ """
1011
+ Log `logs` on the various objects watching training, including stored metrics.
1012
+
1013
+ Args:
1014
+ logs (`dict[str, float]`):
1015
+ The values to log.
1016
+ start_time (`float`, *optional*):
1017
+ Start time of the training.
1018
+ """
1019
+ # logs either has 'loss' or 'eval_loss'
1020
+ train_eval = "train" if "loss" in logs else "eval"
1021
+ # Add averaged stored metrics to logs
1022
+ for key, metrics in self._stored_metrics[train_eval].items():
1023
+ logs[key] = torch.tensor(metrics).mean().item()
1024
+ del self._stored_metrics[train_eval]
1025
+ return super().log(logs, start_time)
1026
+
1027
+ def _shift_right(self, input_ids):
1028
+ if self.decoder_start_token_id is None:
1029
+ raise ValueError(
1030
+ "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
1031
+ )
1032
+
1033
+ # shift inputs to the right
1034
+ if is_torch_fx_proxy(input_ids):
1035
+ # Item assignment is not supported natively for proxies.
1036
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
1037
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
1038
+ else:
1039
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
1040
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
1041
+ shifted_input_ids[..., 0] = self.decoder_start_token_id
1042
+
1043
+ if self.pad_token_id is None:
1044
+ raise ValueError("model.config.pad_token_id has to be defined.")
1045
+ # replace possible -100 values in labels by `pad_token_id`
1046
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
1047
+
1048
+ return shifted_input_ids
1049
+
1050
+ # Ensure the model card is saved along with the checkpoint
1051
+ def _save_checkpoint(self, model, trial):
1052
+ if self.args.hub_model_id is None:
1053
+ model_name = Path(self.args.output_dir).name
1054
+ else:
1055
+ model_name = self.args.hub_model_id.split("/")[-1]
1056
+ self.create_model_card(model_name=model_name)
1057
+ super()._save_checkpoint(model, trial)
ICL/RL/trl_source/trl/experimental/gfpo/gfpo_config.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+
17
+ from ...trainer.grpo_config import GRPOConfig as _GRPOConfig
18
+
19
+
20
+ @dataclass
21
+ class GFPOConfig(_GRPOConfig):
22
+ num_remains_in_group: int | None = field(
23
+ default=None,
24
+ metadata={
25
+ "help": "number inputs remains after group filter function, `'num_remains_in_group'` must be >=2 if given."
26
+ },
27
+ )
28
+
29
+ def __post_init__(self):
30
+ super().__post_init__()
31
+
32
+ if self.num_remains_in_group is not None and self.num_remains_in_group >= self.num_generations:
33
+ raise ValueError(
34
+ f"Number remains in Group {self.num_remains_in_group} must be less than num_generations : {self.num_generations}."
35
+ )
ICL/RL/trl_source/trl/experimental/gkd/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .gkd_config import GKDConfig
16
+ from .gkd_trainer import GKDTrainer
17
+
18
+
19
+ __all__ = ["GKDConfig", "GKDTrainer"]
ICL/RL/trl_source/trl/experimental/gkd/gkd_config.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import Any
17
+
18
+ from transformers import TrainingArguments
19
+
20
+ from ...trainer.sft_config import SFTConfig
21
+
22
+
23
+ @dataclass
24
+ class GKDConfig(SFTConfig):
25
+ """
26
+ Configuration class for [`experimental.gkd.GKDTrainer`].
27
+
28
+ This class includes only the parameters that are specific to GKD training. For a full list of training arguments,
29
+ please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation.
30
+
31
+ Args:
32
+ temperature (`float`, *optional*, defaults to `0.9`):
33
+ Temperature for sampling. The higher the temperature, the more random the completions.
34
+ lmbda (`float`, *optional*, defaults to `0.5`):
35
+ Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
36
+ student-generated outputs).
37
+ beta (`float`, *optional*, defaults to `0.5`):
38
+ Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
39
+ beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
40
+ max_new_tokens (`int`, *optional*, defaults to `128`):
41
+ Maximum number of tokens to generate per completion.
42
+ teacher_model_name_or_path (`str`, *optional*):
43
+ Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
44
+ trained.
45
+ teacher_model_init_kwargs (`dict[str, Any]]`, *optional*):
46
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
47
+ from a string.
48
+ disable_dropout (`bool`, *optional*, defaults to `True`):
49
+ Whether to disable dropout in the model.
50
+ seq_kd (`bool`, *optional*, defaults to `False`):
51
+ Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
52
+ teacher-generated output).
53
+ """
54
+
55
+ _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
56
+
57
+ temperature: float = field(
58
+ default=0.9,
59
+ metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
60
+ )
61
+ lmbda: float = field(
62
+ default=0.5,
63
+ metadata={
64
+ "help": "Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy "
65
+ "student-generated outputs)."
66
+ },
67
+ )
68
+ beta: float = field(
69
+ default=0.5,
70
+ metadata={
71
+ "help": "Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence "
72
+ "loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL "
73
+ "Divergence."
74
+ },
75
+ )
76
+ max_new_tokens: int = field(
77
+ default=128,
78
+ metadata={"help": "Maximum number of tokens to generate per completion."},
79
+ )
80
+ teacher_model_name_or_path: str | None = field(
81
+ default=None,
82
+ metadata={
83
+ "help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the "
84
+ "model being trained."
85
+ },
86
+ )
87
+ teacher_model_init_kwargs: dict[str, Any] | None = field(
88
+ default=None,
89
+ metadata={
90
+ "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
91
+ "teacher model from a string."
92
+ },
93
+ )
94
+ disable_dropout: bool = field(
95
+ default=True,
96
+ metadata={"help": "Whether to disable dropouts in `model`."},
97
+ )
98
+ seq_kd: bool = field(
99
+ default=False,
100
+ metadata={
101
+ "help": "Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised "
102
+ "FT on teacher-generated output)."
103
+ },
104
+ )
105
+
106
+ def __post_init__(self):
107
+ super().__post_init__()
108
+ # check lmbda and beta are in the range [0, 1]
109
+ if self.lmbda < 0.0 or self.lmbda > 1.0:
110
+ raise ValueError("lmbda must be in the range [0.0, 1.0].")
111
+ if self.beta < 0.0 or self.beta > 1.0:
112
+ raise ValueError("beta must be in the range [0.0, 1.0].")
ICL/RL/trl_source/trl/experimental/gold/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .gold_config import GOLDConfig
16
+ from .gold_trainer import GOLDTrainer
17
+
18
+
19
+ __all__ = ["GOLDConfig", "GOLDTrainer"]
ICL/RL/trl_source/trl/experimental/gold/gold.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # /// script
16
+ # dependencies = [
17
+ # "trl @ git+https://github.com/huggingface/trl.git",
18
+ # "peft",
19
+ # "trackio",
20
+ # ]
21
+ # ///
22
+
23
+ # docstyle-ignore
24
+ """
25
+ # Full training:
26
+ python trl/experimental/gold/gold.py \
27
+ --model_name_or_path meta-llama/Llama-3.2-1B-Instruct \
28
+ --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \
29
+ --dataset_name trl-lib/chatbot_arena_completions \
30
+ --learning_rate 2e-5 \
31
+ --per_device_train_batch_size 4 \
32
+ --gradient_accumulation_steps 8 \
33
+ --output_dir gold-model \
34
+ --num_train_epochs 1 \
35
+ --push_to_hub
36
+
37
+ # LoRA:
38
+ python trl/experimental/gold/gold.py \
39
+ --model_name_or_path meta-llama/Llama-3.2-1B-Instruct \
40
+ --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \
41
+ --dataset_name trl-lib/chatbot_arena_completions \
42
+ --learning_rate 2e-4 \
43
+ --per_device_train_batch_size 4 \
44
+ --gradient_accumulation_steps 8 \
45
+ --output_dir gold-model \
46
+ --num_train_epochs 1 \
47
+ --push_to_hub \
48
+ --use_peft \
49
+ --lora_r 64 \
50
+ --lora_alpha 16
51
+ """
52
+
53
+ import logging
54
+
55
+ from datasets import load_dataset
56
+ from transformers import AutoTokenizer, GenerationConfig
57
+
58
+ from trl import (
59
+ LogCompletionsCallback,
60
+ ModelConfig,
61
+ ScriptArguments,
62
+ TrlParser,
63
+ get_kbit_device_map,
64
+ get_peft_config,
65
+ get_quantization_config,
66
+ )
67
+ from trl.experimental.gold.gold_config import GOLDConfig
68
+ from trl.experimental.gold.gold_trainer import GOLDTrainer
69
+
70
+
71
+ logger = logging.getLogger(__name__)
72
+
73
+
74
+ if __name__ == "__main__":
75
+ parser = TrlParser((ScriptArguments, GOLDConfig, ModelConfig))
76
+ script_args, training_args, model_args = parser.parse_args_and_config()
77
+
78
+ ################
79
+ # Model & Tokenizer
80
+ ################
81
+ quantization_config = get_quantization_config(model_args)
82
+ model_kwargs = dict(
83
+ revision=training_args.student_model_revision,
84
+ trust_remote_code=model_args.trust_remote_code,
85
+ attn_implementation=model_args.attn_implementation,
86
+ torch_dtype=model_args.dtype,
87
+ use_cache=False if training_args.gradient_checkpointing else True,
88
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
89
+ quantization_config=quantization_config,
90
+ )
91
+ training_args.model_init_kwargs = model_kwargs
92
+
93
+ if training_args.teacher_tokenizer_name_or_path is None and training_args.use_uld_loss:
94
+ training_args.teacher_tokenizer_name_or_path = training_args.teacher_model_name_or_path
95
+ teacher_model_kwargs = dict(
96
+ revision=model_args.model_revision,
97
+ trust_remote_code=model_args.trust_remote_code,
98
+ attn_implementation=model_args.attn_implementation,
99
+ torch_dtype=model_args.dtype,
100
+ use_cache=True,
101
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
102
+ quantization_config=quantization_config,
103
+ )
104
+ training_args.teacher_model_init_kwargs = teacher_model_kwargs
105
+
106
+ tokenizer = AutoTokenizer.from_pretrained(
107
+ model_args.model_name_or_path,
108
+ revision=model_args.model_revision,
109
+ trust_remote_code=model_args.trust_remote_code,
110
+ padding_side="left",
111
+ )
112
+ if tokenizer.pad_token is None:
113
+ tokenizer.pad_token = tokenizer.eos_token
114
+
115
+ ################
116
+ # Dataset
117
+ ################
118
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
119
+
120
+ ################
121
+ # Training
122
+ ################
123
+ # Handle eval dataset - check if test split exists, fallback to validation or None
124
+ eval_dataset = None
125
+ if training_args.eval_strategy != "no":
126
+ if script_args.dataset_test_split in dataset:
127
+ eval_dataset = dataset[script_args.dataset_test_split]
128
+ elif "validation" in dataset:
129
+ eval_dataset = dataset["validation"]
130
+ elif "dev" in dataset:
131
+ eval_dataset = dataset["dev"]
132
+
133
+ trainer = GOLDTrainer(
134
+ model=model_args.model_name_or_path,
135
+ teacher_model=training_args.teacher_model_name_or_path,
136
+ args=training_args,
137
+ train_dataset=dataset[script_args.dataset_train_split],
138
+ eval_dataset=eval_dataset,
139
+ processing_class=tokenizer,
140
+ peft_config=get_peft_config(model_args),
141
+ )
142
+
143
+ if training_args.eval_strategy != "no":
144
+ generation_config = GenerationConfig(
145
+ max_new_tokens=training_args.max_completion_length, do_sample=True, temperature=training_args.temperature
146
+ )
147
+ completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
148
+ trainer.add_callback(completions_callback)
149
+
150
+ trainer.train()
151
+
152
+ # Save and push to hub
153
+ trainer.save_model(training_args.output_dir)
154
+ if training_args.push_to_hub:
155
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
ICL/RL/trl_source/trl/experimental/gold/gold_config.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import Any
17
+
18
+ from transformers import TrainingArguments
19
+
20
+ from ...trainer.sft_config import SFTConfig
21
+
22
+
23
+ @dataclass
24
+ class GOLDConfig(SFTConfig):
25
+ r"""
26
+ Configuration class for [`GOLDTrainer`].
27
+
28
+ This class includes only the parameters that are specific to GOLD training. For a full list of training arguments,
29
+ please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation.
30
+
31
+ Args:
32
+ temperature (`float`, *optional*, defaults to `0.9`):
33
+ Temperature for sampling. The higher the temperature, the more random the completions.
34
+ lmbda (`float`, *optional*, defaults to `0.5`):
35
+ Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
36
+ student-generated outputs).
37
+ beta (`float`, *optional*, defaults to `0.5`):
38
+ Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
39
+ beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
40
+ max_completion_length (`int`, *optional*, defaults to `128`):
41
+ Maximum number of tokens to generate per completion.
42
+ teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
43
+ Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
44
+ trained.
45
+ teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
46
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
47
+ from a string.
48
+ teacher_tokenizer_name_or_path (`str` or `None`, *optional*, defaults to `None`):
49
+ Tokenizer name or path for the teacher model. If None when using ULD loss, will use the same tokenizer as
50
+ the student model (not recommended for cross-tokenizer distillation).
51
+ disable_dropout (`bool`, *optional*, defaults to `True`):
52
+ Whether to disable dropout in the model.
53
+ seq_kd (`bool`, *optional*, defaults to `False`):
54
+ Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
55
+ teacher-generated output).
56
+ use_uld_loss (`bool`, *optional*, defaults to `False`):
57
+ Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence
58
+ loss.
59
+ uld_crossentropy_weight (`float`, *optional*, defaults to `0.0`):
60
+ Weight for the cross-entropy loss component in ULD loss. If 0, only ULD distillation loss is used.
61
+ uld_distillation_weight (`float`, *optional*, defaults to `1.0`):
62
+ Weight for the distillation loss component in ULD loss.
63
+ uld_student_temperature (`float`, *optional*, defaults to `1.0`):
64
+ Temperature for student logits in ULD loss computation.
65
+ uld_teacher_temperature (`float`, *optional*, defaults to `1.0`):
66
+ Temperature for teacher logits in ULD loss computation.
67
+ uld_skip_student_eos (`bool`, *optional*, defaults to `True`):
68
+ Whether to skip EOS token for student in ULD loss computation.
69
+ uld_skip_teacher_eos (`bool`, *optional*, defaults to `True`):
70
+ Whether to skip EOS token for teacher in ULD loss computation.
71
+ use_vllm (`bool`, *optional*, defaults to `False`):
72
+ Whether to use vLLM for generating completions from the student model. Requires `vllm` to be installed.
73
+ vllm_mode (`str`, *optional*, defaults to `"server"`):
74
+ Mode for student vLLM integration. Either `"server"` (connect to a running TRL vLLM server) or `"colocate"`
75
+ (run vLLM in the same process).
76
+ vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
77
+ Host of the vLLM server for the student model (if `vllm_mode="server"`).
78
+ vllm_server_port (`int`, *optional*, defaults to `8001`):
79
+ Port of the vLLM server for the student model (if `vllm_mode="server"`).
80
+ vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
81
+ Timeout for connecting to the student vLLM server (if `vllm_mode="server"`).
82
+ vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
83
+ GPU memory utilization for the colocated student vLLM engine (if `vllm_mode="colocate"`). It is recommended
84
+ to set this to a low value if the student and teacher models share the same GPU.
85
+ vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
86
+ Tensor parallel size for the colocated student vLLM engine (if `vllm_mode="colocate"`).
87
+ vllm_structured_outputs_regex (`str` or `None`, *optional*, defaults to `None`):
88
+ Regex for vLLM structured outputs for the student model.
89
+ vllm_sync_frequency (`int`, *optional*, defaults to `1`):
90
+ Frequency (in training steps) to synchronize student model weights to vLLM engine. Set to 1 to sync after
91
+ every step.
92
+ vllm_enable_sleep_mode (`bool`, *optional*, defaults to `False`):
93
+ Enable vLLM sleep mode to offload student weights/cache during the optimizer step. Keeps GPU memory usage
94
+ low, but waking the engine adds host–device transfer latency.
95
+ """
96
+
97
+ _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]
98
+
99
+ # Parameters whose default values are overridden from TrainingArguments
100
+ learning_rate: float = field(
101
+ default=1e-7,
102
+ metadata={"help": "The initial learning rate for AdamW."},
103
+ )
104
+
105
+ # GOLD-specific parameters
106
+ temperature: float = field(
107
+ default=0.9,
108
+ metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
109
+ )
110
+ top_p: float = field(
111
+ default=0.95,
112
+ metadata={
113
+ "help": "If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to "
114
+ "`top_p` or higher are kept for generation."
115
+ },
116
+ )
117
+ top_k: int = field(
118
+ default=0,
119
+ metadata={
120
+ "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, "
121
+ "top-k-filtering is disabled and all tokens are considered."
122
+ },
123
+ )
124
+ lmbda: float = field(
125
+ default=0.5,
126
+ metadata={
127
+ "help": "Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy "
128
+ "student-generated outputs)."
129
+ },
130
+ )
131
+ beta: float = field(
132
+ default=0.5,
133
+ metadata={
134
+ "help": "Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence "
135
+ "loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL "
136
+ "Divergence."
137
+ },
138
+ )
139
+ max_completion_length: int = field(
140
+ default=128,
141
+ metadata={"help": "Maximum number of tokens to generate per completion."},
142
+ )
143
+ student_model_revision: str = field(
144
+ default="main",
145
+ metadata={
146
+ "help": "Revision of the student model to use. If not specified, the default revision of the model will be used."
147
+ },
148
+ )
149
+ teacher_model_name_or_path: str | None = field(
150
+ default=None,
151
+ metadata={
152
+ "help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the "
153
+ "model being trained."
154
+ },
155
+ )
156
+ teacher_model_init_kwargs: dict[str, Any] | None = field(
157
+ default=None,
158
+ metadata={
159
+ "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
160
+ "teacher model from a string."
161
+ },
162
+ )
163
+ teacher_tokenizer_name_or_path: str | None = field(
164
+ default=None,
165
+ metadata={
166
+ "help": "Tokenizer name or path for the teacher model. If None when using ULD loss, will use the same "
167
+ "tokenizer as the student model (not recommended for cross-tokenizer distillation)."
168
+ },
169
+ )
170
+ disable_dropout: bool = field(
171
+ default=True,
172
+ metadata={"help": "Whether to disable dropouts in `model`."},
173
+ )
174
+ seq_kd: bool = field(
175
+ default=False,
176
+ metadata={
177
+ "help": "Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised "
178
+ "FT on teacher-generated output)."
179
+ },
180
+ )
181
+ steps_per_generation: int | None = field(
182
+ default=None,
183
+ metadata={
184
+ "help": "Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps."
185
+ },
186
+ )
187
+
188
+ # ULD Loss parameters
189
+ use_uld_loss: bool = field(
190
+ default=False,
191
+ metadata={
192
+ "help": "Whether to use Universal Logit Distillation (ULD) loss instead of Generalized Jensen-Shannon Divergence loss."
193
+ },
194
+ )
195
+ use_extended_uld: bool = field(
196
+ default=True,
197
+ metadata={
198
+ "help": (
199
+ "Whether to enable extended ULD alignment that uses tokenizers to align and merge token "
200
+ "probabilities across student and teacher tokenizations. When True, the trainer will compute "
201
+ "token mappings and merge probabilities for split tokens; when False, ULD will use simple "
202
+ "positional truncation like in the original ULD paper."
203
+ )
204
+ },
205
+ )
206
+ uld_use_hybrid_loss: bool = field(
207
+ default=False,
208
+ metadata={
209
+ "help": (
210
+ "Whether to use a hybrid loss that combines ULD loss and JSD loss. When True, the final loss is a "
211
+ "a combination of JSD for known token mappings and ULD for unknown token mappings."
212
+ )
213
+ },
214
+ )
215
+ uld_hybrid_matched_weight: float | None = field(
216
+ default=None,
217
+ metadata={
218
+ "help": (
219
+ "Weight for the matched token loss component when using hybrid ULD + JSD loss. This weight scales "
220
+ "the JSD loss computed over tokens that have a direct mapping between student and teacher "
221
+ "tokenizations. If None, uses adaptive weighting based on vocabulary overlap. Must be set together "
222
+ "with uld_hybrid_unmatched_weight (both None or both float)."
223
+ )
224
+ },
225
+ )
226
+ uld_hybrid_unmatched_weight: float | None = field(
227
+ default=None,
228
+ metadata={
229
+ "help": (
230
+ "Weight for the unmatched token loss component when using hybrid ULD + JSD loss. This weight scales "
231
+ "the ULD loss computed over tokens that do not have a direct mapping between student and teacher "
232
+ "tokenizations. If None, uses adaptive weighting based on vocabulary overlap. Must be set together "
233
+ "with uld_hybrid_matched_weight (both None or both float)."
234
+ )
235
+ },
236
+ )
237
+ uld_crossentropy_weight: float = field(
238
+ default=0.0,
239
+ metadata={"help": "Weight for the cross-entropy loss component in ULD loss."},
240
+ )
241
+ uld_distillation_weight: float = field(
242
+ default=1.0,
243
+ metadata={"help": "Weight for the distillation loss component in ULD loss."},
244
+ )
245
+ uld_student_temperature: float = field(
246
+ default=1.0,
247
+ metadata={"help": "Temperature for student logits in ULD loss computation."},
248
+ )
249
+ uld_teacher_temperature: float = field(
250
+ default=1.0,
251
+ metadata={"help": "Temperature for teacher logits in ULD loss computation."},
252
+ )
253
+
254
+ uld_skip_student_eos: bool = field(
255
+ default=True,
256
+ metadata={"help": "Whether to skip EOS token for student in ULD loss computation."},
257
+ )
258
+ uld_skip_teacher_eos: bool = field(
259
+ default=True,
260
+ metadata={"help": "Whether to skip EOS token for teacher in ULD loss computation."},
261
+ )
262
+
263
+ # transformers paged attention
264
+ use_transformers_paged: bool = field(
265
+ default=False,
266
+ metadata={
267
+ "help": "Whether to use the `transformers` paged implementation for generation. If set to `True`, the "
268
+ "`transformers` paged implementation will be used for generation instead of the default padded "
269
+ "implementation."
270
+ },
271
+ )
272
+
273
+ # vLLM parameters
274
+ use_vllm: bool = field(
275
+ default=False,
276
+ metadata={"help": "Whether to use vLLM for generating completions. Requires `vllm` to be installed."},
277
+ )
278
+ vllm_mode: str = field(
279
+ default="server",
280
+ metadata={
281
+ "help": 'Mode for vLLM integration. Either "server" (connect to a running TRL vLLM server) or "colocate" (run vLLM in the same process).'
282
+ },
283
+ )
284
+ vllm_server_host: str = field(
285
+ default="0.0.0.0",
286
+ metadata={"help": 'Host of the vLLM server when `vllm_mode="server"`.'},
287
+ )
288
+ vllm_server_port: int = field(
289
+ default=8001,
290
+ metadata={"help": 'Port of the vLLM server when `vllm_mode="server"`.'},
291
+ )
292
+ vllm_server_timeout: float = field(
293
+ default=240.0,
294
+ metadata={"help": 'Timeout (in seconds) for connecting to the vLLM server when `vllm_mode="server"`.'},
295
+ )
296
+ vllm_gpu_memory_utilization: float = field(
297
+ default=0.9,
298
+ metadata={
299
+ "help": 'GPU memory utilization for the colocated vLLM engine when `vllm_mode="colocate"`. Lower values reduce contention when sharing a device with the student/teacher models.'
300
+ },
301
+ )
302
+ vllm_tensor_parallel_size: int = field(
303
+ default=1,
304
+ metadata={"help": 'Tensor parallel size for the colocated vLLM engine when `vllm_mode="colocate"`.'},
305
+ )
306
+ vllm_structured_outputs_regex: str | None = field(
307
+ default=None,
308
+ metadata={"help": "Regex pattern used for vLLM structured outputs (optional)."},
309
+ )
310
+ vllm_sync_frequency: int = field(
311
+ default=1,
312
+ metadata={
313
+ "help": "Frequency (in training steps) to synchronize model weights to the vLLM engine. Set to 1 to sync after every step."
314
+ },
315
+ )
316
+ vllm_enable_sleep_mode: bool = field(
317
+ default=False,
318
+ metadata={
319
+ "help": "Enable vLLM sleep mode to offload student weights/cache during the optimizer step. Keeps GPU "
320
+ "memory usage low, but waking the engine adds host–device transfer latency."
321
+ },
322
+ )
323
+ # Parameters that control the logging
324
+ log_completions: bool = field(
325
+ default=False,
326
+ metadata={
327
+ "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
328
+ "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
329
+ },
330
+ )
331
+ log_completions_steps: int = field(
332
+ default=100,
333
+ metadata={
334
+ "help": "Number of steps between logging (prompt, completion) pairs. Only used if `log_completions` is "
335
+ "set to `True`."
336
+ },
337
+ )
338
+ num_completions_to_print: int | None = field(
339
+ default=None,
340
+ metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."},
341
+ )
342
+ wandb_entity: str | None = field(
343
+ default=None,
344
+ metadata={"help": ("The entity to store runs under.")},
345
+ )
346
+ wandb_project: str | None = field(
347
+ default=None,
348
+ metadata={"help": ("The project to store runs under.")},
349
+ )
350
+ wandb_run_group: str | None = field(
351
+ default=None,
352
+ metadata={"help": ("The group to store runs under.")},
353
+ )
354
+ wandb_log_unique_prompts: bool = field(
355
+ default=True,
356
+ metadata={
357
+ "help": ("Whether to log the unique prompts to wandb. This will create a new run for each unique prompt.")
358
+ },
359
+ )
360
+ callbacks: list[str] = field(
361
+ default_factory=lambda: [],
362
+ metadata={"help": "The callbacks to run during training."},
363
+ )
364
+ hub_model_revision: str | None = field(
365
+ default="main", metadata={"help": "The Hub model branch to push the model to."}
366
+ )
367
+ num_completions_to_print: int = field(default=5, metadata={"help": "Number of completions to print."})
368
+ overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
369
+ push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
370
+ trl_project: str = field(
371
+ default="smollm3",
372
+ metadata={
373
+ "help": "The TRL project to use for evaluation. This is used to determine the path to the evaluation script."
374
+ },
375
+ )
376
+
377
+ def __post_init__(self):
378
+ super().__post_init__()
379
+ # check lmbda and beta are in the range [0, 1]
380
+ if self.lmbda < 0.0 or self.lmbda > 1.0:
381
+ raise ValueError("lmbda must be in the range [0.0, 1.0].")
382
+ if self.beta < 0.0 or self.beta > 1.0:
383
+ raise ValueError("beta must be in the range [0.0, 1.0].")
384
+
385
+ # Validate that max_length is sufficient for max_completion_length
386
+ if self.max_length is not None and self.max_completion_length >= self.max_length:
387
+ raise ValueError(
388
+ f"max_completion_length ({self.max_completion_length}) must be smaller than max_length ({self.max_length}) "
389
+ f"to leave room for the prompt. Consider increasing max_length or reducing max_completion_length."
390
+ )
391
+
392
+ if self.steps_per_generation is None:
393
+ self.steps_per_generation = self.gradient_accumulation_steps
394
+
395
+ # Validate ULD parameters
396
+ if self.use_uld_loss:
397
+ if self.uld_crossentropy_weight < 0.0:
398
+ raise ValueError("uld_crossentropy_weight must be non-negative.")
399
+ if self.uld_distillation_weight < 0.0:
400
+ raise ValueError("uld_distillation_weight must be non-negative.")
401
+ if self.uld_student_temperature <= 0.0:
402
+ raise ValueError("uld_student_temperature must be positive.")
403
+ if self.uld_teacher_temperature <= 0.0:
404
+ raise ValueError("uld_teacher_temperature must be positive.")
405
+
406
+ # Validate hybrid loss weights - both must be None or both must be set
407
+ if self.uld_use_hybrid_loss:
408
+ if (self.uld_hybrid_matched_weight is None) != (self.uld_hybrid_unmatched_weight is None):
409
+ raise ValueError(
410
+ "uld_hybrid_matched_weight and uld_hybrid_unmatched_weight must both be None (for adaptive "
411
+ "weighting) or both be set to numeric values. Got uld_hybrid_matched_weight="
412
+ f"{self.uld_hybrid_matched_weight} and uld_hybrid_unmatched_weight="
413
+ f"{self.uld_hybrid_unmatched_weight}."
414
+ )
415
+ if self.uld_hybrid_matched_weight is not None:
416
+ if self.uld_hybrid_matched_weight < 0.0:
417
+ raise ValueError("uld_hybrid_matched_weight must be non-negative.")
418
+ if self.uld_hybrid_unmatched_weight < 0.0:
419
+ raise ValueError("uld_hybrid_unmatched_weight must be non-negative.")
ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig
16
+ from .grpo_with_replay_buffer_trainer import GRPOWithReplayBufferTrainer, ReplayBuffer
ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+
17
+ from ...trainer.grpo_config import GRPOConfig
18
+
19
+
20
+ @dataclass
21
+ class GRPOWithReplayBufferConfig(GRPOConfig):
22
+ """
23
+ New Parameters:
24
+ replay_buffer_size (`int`, *optional*, defaults to `0`):
25
+ A cache that stores the rollouts with the highest advantage scores and variance per group. If a new
26
+ group has 0 variance, it is replaced with a group sampled from the replay buffer.
27
+ """
28
+
29
+ replay_buffer_size: int = field(
30
+ default=64,
31
+ metadata={
32
+ "help": "A cache that stores the rollouts with the highest advantage scores and variance per group. If a new group has 0 variance, it is replaced with a group sampled from the replay buffer."
33
+ },
34
+ )
ICL/RL/trl_source/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import heapq
16
+ from typing import Any
17
+
18
+ import torch
19
+ from accelerate.utils import gather_object
20
+
21
+ from ...data_utils import apply_chat_template, prepare_multimodal_messages
22
+ from ...models.utils import disable_gradient_checkpointing
23
+ from ...trainer.grpo_trainer import GRPOTrainer
24
+ from ...trainer.utils import nanmax, nanmin, nanstd, pad
25
+ from .grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig
26
+
27
+
28
+ class ReplayBuffer:
29
+ """
30
+ A simple replay buffer to store and sample previously seen rollouts.
31
+ """
32
+
33
+ def __init__(self, max_size: int):
34
+ self.max_size = max_size
35
+ self.heap = [] # Min-heap of (score, data) tuples
36
+
37
+ def add(self, scores: list[float], data: list[dict]):
38
+ for score, datum in zip(scores, data, strict=True):
39
+ if len(self.heap) < self.max_size:
40
+ heapq.heappush(self.heap, (score, datum))
41
+ else:
42
+ # Only add if score is better than worst (minimum) item
43
+ if score > self.heap[0][0]:
44
+ heapq.heapreplace(self.heap, (score, datum))
45
+
46
+ def sample(self, num_samples: int) -> list[dict[str, torch.Tensor]]:
47
+ if not self.heap:
48
+ return None
49
+
50
+ # Sample by normalized scores
51
+ scores = torch.tensor([item[0] for item in self.heap], dtype=torch.float32)
52
+ probabilities = scores / scores.sum()
53
+ replacement = False
54
+ if num_samples > len(self.heap):
55
+ replacement = True
56
+ chosen_indices = torch.multinomial(probabilities, num_samples, replacement=replacement).tolist()
57
+ return [self.heap[i][1] for i in chosen_indices]
58
+
59
+
60
+ class GRPOWithReplayBufferTrainer(GRPOTrainer):
61
+ def __init__(self, args: GRPOWithReplayBufferConfig | None = None, **kwargs):
62
+ super().__init__(args=args, **kwargs)
63
+ self.replay_buffer = ReplayBuffer(args.replay_buffer_size) if args.replay_buffer_size > 0 else None
64
+
65
+ def _generate_and_score_completions(
66
+ self, inputs: list[dict[str, torch.Tensor | Any]]
67
+ ) -> dict[str, torch.Tensor | Any]:
68
+ device = self.accelerator.device
69
+ mode = "train" if self.model.training else "eval"
70
+
71
+ prompts = [x["prompt"] for x in inputs]
72
+
73
+ if "images" in inputs[0]:
74
+ images = [example.get("images") for example in inputs]
75
+ elif "image" in inputs[0]:
76
+ images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
77
+ else:
78
+ images = None
79
+ # Transformers requires at least one image in the batch, otherwise it throws an error
80
+ if images is not None and all(img_list == [] for img_list in images):
81
+ images = None
82
+
83
+ # If the prompts are conversational and the inputs contain images, we need to convert the prompts from
84
+ # [{"role": "user", "content": "What color is the sky?"}] to
85
+ # [{"role": "user", "content": [{"type": "image", "image": <Image>}, {"type": "text", "text": "What color is the sky?"}]}]
86
+ if images is not None:
87
+ prompts = [
88
+ prepare_multimodal_messages(prompt, image_list)
89
+ for prompt, image_list in zip(prompts, images, strict=True)
90
+ ]
91
+
92
+ (
93
+ prompt_ids_list,
94
+ completion_ids_list,
95
+ tool_mask_list,
96
+ completions,
97
+ num_items_in_batch,
98
+ sampling_per_token_logps_list,
99
+ extra_fields,
100
+ ) = self._generate(prompts)
101
+
102
+ # Convert lists of token IDs to padded tensors
103
+ prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
104
+ prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
105
+ prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
106
+ prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
107
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list]
108
+ completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
109
+ completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
110
+ completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
111
+ if sampling_per_token_logps_list is not None:
112
+ sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list]
113
+ sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right")
114
+ else:
115
+ sampling_per_token_logps = None
116
+ if self.tools:
117
+ tool_mask = [torch.tensor(mask, device=device) for mask in tool_mask_list]
118
+ tool_mask = pad(tool_mask, padding_value=1, padding_side="right") # 0 for tool result tokens, 1 elsewhere
119
+
120
+ # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
121
+ if self.mask_truncated_completions:
122
+ eos_and_pad = [self.eos_token_id, self.pad_token_id]
123
+ is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
124
+ completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
125
+
126
+ # Concatenate prompt_mask with completion_mask for logit computation
127
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
128
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
129
+
130
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
131
+ batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
132
+
133
+ num_images = [len(img_list) for img_list in images] if images is not None else None
134
+
135
+ # Get forward_kwargs for models with multimodal inputs
136
+ if images is not None:
137
+ prompts_text = [
138
+ apply_chat_template(
139
+ {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs
140
+ )["prompt"]
141
+ for prompt in prompts
142
+ ]
143
+ prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
144
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
145
+ forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
146
+ else:
147
+ forward_kwargs = {}
148
+
149
+ # If token_type_ids are used, extend them with zeros for the completion part
150
+ if "token_type_ids" in forward_kwargs:
151
+ token_type_ids = forward_kwargs["token_type_ids"]
152
+ forward_kwargs["token_type_ids"] = torch.cat(
153
+ [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
154
+ )
155
+
156
+ # When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a
157
+ # torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True").
158
+ # Temporarily disable checkpointing to avoid this warning during inference.
159
+ with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs):
160
+ # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
161
+ # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
162
+ # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps
163
+ # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
164
+ # old_per_token_logps to None.
165
+ # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the
166
+ # distribution mismatch between vLLM and the training model can be large and harm the training.
167
+ generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency
168
+ if self.args.gradient_accumulation_steps % generate_every != 0 or (
169
+ self.use_vllm and self.vllm_importance_sampling_correction
170
+ ):
171
+ old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
172
+ self.model,
173
+ prompt_completion_ids,
174
+ attention_mask,
175
+ logits_to_keep,
176
+ batch_size,
177
+ num_images=num_images,
178
+ **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
179
+ )
180
+ else:
181
+ old_per_token_logps = None
182
+
183
+ # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch
184
+ if self.use_vllm and self.vllm_importance_sampling_correction:
185
+ importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps)
186
+ importance_sampling_ratio = torch.clamp(
187
+ importance_sampling_ratio, max=self.vllm_importance_sampling_cap
188
+ )
189
+
190
+ # Compute the per-token log probabilities for the reference model
191
+ if self.beta != 0.0:
192
+ if self.ref_model is not None:
193
+ ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
194
+ self.ref_model,
195
+ prompt_completion_ids,
196
+ attention_mask,
197
+ logits_to_keep,
198
+ batch_size=batch_size,
199
+ num_images=num_images,
200
+ **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
201
+ )
202
+ else:
203
+ with self.accelerator.unwrap_model(self.model).disable_adapter():
204
+ ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
205
+ self.model,
206
+ prompt_completion_ids,
207
+ attention_mask,
208
+ logits_to_keep,
209
+ batch_size=batch_size,
210
+ num_images=num_images,
211
+ **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
212
+ )
213
+ else:
214
+ ref_per_token_logps = None
215
+
216
+ # Decode
217
+ prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True)
218
+ completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
219
+
220
+ # Merge extra_fields from rollout_func into inputs for reward functions
221
+ if extra_fields:
222
+ for i, inp in enumerate(inputs):
223
+ for key, values in extra_fields.items():
224
+ if isinstance(values, list) and i < len(values):
225
+ inp[key] = values[i]
226
+ elif not isinstance(values, list):
227
+ inp[key] = values
228
+
229
+ # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
230
+ # important because rewards will be normalized per group, and completions are distributed. We will later slice
231
+ # rewards_per_func to extract each process's subset.
232
+ rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
233
+
234
+ # Apply weights to each reward function's output and sum
235
+ rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
236
+
237
+ # Compute grouped-wise rewards
238
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
239
+
240
+ # Normalize the rewards to compute the advantages
241
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
242
+ advantages = rewards - mean_grouped_rewards
243
+
244
+ grouped_std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
245
+ grouped_std_rewards = grouped_std_rewards.repeat_interleave(self.num_generations, dim=0)
246
+
247
+ if self.scale_rewards in ["group", "none"]:
248
+ # If self.scale_rewards = "none", we'll still log group level std
249
+ std_rewards = grouped_std_rewards.clone()
250
+ elif self.scale_rewards == "batch":
251
+ # Compute global std
252
+ std_rewards = rewards.std().expand_as(rewards)
253
+ else:
254
+ raise ValueError(
255
+ f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'."
256
+ )
257
+
258
+ is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards))
259
+ if self.scale_rewards != "none":
260
+ advantages = advantages / (std_rewards + 1e-4)
261
+
262
+ # Slice to keep only the local part of the data
263
+ process_slice = slice(
264
+ self.accelerator.process_index * len(prompts),
265
+ (self.accelerator.process_index + 1) * len(prompts),
266
+ )
267
+ all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
268
+ advantages = advantages[process_slice]
269
+ grouped_std_rewards = grouped_std_rewards[process_slice]
270
+
271
+ # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
272
+ for i, reward_func_name in enumerate(self.reward_func_names):
273
+ mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
274
+ self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
275
+ std_func_rewards = nanstd(rewards_per_func[:, i]).item()
276
+ self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards)
277
+ self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
278
+ self._metrics[mode]["reward_std"].append(std_rewards.mean().item())
279
+ self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())
280
+
281
+ # Log prompt and completion texts
282
+ self._logs["prompt"].extend(gather_object(prompts_text))
283
+ self._logs["completion"].extend(gather_object(completions_text))
284
+ for i, name in enumerate(self.reward_func_names):
285
+ self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
286
+ self._logs["advantages"].extend(all_process_advantages.tolist())
287
+
288
+ if images is not None:
289
+ self._logs["images"].extend(gather_object(images))
290
+
291
+ if self.use_vllm and self.vllm_importance_sampling_correction:
292
+ delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
293
+ mask = completion_mask.bool() if not self.tools else (completion_mask * tool_mask).bool()
294
+ delta = delta[mask]
295
+ mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)
296
+ max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)
297
+ self._metrics[mode]["sampling/sampling_logp_difference/mean"].append(
298
+ self.accelerator.gather(mean_delta).mean().item()
299
+ )
300
+ self._metrics[mode]["sampling/sampling_logp_difference/max"].append(
301
+ self.accelerator.gather(max_delta).max().item()
302
+ )
303
+
304
+ flat_is_ratio = importance_sampling_ratio[mask]
305
+ min_importance_sampling_ratio = (
306
+ torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
307
+ )
308
+ mean_importance_sampling_ratio = (
309
+ torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
310
+ )
311
+ max_importance_sampling_ratio = (
312
+ torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
313
+ )
314
+ self._metrics[mode]["sampling/importance_sampling_ratio/min"].append(
315
+ nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item()
316
+ )
317
+ self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append(
318
+ self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item()
319
+ )
320
+ self._metrics[mode]["sampling/importance_sampling_ratio/max"].append(
321
+ nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item()
322
+ )
323
+ outputs_after_sampling_buffer = self.update_with_replay_buffer(
324
+ advantages,
325
+ grouped_std_rewards,
326
+ prompt_ids,
327
+ prompt_mask,
328
+ completion_ids,
329
+ completion_mask,
330
+ forward_kwargs,
331
+ num_items_in_batch,
332
+ old_per_token_logps,
333
+ ref_per_token_logps,
334
+ importance_sampling_ratio if self.use_vllm and self.vllm_importance_sampling_correction else None,
335
+ )
336
+ if outputs_after_sampling_buffer is not None:
337
+ return outputs_after_sampling_buffer
338
+ else:
339
+ output = {
340
+ "prompt_ids": prompt_ids,
341
+ "prompt_mask": prompt_mask,
342
+ "completion_ids": completion_ids,
343
+ "completion_mask": completion_mask,
344
+ "advantages": advantages,
345
+ "num_items_in_batch": num_items_in_batch,
346
+ }
347
+ if old_per_token_logps is not None:
348
+ output["old_per_token_logps"] = old_per_token_logps
349
+ if self.use_vllm and self.vllm_importance_sampling_correction:
350
+ output["importance_sampling_ratio"] = importance_sampling_ratio
351
+ if ref_per_token_logps is not None:
352
+ output["ref_per_token_logps"] = ref_per_token_logps
353
+ if "pixel_values" in forward_kwargs:
354
+ output["pixel_values"] = forward_kwargs["pixel_values"]
355
+ if "image_grid_thw" in forward_kwargs:
356
+ output["image_grid_thw"] = forward_kwargs["image_grid_thw"]
357
+ if "pixel_attention_mask" in forward_kwargs:
358
+ output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"]
359
+ if "image_sizes" in forward_kwargs:
360
+ output["image_sizes"] = forward_kwargs["image_sizes"]
361
+ if "token_type_ids" in forward_kwargs:
362
+ output["token_type_ids"] = forward_kwargs["token_type_ids"]
363
+ if images is not None:
364
+ output["num_images"] = num_images
365
+ if self.tools:
366
+ output["tool_mask"] = tool_mask
367
+ return output
368
+
369
+ def slice_group_data(
370
+ self, data: torch.Tensor, mask: torch.Tensor, group_idx: int
371
+ ) -> tuple[torch.Tensor, torch.Tensor]:
372
+ """
373
+ Slices the input data and mask tensors for a specific group index. Also trims the sequence length to the
374
+ maximum length in the group based on the mask.
375
+
376
+ Args:
377
+ data: Tensor of shape (num_groups * num_generations, seq_length)
378
+ mask: Tensor of shape (num_groups * num_generations, seq_length)
379
+ group_idx: Index of the group to slice
380
+ Returns:
381
+ Tuple of (sliced_data, sliced_mask) for the specified group, with sequence length trimmed to the maximum
382
+ length in the group.
383
+ """
384
+ start_idx = group_idx * self.num_generations
385
+ end_idx = (group_idx + 1) * self.num_generations
386
+ group_data = data[start_idx:end_idx]
387
+ group_mask = mask[start_idx:end_idx]
388
+ group_max_len = group_mask.sum(dim=1).max().item()
389
+ return group_data[:, :group_max_len], group_mask[:, :group_max_len]
390
+
391
+ def update_replay_buffer(
392
+ self,
393
+ groups_with_variance: torch.Tensor,
394
+ group_advantages: torch.Tensor,
395
+ group_std_rewards: torch.Tensor,
396
+ prompt_ids: torch.Tensor,
397
+ prompt_mask: torch.Tensor,
398
+ completion_ids: torch.Tensor,
399
+ completion_mask: torch.Tensor,
400
+ forward_kwargs: dict,
401
+ optional_vision_fields: list[str] = None,
402
+ old_per_token_logps: torch.Tensor | None = None,
403
+ ref_per_token_logps: torch.Tensor | None = None,
404
+ importance_sampling_ratio: float | None = None,
405
+ ) -> None:
406
+ """
407
+ Update the replay buffer with groups that have reward variance (std > 0).
408
+
409
+ Args:
410
+ groups_with_variance: Boolean tensor indicating which groups have reward variance
411
+ group_advantages: Tensor of shape (num_groups, num_generations) containing advantage values
412
+ std_rewards: Tensor of shape (num_groups, num_generations) containing std of rewards per group
413
+ prompt_ids: Tensor containing prompt token IDs
414
+ prompt_mask: Tensor containing prompt attention masks
415
+ completion_ids: Tensor containing completion token IDs
416
+ completion_mask: Tensor containing completion attention masks
417
+ forward_kwargs: Dictionary containing additional prompt inputs (vision data, etc.)
418
+ optional_vision_fields: List of optional vision-related fields to include if present in forward_kwargs
419
+ old_per_token_logps: Optional tensor of old per-token log probabilities
420
+ ref_per_token_logps: Optional tensor of reference per-token log probabilities
421
+ importance_sampling_ratio: Optional importance sampling correction ratio
422
+ """
423
+ # Prepare buffered outputs for groups with variance
424
+ buffered_outputs = []
425
+ for _, group_idx in enumerate(groups_with_variance.nonzero(as_tuple=True)[0].unique().tolist()):
426
+ group_prompt_ids, group_prompt_mask = self.slice_group_data(prompt_ids, prompt_mask, group_idx)
427
+ group_completion_ids, group_completion_mask = self.slice_group_data(
428
+ completion_ids, completion_mask, group_idx
429
+ )
430
+
431
+ # Store unpadded data in the buffer
432
+ buffered_output = {
433
+ "prompt_ids": group_prompt_ids,
434
+ "completion_ids": group_completion_ids,
435
+ "advantages": group_advantages[group_idx].tolist(),
436
+ "prompt_mask": group_prompt_mask,
437
+ "completion_mask": group_completion_mask,
438
+ }
439
+
440
+ # Add optional fields if they exist
441
+ optional_fields = {
442
+ "old_per_token_logps": old_per_token_logps if old_per_token_logps is not None else None,
443
+ "ref_per_token_logps": ref_per_token_logps if ref_per_token_logps is not None else None,
444
+ }
445
+
446
+ for field_name, field_data in optional_fields.items():
447
+ if field_data is not None:
448
+ buffered_output[field_name] = self.slice_group_data(field_data, completion_mask, group_idx)[0]
449
+
450
+ # Add importance sampling if needed
451
+ if self.use_vllm and self.vllm_importance_sampling_correction:
452
+ buffered_output["importance_sampling_ratio"] = importance_sampling_ratio
453
+
454
+ if optional_vision_fields:
455
+ # Add vision-related fields if they exist
456
+ for field_name in optional_vision_fields:
457
+ if field_name in forward_kwargs:
458
+ buffered_output[field_name] = self.slice_group_data(
459
+ forward_kwargs[field_name], prompt_mask, group_idx
460
+ )[0]
461
+
462
+ buffered_outputs.append(buffered_output)
463
+
464
+ if groups_with_variance.any():
465
+ # Calculate replay buffer scores for groups with variance
466
+ replay_buffer_scores = (group_advantages.abs() * group_std_rewards).sum(dim=-1)[groups_with_variance]
467
+ # Add all groups to replay buffer at once (batch operation)
468
+ self.replay_buffer.add(replay_buffer_scores.tolist(), buffered_outputs)
469
+
470
+ def sample_from_replay_buffer(
471
+ self, num_samples: int, optional_vision_fields: list[str] = None, optional_tensor_fields: list[str] = None
472
+ ) -> list[dict]:
473
+ """
474
+ Sample groups from the replay buffer.
475
+
476
+ Args:
477
+ num_samples: Number of samples to draw from the replay buffer
478
+ optional_vision_fields: List of optional vision-related fields to include if present in sampled data
479
+ optional_tensor_fields: List of optional tensor fields to include if present in sampled data
480
+ Returns:
481
+ List of sampled data dictionaries from the replay buffer
482
+ """
483
+ sampled = self.replay_buffer.sample(num_samples=num_samples)
484
+
485
+ # Extract and concatenate sampled data
486
+ sampled_data = {
487
+ "prompt_ids": [],
488
+ "prompt_mask": [],
489
+ "completion_ids": [],
490
+ "completion_mask": [],
491
+ "advantages": [],
492
+ }
493
+
494
+ all_optional_fields = (optional_tensor_fields or []) + (optional_vision_fields or [])
495
+ # Initialize containers for optional fields if they exist in sampled data
496
+ for field in all_optional_fields:
497
+ if sampled and field in sampled[0]:
498
+ sampled_data[field] = []
499
+
500
+ # Extract data from each sampled item
501
+ for item in sampled:
502
+ # Handle core fields
503
+ for key in ["prompt_ids", "prompt_mask", "completion_ids", "completion_mask"]:
504
+ sampled_data[key].append(item[key])
505
+
506
+ # Handle advantages (list, not tensor)
507
+ sampled_data["advantages"].append(item["advantages"])
508
+
509
+ # Handle optional fields
510
+ for field in all_optional_fields:
511
+ if field in item:
512
+ sampled_data[field].append(item[field])
513
+
514
+ return sampled_data
515
+
516
+ def update_with_replay_buffer(
517
+ self,
518
+ group_advantages: torch.Tensor,
519
+ group_std_rewards: torch.Tensor,
520
+ prompt_ids: torch.Tensor,
521
+ prompt_mask: torch.Tensor,
522
+ completion_ids: torch.Tensor,
523
+ completion_mask: torch.Tensor,
524
+ forward_kwargs: dict,
525
+ num_items_in_batch: int,
526
+ old_per_token_logps: torch.Tensor | None = None,
527
+ ref_per_token_logps: torch.Tensor | None = None,
528
+ importance_sampling_ratio: float | None = None,
529
+ ) -> None:
530
+ """
531
+ Update current batch data with samples from replay buffer.
532
+
533
+ Groups with reward variance (std > 0) are added to the replay buffer and then replaced with samples from the
534
+ buffer to improve training stability.
535
+
536
+ Args:
537
+ group_advantages: Tensor of shape (num_groups, num_generations) containing advantage values
538
+ std_rewards: Tensor of shape (num_groups, num_generations) containing std of rewards per group
539
+ prompt_ids: Tensor containing prompt token IDs
540
+ prompt_mask: Tensor containing prompt attention masks
541
+ completion_ids: Tensor containing completion token IDs
542
+ completion_mask: Tensor containing completion attention masks
543
+ forward_kwargs: Dictionary containing additional prompt inputs (vision data, etc.)
544
+ num_items_in_batch: Number of items in the current batch
545
+ old_per_token_logps: Optional tensor of old per-token log probabilities
546
+ ref_per_token_logps: Optional tensor of reference per-token log probabilities
547
+ importance_sampling_ratio: Optional importance sampling correction ratio
548
+ """
549
+ if self.replay_buffer.max_size <= 0:
550
+ return
551
+
552
+ # Groups to consider for adding to the replay buffer
553
+ groups_with_variance = group_std_rewards.max(dim=0).values > 0
554
+ # Groups to replace from the replay buffer
555
+ groups_without_variance = ~groups_with_variance
556
+
557
+ # Track which optional fields are present in sampled data
558
+ optional_tensor_fields = ["old_per_token_logps", "ref_per_token_logps"]
559
+ vision_fields = ["pixel_values", "image_grid_thw", "pixel_attention_mask", "image_sizes"]
560
+
561
+ self.update_replay_buffer(
562
+ groups_with_variance,
563
+ group_advantages,
564
+ group_std_rewards,
565
+ prompt_ids,
566
+ prompt_mask,
567
+ completion_ids,
568
+ completion_mask,
569
+ forward_kwargs,
570
+ vision_fields,
571
+ old_per_token_logps,
572
+ ref_per_token_logps,
573
+ importance_sampling_ratio,
574
+ )
575
+
576
+ # Sample from replay buffer to replace groups with variance
577
+ num_groups_to_replace = groups_without_variance.sum().item()
578
+ if not num_groups_to_replace:
579
+ return
580
+
581
+ sampled_data = self.sample_from_replay_buffer(
582
+ num_samples=num_groups_to_replace,
583
+ optional_vision_fields=vision_fields,
584
+ optional_tensor_fields=optional_tensor_fields,
585
+ )
586
+
587
+ # Pad sampled data if they are shorter than the current batch sequences
588
+ # Or pad the current batch if sampled are longer
589
+ current_batch_prompt_seq_len = prompt_ids.size(1)
590
+ current_batch_completion_seq_len = completion_ids.size(1)
591
+
592
+ groups_to_replace_idxs = groups_with_variance.logical_not().nonzero(as_tuple=True)[0].unique().tolist()
593
+
594
+ # Determine target (max) sequence lengths once
595
+ sampled_prompt_lengths = [t.size(1) for t in sampled_data["prompt_ids"]]
596
+ sampled_completion_lengths = [t.size(1) for t in sampled_data["completion_ids"]]
597
+ target_prompt_len = max([current_batch_prompt_seq_len] + sampled_prompt_lengths)
598
+ target_completion_len = max([current_batch_completion_seq_len] + sampled_completion_lengths)
599
+
600
+ # If any sampled prompt is longer, pad the whole batch prompt tensors once (left padding)
601
+ if target_prompt_len > current_batch_prompt_seq_len:
602
+ prompt_ids = pad(
603
+ list(prompt_ids.unbind(0)),
604
+ padding_value=self.pad_token_id,
605
+ pad_to_multiple_of=target_prompt_len,
606
+ padding_side="left",
607
+ )
608
+ prompt_mask = pad(
609
+ list(prompt_mask.unbind(0)), padding_value=0, pad_to_multiple_of=target_prompt_len, padding_side="left"
610
+ )
611
+ # If any sampled completion is longer, pad the whole batch completion tensors once (right padding)
612
+ if target_completion_len > current_batch_completion_seq_len:
613
+ completion_ids = pad(
614
+ list(completion_ids.unbind(0)),
615
+ padding_value=self.pad_token_id,
616
+ pad_to_multiple_of=target_completion_len,
617
+ padding_side="right",
618
+ )
619
+ completion_mask = pad(
620
+ list(completion_mask.unbind(0)),
621
+ padding_value=0,
622
+ pad_to_multiple_of=target_completion_len,
623
+ padding_side="right",
624
+ )
625
+ if old_per_token_logps is not None:
626
+ old_per_token_logps = pad(
627
+ list(old_per_token_logps.unbind(0)),
628
+ padding_value=0.0,
629
+ pad_to_multiple_of=target_completion_len,
630
+ padding_side="right",
631
+ )
632
+ if ref_per_token_logps is not None:
633
+ ref_per_token_logps = pad(
634
+ list(ref_per_token_logps.unbind(0)),
635
+ padding_value=0.0,
636
+ pad_to_multiple_of=target_completion_len,
637
+ padding_side="right",
638
+ )
639
+
640
+ # Replace per-group data, padding only sampled groups that are shorter than the target
641
+ for i, group_idx in enumerate(groups_to_replace_idxs):
642
+ start_idx = group_idx * self.num_generations
643
+ end_idx = (group_idx + 1) * self.num_generations
644
+ idx_range = slice(start_idx, end_idx)
645
+
646
+ # Pad sampled prompt to target length if needed
647
+ if sampled_data["prompt_ids"][i].size(1) < target_prompt_len:
648
+ sampled_data["prompt_ids"][i] = pad(
649
+ sampled_data["prompt_ids"][i],
650
+ padding_value=self.pad_token_id,
651
+ pad_to_multiple_of=target_prompt_len,
652
+ padding_side="left",
653
+ )
654
+ sampled_data["prompt_mask"][i] = pad(
655
+ sampled_data["prompt_mask"][i],
656
+ padding_value=0,
657
+ pad_to_multiple_of=target_prompt_len,
658
+ padding_side="left",
659
+ )
660
+
661
+ # Pad sampled completion to target length if needed
662
+ if sampled_data["completion_ids"][i].size(1) < target_completion_len:
663
+ sampled_data["completion_ids"][i] = pad(
664
+ sampled_data["completion_ids"][i],
665
+ padding_value=self.pad_token_id,
666
+ pad_to_multiple_of=target_completion_len,
667
+ padding_side="right",
668
+ )
669
+ sampled_data["completion_mask"][i] = pad(
670
+ sampled_data["completion_mask"][i],
671
+ padding_value=0,
672
+ pad_to_multiple_of=target_completion_len,
673
+ padding_side="right",
674
+ )
675
+ if "old_per_token_logps" in sampled_data:
676
+ sampled_data["old_per_token_logps"][i] = pad(
677
+ sampled_data["old_per_token_logps"][i],
678
+ padding_value=0.0,
679
+ pad_to_multiple_of=target_completion_len,
680
+ padding_side="right",
681
+ )
682
+ if "ref_per_token_logps" in sampled_data:
683
+ sampled_data["ref_per_token_logps"][i] = pad(
684
+ sampled_data["ref_per_token_logps"][i],
685
+ padding_value=0.0,
686
+ pad_to_multiple_of=target_completion_len,
687
+ padding_side="right",
688
+ )
689
+
690
+ # Assign (replace) group slice
691
+ prompt_ids[idx_range] = sampled_data["prompt_ids"][i]
692
+ prompt_mask[idx_range] = sampled_data["prompt_mask"][i]
693
+ completion_ids[idx_range] = sampled_data["completion_ids"][i]
694
+ completion_mask[idx_range] = sampled_data["completion_mask"][i]
695
+ group_advantages[group_idx] = sampled_data["advantages"][i]
696
+
697
+ if "old_per_token_logps" in sampled_data:
698
+ old_per_token_logps[idx_range] = sampled_data["old_per_token_logps"][i]
699
+ if "ref_per_token_logps" in sampled_data:
700
+ ref_per_token_logps[idx_range] = sampled_data["ref_per_token_logps"][i]
701
+
702
+ for field in vision_fields:
703
+ if field in sampled_data and field in forward_kwargs:
704
+ forward_kwargs[field][idx_range] = sampled_data[field][i]
705
+
706
+ # Prepare final outputs after sampling and replacement
707
+ outputs_after_sampling_buffer = {
708
+ "prompt_ids": prompt_ids,
709
+ "prompt_mask": prompt_mask,
710
+ "completion_ids": completion_ids,
711
+ "completion_mask": completion_mask,
712
+ "advantages": group_advantages,
713
+ }
714
+
715
+ # Replace optional tensor fields if they exist
716
+ for field in optional_tensor_fields:
717
+ if field in sampled_data:
718
+ outputs_after_sampling_buffer[field] = (
719
+ old_per_token_logps if field == "old_per_token_logps" else ref_per_token_logps
720
+ )
721
+
722
+ # Replace vision fields if they exist
723
+ for field in vision_fields:
724
+ if field in sampled_data and field in forward_kwargs:
725
+ outputs_after_sampling_buffer[field] = forward_kwargs[field]
726
+
727
+ outputs_after_sampling_buffer["num_items_in_batch"] = num_items_in_batch
728
+ if self.use_vllm and self.vllm_importance_sampling_correction:
729
+ outputs_after_sampling_buffer["importance_sampling_ratio"] = importance_sampling_ratio
730
+
731
+ return outputs_after_sampling_buffer
ICL/RL/trl_source/trl/experimental/gspo_token/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .grpo_trainer import GRPOTrainer
ICL/RL/trl_source/trl/experimental/gspo_token/grpo_trainer.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer
18
+ from ...trainer.utils import nanmax, nanmin
19
+
20
+
21
+ class GRPOTrainer(_GRPOTrainer):
22
+ def _compute_loss(self, model, inputs):
23
+ # Compute the per-token log probabilities for the model
24
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
25
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
26
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
27
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
28
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
29
+
30
+ # Compute the per_token_logps and the entropy at each position in the completion
31
+ per_token_logps, entropies = self._get_per_token_logps_and_entropies(
32
+ model,
33
+ input_ids,
34
+ attention_mask,
35
+ logits_to_keep,
36
+ compute_entropy=True,
37
+ pixel_values=inputs.get("pixel_values"),
38
+ image_grid_thw=inputs.get("image_grid_thw"),
39
+ num_images=inputs.get("num_images"),
40
+ pixel_attention_mask=inputs.get("pixel_attention_mask"),
41
+ image_sizes=inputs.get("image_sizes"),
42
+ token_type_ids=inputs.get("token_type_ids"),
43
+ )
44
+
45
+ if self.top_entropy_quantile < 1.0:
46
+ entropy_mask = self.get_high_entropy_mask(entropies, completion_mask, 1 - self.top_entropy_quantile)
47
+ else:
48
+ entropy_mask = None
49
+
50
+ # Compute the KL divergence between the model and the reference model
51
+ if self.beta != 0.0:
52
+ ref_per_token_logps = inputs["ref_per_token_logps"]
53
+ per_token_kl = (
54
+ torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
55
+ )
56
+
57
+ # Compute the loss
58
+ advantages = inputs["advantages"]
59
+ # When num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps,
60
+ # old_per_token_logps == per_token_logps. In this case we can skip its computation
61
+ # (see _generate_and_score_completions) and instead use per_token_logps.detach().
62
+ # The exception is when using vLLM, where we always compute old_per_token_logps
63
+ # for importance sampling
64
+ old_per_token_logps = inputs.get("old_per_token_logps")
65
+ old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps
66
+
67
+ log_ratio = per_token_logps - old_per_token_logps
68
+ if self.importance_sampling_level == "token":
69
+ log_importance_weights = log_ratio
70
+ elif self.importance_sampling_level == "sequence":
71
+ log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
72
+ log_importance_weights = log_importance_weights.unsqueeze(-1)
73
+ elif self.importance_sampling_level == "sequence_token":
74
+ # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)]
75
+ seq_level_log_weight = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
76
+ seq_level_log_weight = seq_level_log_weight.detach().unsqueeze(-1) # Stop gradient
77
+ log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight
78
+ else:
79
+ raise ValueError(
80
+ f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
81
+ "and 'sequence'."
82
+ )
83
+ # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
84
+ # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
85
+
86
+ coef_1 = torch.exp(log_importance_weights)
87
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
88
+
89
+ # Two-sided clipping
90
+ if self.args.delta is not None:
91
+ coef_1 = torch.clamp(coef_1, max=self.args.delta)
92
+
93
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
94
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
95
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
96
+ if entropy_mask is not None:
97
+ per_token_loss = per_token_loss * entropy_mask
98
+
99
+ if self.use_vllm and self.vllm_importance_sampling_correction:
100
+ per_token_loss = per_token_loss * inputs["importance_sampling_ratio"]
101
+
102
+ if self.beta != 0.0:
103
+ per_token_loss = per_token_loss + self.beta * per_token_kl
104
+
105
+ mode = "train" if self.model.training else "eval"
106
+ if self.loss_type == "grpo":
107
+ loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
108
+ normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
109
+ loss = loss / normalizer
110
+ elif self.loss_type == "bnpo":
111
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
112
+ normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
113
+ loss = loss / normalizer
114
+ elif self.loss_type == "dr_grpo":
115
+ loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
116
+ normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval
117
+ loss = loss / normalizer
118
+ elif self.loss_type == "dapo":
119
+ normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
120
+ loss = (per_token_loss * completion_mask).sum() / normalizer
121
+ else:
122
+ raise ValueError(f"Unknown loss type: {self.loss_type}")
123
+
124
+ # Log the metrics
125
+ completion_token_count = completion_mask.sum().clamp(min=1.0)
126
+
127
+ def masked_batch_mean(x):
128
+ if x.shape[1] == 1: # when importance_sampling_level == "sequence"
129
+ return x.mean()
130
+ else:
131
+ return (x * completion_mask).sum() / completion_token_count
132
+
133
+ if self.beta != 0.0:
134
+ mean_kl = masked_batch_mean(per_token_kl)
135
+ self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item())
136
+
137
+ mean_entropy = masked_batch_mean(entropies)
138
+ self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())
139
+
140
+ # Compute the clipped probability ratios
141
+ is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
142
+ is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
143
+ is_region_clipped = is_low_clipped | is_high_clipped
144
+
145
+ low_clip = masked_batch_mean(is_low_clipped.float())
146
+ high_clip = masked_batch_mean(is_high_clipped.float())
147
+ clip_ratio = masked_batch_mean(is_region_clipped.float())
148
+
149
+ gathered_low_clip = self.accelerator.gather(low_clip)
150
+ self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
151
+ self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
152
+ gathered_high_clip = self.accelerator.gather(high_clip)
153
+ self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
154
+ self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
155
+ gathered_clip_ratio = self.accelerator.gather(clip_ratio)
156
+ self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
157
+ return loss
ICL/RL/trl_source/trl/experimental/judges/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .judges import (
16
+ AllTrueJudge,
17
+ BaseBinaryJudge,
18
+ BaseJudge,
19
+ BasePairwiseJudge,
20
+ BaseRankJudge,
21
+ HfPairwiseJudge,
22
+ OpenAIPairwiseJudge,
23
+ PairRMJudge,
24
+ )
25
+
26
+
27
+ __all__ = [
28
+ "AllTrueJudge",
29
+ "BaseBinaryJudge",
30
+ "BaseJudge",
31
+ "BasePairwiseJudge",
32
+ "BaseRankJudge",
33
+ "HfPairwiseJudge",
34
+ "OpenAIPairwiseJudge",
35
+ "PairRMJudge",
36
+ ]
ICL/RL/trl_source/trl/experimental/judges/judges.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import concurrent.futures
16
+ import logging
17
+ from abc import ABC, abstractmethod
18
+
19
+ import numpy as np
20
+ from accelerate import Accelerator
21
+ from huggingface_hub import InferenceClient
22
+ from packaging.version import Version
23
+ from transformers.utils import is_openai_available
24
+
25
+ from ...import_utils import is_llm_blender_available
26
+
27
+
28
+ DEFAULT_PAIRWISE_SYSTEM_PROMPT = '''I require a leaderboard for various large language models. I'll provide you with prompts given to these models and their corresponding outputs. Your task is to assess these responses, and select the model that produces the best output from a human perspective.
29
+
30
+ ## Instruction
31
+
32
+ {{
33
+ "instruction": """{prompt}""",
34
+ }}
35
+
36
+ ## Model Outputs
37
+
38
+ Here are the unordered outputs from the models. Each output is associated with a specific model, identified by a unique model identifier.
39
+
40
+ {{
41
+ {{
42
+ "model_identifier": "0",
43
+ "output": """{response0}"""
44
+ }},
45
+ {{
46
+ "model_identifier": "1",
47
+ "output": """{response1}"""
48
+ }}
49
+ }}
50
+
51
+ ## Task
52
+
53
+ Evaluate the models on the basis of the quality and relevance of their results, and select the model that generated the best result. Reply with the identifier of the best model. Our evaluation will only take into account the first character of your answer, so make sure it contains only one of the identifiers and nothing else (no quotation marks, no spaces, no new lines, ...).
54
+ '''
55
+
56
+
57
+ def _ensure_llm_blender_importable() -> None:
58
+ """
59
+ Pre-import shim to work around a known `llm-blender` issue.
60
+
61
+ As of `llm-blender` v0.0.2 (see upstream issue: https://github.com/yuchenlin/LLM-Blender/issues/33), importing
62
+ `llm_blender` may fail on `transformers` >= 5.0.0 because it unconditionally accesses
63
+ `transformers.utils.hub.TRANSFORMERS_CACHE`.
64
+
65
+ We set this attribute to a dummy value before importing `llm_blender` so that the import succeeds. This helper is
66
+ intentionally a no-op on older `transformers` versions.
67
+
68
+ This shim can be removed once the upstream issue is fixed and the minimum required `llm-blender` version includes
69
+ that fix.
70
+ """
71
+ import transformers.utils.hub
72
+
73
+ if Version(transformers.__version__) >= Version("5.0.0"):
74
+ transformers.utils.hub.TRANSFORMERS_CACHE = None # unused; just needs to exist
75
+
76
+
77
+ class BaseJudge(ABC):
78
+ """
79
+ Base class for judges. The subclasses of this class should implement the `judge` method.
80
+ """
81
+
82
+ @abstractmethod
83
+ def judge(self, prompts: list[str], completions: list[str], shuffle_order: bool = True) -> list:
84
+ raise NotImplementedError("Judge subclasses must implement the `judge` method.")
85
+
86
+
87
+ class BaseRankJudge(ABC):
88
+ """
89
+ Base class for LLM ranking judges.
90
+
91
+ **Example**:
92
+ ```python
93
+ class MyRankJudge(BaseRankJudge):
94
+ def judge(self, prompts, completions, shuffle_order=True):
95
+ return ... # Your ranking logic here
96
+
97
+
98
+ judge = MyRankJudge()
99
+ judge.judge(
100
+ prompts=["The capital of France is", "The capital of Germany is"],
101
+ completions=[[" Paris", " Marseille", "Lyon"], [" Munich", " Berlin"]],
102
+ ) # [[0, 1, 2], [1, 0]]
103
+ ```
104
+ """
105
+
106
+ @abstractmethod
107
+ def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[list[int]]:
108
+ """
109
+ Judge the completion for the given prompts and return the ranks of each completion.
110
+
111
+ Args:
112
+ prompts (`list[str]`):
113
+ List of prompts.
114
+ completions (`list[list[str]]`):
115
+ List of completions list, where each element is a list of completions for the corresponding prompt.
116
+ shuffle_order (`bool`, *optional*, defaults to `True`):
117
+ Whether to shuffle the order of the completions to avoid positional bias.
118
+
119
+ Returns:
120
+ `list[list[int]]`:
121
+ List of lists of idxs, where each list contains the ranks of the completions for the corresponding
122
+ prompt. E.g., `[1, 2, 0]` means that the second completion (`idx=1`) is the best, followed by the
123
+ third, and then the first.
124
+ """
125
+ raise NotImplementedError("Judge subclasses must implement the `judge` method.")
126
+
127
+
128
+ class BasePairwiseJudge(BaseJudge):
129
+ """
130
+ Base class for pairwise judges.
131
+ """
132
+
133
+ @abstractmethod
134
+ def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[int]:
135
+ """
136
+ Judge the completion pairs for the given prompts.
137
+
138
+ Args:
139
+ prompts (`list[str]`):
140
+ List of prompts.
141
+ completions (`list[list[str]]`):
142
+ List of completions pairs, where each element is a pair of completions for the corresponding prompt.
143
+ shuffle_order (`bool`, *optional*, defaults to `True`):
144
+ Whether to shuffle the order of the completions to avoid positional bias.
145
+
146
+ Returns:
147
+ `list[int]`:
148
+ List of idxs, where each idx is the rank of the best completion for the corresponding prompt. E.g., `1`
149
+ means that the second completion (`idx=1`) is the best.
150
+
151
+ Note:
152
+ If the judge returns `-1` for any prompt, it indicates that the inner process used to compute the
153
+ preference has failed. For instance, this could occur if the underlying language model returned an invalid
154
+ answer. In such cases, the caller should handle these invalid indices appropriately, possibly by
155
+ implementing fallback logic or error handling.
156
+ """
157
+ raise NotImplementedError("Judge subclasses must implement the `judge` method.")
158
+
159
+
160
+ class BaseBinaryJudge(BaseJudge):
161
+ """
162
+ Base class for binary judges.
163
+ """
164
+
165
+ @abstractmethod
166
+ def judge(
167
+ self,
168
+ prompts: list[str],
169
+ completions: list[str],
170
+ gold_completions: list[str] | None = None,
171
+ shuffle_order: bool = True,
172
+ ) -> list[int]:
173
+ """
174
+ Judge the completion for a given prompt. Used to assess if a completion satisfies a constraint.
175
+
176
+ This base class should be used to implement binary evaluations as done in section 4.1.4 of the [CGPO
177
+ paper](https://huggingface.co/papers/2409.20370). It is relevant for assessing whether a prompt-completion pair
178
+ satisfies a specific constraint.
179
+
180
+ Args:
181
+ prompts (`list[str]`): List of prompts.
182
+ completions (`list[str]`): List of completions.
183
+ gold_completions (`list[str]`, `optional`): List of gold completions if it exists.
184
+ shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias.
185
+
186
+ Returns:
187
+ list[int]: A list of binary labels:
188
+ - 1 indicates that the completion satisfies the evaluated constraint.
189
+ - 0 indicates that the completion does not satisfy the evaluated constraint.
190
+
191
+ Note:
192
+ If the judge returns -1 for any prompt, it indicates that the inner process used to compute the preference
193
+ has failed. For instance, this could occur if the underlying language model or rule based constraint
194
+ returned an invalid answer. In such cases, the caller should handle these invalid indices appropriately,
195
+ possibly by implementing fallback logic or error handling.
196
+ """
197
+ raise NotImplementedError("Judge subclasses must implement the `judge` method.")
198
+
199
+
200
+ class PairRMJudge(BasePairwiseJudge):
201
+ # docstyle-ignore
202
+ """
203
+ LLM judge based on the PairRM model from AllenAI.
204
+
205
+ This judge uses the PairRM model to rank pairs of completions for given prompts. It's designed for pairwise
206
+ comparison of language model outputs. The PairRM model is loaded using the llm-blender library and runs on the
207
+ default Accelerator device.
208
+
209
+ **Attributes**:
210
+
211
+ blender (`llm_blender.Blender`):
212
+ An instance of the Blender class from llm-blender.
213
+
214
+ **Example**:
215
+ ```python
216
+ >>> pairrm_judge = PairRMJudge()
217
+ >>> prompts = ["Translate 'hello' to French", "What's the capital of Japan?"]
218
+ >>> completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]]
219
+ >>> results = pairrm_judge.judge(prompts, completions)
220
+ >>> print(results) # [0, 1] (indicating the first completion is preferred for the first prompt and the second)
221
+ ```
222
+
223
+ > [!TIP]
224
+ > This class requires the llm-blender library to be installed. Install it with: `pip install llm-blender`.
225
+ """
226
+
227
+ def __init__(self):
228
+ if not is_llm_blender_available():
229
+ raise ValueError("llm-blender is not installed. Please install it with `pip install llm-blender`.")
230
+ import transformers
231
+
232
+ if Version(transformers.__version__) >= Version("5.0.0"):
233
+ raise RuntimeError(
234
+ "llm-blender currently supports transformers < 5.0.0. Please install a compatible version: `pip install 'transformers<5.0.0'`. Check the issue tracker for updates: https://github.com/huggingface/trl/issues/4918"
235
+ )
236
+ _ensure_llm_blender_importable()
237
+ import llm_blender
238
+
239
+ self.blender = llm_blender.Blender()
240
+ self.blender.loadranker("llm-blender/PairRM", device=Accelerator().device)
241
+
242
+ def judge(
243
+ self,
244
+ prompts: list[str],
245
+ completions: list[list[str]],
246
+ shuffle_order: bool = True,
247
+ return_scores: bool = False,
248
+ temperature: float = 1.0,
249
+ ) -> list[int | float]:
250
+ """
251
+ Judge the completion pairs for the given prompts using the PairRM model.
252
+
253
+ Args:
254
+ prompts (`list[str]`):
255
+ List of prompts to judge.
256
+ completions (`list[list[str]]`):
257
+ List of completion pairs for each prompt.
258
+ shuffle_order (`bool`, *optional*, defaults to `True`):
259
+ Whether to shuffle the order of the completions to avoid positional bias.
260
+ return_scores (`bool`, *optional*, defaults to `False`):
261
+ If `True`, return probability scores of the first completion instead of ranks (i.e. a *soft-judge*).
262
+ temperature (`float`, *optional*, defaults to `1.0`):
263
+ Temperature for scaling logits if `return_scores` is True.
264
+
265
+ Returns:
266
+ `list[int | float]`:
267
+ If `return_scores` is `False`, returns a list of ranks (`0` or `1`) for each prompt, indicating which
268
+ completion is preferred. If `return_scores` is `True`, returns softmax probabilities for the first
269
+ completion.
270
+
271
+ Raises:
272
+ `ValueError`:
273
+ If the number of completions per prompt is not exactly 2.
274
+
275
+ Note:
276
+ Unlike llm-blender, ranks are 0-indexed (`0` means the first completion is preferred).
277
+ """
278
+
279
+ if len(completions[0]) != 2:
280
+ raise ValueError("PairRM judge requires exactly 2 completions per prompt.")
281
+
282
+ # Shuffle the order of the completions to avoid positional bias
283
+ if shuffle_order:
284
+ flip_mask = np.random.choice([True, False], size=len(prompts))
285
+ completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)]
286
+
287
+ # Rank the completions
288
+ ranks = self.blender.rank(prompts, completions, return_scores=return_scores, disable_tqdm=True)
289
+ if not return_scores:
290
+ ranks -= 1 # PairRM rank is 1-indexed, so we subtract 1 to make it 0-indexed
291
+ else:
292
+ # scale the logits by temperature
293
+ ranks /= temperature
294
+
295
+ # Flip back the ranks or scores to the original order if needed
296
+ if shuffle_order:
297
+ ranks[flip_mask] = ranks[flip_mask][:, ::-1]
298
+
299
+ # Return the ranks or score probability
300
+ if return_scores:
301
+ logit_max = np.amax(ranks, axis=-1, keepdims=True)
302
+ exp_logit_shifted = np.exp(ranks - logit_max)
303
+ probs = exp_logit_shifted / np.sum(exp_logit_shifted, axis=-1, keepdims=True)
304
+ return probs[:, 0].tolist()
305
+ else:
306
+ return ranks[:, 0].tolist()
307
+
308
+
309
+ class HfPairwiseJudge(BasePairwiseJudge):
310
+ """
311
+ Pairwise judge based on the Hugging Face API with chat completion.
312
+
313
+ This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt.
314
+
315
+ Args:
316
+ model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3-70B-Instruct"`):
317
+ Model to use for the judge.
318
+ token (`str`, *optional*):
319
+ Hugging Face API token to use for the [`huggingface_hub.InferenceClient`].
320
+ system_prompt (`str`, *optional*):
321
+ The system prompt to be used for the judge. If not provided, a default prompt is used. Note that the system
322
+ prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the
323
+ inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token
324
+ response.
325
+ """
326
+
327
+ def __init__(
328
+ self,
329
+ model="meta-llama/Meta-Llama-3-70B-Instruct",
330
+ token: str | None = None,
331
+ system_prompt: str | None = None,
332
+ ):
333
+ self.client = InferenceClient(model=model, token=token)
334
+ self.system_prompt = system_prompt or DEFAULT_PAIRWISE_SYSTEM_PROMPT
335
+
336
+ def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[int]:
337
+ # Shuffle the order of the completions to avoid positional bias
338
+ if shuffle_order:
339
+ flip_mask = np.random.choice([True, False], size=len(prompts))
340
+ completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)]
341
+
342
+ # Define a function to get the rank for a single prompt, will be called concurrently
343
+ def get_rank(prompt, candidates):
344
+ content = self.system_prompt.format(prompt=prompt, response0=candidates[0], response1=candidates[1])
345
+ completion = self.client.chat_completion(messages=[{"role": "user", "content": content}], max_tokens=1)
346
+ response = completion.choices[0].message.content
347
+ if response in ["0", "1"]:
348
+ return int(response)
349
+ else:
350
+ logging.debug(f"Invalid response from the judge model: '{response}'. Returning -1.")
351
+ return -1
352
+
353
+ # Call the completions concurrently
354
+ with concurrent.futures.ThreadPoolExecutor() as executor:
355
+ ranks = list(executor.map(get_rank, prompts, completions))
356
+
357
+ # Flip back the ranks to the original order if needed
358
+ if shuffle_order:
359
+ ranks = [ranks[i] if not flip else 1 - ranks[i] for i, flip in enumerate(flip_mask)]
360
+
361
+ # Return the ranks
362
+ return ranks
363
+
364
+
365
+ class OpenAIPairwiseJudge(BasePairwiseJudge):
366
+ """
367
+ Judge based on the OpenAI API.
368
+
369
+ This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt.
370
+
371
+ Args:
372
+ model (`str`, *optional*, defaults to `"gpt-4-turbo-preview"`):
373
+ Model to use for the judge.
374
+ system_prompt (`str`, *optional*):
375
+ System prompt to be used for the judge. If not provided, a default prompt is used. Note that the system
376
+ prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the
377
+ inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token
378
+ response.
379
+ max_requests (`int` or `None`, *optional*, defaults to `1000`):
380
+ Maximum number of requests to make to the OpenAI API. If set to `None`, there is no limit.
381
+ """
382
+
383
+ def __init__(
384
+ self, model="gpt-4-turbo-preview", system_prompt: str | None = None, max_requests: int | None = 1_000
385
+ ):
386
+ if not is_openai_available():
387
+ raise ValueError("OpenAI client is not installed. Please install it with 'pip install openai'.")
388
+ from openai import OpenAI
389
+
390
+ self.client = OpenAI()
391
+ self.model = model
392
+ self.system_prompt = system_prompt or DEFAULT_PAIRWISE_SYSTEM_PROMPT
393
+ self.max_requests = max_requests
394
+ self.num_requests = 0
395
+ self._warned = False
396
+
397
+ def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[int]:
398
+ # Check if the limit of requests is reached, if so, use random choice instead
399
+ if self.max_requests is not None and self.num_requests >= self.max_requests:
400
+ if not self._warned: # Print the warning only once
401
+ logging.warning(
402
+ f"Reached the maximum number of requests ({self.max_requests}). From now on, returning -1 instead. "
403
+ " To increase the limit, set `max_requests` to a higher value, or to `None` for no limit."
404
+ )
405
+ self._warned = True
406
+ return [-1] * len(prompts)
407
+
408
+ # Shuffle the order of the completions to avoid positional bias
409
+ if shuffle_order:
410
+ flip_mask = np.random.choice([True, False], size=len(prompts))
411
+ completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)]
412
+
413
+ # Define a function to get the rank for a single prompt, will be called concurrently
414
+ def get_rank(prompt, candidates):
415
+ content = self.system_prompt.format(prompt=prompt, response0=candidates[0], response1=candidates[1])
416
+ messages = [{"role": "user", "content": content}]
417
+ completion = self.client.chat.completions.create(model=self.model, messages=messages, max_tokens=1)
418
+ response = completion.choices[0].message.content
419
+ if response in ["0", "1"]:
420
+ return int(response)
421
+ else:
422
+ logging.debug(f"Invalid response from the judge model: '{response}'. Returning -1.")
423
+ return -1
424
+
425
+ # Call the completions concurrently
426
+ with concurrent.futures.ThreadPoolExecutor() as executor:
427
+ ranks = list(executor.map(get_rank, prompts, completions))
428
+
429
+ # Flip back the ranks to the original order if needed
430
+ if shuffle_order:
431
+ ranks = [ranks[i] if not flip else 1 - ranks[i] for i, flip in enumerate(flip_mask)]
432
+
433
+ # Update the number of requests
434
+ self.num_requests += len(prompts)
435
+
436
+ # Return the ranks
437
+ return ranks
438
+
439
+
440
+ class AllTrueJudge(BaseBinaryJudge):
441
+ """
442
+ Unify the decision of multiple [`experimental.judges.BaseBinaryJudge`] instances.
443
+
444
+ Returns `1` only if all inner binary judges return `1`. If any judge returns `0`, it returns `0`. If any judge
445
+ returns `-1`, indicating a failure in its process, this judge will also return `-1`.
446
+
447
+ Implements the Mixture of Judges as described in the [CGPO paper](https://huggingface.co/papers/2409.20370).
448
+
449
+ Args:
450
+ judges (`list` of [`experimental.judges.BaseBinaryJudge`]):
451
+ A list of [`experimental.judges.BaseBinaryJudge`] instances whose decisions will be unified.
452
+ """
453
+
454
+ def __init__(self, judges: list[BaseBinaryJudge]):
455
+ self.judges = judges
456
+
457
+ def judge(
458
+ self,
459
+ prompts: list[str],
460
+ completions: list[str],
461
+ gold_completions: list[str] | None = None,
462
+ shuffle_order: bool = True,
463
+ ) -> list[int]:
464
+ all_binary_judgments = [
465
+ judge.judge(prompts, completions, gold_completions, shuffle_order) for judge in self.judges
466
+ ]
467
+ output = []
468
+ for binary_judgments in zip(*all_binary_judgments, strict=True):
469
+ # Check that all values are in {0, 1, -1}
470
+ if any(binary_judgment not in {0, 1, -1} for binary_judgment in binary_judgments):
471
+ raise ValueError(
472
+ f"Invalid binary judgment: {binary_judgments}, expected list of values in {{0, 1, -1}}."
473
+ )
474
+
475
+ # Unify the decision
476
+ if -1 in binary_judgments:
477
+ output.append(-1)
478
+ elif all(binary_judgment == 1 for binary_judgment in binary_judgments):
479
+ output.append(1)
480
+ else:
481
+ output.append(0)
482
+ return output
ICL/RL/trl_source/trl/experimental/kto/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .kto_config import KTOConfig
16
+ from .kto_trainer import KTOTrainer
17
+
18
+
19
+ __all__ = ["KTOConfig", "KTOTrainer"]
ICL/RL/trl_source/trl/experimental/kto/kto_config.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import Any
17
+
18
+ from transformers import TrainingArguments
19
+
20
+
21
+ @dataclass
22
+ class KTOConfig(TrainingArguments):
23
+ r"""
24
+ Configuration class for the [`experimental.kto.KTOTrainer`].
25
+
26
+ This class includes only the parameters that are specific to KTO training. For a full list of training arguments,
27
+ please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
28
+ differ from those in [`~transformers.TrainingArguments`].
29
+
30
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
31
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
32
+ command line.
33
+
34
+ Parameters:
35
+ max_length (`int` or `None`, *optional*, defaults to `1024`):
36
+ Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
37
+ to use the default data collator.
38
+ beta (`float`, *optional*, defaults to `0.1`):
39
+ Parameter controlling the deviation from the reference model. Higher β means less deviation from the
40
+ reference model.
41
+ loss_type (`str`, *optional*, defaults to `"kto"`):
42
+ Type of loss to use. Possible values are:
43
+
44
+ - `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper.
45
+ - `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the
46
+ [APO](https://huggingface.co/papers/2408.06266) paper.
47
+
48
+ desirable_weight (`float`, *optional*, defaults to `1.0`):
49
+ Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
50
+ undesirable_weight (`float`, *optional*, defaults to `1.0`):
51
+ Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
52
+ generate_during_eval (`bool`, *optional*, defaults to `False`):
53
+ If `True`, generates and logs completions from both the model and the reference model to W&B or Comet
54
+ during evaluation.
55
+ precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
56
+ Whether to precompute reference model log probabilities for training and evaluation datasets. This is
57
+ useful when training without the reference model to reduce the total GPU memory needed.
58
+ model_init_kwargs (`dict[str, Any]`, *optional*):
59
+ Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
60
+ string.
61
+ dataset_num_proc: (`int`, *optional*):
62
+ Number of processes to use for processing the dataset.
63
+ disable_dropout (`bool`, *optional*, defaults to `True`):
64
+ Whether to disable dropout in the model and reference model.
65
+ """
66
+
67
+ _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
68
+
69
+ # Parameters whose default values are overridden from TrainingArguments
70
+ learning_rate: float = field(
71
+ default=1e-6,
72
+ metadata={"help": "The initial learning rate for AdamW."},
73
+ )
74
+ logging_steps: float = field(
75
+ default=10,
76
+ metadata={
77
+ "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, "
78
+ "will be interpreted as ratio of total training steps."
79
+ },
80
+ )
81
+ gradient_checkpointing: bool = field(
82
+ default=True,
83
+ metadata={
84
+ "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
85
+ },
86
+ )
87
+ bf16: bool | None = field(
88
+ default=None,
89
+ metadata={
90
+ "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
91
+ "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if "
92
+ "`fp16` is not set."
93
+ },
94
+ )
95
+ # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue
96
+ # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary
97
+ # workaround here, which can be removed once we drop support for versions older than 4.57.5.
98
+ lr_scheduler_kwargs: dict | str | None = field(
99
+ default=None,
100
+ metadata={
101
+ "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard "
102
+ "restarts."
103
+ },
104
+ )
105
+
106
+ max_length: int | None = field(
107
+ default=1024,
108
+ metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
109
+ )
110
+ beta: float = field(
111
+ default=0.1,
112
+ metadata={
113
+ "help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from "
114
+ "the reference model."
115
+ },
116
+ )
117
+ loss_type: str = field(
118
+ default="kto",
119
+ metadata={
120
+ "help": "Type of loss to use.",
121
+ "choices": ["kto", "apo_zero_unpaired"],
122
+ },
123
+ )
124
+ desirable_weight: float = field(
125
+ default=1.0,
126
+ metadata={
127
+ "help": "Desirable losses are weighed by this factor to counter unequal number of desirable and "
128
+ "undesirable pairs.",
129
+ },
130
+ )
131
+ undesirable_weight: float = field(
132
+ default=1.0,
133
+ metadata={
134
+ "help": "Undesirable losses are weighed by this factor to counter unequal number of desirable and "
135
+ "undesirable pairs.",
136
+ },
137
+ )
138
+ generate_during_eval: bool = field(
139
+ default=False,
140
+ metadata={
141
+ "help": "If `True`, generates and logs completions from both the model and the reference model to W&B "
142
+ "during evaluation."
143
+ },
144
+ )
145
+ disable_dropout: bool = field(
146
+ default=True,
147
+ metadata={"help": "Whether to disable dropout in the model."},
148
+ )
149
+ precompute_ref_log_probs: bool = field(
150
+ default=False,
151
+ metadata={
152
+ "help": "Whether to precompute reference model log probabilities for training and evaluation datasets. "
153
+ "This is useful when training without the reference model to reduce the total GPU memory needed."
154
+ },
155
+ )
156
+ model_init_kwargs: dict[str, Any] | None = field(
157
+ default=None,
158
+ metadata={
159
+ "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
160
+ "from a string."
161
+ },
162
+ )
163
+ dataset_num_proc: int | None = field(
164
+ default=None,
165
+ metadata={"help": "Number of processes to use for processing the dataset."},
166
+ )
167
+
168
+ def __post_init__(self):
169
+ self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
170
+
171
+ super().__post_init__()
ICL/RL/trl_source/trl/experimental/kto/kto_trainer.py ADDED
@@ -0,0 +1,1511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import random
17
+ import textwrap
18
+ from collections import defaultdict
19
+ from collections.abc import Callable
20
+ from contextlib import contextmanager, nullcontext
21
+ from operator import itemgetter
22
+ from pathlib import Path
23
+ from typing import TYPE_CHECKING, Any, Literal
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ import transformers
31
+ from accelerate import PartialState, logging
32
+ from accelerate.utils import tqdm
33
+ from datasets import Dataset, concatenate_datasets
34
+ from packaging.version import Version
35
+ from torch import autocast
36
+ from torch.utils.data import DataLoader, SequentialSampler
37
+ from transformers import (
38
+ BaseImageProcessor,
39
+ DataCollator,
40
+ FeatureExtractionMixin,
41
+ PreTrainedModel,
42
+ PreTrainedTokenizerBase,
43
+ ProcessorMixin,
44
+ TrainerCallback,
45
+ TrainingArguments,
46
+ is_comet_available,
47
+ is_wandb_available,
48
+ )
49
+ from transformers.trainer_utils import EvalLoopOutput, has_length
50
+ from transformers.utils import is_peft_available
51
+
52
+ from ...data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset
53
+ from ...import_utils import is_liger_kernel_available
54
+ from ...models.utils import create_reference_model, peft_module_casting_to_bf16, prepare_deepspeed
55
+ from ...trainer.base_trainer import BaseTrainer
56
+ from ...trainer.utils import (
57
+ create_model_from_path,
58
+ disable_dropout_in_model,
59
+ log_table_to_comet_experiment,
60
+ pad_to_length,
61
+ selective_log_softmax,
62
+ )
63
+ from ..utils import DPODataCollatorWithPadding
64
+ from .kto_config import KTOConfig
65
+
66
+
67
+ if is_liger_kernel_available():
68
+ from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss
69
+
70
+ if is_peft_available():
71
+ from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
72
+
73
+ if is_wandb_available():
74
+ import wandb
75
+
76
+
77
+ if TYPE_CHECKING:
78
+ from transformers import PreTrainedModel, PreTrainedTokenizer
79
+
80
+
81
+ logger = logging.get_logger(__name__)
82
+
83
+ RUNNING_NAME = "running.pt"
84
+
85
+
86
+ def _get_kl_dataset(batch: dict[str, list[Any]]) -> dict[str, list[Any]]:
87
+ """
88
+ Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of
89
+ completions. For best results, the mismatched outputs y' used to estimate the KL term for a batch should be the
90
+ same set as the matched outputs y used to estimate the rewards in that batch, just paired with different x.
91
+ """
92
+ batch["answer_input_ids"] = [batch["answer_input_ids"][-1]] + batch["answer_input_ids"][:-1]
93
+ batch["answer_attention_mask"] = [batch["answer_attention_mask"][-1]] + batch["answer_attention_mask"][:-1]
94
+ return batch
95
+
96
+
97
+ def _tokenize(
98
+ batch: dict[str, list[Any]],
99
+ tokenizer: "PreTrainedTokenizer",
100
+ ) -> dict[str, list[Any]]:
101
+ """Tokenize a batch from a KTO specific dataset."""
102
+ prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False)
103
+ prompt_input_ids = prompt_tokenized["input_ids"]
104
+ prompt_attention_mask = prompt_tokenized["attention_mask"]
105
+ prompt_and_completion = [
106
+ prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"], strict=True)
107
+ ]
108
+ full_tokenized = tokenizer(prompt_and_completion, add_special_tokens=False)
109
+ full_input_ids = full_tokenized["input_ids"]
110
+ full_attention_mask = full_tokenized["attention_mask"]
111
+
112
+ answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids, strict=True)]
113
+ answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask, strict=True)]
114
+
115
+ # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
116
+ full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids, strict=True)]
117
+ # Prepare input tokens for token by token comparison
118
+ full_input_ids = [np.array(f) for f in full_input_ids]
119
+ for full, concat in zip(full_input_ids, full_concat_input_ids, strict=True):
120
+ if len(full) != len(concat):
121
+ raise ValueError(
122
+ "The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length."
123
+ )
124
+
125
+ # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
126
+ # can be merged together when tokenizing prompt+answer. This could result
127
+ # on the last token from the prompt being different when tokenized on its own
128
+ # vs when done as prompt+answer.
129
+ response_token_ids_start_idx = [len(p) for p in prompt_input_ids]
130
+
131
+ # If tokenized prompt is different than both prompt+answer, then it means the
132
+ # last token has changed due to merging.
133
+ for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx, strict=True)):
134
+ if not np.array_equal(p, f[:r]):
135
+ response_token_ids_start_idx[idx] -= 1
136
+
137
+ prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)]
138
+ prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)]
139
+
140
+ for p, m in zip(prompt_input_ids, prompt_attention_mask, strict=True):
141
+ if len(p) != len(m):
142
+ raise ValueError("Prompt input ids and attention mask should have the same length.")
143
+
144
+ answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)]
145
+ answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)]
146
+
147
+ output = dict(
148
+ prompt_input_ids=prompt_input_ids,
149
+ prompt_attention_mask=prompt_attention_mask,
150
+ answer_input_ids=answer_input_ids,
151
+ answer_attention_mask=answer_attention_mask,
152
+ )
153
+
154
+ return output
155
+
156
+
157
+ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **kwargs) -> dict:
158
+ """Process tokens of a KTO specific dataset.
159
+
160
+ At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
161
+ completion responses is/are too long. We truncate from the end (completion) to fit within max_length.
162
+
163
+ We also create the labels for the completion responses, which are of length equal to the sum of the length of the
164
+ prompt and the completion response, with `-100` for the prompt tokens.
165
+ """
166
+ prompt = example["prompt"]
167
+ completion = example["completion"]
168
+
169
+ batch = {
170
+ f"{kwargs['prefix']}prompt": prompt,
171
+ f"{kwargs['prefix']}completion": completion,
172
+ f"{kwargs['prefix']}label": example["label"],
173
+ }
174
+
175
+ # Check issues below for more details
176
+ # 1. https://github.com/huggingface/trl/issues/907
177
+ # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
178
+ # 3. https://github.com/LianjiaTech/BELLE/issues/337
179
+
180
+ if not isinstance(prompt, str):
181
+ raise ValueError(f"prompt should be an str but got {type(prompt)}")
182
+
183
+ if not isinstance(completion, str):
184
+ raise ValueError(f"completion should be an str but got {type(completion)}")
185
+
186
+ # keys of format prompt_* refers to just the prompt and answer_* refers to just the answer
187
+ all_tokens = {
188
+ "prompt_input_ids": example["prompt_input_ids"],
189
+ "prompt_attention_mask": example["prompt_attention_mask"],
190
+ "answer_input_ids": example["answer_input_ids"],
191
+ "answer_attention_mask": example["answer_attention_mask"],
192
+ }
193
+
194
+ # calculate max length by checking if BOS/EOS is already there
195
+ max_length = kwargs["max_length"]
196
+ bos_token_id = kwargs["tokenizer"].bos_token_id
197
+ eos_token_id = kwargs["tokenizer"].eos_token_id
198
+ if len(all_tokens["prompt_input_ids"]) > 0 and bos_token_id != all_tokens["prompt_input_ids"][0]:
199
+ max_length -= 1
200
+ if len(all_tokens["answer_input_ids"]) > 0 and eos_token_id != all_tokens["answer_input_ids"][-1]:
201
+ max_length -= 1
202
+
203
+ # if combined sequence is too long, truncate the completion (answer) from the end
204
+ prompt_length = len(all_tokens["prompt_input_ids"])
205
+ completion_length = len(all_tokens["answer_input_ids"])
206
+ if prompt_length + completion_length > max_length:
207
+ max_completion_length = max_length - prompt_length
208
+ for k in ["answer_input_ids", "answer_attention_mask"]:
209
+ all_tokens[k] = all_tokens[k][:max_completion_length]
210
+
211
+ # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens
212
+ batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"]
213
+ batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"]
214
+ batch[f"{kwargs['prefix']}completion_input_ids"] = all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"]
215
+ batch[f"{kwargs['prefix']}completion_attention_mask"] = (
216
+ all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"]
217
+ )
218
+
219
+ # add BOS, which affects both prompt and the full completion
220
+ if bos_token_id is not None:
221
+ if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]:
222
+ batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[
223
+ f"{kwargs['prefix']}prompt_input_ids"
224
+ ]
225
+ batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"]
226
+ batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[
227
+ f"{kwargs['prefix']}completion_input_ids"
228
+ ]
229
+ batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[
230
+ f"{kwargs['prefix']}completion_attention_mask"
231
+ ]
232
+ # add EOS, which affects only the full completion
233
+ if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]:
234
+ batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [
235
+ eos_token_id
236
+ ]
237
+ batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[
238
+ f"{kwargs['prefix']}completion_attention_mask"
239
+ ] + [1]
240
+
241
+ batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:]
242
+ batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [-100] * len(
243
+ batch[f"{kwargs['prefix']}prompt_input_ids"]
244
+ )
245
+
246
+ return batch
247
+
248
+
249
+ class KTOTrainer(BaseTrainer):
250
+ r"""
251
+ Initialize KTOTrainer.
252
+
253
+ Args:
254
+ model ([`~transformers.PreTrainedModel`]):
255
+ The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`].
256
+ ref_model ([`~transformers.PreTrainedModel`]):
257
+ Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation
258
+ and loss. If no reference model is provided, the trainer will create a reference model with the same
259
+ architecture as the model to be optimized.
260
+ args ([`experimental.kto.KTOConfig`]):
261
+ The arguments to use for training.
262
+ train_dataset ([`~datasets.Dataset`]):
263
+ The dataset to use for training.
264
+ eval_dataset ([`~datasets.Dataset`]):
265
+ The dataset to use for evaluation.
266
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
267
+ Processing class used to process the data. If provided, will be used to automatically process the inputs
268
+ for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
269
+ reuse the fine-tuned model.
270
+ data_collator ([`~transformers.DataCollator`], *optional*):
271
+ The data collator to use for training. If None is specified, the default data collator
272
+ ([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the
273
+ maximum length of the sequences in the batch, given a dataset of paired sequences.
274
+ model_init (`Callable[[], transformers.PreTrainedModel]`):
275
+ The model initializer to use for training. If None is specified, the default model initializer will be
276
+ used.
277
+ callbacks (`list[transformers.TrainerCallback]`):
278
+ The callbacks to use for training.
279
+ optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
280
+ The optimizer and scheduler to use for training.
281
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
282
+ The function to use to preprocess the logits before computing the metrics.
283
+ peft_config (`dict`, defaults to `None`):
284
+ The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in
285
+ a PEFT model.
286
+ compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
287
+ The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
288
+ metric values.
289
+ model_adapter_name (`str`, defaults to `None`):
290
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
291
+ ref_adapter_name (`str`, defaults to `None`):
292
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
293
+ """
294
+
295
+ _tag_names = ["trl", "kto"]
296
+ _name = "KTO"
297
+ _paper = {
298
+ "title": "KTO: Model Alignment as Prospect Theoretic Optimization",
299
+ "id": "2402.01306",
300
+ # docstyle-ignore
301
+ "citation": textwrap.dedent("""\
302
+ @article{ethayarajh2024kto,
303
+ title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
304
+ author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
305
+ year = 2024,
306
+ eprint = {arXiv:2402.01306},
307
+ }"""),
308
+ }
309
+
310
+ def __init__(
311
+ self,
312
+ model: PreTrainedModel | nn.Module | str = None,
313
+ ref_model: PreTrainedModel | nn.Module | str | None = None,
314
+ args: KTOConfig = None,
315
+ train_dataset: Dataset | None = None,
316
+ eval_dataset: Dataset | dict[str, Dataset] | None = None,
317
+ processing_class: PreTrainedTokenizerBase
318
+ | BaseImageProcessor
319
+ | FeatureExtractionMixin
320
+ | ProcessorMixin
321
+ | None = None,
322
+ data_collator: DataCollator | None = None,
323
+ model_init: Callable[[], PreTrainedModel] | None = None,
324
+ callbacks: list[TrainerCallback] | None = None,
325
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
326
+ preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
327
+ peft_config: dict | None = None,
328
+ compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
329
+ model_adapter_name: str | None = None,
330
+ ref_adapter_name: str | None = None,
331
+ ):
332
+ if type(args) is TrainingArguments:
333
+ raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
334
+
335
+ if not isinstance(model, str) and ref_model is model:
336
+ raise ValueError(
337
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
338
+ "same as `model`, you must mass a copy of it, or `None` if you use peft."
339
+ )
340
+
341
+ # Model initialization
342
+ if isinstance(model, str):
343
+ model_init_kwargs = args.model_init_kwargs or {}
344
+ # Distributed training requires device_map=None ("auto" fails)
345
+ if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
346
+ model_init_kwargs["device_map"] = None
347
+ model = create_model_from_path(model, **model_init_kwargs)
348
+ else:
349
+ if args.model_init_kwargs is not None:
350
+ logger.warning(
351
+ "You passed `model_init_kwargs` to the KTOConfig, but your model is already instantiated. "
352
+ "The `model_init_kwargs` will be ignored."
353
+ )
354
+
355
+ # Reference model initialization
356
+ if isinstance(ref_model, str):
357
+ ref_model_init_kwargs = args.model_init_kwargs or {}
358
+ # Distributed training requires device_map=None ("auto" fails)
359
+ if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
360
+ ref_model_init_kwargs["device_map"] = None
361
+ ref_model = create_model_from_path(ref_model, **ref_model_init_kwargs)
362
+
363
+ # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
364
+ # has been called in order to properly call autocast if needed.
365
+ self._peft_has_been_casted_to_bf16 = False
366
+
367
+ if not is_peft_available() and peft_config is not None:
368
+ raise ValueError(
369
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
370
+ )
371
+ elif is_peft_available() and peft_config is not None:
372
+ if isinstance(model, PeftModel):
373
+ raise ValueError(
374
+ "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first "
375
+ "merge and unload the existing adapter, save the resulting base model, and then pass that base "
376
+ "model along with the new `peft_config` to the trainer."
377
+ )
378
+
379
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
380
+ _support_gc_kwargs = hasattr(
381
+ args, "gradient_checkpointing_kwargs"
382
+ ) and "gradient_checkpointing_kwargs" in list(
383
+ inspect.signature(prepare_model_for_kbit_training).parameters
384
+ )
385
+
386
+ prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
387
+
388
+ if _support_gc_kwargs:
389
+ prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
390
+
391
+ model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
392
+ elif args.gradient_checkpointing:
393
+ # For backward compatibility with older versions of transformers
394
+ if hasattr(model, "enable_input_require_grads"):
395
+ model.enable_input_require_grads()
396
+ else:
397
+
398
+ def make_inputs_require_grad(module, input, output):
399
+ output.requires_grad_(True)
400
+
401
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
402
+
403
+ # get peft model with the given config
404
+ model = get_peft_model(model, peft_config)
405
+ if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
406
+ peft_module_casting_to_bf16(model)
407
+ # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
408
+ self._peft_has_been_casted_to_bf16 = True
409
+
410
+ # For models that use gradient_checkpointing, we need to attach a hook that enables input
411
+ # to explicitly have `requires_grad=True`, otherwise training will either silently
412
+ # fail or completely fail.
413
+ elif args.gradient_checkpointing:
414
+ # For backward compatibility with older versions of transformers
415
+ if hasattr(model, "enable_input_require_grads"):
416
+ model.enable_input_require_grads()
417
+ else:
418
+
419
+ def make_inputs_require_grad(module, input, output):
420
+ output.requires_grad_(True)
421
+
422
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
423
+
424
+ if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
425
+ raise ValueError(
426
+ "`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
427
+ " Please install `wandb` or `comet-ml` to resolve."
428
+ )
429
+
430
+ # KTO only supports causal language models, not encoder-decoder models
431
+ if model is not None and hasattr(model.config, "is_encoder_decoder") and model.config.is_encoder_decoder:
432
+ raise ValueError(
433
+ "KTO only supports causal language models. Encoder-decoder models are not supported. "
434
+ "Please use a causal LM (e.g., GPT, Llama, Mistral) instead of an encoder-decoder model (e.g., T5, BART)."
435
+ )
436
+
437
+ self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
438
+ self.model_adapter_name = model_adapter_name
439
+ self.ref_adapter_name = ref_adapter_name
440
+
441
+ if ref_model:
442
+ self.ref_model = ref_model
443
+ elif self.is_peft_model or args.precompute_ref_log_probs:
444
+ # The `model` with adapters turned off will be used as the reference model
445
+ self.ref_model = None
446
+ else:
447
+ self.ref_model = create_reference_model(model)
448
+
449
+ if processing_class is None:
450
+ raise ValueError(
451
+ "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
452
+ )
453
+ if args.max_length is None:
454
+ logger.warning(
455
+ "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
456
+ " it will be set to `512` by default, but you should do it yourself in the future.",
457
+ )
458
+ max_length = 512
459
+ if args.max_length is not None:
460
+ max_length = args.max_length
461
+
462
+ if data_collator is None:
463
+ data_collator = DPODataCollatorWithPadding(
464
+ pad_token_id=processing_class.pad_token_id,
465
+ )
466
+
467
+ if args.remove_unused_columns:
468
+ args.remove_unused_columns = False
469
+ # warn users
470
+ logger.warning(
471
+ "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
472
+ " we have set it for you, but you should do it yourself in the future.",
473
+ )
474
+
475
+ self.use_dpo_data_collator = True
476
+ else:
477
+ self.use_dpo_data_collator = False
478
+
479
+ # Disable dropout in the model and reference model
480
+ if args.disable_dropout:
481
+ disable_dropout_in_model(model)
482
+ if self.ref_model is not None:
483
+ disable_dropout_in_model(self.ref_model)
484
+
485
+ self.loss_type = args.loss_type
486
+ self.max_length = max_length
487
+ self.generate_during_eval = args.generate_during_eval
488
+ self.processing_class = processing_class
489
+ self.precompute_ref_log_probs = args.precompute_ref_log_probs
490
+
491
+ # Not all losses require a KL calculation
492
+ self.calculate_KL = True
493
+ if self.loss_type in ["apo_zero_unpaired"]:
494
+ self.calculate_KL = False
495
+
496
+ # Since ref_logs are precomputed on the first call to get_train/eval_dataloader
497
+ # keep track of first called to avoid computation of future calls
498
+ self._precomputed_train_ref_log_probs = False
499
+ self._precomputed_eval_ref_log_probs = False
500
+
501
+ # metric
502
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
503
+
504
+ # KTO parameter
505
+ self.beta = args.beta
506
+ self.desirable_weight = args.desirable_weight
507
+ self.undesirable_weight = args.undesirable_weight
508
+ self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
509
+ self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
510
+ if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
511
+ logger.warning(
512
+ "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
513
+ "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
514
+ "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
515
+ "loss.",
516
+ )
517
+
518
+ # Compute that only on the main process for faster data processing.
519
+ # see: https://github.com/huggingface/trl/pull/1255
520
+ with PartialState().main_process_first():
521
+ # Extract the prompt if needed
522
+ train_dataset = train_dataset.map(
523
+ maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
524
+ )
525
+ # Unpair the dataset if needed
526
+ train_dataset = maybe_unpair_preference_dataset(
527
+ train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
528
+ )
529
+ # Apply the chat template if needed
530
+ train_dataset = train_dataset.map(
531
+ maybe_apply_chat_template,
532
+ fn_kwargs={"tokenizer": processing_class},
533
+ num_proc=args.dataset_num_proc,
534
+ desc="Applying chat template to train dataset",
535
+ )
536
+ if eval_dataset is not None:
537
+ eval_dataset = eval_dataset.map(
538
+ maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
539
+ )
540
+ eval_dataset = maybe_unpair_preference_dataset(
541
+ eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
542
+ )
543
+ eval_dataset = eval_dataset.map(
544
+ maybe_apply_chat_template,
545
+ fn_kwargs={"tokenizer": processing_class},
546
+ num_proc=args.dataset_num_proc,
547
+ desc="Applying chat template to eval dataset",
548
+ )
549
+
550
+ # Tokenize and prepare the training datasets
551
+ train_dataset = train_dataset.map(
552
+ _tokenize,
553
+ batched=True,
554
+ fn_kwargs={"tokenizer": self.processing_class},
555
+ num_proc=args.dataset_num_proc,
556
+ desc="Tokenizing train dataset",
557
+ )
558
+
559
+ fn_kwargs = {
560
+ "prefix": "",
561
+ "tokenizer": self.processing_class,
562
+ "max_length": self.max_length,
563
+ }
564
+
565
+ train_dataset = train_dataset.map(
566
+ _process_tokens,
567
+ fn_kwargs=fn_kwargs,
568
+ num_proc=args.dataset_num_proc,
569
+ desc="Processing tokenized train dataset",
570
+ )
571
+
572
+ # Tokenize and prepare the eval datasets
573
+ if eval_dataset is not None:
574
+ eval_dataset = eval_dataset.map(
575
+ _tokenize,
576
+ fn_kwargs={"tokenizer": self.processing_class},
577
+ batched=True,
578
+ num_proc=args.dataset_num_proc,
579
+ desc="Tokenizing eval dataset",
580
+ )
581
+
582
+ eval_dataset = eval_dataset.map(
583
+ _process_tokens,
584
+ fn_kwargs=fn_kwargs,
585
+ num_proc=args.dataset_num_proc,
586
+ desc="Processing tokenized eval dataset",
587
+ )
588
+
589
+ # Get KL datasets if needed
590
+ if self.calculate_KL:
591
+ if args.per_device_train_batch_size <= 1:
592
+ raise ValueError(
593
+ "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
594
+ )
595
+
596
+ # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
597
+ # i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
598
+ train_kl_dataset = train_dataset.map(
599
+ _get_kl_dataset,
600
+ batched=True,
601
+ batch_size=args.per_device_train_batch_size,
602
+ num_proc=args.dataset_num_proc,
603
+ desc="Extracting KL train dataset",
604
+ )
605
+
606
+ fn_kwargs["prefix"] = "KL_"
607
+ train_kl_dataset = train_kl_dataset.map(
608
+ _process_tokens,
609
+ fn_kwargs=fn_kwargs,
610
+ num_proc=args.dataset_num_proc,
611
+ remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
612
+ desc="Processing tokenized train KL dataset",
613
+ )
614
+
615
+ # merge the datasets
616
+ train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
617
+
618
+ if eval_dataset is not None:
619
+ # Get KL dataset
620
+ eval_kl_dataset = eval_dataset.map(
621
+ _get_kl_dataset,
622
+ batched=True,
623
+ batch_size=args.per_device_train_batch_size,
624
+ num_proc=args.dataset_num_proc,
625
+ desc="Extracting eval KL dataset",
626
+ )
627
+
628
+ eval_kl_dataset = eval_kl_dataset.map(
629
+ _process_tokens,
630
+ fn_kwargs=fn_kwargs,
631
+ num_proc=args.dataset_num_proc,
632
+ remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
633
+ desc="Processing tokenized eval KL dataset",
634
+ )
635
+
636
+ # merge the datasets
637
+ eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
638
+
639
+ # calculate dataset desirability balance
640
+ num_desirable = max(sum(train_dataset["label"]), 1)
641
+ num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
642
+
643
+ if num_desirable != num_undesirable:
644
+ # The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
645
+ des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
646
+ des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
647
+ und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
648
+ und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
649
+
650
+ des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
651
+ und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
652
+
653
+ if not (des_weight_in_range or und_weight_in_range):
654
+ logger.warning(
655
+ "You have different amounts of desirable/positive and undesirable/negative examples but the "
656
+ "weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
657
+ f"on your data, we recommend EITHER "
658
+ f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
659
+ f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
660
+ "See the documentation on how to optimally set these weights.",
661
+ )
662
+
663
+ # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
664
+ # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
665
+ # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
666
+ # default to the recommended non-reentrant behavior here, while preserving any user-provided value.
667
+ if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
668
+ args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
669
+ args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)
670
+
671
+ super().__init__(
672
+ model=model,
673
+ args=args,
674
+ data_collator=data_collator,
675
+ train_dataset=train_dataset,
676
+ eval_dataset=eval_dataset,
677
+ processing_class=processing_class,
678
+ model_init=model_init,
679
+ compute_metrics=compute_metrics,
680
+ callbacks=callbacks,
681
+ optimizers=optimizers,
682
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
683
+ )
684
+
685
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
686
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
687
+ # self.model_accepts_loss_kwargs to False to enable scaling.
688
+ self.model_accepts_loss_kwargs = False
689
+
690
+ # Add tags for models that have been loaded with the correct transformers version
691
+ if hasattr(self.model, "add_model_tags"):
692
+ self.model.add_model_tags(self._tag_names)
693
+
694
+ if not hasattr(self, "accelerator"):
695
+ raise AttributeError(
696
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
697
+ )
698
+
699
+ # Deepspeed Zero-3 does not support precompute_ref_log_probs
700
+ if self.is_deepspeed_enabled:
701
+ if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
702
+ raise ValueError(
703
+ "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
704
+ )
705
+
706
+ if self.ref_model is None:
707
+ if not (self.is_peft_model or self.precompute_ref_log_probs):
708
+ raise ValueError(
709
+ "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
710
+ )
711
+ else:
712
+ if self.is_deepspeed_enabled:
713
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
714
+ else:
715
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
716
+
717
+ # Import Liger kernel if enabled
718
+ if self.args.use_liger_kernel:
719
+ if not is_liger_kernel_available():
720
+ raise ImportError(
721
+ "You set `use_liger_kernel=True` but the liger kernel is not available. "
722
+ "Please install liger-kernel first: `pip install liger-kernel`"
723
+ )
724
+ if self.loss_type in ["apo_zero_unpaired"]:
725
+ raise ValueError(
726
+ "You cannot set `loss_type='apo_zero_unpaired'` with liger-kernel."
727
+ "Only KTO loss is supported with liger-kernel."
728
+ )
729
+ if self.precompute_ref_log_probs:
730
+ raise ValueError(
731
+ "You cannot use `precompute_ref_log_probs=True` with liger kernel. Please set "
732
+ "`precompute_ref_log_probs=False`."
733
+ )
734
+ if self.is_peft_model or self.ref_adapter_name is not None:
735
+ raise ValueError(
736
+ "You cannot use `use_liger_kernel=True` with Peft models. Please set `use_liger_kernel=False`."
737
+ )
738
+ self.kto_loss_fn = LigerFusedLinearKTOLoss(beta=self.beta, use_ref_model=(self.ref_model is not None))
739
+
740
+ @contextmanager
741
+ def null_ref_context(self):
742
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
743
+ with (
744
+ self.accelerator.unwrap_model(self.model).disable_adapter()
745
+ if self.is_peft_model and not self.ref_adapter_name
746
+ else nullcontext()
747
+ ):
748
+ if self.ref_adapter_name:
749
+ self.model.set_adapter(self.ref_adapter_name)
750
+ yield
751
+ if self.ref_adapter_name:
752
+ self.model.set_adapter(self.model_adapter_name or "default")
753
+
754
+ def get_train_dataloader(self) -> DataLoader:
755
+ """
756
+ Returns the training [`~torch.utils.data.DataLoader`].
757
+
758
+ Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`.
759
+ """
760
+
761
+ if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
762
+ dataloader_params = {
763
+ "batch_size": self.args.per_device_train_batch_size,
764
+ "collate_fn": self.data_collator,
765
+ "num_workers": self.args.dataloader_num_workers,
766
+ "pin_memory": self.args.dataloader_pin_memory,
767
+ "shuffle": False,
768
+ }
769
+
770
+ # prepare dataloader
771
+ data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params))
772
+ reference_completion_logps = []
773
+ reference_KL_logps = []
774
+
775
+ for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"):
776
+ reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
777
+
778
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
779
+ reference_completion_logps.append(reference_completion_logp.cpu())
780
+
781
+ if self.calculate_KL:
782
+ reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
783
+ reference_KL_logps.append(reference_KL_logp.cpu())
784
+
785
+ self.train_dataset = self.train_dataset.add_column(
786
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
787
+ )
788
+
789
+ if self.calculate_KL:
790
+ self.train_dataset = self.train_dataset.add_column(
791
+ name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
792
+ )
793
+
794
+ self._precomputed_train_ref_log_probs = True
795
+
796
+ return super().get_train_dataloader()
797
+
798
+ def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
799
+ """
800
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
801
+
802
+ Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`.
803
+
804
+ Args:
805
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
806
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
807
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
808
+ """
809
+ if eval_dataset is None and self.eval_dataset is None:
810
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
811
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
812
+
813
+ if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
814
+ dataloader_params = {
815
+ "batch_size": self.args.per_device_eval_batch_size,
816
+ "collate_fn": self.data_collator,
817
+ "num_workers": self.args.dataloader_num_workers,
818
+ "pin_memory": self.args.dataloader_pin_memory,
819
+ "shuffle": False,
820
+ }
821
+
822
+ # prepare dataloader
823
+ data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
824
+
825
+ reference_completion_logps = []
826
+ reference_KL_logps = []
827
+
828
+ for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"):
829
+ reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch)
830
+
831
+ reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp)
832
+ reference_completion_logps.append(reference_completion_logp.cpu())
833
+
834
+ if self.calculate_KL:
835
+ reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp)
836
+ reference_KL_logps.append(reference_KL_logp.cpu())
837
+
838
+ eval_dataset = eval_dataset.add_column(
839
+ name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
840
+ )
841
+ if self.calculate_KL:
842
+ eval_dataset = eval_dataset.add_column(
843
+ name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy()
844
+ )
845
+
846
+ # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs
847
+ if self.eval_dataset is not None:
848
+ self.eval_dataset = eval_dataset
849
+ self._precomputed_eval_ref_log_probs = True
850
+
851
+ return super().get_eval_dataloader(eval_dataset=eval_dataset)
852
+
853
+ def compute_reference_log_probs(self, padded_batch: dict) -> dict:
854
+ """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
855
+ with torch.no_grad():
856
+ if self.ref_model is None:
857
+ with self.null_ref_context():
858
+ completion_logits = self.model(
859
+ padded_batch["completion_input_ids"],
860
+ attention_mask=padded_batch["completion_attention_mask"],
861
+ ).logits
862
+
863
+ if self.calculate_KL:
864
+ KL_logits = self.model(
865
+ padded_batch["KL_completion_input_ids"],
866
+ attention_mask=padded_batch["KL_completion_attention_mask"],
867
+ ).logits
868
+ else:
869
+ completion_logits = self.ref_model(
870
+ padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"]
871
+ ).logits
872
+
873
+ if self.calculate_KL:
874
+ KL_logits = self.ref_model(
875
+ padded_batch["KL_completion_input_ids"],
876
+ attention_mask=padded_batch["KL_completion_attention_mask"],
877
+ ).logits
878
+
879
+ completion_logps = self.get_batch_logps(
880
+ completion_logits,
881
+ padded_batch["completion_labels"],
882
+ average_log_prob=False,
883
+ )
884
+
885
+ if self.calculate_KL:
886
+ KL_logps = self.get_batch_logps(
887
+ KL_logits,
888
+ padded_batch["KL_completion_labels"],
889
+ average_log_prob=False,
890
+ )
891
+ else:
892
+ KL_logps = None
893
+
894
+ return completion_logps, KL_logps
895
+
896
+ @staticmethod
897
+ def get_batch_logps(
898
+ logits: torch.FloatTensor,
899
+ labels: torch.LongTensor,
900
+ average_log_prob: bool = False,
901
+ ) -> torch.FloatTensor:
902
+ """Compute the log probabilities of the given labels under the given logits.
903
+
904
+ Args:
905
+ logits:
906
+ Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
907
+ labels:
908
+ Labels for which to compute the log probabilities. Label tokens with a value of `-100` are ignored.
909
+ Shape: (batch_size, sequence_length)
910
+ average_log_prob:
911
+ If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the
912
+ log probabilities of the (non-masked) tokens.
913
+
914
+ Returns:
915
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the
916
+ given logits.
917
+ """
918
+ if logits.shape[:-1] != labels.shape:
919
+ raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
920
+
921
+ # For causal LM, shift labels and logits by one position
922
+ labels = labels[:, 1:].clone()
923
+ logits = logits[:, :-1, :]
924
+
925
+ loss_mask = labels != -100
926
+
927
+ # dummy token; we'll ignore the losses on these tokens later
928
+ labels[labels == -100] = 0
929
+
930
+ per_token_logps = selective_log_softmax(logits, labels)
931
+
932
+ if average_log_prob:
933
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
934
+ else:
935
+ return (per_token_logps * loss_mask).sum(-1)
936
+
937
+ def forward(
938
+ self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
939
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
940
+ KL_logps = self._compute_kl_logps(model, batch)
941
+
942
+ model_kwargs = {}
943
+ if self.aux_loss_enabled:
944
+ model_kwargs["output_router_logits"] = True
945
+
946
+ outputs = model(
947
+ batch["completion_input_ids"],
948
+ attention_mask=batch["completion_attention_mask"],
949
+ **model_kwargs,
950
+ )
951
+ completion_logits = outputs.logits
952
+
953
+ completion_logps = self.get_batch_logps(
954
+ completion_logits,
955
+ batch["completion_labels"],
956
+ average_log_prob=False,
957
+ )
958
+
959
+ if completion_logps.shape[0] != len(batch["label"]):
960
+ raise ValueError(
961
+ "There is a mismatch between the number of examples in this batch and the number of "
962
+ "examples for which an output sequence was predicted."
963
+ )
964
+
965
+ # Use torch.nonzero for efficient tensor index selection
966
+ device = completion_logits.device
967
+ labels = torch.as_tensor(batch["label"], dtype=torch.bool, device=device)
968
+ chosen_idx = torch.nonzero(labels, as_tuple=False).view(-1)
969
+ rejected_idx = torch.nonzero(~labels, as_tuple=False).view(-1)
970
+
971
+ # Use index_select for efficient CUDA operations
972
+ chosen_logps = completion_logps.index_select(0, chosen_idx)
973
+ rejected_logps = completion_logps.index_select(0, rejected_idx)
974
+
975
+ chosen_logits = completion_logits.index_select(0, chosen_idx)
976
+ rejected_logits = completion_logits.index_select(0, rejected_idx)
977
+
978
+ if self.aux_loss_enabled:
979
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss)
980
+ else:
981
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
982
+
983
+ def kto_loss(
984
+ self,
985
+ policy_chosen_logps: torch.FloatTensor,
986
+ policy_rejected_logps: torch.FloatTensor,
987
+ policy_KL_logps: torch.FloatTensor,
988
+ reference_chosen_logps: torch.FloatTensor,
989
+ reference_rejected_logps: torch.FloatTensor,
990
+ reference_KL_logps: torch.FloatTensor,
991
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
992
+ """Compute the KTO loss for a batch of policy and reference model log probabilities.
993
+
994
+ Args:
995
+ policy_chosen_logps:
996
+ Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,)
997
+ policy_rejected_logps:
998
+ Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,)
999
+ policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,)
1000
+ reference_chosen_logps:
1001
+ Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,)
1002
+ reference_rejected_logps:
1003
+ Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in
1004
+ batch_size,)
1005
+ reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,)
1006
+
1007
+ Returns:
1008
+ A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL). The losses tensor contains the KTO
1009
+ loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards for
1010
+ the chosen and rejected responses, respectively. The KL tensor contains the detached KL divergence estimate
1011
+ between the policy and reference models.
1012
+ """
1013
+ if self.calculate_KL:
1014
+ kl = (policy_KL_logps - reference_KL_logps).mean().detach()
1015
+ kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
1016
+ else:
1017
+ kl = torch.zeros(1).to(policy_chosen_logps.device)
1018
+
1019
+ # Chosen losses
1020
+ if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0:
1021
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
1022
+
1023
+ if self.loss_type == "kto":
1024
+ # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306)
1025
+ chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl))
1026
+ elif self.loss_type == "apo_zero_unpaired":
1027
+ # Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
1028
+ # Use this loss when you believe the chosen outputs are better than your model's default output
1029
+ chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios)
1030
+
1031
+ chosen_rewards = self.beta * chosen_logratios.detach()
1032
+
1033
+ else:
1034
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1035
+ chosen_losses = torch.Tensor([]).to(self.accelerator.device)
1036
+ chosen_rewards = torch.Tensor([]).to(self.accelerator.device)
1037
+
1038
+ # Rejected losses
1039
+ if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0:
1040
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
1041
+
1042
+ if self.loss_type == "kto":
1043
+ rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios))
1044
+ elif self.loss_type == "apo_zero_unpaired":
1045
+ rejected_losses = F.sigmoid(self.beta * rejected_logratios)
1046
+
1047
+ rejected_rewards = self.beta * rejected_logratios.detach()
1048
+ else:
1049
+ # lists can't be empty -- if they are, then accelerate.gather will hang
1050
+ rejected_losses = torch.Tensor([]).to(self.accelerator.device)
1051
+ rejected_rewards = torch.Tensor([]).to(self.accelerator.device)
1052
+
1053
+ losses = torch.cat(
1054
+ (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses),
1055
+ 0,
1056
+ )
1057
+
1058
+ return losses, chosen_rewards, rejected_rewards, kl
1059
+
1060
+ def _compute_kl_logps(self, model, batch):
1061
+ """Compute KL log probabilities for a given batch."""
1062
+ KL_logps = None
1063
+ if self.calculate_KL:
1064
+ KL_model_kwargs = {
1065
+ "input_ids": batch["KL_completion_input_ids"],
1066
+ "attention_mask": batch["KL_completion_attention_mask"],
1067
+ }
1068
+
1069
+ with torch.no_grad():
1070
+ KL_logits = model(**KL_model_kwargs).logits
1071
+
1072
+ KL_logps = self.get_batch_logps(
1073
+ KL_logits,
1074
+ batch["KL_completion_labels"],
1075
+ average_log_prob=False,
1076
+ )
1077
+ return KL_logps
1078
+
1079
+ def _compute_loss_liger(self, model, batch):
1080
+ """
1081
+ Compute the KTO loss using the Liger-Kernel's LigerFusedLinearKTOLoss.
1082
+
1083
+ Args:
1084
+ model:
1085
+ The policy model used for generating log probabilities and outputs. It could be an encoder-decoder
1086
+ model or a regular language model.
1087
+ batch: A dictionary containing the input data and labels for the batch.
1088
+
1089
+ Returns:
1090
+ A dictionary containing the following keys:
1091
+ - "loss": The computed KTO loss for the batch.
1092
+ - "chosen_logits_sum": Sum of the logits for the chosen responses from the policy model.
1093
+ - "rejected_logits_sum": Sum of the logits for the rejected responses from the policy model.
1094
+ - "chosen_logps": Log probabilities of the chosen responses from the policy model.
1095
+ - "rejected_logps": Log probabilities of the rejected responses from the policy model.
1096
+ - "chosen_rewards": Rewards for the chosen responses.
1097
+ - "rejected_rewards": Rewards for the rejected responses.
1098
+ - "kl": The KL divergence between the policy and reference models (detached).
1099
+
1100
+ If auxiliary loss is enabled, the dictionary will also include:
1101
+ - "aux_loss": The auxiliary loss from the model outputs.
1102
+ """
1103
+ policy_KL_logps = self._compute_kl_logps(model, batch)
1104
+ reference_KL_logps = self._compute_kl_logps(self.ref_model, batch)
1105
+ if self.calculate_KL:
1106
+ kl = (policy_KL_logps - reference_KL_logps).mean().detach()
1107
+ kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0)
1108
+ else:
1109
+ kl = torch.zeros(1).to(self.accelerator.device)
1110
+
1111
+ model_kwargs = {}
1112
+ if self.aux_loss_enabled:
1113
+ model_kwargs["output_router_logits"] = True
1114
+
1115
+ # skip the lm head and get the last hidden state
1116
+ base_model = model.get_decoder()
1117
+ outputs = base_model(
1118
+ batch["completion_input_ids"],
1119
+ attention_mask=batch["completion_attention_mask"],
1120
+ use_cache=False,
1121
+ **model_kwargs,
1122
+ )
1123
+
1124
+ # reference model
1125
+ ref_base_model = self.ref_model.get_decoder()
1126
+ ref_outputs = ref_base_model(
1127
+ batch["completion_input_ids"],
1128
+ attention_mask=batch["completion_attention_mask"],
1129
+ use_cache=False,
1130
+ **model_kwargs,
1131
+ )
1132
+ lm_head = model.get_output_embeddings()
1133
+ ref_lm_head = self.ref_model.get_output_embeddings()
1134
+
1135
+ (
1136
+ loss,
1137
+ (
1138
+ chosen_logps_sum,
1139
+ rejected_logps_sum,
1140
+ chosen_logits_sum,
1141
+ rejected_logits_sum,
1142
+ chosen_rewards_sum,
1143
+ rejected_rewards_sum,
1144
+ ),
1145
+ ) = self.kto_loss_fn(
1146
+ _input=outputs.last_hidden_state[:, :-1],
1147
+ lin_weight=lm_head.weight,
1148
+ target=batch["completion_labels"][:, 1:],
1149
+ bias=lm_head.bias if hasattr(lm_head, "bias") else None,
1150
+ preference_labels=torch.tensor(batch["label"], dtype=torch.bool).to(self.accelerator.device),
1151
+ ref_input=ref_outputs.last_hidden_state[:, :-1],
1152
+ ref_weight=ref_lm_head.weight,
1153
+ ref_bias=ref_lm_head.bias if hasattr(lm_head, "bias") else None,
1154
+ kl=kl,
1155
+ )
1156
+
1157
+ output = {
1158
+ "loss": loss,
1159
+ "chosen_logits_sum": chosen_logits_sum,
1160
+ "rejected_logits_sum": rejected_logits_sum,
1161
+ "chosen_logps_sum": chosen_logps_sum,
1162
+ "rejected_logps_sum": rejected_logps_sum,
1163
+ "chosen_rewards_sum": chosen_rewards_sum,
1164
+ "rejected_rewards_sum": rejected_rewards_sum,
1165
+ "kl": kl,
1166
+ }
1167
+ if self.aux_loss_enabled:
1168
+ output["aux_loss"] = outputs.aux_loss
1169
+
1170
+ return output
1171
+
1172
+ def get_batch_loss_metrics(
1173
+ self,
1174
+ model,
1175
+ batch: dict[str, list | torch.LongTensor],
1176
+ ):
1177
+ """Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
1178
+ metrics = {}
1179
+ batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
1180
+
1181
+ labels = torch.tensor(batch["label"])
1182
+ num_chosen = labels.sum().to(self.accelerator.device)
1183
+ num_rejected = (len(labels) - num_chosen).to(self.accelerator.device)
1184
+
1185
+ if self.args.use_liger_kernel:
1186
+ model_output = self._compute_loss_liger(model, batch)
1187
+ losses = model_output["loss"]
1188
+ policy_chosen_logits = model_output["chosen_logits_sum"]
1189
+ policy_rejected_logits = model_output["rejected_logits_sum"]
1190
+ policy_chosen_logps = model_output["chosen_logps_sum"]
1191
+ policy_rejected_logps = model_output["rejected_logps_sum"]
1192
+ chosen_rewards = model_output["chosen_rewards_sum"]
1193
+ rejected_rewards = model_output["rejected_rewards_sum"]
1194
+ kl = model_output["kl"]
1195
+ if self.aux_loss_enabled:
1196
+ aux_loss = model_output["aux_loss"]
1197
+ else:
1198
+ forward_output = self.forward(model, batch)
1199
+ (
1200
+ policy_chosen_logps,
1201
+ policy_rejected_logps,
1202
+ policy_chosen_logits,
1203
+ policy_rejected_logits,
1204
+ policy_KL_logps,
1205
+ ) = forward_output[:5]
1206
+ if self.aux_loss_enabled:
1207
+ aux_loss = forward_output[5]
1208
+
1209
+ # if reference_logps in batch use them, otherwise use the reference model
1210
+ if "reference_logps" in batch:
1211
+ # Convert Python lists to tensor indices for efficient CUDA operations
1212
+ device = batch["reference_logps"].device
1213
+ labels = torch.as_tensor(batch["label"], dtype=torch.bool, device=device)
1214
+ chosen_idx = torch.nonzero(labels, as_tuple=False).view(-1)
1215
+ rejected_idx = torch.nonzero(~labels, as_tuple=False).view(-1)
1216
+
1217
+ # Use index_select for efficient CUDA operations
1218
+ reference_chosen_logps = batch["reference_logps"].index_select(0, chosen_idx)
1219
+ reference_rejected_logps = batch["reference_logps"].index_select(0, rejected_idx)
1220
+ if self.calculate_KL:
1221
+ reference_KL_logps = batch["reference_KL_logps"]
1222
+ else:
1223
+ reference_KL_logps = None
1224
+ else:
1225
+ with torch.no_grad():
1226
+ if self.ref_model is None:
1227
+ with self.null_ref_context():
1228
+ (
1229
+ reference_chosen_logps,
1230
+ reference_rejected_logps,
1231
+ _,
1232
+ _,
1233
+ reference_KL_logps,
1234
+ ) = self.forward(self.model, batch)[:5]
1235
+ else:
1236
+ (
1237
+ reference_chosen_logps,
1238
+ reference_rejected_logps,
1239
+ _,
1240
+ _,
1241
+ reference_KL_logps,
1242
+ ) = self.forward(self.ref_model, batch)[:5]
1243
+
1244
+ losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
1245
+ policy_chosen_logps,
1246
+ policy_rejected_logps,
1247
+ policy_KL_logps,
1248
+ reference_chosen_logps,
1249
+ reference_rejected_logps,
1250
+ reference_KL_logps,
1251
+ )
1252
+
1253
+ metrics["kl"] = kl.item()
1254
+
1255
+ all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item()
1256
+ all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item()
1257
+
1258
+ if all_num_chosen > 0:
1259
+ metrics["rewards/chosen_sum"] = (
1260
+ self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item()
1261
+ )
1262
+ metrics["logps/chosen_sum"] = (
1263
+ self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item()
1264
+ )
1265
+ metrics["logits/chosen_sum"] = (
1266
+ self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item()
1267
+ )
1268
+ metrics["count/chosen"] = all_num_chosen
1269
+
1270
+ if all_num_rejected > 0:
1271
+ metrics["rewards/rejected_sum"] = (
1272
+ self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item()
1273
+ )
1274
+ metrics["logps/rejected_sum"] = (
1275
+ self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item()
1276
+ )
1277
+ metrics["logits/rejected_sum"] = (
1278
+ self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item()
1279
+ )
1280
+ metrics["count/rejected"] = all_num_rejected
1281
+
1282
+ loss = losses.nanmean()
1283
+ if self.aux_loss_enabled:
1284
+ loss += self.aux_loss_coef * aux_loss
1285
+
1286
+ return loss, metrics
1287
+
1288
+ def compute_loss(
1289
+ self,
1290
+ model: PreTrainedModel | nn.Module,
1291
+ inputs: dict[str, torch.Tensor | Any],
1292
+ return_outputs=False,
1293
+ num_items_in_batch=None,
1294
+ ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
1295
+ compute_loss_context_manager = (
1296
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1297
+ )
1298
+
1299
+ with compute_loss_context_manager:
1300
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1301
+
1302
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
1303
+ loss = loss.to(self.args.device)
1304
+ # force log the metrics
1305
+ if self.accelerator.is_main_process:
1306
+ self.store_metrics(metrics, train_eval="train")
1307
+
1308
+ if return_outputs:
1309
+ return (loss, metrics)
1310
+ return loss
1311
+
1312
+ def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
1313
+ for key, value in metrics.items():
1314
+ self._stored_metrics[train_eval][key].append(value)
1315
+
1316
+ def _get_train_sampler(self, dataset: Dataset | None = None) -> torch.utils.data.Sampler | None:
1317
+ if dataset is None:
1318
+ dataset = self.train_dataset
1319
+ if dataset is None or not has_length(dataset):
1320
+ return None
1321
+ return SequentialSampler(dataset)
1322
+
1323
+ def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
1324
+ """Generate samples from the model and reference model for the given batch of inputs."""
1325
+
1326
+ # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
1327
+ # the torch amp context manager as some hidden states are silently casted to full precision.
1328
+ generate_context_manager = (
1329
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1330
+ )
1331
+
1332
+ with generate_context_manager:
1333
+ policy_output = model.generate(
1334
+ input_ids=batch["prompt_input_ids"],
1335
+ attention_mask=batch["prompt_attention_mask"],
1336
+ max_length=self.max_length,
1337
+ do_sample=True,
1338
+ pad_token_id=self.processing_class.pad_token_id,
1339
+ )
1340
+
1341
+ # if reference_output in batch use that otherwise use the reference model
1342
+ if "reference_output" in batch:
1343
+ reference_output = batch["reference_output"]
1344
+ else:
1345
+ if self.ref_model is None:
1346
+ with self.null_ref_context():
1347
+ reference_output = self.model.generate(
1348
+ input_ids=batch["prompt_input_ids"],
1349
+ attention_mask=batch["prompt_attention_mask"],
1350
+ max_length=self.max_length,
1351
+ do_sample=True,
1352
+ pad_token_id=self.processing_class.pad_token_id,
1353
+ )
1354
+ else:
1355
+ reference_output = self.ref_model.generate(
1356
+ input_ids=batch["prompt_input_ids"],
1357
+ attention_mask=batch["prompt_attention_mask"],
1358
+ max_length=self.max_length,
1359
+ do_sample=True,
1360
+ pad_token_id=self.processing_class.pad_token_id,
1361
+ )
1362
+
1363
+ policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id)
1364
+ policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True)
1365
+
1366
+ reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id)
1367
+ reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True)
1368
+
1369
+ return policy_output_decoded, reference_output_decoded
1370
+
1371
+ def prediction_step(
1372
+ self,
1373
+ model: PreTrainedModel | nn.Module,
1374
+ inputs: dict[str, torch.Tensor | Any],
1375
+ prediction_loss_only: bool,
1376
+ ignore_keys: list[str] | None = None,
1377
+ ):
1378
+ if ignore_keys is None:
1379
+ if hasattr(model, "config"):
1380
+ ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
1381
+ else:
1382
+ ignore_keys = []
1383
+
1384
+ prediction_context_manager = (
1385
+ autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
1386
+ )
1387
+ with torch.no_grad(), prediction_context_manager:
1388
+ loss, metrics = self.get_batch_loss_metrics(model, inputs)
1389
+
1390
+ # force log the metrics
1391
+ if self.accelerator.is_main_process:
1392
+ self.store_metrics(metrics, train_eval="eval")
1393
+
1394
+ if prediction_loss_only:
1395
+ return (loss.detach(), None, None)
1396
+
1397
+ # logits for the chosen and rejected samples from model
1398
+ logits_dict = {}
1399
+ if "logits/chosen_sum" in metrics:
1400
+ logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"]
1401
+ if "logits/rejected_sum" in metrics:
1402
+ logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"]
1403
+ logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
1404
+ logits = torch.tensor(logits, device=self.accelerator.device)
1405
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
1406
+
1407
+ return (loss.detach(), logits, labels)
1408
+
1409
+ def evaluation_loop(
1410
+ self,
1411
+ dataloader: DataLoader,
1412
+ description: str,
1413
+ prediction_loss_only: bool | None = None,
1414
+ ignore_keys: list[str] | None = None,
1415
+ metric_key_prefix: str = "eval",
1416
+ ) -> EvalLoopOutput:
1417
+ """
1418
+ Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by
1419
+ `Trainer.evaluate()` and `Trainer.predict()`.
1420
+
1421
+ Works both with or without labels.
1422
+ """
1423
+
1424
+ # Sample and save to game log if requested (for one batch to save time)
1425
+ if self.generate_during_eval:
1426
+ # Generate random indices within the range of the total number of samples
1427
+ num_samples = len(dataloader.dataset)
1428
+ random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
1429
+
1430
+ # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
1431
+ random_batch_dataset = dataloader.dataset.select(random_indices)
1432
+ random_batch = self.data_collator(random_batch_dataset)
1433
+ random_batch = self._prepare_inputs(random_batch)
1434
+
1435
+ target_labels = torch.tensor(random_batch["label"], dtype=torch.bool, device=self.accelerator.device)
1436
+ target_indices = torch.where(~target_labels)[0]
1437
+ target_batch = {
1438
+ "prompt_input_ids": random_batch["prompt_input_ids"][target_indices],
1439
+ "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indices],
1440
+ "prompt": itemgetter(*target_indices)(random_batch["prompt"]),
1441
+ }
1442
+ policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch)
1443
+
1444
+ table = pd.DataFrame(
1445
+ columns=["Prompt", "Policy", "Ref Model"],
1446
+ data=[
1447
+ [prompt, pol[len(prompt) :], ref[len(prompt) :]]
1448
+ for prompt, pol, ref in zip(
1449
+ target_batch["prompt"], policy_output_decoded, ref_output_decoded, strict=True
1450
+ )
1451
+ ],
1452
+ )
1453
+ if "wandb" in self.args.report_to:
1454
+ wandb.log({"game_log": wandb.Table(data=table)})
1455
+
1456
+ if "comet_ml" in self.args.report_to:
1457
+ log_table_to_comet_experiment(
1458
+ name="game_log.csv",
1459
+ table=table,
1460
+ )
1461
+
1462
+ # Base evaluation
1463
+ initial_output = super().evaluation_loop(
1464
+ dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
1465
+ )
1466
+
1467
+ return initial_output
1468
+
1469
+ def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
1470
+ """
1471
+ Log `logs` on the various objects watching training, including stored metrics.
1472
+
1473
+ Args:
1474
+ logs (`dict[str, float]`):
1475
+ The values to log.
1476
+ start_time (`float`, *optional*):
1477
+ Start time of the training.
1478
+ """
1479
+ # logs either has 'loss' or 'eval_loss'
1480
+ train_eval = "train" if "loss" in logs else "eval"
1481
+ # train metrics should have no prefix, eval should have 'eval_'
1482
+ prefix = "eval_" if train_eval == "eval" else ""
1483
+ # accumulate average metrics from sums and lengths
1484
+ for split in ["chosen", "rejected"]:
1485
+ if f"count/{split}" in self._stored_metrics[train_eval]:
1486
+ count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item()
1487
+ for metric in ["rewards", "logps", "logits"]:
1488
+ logs[f"{prefix}{metric}/{split}"] = (
1489
+ torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item()
1490
+ / count_sum
1491
+ )
1492
+ # delete obsolete metric
1493
+ del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
1494
+ del self._stored_metrics[train_eval][f"count/{split}"]
1495
+ # calculate reward margin
1496
+ if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
1497
+ logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
1498
+ # Add averaged stored metrics to logs
1499
+ for key, metrics in self._stored_metrics[train_eval].items():
1500
+ logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
1501
+ del self._stored_metrics[train_eval]
1502
+ return super().log(logs, start_time)
1503
+
1504
+ # Ensure the model card is saved along with the checkpoint
1505
+ def _save_checkpoint(self, model, trial):
1506
+ if self.args.hub_model_id is None:
1507
+ model_name = Path(self.args.output_dir).name
1508
+ else:
1509
+ model_name = self.args.hub_model_id.split("/")[-1]
1510
+ self.create_model_card(model_name=model_name)
1511
+ super()._save_checkpoint(model, trial)
ICL/RL/trl_source/trl/experimental/merge_model_callback.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+
18
+ import torch
19
+ from huggingface_hub import HfApi
20
+ from transformers import TrainerCallback
21
+
22
+ from ..import_utils import is_mergekit_available
23
+ from ..trainer.utils import get_config_model_id
24
+
25
+
26
+ if is_mergekit_available():
27
+ from mergekit.config import MergeConfiguration
28
+ from mergekit.merge import MergeOptions, run_merge
29
+
30
+
31
+ # Logger for module-level logging
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def upload_model_to_hf(folder_path: str, repo_id: str):
36
+ api = HfApi()
37
+ # Create the repository if it doesn't exist
38
+ repo = api.create_repo(repo_id, repo_type="model")
39
+
40
+ # Upload the folder to the specified repository
41
+ api.upload_folder(
42
+ folder_path=folder_path,
43
+ repo_id=repo.repo_id,
44
+ repo_type=repo.repo_type,
45
+ )
46
+
47
+
48
+ class MergeConfig:
49
+ r"""
50
+ Configuration class for merging two models using `mergekit`.
51
+
52
+ This class provides a structured way to configure and generate merge configurations for various merge methods, such
53
+ as `linear`, `ties`, `dare_ties`, and `slerp`.
54
+
55
+ Args:
56
+ method (`str`, *optional*, defaults to `"linear"`):
57
+ Merge method to use. Supported methods include:
58
+
59
+ - `"linear"`: Linearly combines two models with specified weights.
60
+ - `"ties"`: Combines two models using the TIES method with density parameters.
61
+ - `"dare_ties"`: A variant of TIES for domain adaptation.
62
+ - `"slerp"`: Combines models using spherical linear interpolation.
63
+
64
+ Note:
65
+
66
+ For more details about the merge methods and how they are implemented, see the [MergeKit GitHub
67
+ repository](https://github.com/arcee-ai/mergekit?tab=readme-ov-file#merge-methods).
68
+
69
+ Attributes:
70
+ method (`str`): The merge method to use.
71
+ policy_model_path (`str` or `None`): Path to the policy model.
72
+ target_model_path (`str` or `None`): Path to the target model.
73
+ policy_model_weight (`float`): Weight for the policy model (for `linear` and `ties` methods).
74
+ target_model_weight (`float`): Weight for the target model (for `linear` and `ties` methods).
75
+ policy_model_density (`list[float]`): Density parameters for the policy model (for `ties` and `dare_ties`).
76
+ target_model_density (`list[float]`): Density parameters for the target model (for `ties` and `dare_ties`).
77
+ normalize (`float` or `None`): Normalization factor for the TIES method.
78
+ t_values (`float` or `None`): Interpolation factor for the SLERP method.
79
+ dtype (`str`): Data type to use for merging, e.g., `"float16"`.
80
+ """
81
+
82
+ def __init__(self, method: str = "linear"):
83
+ if not is_mergekit_available():
84
+ raise ImportError("MergeConfig requires the `mergekit` extra. To install, run `pip install mergekit`.")
85
+ self.method = method
86
+ self.policy_model_path = None
87
+ self.target_model_path = None
88
+
89
+ # Initialize relevant parameters based on the method
90
+ if method == "linear":
91
+ self.policy_model_weight = 0.5
92
+ self.target_model_weight = 0.5
93
+ self.dtype = "float16"
94
+ elif method == "ties":
95
+ self.policy_model_weight = 1.0
96
+ self.policy_model_density = [1.0, 0.7, 0.1]
97
+ self.target_model_weight = 1.0
98
+ self.target_model_density = [1.0]
99
+ self.normalize = 1.0
100
+ self.dtype = "float16"
101
+ elif method == "dare_ties":
102
+ self.policy_model_weight = 1.0
103
+ self.policy_model_density = [1.0, 0.7, 0.1]
104
+ self.target_model_weight = 1.0
105
+ self.target_model_density = [1.0]
106
+ self.normalize = 1.0
107
+ self.dtype = "float16"
108
+ elif method == "slerp":
109
+ self.t_values = 0.5
110
+ self.dtype = "float16"
111
+ else:
112
+ raise ValueError(f"Unsupported merge method: {method}")
113
+
114
+ def create_merge_config_linear(self) -> "MergeConfiguration":
115
+ """
116
+ Creates a merge configuration for a linear merge of two models with specified weights.
117
+ """
118
+ # Create the merge configuration dictionary
119
+ merge_config_dict = {
120
+ "dtype": self.dtype,
121
+ "merge_method": "linear",
122
+ "models": [
123
+ {"model": self.policy_model_path, "parameters": {"weight": self.policy_model_weight}},
124
+ {"model": self.target_model_path, "parameters": {"weight": self.target_model_weight}},
125
+ ],
126
+ }
127
+
128
+ # Create the MergeConfiguration from the dictionary
129
+ merge_config = MergeConfiguration.model_validate(merge_config_dict)
130
+
131
+ return merge_config
132
+
133
+ def create_merge_config_ties(self) -> "MergeConfiguration":
134
+ """
135
+ Creates a merge configuration for a TIES merge of two models, with specified weights and densities.
136
+ """
137
+ # Create the TIES merge configuration dictionary
138
+ merge_config_dict = {
139
+ "merge_method": "ties",
140
+ "slices": None, # Optional slices if needed
141
+ "models": [
142
+ {
143
+ "model": {
144
+ "model": {"path": self.target_model_path, "revision": None},
145
+ "lora": None,
146
+ "override_architecture": None,
147
+ },
148
+ "parameters": {"density": self.target_model_density, "weight": self.target_model_weight},
149
+ },
150
+ {
151
+ "model": {
152
+ "model": {"path": self.policy_model_path, "revision": None},
153
+ "lora": None,
154
+ "override_architecture": None,
155
+ },
156
+ "parameters": {"density": self.policy_model_density, "weight": self.policy_model_weight},
157
+ },
158
+ ],
159
+ "parameters": {"normalize": self.normalize},
160
+ "base_model": {
161
+ "model": {"path": self.policy_model_path, "revision": None},
162
+ "lora": None,
163
+ "override_architecture": None,
164
+ },
165
+ "dtype": self.dtype,
166
+ "tokenizer_source": None,
167
+ "tokenizer": None,
168
+ "chat_template": None,
169
+ "out_dtype": None,
170
+ }
171
+
172
+ # Create the MergeConfiguration from the dictionary
173
+ merge_config = MergeConfiguration.model_validate(merge_config_dict)
174
+
175
+ return merge_config
176
+
177
+ def create_merge_config_dare_ties(self) -> "MergeConfiguration":
178
+ """
179
+ Creates a merge configuration for a DARE TIES merge of two models, with specified weights and densities.
180
+ """
181
+ # Create the DARE TIES merge configuration dictionary
182
+ merge_config_dict = {
183
+ "merge_method": "dare_ties",
184
+ "slices": None, # Optional slices if needed
185
+ "models": [
186
+ {
187
+ "model": {
188
+ "model": {"path": self.target_model_path, "revision": None},
189
+ "lora": None,
190
+ "override_architecture": None,
191
+ },
192
+ "parameters": {"density": self.target_model_density, "weight": self.target_model_weight},
193
+ },
194
+ {
195
+ "model": {
196
+ "model": {"path": self.policy_model_path, "revision": None},
197
+ "lora": None,
198
+ "override_architecture": None,
199
+ },
200
+ "parameters": {"density": self.policy_model_density, "weight": self.policy_model_weight},
201
+ },
202
+ ],
203
+ "parameters": {"normalize": self.normalize},
204
+ "base_model": {
205
+ "model": {"path": self.policy_model_path, "revision": None},
206
+ "lora": None,
207
+ "override_architecture": None,
208
+ },
209
+ "dtype": self.dtype,
210
+ "tokenizer_source": None,
211
+ "tokenizer": None,
212
+ "chat_template": None,
213
+ "out_dtype": None,
214
+ }
215
+
216
+ # Create the MergeConfiguration from the dictionary
217
+ merge_config = MergeConfiguration.model_validate(merge_config_dict)
218
+
219
+ return merge_config
220
+
221
+ def create_merge_config_slerp(self) -> "MergeConfiguration":
222
+ """
223
+ Creates a merge configuration for a SLERP merge of a model with a base model.
224
+ """
225
+
226
+ # Create the SLERP merge configuration dictionary
227
+ merge_config_dict = {
228
+ "merge_method": "slerp",
229
+ "slices": None, # Optional slices if needed
230
+ "models": [
231
+ {
232
+ "model": {
233
+ "model": {"path": self.target_model_path, "revision": None},
234
+ "lora": None,
235
+ "override_architecture": None,
236
+ },
237
+ "parameters": None, # No specific parameters for SLERP model
238
+ }
239
+ ],
240
+ "parameters": {
241
+ "t": self.t_values # Set the t values for SLERP
242
+ },
243
+ "base_model": {
244
+ "model": {"path": self.policy_model_path, "revision": None},
245
+ "lora": None,
246
+ "override_architecture": None,
247
+ },
248
+ "dtype": self.dtype,
249
+ "tokenizer_source": None,
250
+ "tokenizer": None,
251
+ "chat_template": None,
252
+ "out_dtype": None,
253
+ }
254
+
255
+ # Create the MergeConfiguration from the dictionary
256
+ merge_config = MergeConfiguration.model_validate(merge_config_dict)
257
+
258
+ return merge_config
259
+
260
+ def create(self) -> "MergeConfiguration":
261
+ if self.method == "linear":
262
+ return self.create_merge_config_linear()
263
+ elif self.method == "ties":
264
+ return self.create_merge_config_ties()
265
+ elif self.method == "dare_ties":
266
+ return self.create_merge_config_dare_ties()
267
+ elif self.method == "slerp":
268
+ return self.create_merge_config_slerp()
269
+
270
+
271
+ def merge_models(config: "MergeConfiguration", out_path: str):
272
+ """
273
+ Merge two models using mergekit
274
+
275
+ Args:
276
+ config (`MergeConfiguration`): The merge configuration.
277
+ out_path (`str`): The output path for the merged model.
278
+ """
279
+ if not is_mergekit_available():
280
+ raise ImportError("merge_models requires the `mergekit` extra. To install, run `pip install mergekit`.")
281
+ run_merge(
282
+ config,
283
+ out_path=out_path,
284
+ options=MergeOptions(
285
+ device="auto",
286
+ cuda=torch.cuda.is_available(),
287
+ copy_tokenizer=True,
288
+ lazy_unpickle=False,
289
+ low_cpu_memory=False,
290
+ ),
291
+ )
292
+
293
+
294
+ class MergeModelCallback(TrainerCallback):
295
+ r"""
296
+ A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based
297
+ on a merge configuration.
298
+
299
+ Args:
300
+ merge_config ([`experimental.merge_model_callback.MergeConfig`], *optional*):
301
+ Configuration used for the merging process. If not provided, the default
302
+ [`~experimental.merge_model_callback.MergeConfig`] is used.
303
+ merge_at_every_checkpoint (`bool`, *optional*, defaults to `False`):
304
+ Whether to merge the model at every checkpoint.
305
+ push_to_hub (`bool`, *optional*, defaults to `False`):
306
+ Whether to push the merged model to the Hub after merging.
307
+
308
+ Example:
309
+
310
+ ```python
311
+ from trl.experimental.merge_model_callback import MergeConfig, MergeModelCallback
312
+
313
+ config = MergeConfig()
314
+ merge_callback = MergeModelCallback(config)
315
+ trainer = DPOTrainer(..., callbacks=[merge_callback])
316
+ ```
317
+ """
318
+
319
+ def __init__(
320
+ self,
321
+ merge_config: "MergeConfig | None" = None,
322
+ merge_at_every_checkpoint: bool = False,
323
+ push_to_hub: bool = False,
324
+ ):
325
+ if not is_mergekit_available():
326
+ raise ImportError(
327
+ "MergeModelCallback requires the `mergekit` extra. To install, run `pip install mergekit`."
328
+ )
329
+ self.merge_config = merge_config or MergeConfig()
330
+ self.merge_at_every_checkpoint = merge_at_every_checkpoint
331
+ self.push_to_hub = push_to_hub
332
+
333
+ def _merge_and_maybe_push(self, output_dir, global_step, model):
334
+ checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}")
335
+ self.merge_config.policy_model_path = checkpoint_path
336
+ if self.merge_config.target_model_path is None:
337
+ self.merge_config.target_model_path = get_config_model_id(model.config)
338
+ merge_path = os.path.join(checkpoint_path, "merged")
339
+
340
+ merge_models(self.merge_config.create(), merge_path)
341
+
342
+ if self.push_to_hub:
343
+ repo_name = f"{output_dir}_checkpoint-{global_step}_merged"
344
+ upload_model_to_hf(merge_path, repo_name)
345
+
346
+ def on_save(self, args, state, control, model=None, **kwargs):
347
+ if self.merge_at_every_checkpoint:
348
+ self._merge_and_maybe_push(args.output_dir, state.global_step, model)
349
+
350
+ def on_train_end(self, args, state, control, model=None, **kwargs):
351
+ if not self.merge_at_every_checkpoint:
352
+ self._merge_and_maybe_push(args.output_dir, state.global_step, model)
ICL/RL/trl_source/trl/experimental/minillm/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .minillm_config import MiniLLMConfig
16
+ from .minillm_trainer import MiniLLMTrainer
17
+
18
+
19
+ __all__ = ["MiniLLMConfig", "MiniLLMTrainer"]