qgallouedec HF Staff commited on
Commit
0666291
·
verified ·
1 Parent(s): 7907c84

Create rloo_trainer.py

Browse files
Files changed (1) hide show
  1. rloo_trainer.py +1520 -0
rloo_trainer.py ADDED
@@ -0,0 +1,1520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020-2026 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import atexit
17
+ import copy
18
+ import inspect
19
+ import math
20
+ import textwrap
21
+ import time
22
+ from collections import defaultdict, deque
23
+ from collections.abc import Callable
24
+ from contextlib import nullcontext
25
+ from pathlib import Path
26
+ from typing import Any
27
+
28
+ import numpy as np
29
+ import pandas as pd
30
+ import torch
31
+ import torch.utils.data
32
+ import transformers
33
+ from accelerate.logging import get_logger
34
+ from accelerate.utils import gather, gather_object, is_peft_model, set_seed
35
+ from datasets import Dataset, IterableDataset
36
+ from packaging.version import Version
37
+ from torch import nn
38
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
39
+ from torch.utils.data import Sampler
40
+ from transformers import (
41
+ AutoModelForSequenceClassification,
42
+ AutoProcessor,
43
+ AutoTokenizer,
44
+ GenerationConfig,
45
+ PreTrainedModel,
46
+ PreTrainedTokenizerBase,
47
+ ProcessorMixin,
48
+ TrainerCallback,
49
+ is_trackio_available,
50
+ is_wandb_available,
51
+ )
52
+ from transformers.utils import is_peft_available, is_rich_available
53
+
54
+ from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages
55
+ from ..extras.profiling import profiling_context, profiling_decorator
56
+ from ..generation.vllm_generation import VLLMGeneration
57
+ from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
58
+ from ..models.utils import disable_gradient_checkpointing
59
+ from .base_trainer import _BaseTrainer
60
+ from .callbacks import SyncRefModelCallback
61
+ from .rloo_config import RLOOConfig
62
+ from .utils import (
63
+ RepeatSampler,
64
+ create_model_from_path,
65
+ disable_dropout_in_model,
66
+ entropy_from_logits,
67
+ get_config_model_id,
68
+ identity,
69
+ nanmax,
70
+ nanmin,
71
+ nanstd,
72
+ pad,
73
+ print_prompt_completions_sample,
74
+ selective_log_softmax,
75
+ shuffle_sequence_dict,
76
+ shutdown_event_loop_in_daemon,
77
+ split_pixel_values_by_grid,
78
+ split_tensor_dict,
79
+ start_event_loop_in_daemon,
80
+ unsplit_pixel_values_by_grid,
81
+ use_adapter,
82
+ )
83
+
84
+
85
+ if is_peft_available():
86
+ from peft import PeftConfig, PeftModel, get_peft_model
87
+
88
+
89
+ if is_wandb_available():
90
+ import wandb
91
+
92
+ if is_trackio_available():
93
+ import trackio
94
+
95
+
96
+ logger = get_logger(__name__)
97
+
98
+ # A reward function can be a string, interpreted as a model ID and loaded as a pretrained model, a pretrained model, or
99
+ # a callable that returns a list of floats (the rewards). The callable receives prompts, completions, and additional
100
+ # arguments from the trainer (refer to the trainer's source for details). To ensure forward compatibility, it should
101
+ # accept **kwargs.
102
+ RewardFunc = str | PreTrainedModel | Callable[..., list[float | None]]
103
+
104
+
105
+ class RLOOTrainer(_BaseTrainer):
106
+ """
107
+ Trainer for the Reinforce Leave One Out (RLOO) method. This algorithm was initially proposed in the paper [Back to
108
+ Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in
109
+ LLMs](https://huggingface.co/papers/2402.14740).
110
+
111
+ Example:
112
+
113
+ ```python
114
+ from trl import RLOOTrainer
115
+ from trl.rewards import accuracy_reward
116
+ from datasets import load_dataset
117
+
118
+ dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
119
+
120
+ trainer = RLOOTrainer(
121
+ model="Qwen/Qwen2.5-0.5B-Instruct",
122
+ reward_funcs=accuracy_reward,
123
+ train_dataset=dataset,
124
+ )
125
+ trainer.train()
126
+ ```
127
+
128
+ Args:
129
+ model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]):
130
+ Model to be trained. Can be either:
131
+
132
+ - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
133
+ path to a *directory* containing model weights saved using
134
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
135
+ using `<ModelArchitecture>.from_pretrained` (where `<ModelArchitecture>` is derived from the model
136
+ config) with the keyword arguments in `args.model_init_kwargs`.
137
+ - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
138
+ - A [`~peft.PeftModel`] object. Only causal language models are supported.
139
+ reward_funcs (`RewardFunc | list[RewardFunc]`):
140
+ Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
141
+ functions with the prompts and completions and sum the rewards. Can be either:
142
+
143
+ - A single reward function, such as:
144
+ - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
145
+ path to a *directory* containing model weights saved using
146
+ [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
147
+ using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
148
+ keyword arguments in `args.model_init_kwargs`.
149
+ - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
150
+ - A custom reward function: The function is provided with the prompts and the generated completions,
151
+ plus any additional columns in the dataset. It should return a list of rewards. Custom reward
152
+ functions can be either synchronous or asynchronous and can also return `None` when the reward is
153
+ not applicable to those samples. This is useful for multi-task training where different reward
154
+ functions apply to different types of samples. When a reward function returns `None` for a sample,
155
+ that reward function is excluded from the reward calculation for that sample. For more details, see
156
+ [Using a custom reward
157
+ function](#using-a-custom-reward-function).
158
+
159
+ The trainer's state is also passed to the reward function. The trainer's state is an instance of
160
+ [`~transformers.TrainerState`] and can be accessed by accessing the `trainer_state` argument to the
161
+ reward function's signature.
162
+ - A list of reward functions, where each item can independently be any of the above types. Mixing different
163
+ types within the list (e.g., a string model ID and a custom reward function) is allowed.
164
+ args ([`RLOOConfig`], *optional*):
165
+ Configuration for this trainer. If `None`, a default configuration is used.
166
+ train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
167
+ Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
168
+ ignored. The format of the samples can be either:
169
+
170
+ - [Standard](dataset_formats#standard): Each sample contains plain text.
171
+ - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
172
+ and content).
173
+ eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`):
174
+ Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
175
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*):
176
+ Processing class used to process the data. The padding side must be set to "left". If `None`, the
177
+ processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
178
+ padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
179
+ `tokenizer.eos_token` will be used as the default.
180
+ reward_processing_classes ([`~transformers.PreTrainedTokenizerBase`] or `list[PreTrainedTokenizerBase]`, *optional*):
181
+ Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
182
+
183
+ - A single processing class: Used when `reward_funcs` contains only one reward function.
184
+ - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
185
+ If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
186
+ `None`, the tokenizer for the model is automatically loaded using
187
+ [`~transformers.AutoTokenizer.from_pretrained`]. For elements in `reward_funcs` that are custom reward
188
+ functions (not [`~transformers.PreTrainedModel`]), the corresponding entries in `reward_processing_classes`
189
+ are ignored.
190
+ callbacks (list of [`~transformers.TrainerCallback`], *optional*):
191
+ List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
192
+ in [here](https://huggingface.co/docs/transformers/main_classes/callback).
193
+
194
+ If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
195
+ method.
196
+ optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`):
197
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
198
+ model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
199
+ peft_config ([`~peft.PeftConfig`], *optional*):
200
+ PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
201
+ """
202
+
203
+ _tag_names = ["trl", "rloo"]
204
+ _name = "RLOO"
205
+ _paper = {
206
+ "title": "Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
207
+ "id": "2402.14740",
208
+ # docstyle-ignore
209
+ "citation": textwrap.dedent("""\
210
+ @inproceedings{ahmadian2024back,
211
+ title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
212
+ author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
213
+ year = 2024,
214
+ booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
215
+ pages = {12248--12267},
216
+ publisher = {Association for Computational Linguistics},
217
+ editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
218
+ }"""),
219
+ }
220
+
221
+ def __init__(
222
+ self,
223
+ model: "str | PreTrainedModel | PeftModel",
224
+ reward_funcs: RewardFunc | list[RewardFunc],
225
+ args: RLOOConfig | None = None,
226
+ train_dataset: Dataset | IterableDataset | None = None,
227
+ eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
228
+ processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None,
229
+ reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
230
+ callbacks: list[TrainerCallback] | None = None,
231
+ optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
232
+ peft_config: "PeftConfig | None" = None,
233
+ ):
234
+ # Args
235
+ if args is None:
236
+ model_name = model if isinstance(model, str) else get_config_model_id(model.config)
237
+ model_name = model_name.split("/")[-1]
238
+ args = RLOOConfig(f"{model_name}-RLOO")
239
+
240
+ # Model
241
+ if isinstance(model, str):
242
+ model_init_kwargs = args.model_init_kwargs or {}
243
+ # Distributed training requires device_map=None ("auto" fails)
244
+ if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
245
+ model_init_kwargs["device_map"] = None
246
+ model = create_model_from_path(model, **model_init_kwargs)
247
+ else:
248
+ if args.model_init_kwargs is not None:
249
+ logger.warning(
250
+ "You passed `model_init_kwargs` to the `RLOOConfig`, but your model is already instantiated. "
251
+ "The `model_init_kwargs` will be ignored."
252
+ )
253
+
254
+ # Some models (SmolVLM/Idefics3) don't support `logits_to_keep` argument and error out if we pass it
255
+ # Inspect the forward method before we wrap the model with PEFT
256
+ self.model_kwarg_keys = (
257
+ inspect.signature(model.forward).parameters.keys()
258
+ if not hasattr(model, "get_base_model")
259
+ else inspect.signature(model.get_base_model().forward).parameters.keys()
260
+ )
261
+
262
+ # Processing class
263
+ if processing_class is None:
264
+ processing_class = AutoProcessor.from_pretrained(
265
+ get_config_model_id(model.config), truncation_side="left", padding_side="left"
266
+ )
267
+
268
+ # Handle pad token for processors or tokenizers
269
+ if isinstance(processing_class, ProcessorMixin):
270
+ tokenizer = processing_class.tokenizer
271
+ elif isinstance(processing_class, PreTrainedTokenizerBase):
272
+ tokenizer = processing_class
273
+ else:
274
+ raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
275
+
276
+ if tokenizer.pad_token is None:
277
+ tokenizer.pad_token = tokenizer.eos_token
278
+
279
+ self.pad_token = tokenizer.pad_token
280
+ self.pad_token_id = tokenizer.pad_token_id
281
+ self.eos_token_id = tokenizer.eos_token_id
282
+
283
+ if is_peft_available() and is_peft_model(model) and peft_config is not None:
284
+ raise ValueError(
285
+ "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge "
286
+ "and unload the existing adapter, save the resulting base model, and then pass that base model along "
287
+ "with the new `peft_config` to the trainer."
288
+ )
289
+ if is_peft_available() and is_peft_model(model):
290
+ # If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy
291
+ # of the "default" adapter, so that we can use it as the reference model during the training.
292
+ model.add_adapter("ref", model.peft_config["default"])
293
+ for name, param in model.named_parameters():
294
+ if ".default." in name:
295
+ ref_name = name.replace(".default.", ".ref.")
296
+ ref_param = model.get_parameter(ref_name)
297
+ ref_param.data.copy_(param.data)
298
+
299
+ # Create PEFT model
300
+ if peft_config is not None:
301
+ model = get_peft_model(model, peft_config)
302
+
303
+ # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally
304
+ # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489
305
+ if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing:
306
+ model.enable_input_require_grads()
307
+
308
+ # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the
309
+ # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by
310
+ # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for
311
+ # quantized models. See: https://github.com/huggingface/peft/issues/2889
312
+ # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do
313
+ if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False):
314
+ for param in model.parameters():
315
+ if param.requires_grad:
316
+ param.data = param.data.to(torch.bfloat16)
317
+
318
+ # Reward functions
319
+ if not isinstance(reward_funcs, list):
320
+ reward_funcs = [reward_funcs]
321
+ self.reward_func_names = []
322
+ for i, reward_func in enumerate(reward_funcs):
323
+ if isinstance(reward_func, str):
324
+ model_init_kwargs = args.model_init_kwargs or {}
325
+ # Distributed training requires device_map=None ("auto" fails)
326
+ if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
327
+ model_init_kwargs["device_map"] = None
328
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
329
+ reward_func, num_labels=1, **model_init_kwargs
330
+ )
331
+ if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models
332
+ self.reward_func_names.append(get_config_model_id(reward_funcs[i].config).split("/")[-1])
333
+ else:
334
+ self.reward_func_names.append(reward_funcs[i].__name__)
335
+ self.reward_funcs = reward_funcs
336
+
337
+ self._has_async_reward_funcs = any(inspect.iscoroutinefunction(func) for func in self.reward_funcs)
338
+ if self._has_async_reward_funcs:
339
+ self.async_reward_loop_thread, self.async_reward_loop, self.async_reward_loop_ready_event = (
340
+ start_event_loop_in_daemon(name="RLOOTrainer-AsyncRewardLoop")
341
+ )
342
+ # wait until the event loop is running in the daemon thread
343
+ self.async_reward_loop_ready_event.wait()
344
+ atexit.register(shutdown_event_loop_in_daemon, self.async_reward_loop_thread, self.async_reward_loop)
345
+
346
+ # Reward weights
347
+ if args.reward_weights is not None:
348
+ if len(args.reward_weights) != len(reward_funcs):
349
+ raise ValueError(
350
+ f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
351
+ f"functions ({len(reward_funcs)})"
352
+ )
353
+ self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
354
+ else:
355
+ self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
356
+
357
+ # Reward processing class
358
+ if reward_processing_classes is None:
359
+ reward_processing_classes = [None] * len(reward_funcs)
360
+ elif not isinstance(reward_processing_classes, list):
361
+ reward_processing_classes = [reward_processing_classes]
362
+ if len(reward_processing_classes) != len(reward_funcs):
363
+ raise ValueError(
364
+ f"The number of reward processing classes ({len(reward_processing_classes)}) must match the number of "
365
+ f"reward functions ({len(reward_funcs)})."
366
+ )
367
+
368
+ for i, (reward_processing_class, reward_func) in enumerate(
369
+ zip(reward_processing_classes, reward_funcs, strict=True)
370
+ ):
371
+ if isinstance(reward_func, PreTrainedModel):
372
+ if reward_processing_class is None:
373
+ reward_processing_class = AutoTokenizer.from_pretrained(get_config_model_id(reward_func.config))
374
+ if reward_processing_class.pad_token_id is None:
375
+ reward_processing_class.pad_token = reward_processing_class.eos_token
376
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
377
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
378
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
379
+ reward_processing_classes[i] = reward_processing_class
380
+
381
+ self.reward_processing_classes = reward_processing_classes
382
+
383
+ # Training arguments
384
+ self.max_completion_length = args.max_completion_length
385
+ self.num_generations = args.num_generations
386
+ self.num_generations_eval = args.num_generations_eval or self.num_generations
387
+ self.chat_template_kwargs = args.chat_template_kwargs or {}
388
+ self.temperature = args.temperature
389
+ self.top_p = args.top_p
390
+ self.top_k = args.top_k
391
+ self.min_p = args.min_p
392
+ self.repetition_penalty = args.repetition_penalty
393
+ self.use_transformers_paged = args.use_transformers_paged
394
+ self.pad_to_multiple_of = args.pad_to_multiple_of
395
+ self.use_vllm = args.use_vllm
396
+ self.vllm_mode = args.vllm_mode
397
+ self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode
398
+ self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
399
+ self.normalize_advantages = args.normalize_advantages
400
+ self.mask_truncated_completions = args.mask_truncated_completions
401
+ self.reward_clip_range = args.reward_clip_range
402
+
403
+ # Datasets
404
+ self.shuffle_dataset = args.shuffle_dataset
405
+
406
+ if train_dataset is None:
407
+ raise ValueError("`train_dataset` is required")
408
+ elif (
409
+ isinstance(train_dataset, IterableDataset)
410
+ or isinstance(eval_dataset, IterableDataset)
411
+ or (
412
+ isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values())
413
+ )
414
+ ):
415
+ # See https://github.com/huggingface/trl/issues/3213
416
+ raise NotImplementedError(
417
+ "Iterable datasets are not yet supported in RLOOTrainer. Please use a standard dataset instead."
418
+ )
419
+
420
+ # Multi-step
421
+ self.num_iterations = args.num_iterations
422
+ self.epsilon_low = args.epsilon
423
+ self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
424
+ # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
425
+ self._step = 0
426
+ # Buffer the batch to reuse generated outputs across multiple updates. For more details, see
427
+ # `_get_train_sampler` and `_prepare_inputs`.
428
+ self._buffered_inputs = None
429
+
430
+ # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was
431
+ # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream
432
+ # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we
433
+ # default to the recommended non-reentrant behavior here, while preserving any user-provided value.
434
+ if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"):
435
+ args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
436
+ args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False)
437
+
438
+ super().__init__(
439
+ model=model,
440
+ args=args,
441
+ data_collator=identity, # No data collation is needed in RLOO
442
+ train_dataset=train_dataset,
443
+ eval_dataset=eval_dataset,
444
+ processing_class=processing_class,
445
+ callbacks=callbacks,
446
+ optimizers=optimizers,
447
+ )
448
+
449
+ # Reference model
450
+ self.beta = args.beta
451
+ if self.beta == 0.0:
452
+ # If beta is 0.0, the reference model is not needed
453
+ self.ref_model = None
454
+ elif is_peft_model(model):
455
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
456
+ # to revert to the initial model.
457
+ self.ref_model = None
458
+ else:
459
+ # For deepspeed, fsdp or non-distributed models, create a reference model from scratch
460
+ model_init_kwargs = args.model_init_kwargs or {}
461
+ # Distributed training requires device_map=None ("auto" fails)
462
+ if self.args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
463
+ model_init_kwargs["device_map"] = None
464
+ self.ref_model = create_model_from_path(get_config_model_id(self.model.config), **model_init_kwargs)
465
+
466
+ # Disable dropout in the models
467
+ if args.disable_dropout:
468
+ disable_dropout_in_model(model)
469
+ if self.ref_model is not None:
470
+ disable_dropout_in_model(self.ref_model)
471
+
472
+ # Initialize the metrics
473
+ self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
474
+ self._total_train_tokens = 0
475
+ self._current_train_step_time = 0.0
476
+ self.log_completions = args.log_completions
477
+ self.log_unique_prompts = args.log_unique_prompts
478
+ self.num_completions_to_print = args.num_completions_to_print
479
+ # Keep logs sized to the generation batch to record only outputs from the latest model update.
480
+ self._logs = {
481
+ "images": deque(maxlen=args.generation_batch_size),
482
+ "prompt": deque(maxlen=args.generation_batch_size),
483
+ "completion": deque(maxlen=args.generation_batch_size),
484
+ "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
485
+ "advantages": deque(maxlen=args.generation_batch_size),
486
+ "extra": defaultdict(lambda: deque(maxlen=args.generation_batch_size)),
487
+ }
488
+ # Buffers for user-logged data from reward functions, flushed after gathering
489
+ self._pending_extra_logs = defaultdict(list)
490
+ self._pending_metrics = defaultdict(list)
491
+
492
+ # Ensure each process receives a unique seed to prevent duplicate completions when generating with
493
+ # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
494
+ # it's safer to set it in all cases.
495
+ set_seed(args.seed, device_specific=True)
496
+
497
+ if self.use_vllm:
498
+ # Initialize vLLM generation backend
499
+ self.vllm_generation = VLLMGeneration(
500
+ model=self.model,
501
+ accelerator=self.accelerator,
502
+ is_fsdp_enabled=self.is_fsdp_enabled,
503
+ processing_class=self.processing_class,
504
+ # vLLM configuration
505
+ mode=args.vllm_mode,
506
+ structured_outputs_regex=args.vllm_structured_outputs_regex,
507
+ # Server mode configuration
508
+ server_base_url=args.vllm_server_base_url,
509
+ server_host=args.vllm_server_host,
510
+ server_port=args.vllm_server_port,
511
+ group_port=args.vllm_group_port,
512
+ server_timeout=args.vllm_server_timeout,
513
+ # Colocate mode configuration
514
+ tensor_parallel_size=args.vllm_tensor_parallel_size,
515
+ gpu_memory_utilization=args.vllm_gpu_memory_utilization,
516
+ max_model_length=args.vllm_max_model_length,
517
+ max_num_seqs=args.per_device_train_batch_size
518
+ * args.vllm_tensor_parallel_size
519
+ * args.steps_per_generation,
520
+ enable_sleep_mode=args.vllm_enable_sleep_mode,
521
+ model_impl=args.vllm_model_impl,
522
+ # Generation configuration
523
+ repetition_penalty=self.repetition_penalty,
524
+ temperature=self.temperature,
525
+ top_p=self.top_p,
526
+ top_k=self.top_k,
527
+ min_p=self.min_p,
528
+ max_completion_length=self.max_completion_length,
529
+ logprobs=None, # we don't need logprobs from vLLM in RLOO
530
+ generation_kwargs=args.generation_kwargs,
531
+ )
532
+ self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation
533
+ else:
534
+ generation_kwargs = {
535
+ "max_new_tokens": self.max_completion_length,
536
+ "do_sample": True,
537
+ "pad_token_id": tokenizer.pad_token_id,
538
+ "bos_token_id": tokenizer.bos_token_id,
539
+ "eos_token_id": tokenizer.eos_token_id,
540
+ "temperature": self.temperature,
541
+ "top_p": self.top_p,
542
+ "top_k": self.top_k,
543
+ "min_p": self.min_p,
544
+ "repetition_penalty": self.repetition_penalty,
545
+ "cache_implementation": args.cache_implementation,
546
+ }
547
+ if args.generation_kwargs is not None:
548
+ generation_kwargs.update(args.generation_kwargs)
549
+ self.generation_config = GenerationConfig(**generation_kwargs, disable_compile=True)
550
+ # Keep training-specific generation kwargs to overwrite model's original generation config
551
+ self.generation_kwargs = generation_kwargs
552
+
553
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
554
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
555
+ # self.model_accepts_loss_kwargs to False to enable scaling.
556
+ self.model_accepts_loss_kwargs = False
557
+
558
+ # Add tags to the model
559
+ self.model.add_model_tags(self._tag_names)
560
+
561
+ if self.ref_model is not None:
562
+ if self.is_deepspeed_enabled:
563
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
564
+ elif self.is_fsdp_enabled:
565
+ self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
566
+ else:
567
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
568
+
569
+ if args.sync_ref_model:
570
+ if self.beta == 0.0:
571
+ raise ValueError(
572
+ "You passed `sync_ref_model=True` while `beta=0.0`, which means the reference model is not used "
573
+ "during training. Consequently, RLOOTrainer does not create a `ref_model` instance, and there is "
574
+ "nothing to synchronize. Please set `sync_ref_model=False`, or set `beta` to a non-zero value."
575
+ )
576
+ if is_peft_model(model):
577
+ raise NotImplementedError(
578
+ "You passed `sync_ref_model=True` while using a PEFT model, which is currently not supported. "
579
+ "With PEFT, RLOOTrainer does not keep a separate reference model in memory; instead, it recovers "
580
+ "reference behavior by temporarily disabling the adapter. As a result, there is no standalone "
581
+ "`ref_model` instance to synchronize. Use `sync_ref_model=False`, or opt for full fine-tuning if "
582
+ "you need a synced reference model. If you need `sync_ref_model` to work with PEFT, please open a "
583
+ "feature request at https://github.com/huggingface/trl/issues."
584
+ )
585
+ self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
586
+
587
+ for i, reward_func in enumerate(self.reward_funcs):
588
+ if isinstance(reward_func, PreTrainedModel):
589
+ if self.is_deepspeed_enabled:
590
+ self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
591
+ else:
592
+ # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
593
+ self.reward_funcs[i] = self.accelerator.prepare_model(
594
+ reward_func, evaluation_mode=True, device_placement=True
595
+ )
596
+
597
+ def _set_signature_columns_if_needed(self):
598
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
599
+ # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
600
+ # and "attention_mask"). In RLOOTrainer, we preprocess data, so using the model's signature columns doesn't
601
+ # work. Instead, we set them to the columns expected by the `training_step` method, hence the override.
602
+ if self._signature_columns is None:
603
+ self._signature_columns = ["prompt", "image", "images"]
604
+
605
+ # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy.
606
+ # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an
607
+ # *generation* batch (i.e., `per_device_batch_size × steps_per_generation`). This allows us to generate completions
608
+ # once every steps_per_generation step—rather than once per accumulation step—which is significantly more
609
+ # efficient. The only change from the original implementation is multiplying the batch size by
610
+ # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the
611
+ # splitting internally.
612
+ # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line
613
+ # modification.
614
+ def get_train_dataloader(self):
615
+ return self._get_dataloader(
616
+ dataset=self.train_dataset,
617
+ description="Training",
618
+ batch_size=self._train_batch_size * self.args.steps_per_generation, # < this is the change
619
+ sampler_fn=self._get_train_sampler,
620
+ is_training=True,
621
+ )
622
+
623
+ def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler:
624
+ # Returns a sampler that
625
+ # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are
626
+ # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt
627
+ # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies
628
+ # in group formation.
629
+ # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to
630
+ # _prepare_inputs to see how the generations are stored and reused.
631
+
632
+ # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the
633
+ # second row shows the second sampled batch, and so on.
634
+ #
635
+ # | GPU 0 | GPU 1 |
636
+ #
637
+ # global_step step <-───> num_generations=2
638
+ # <-───────> per_device_train_batch_size=3
639
+ # grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss
640
+ # =2 ▼ | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss
641
+ # |
642
+ # | 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss
643
+ # steps_per_gen=4 ▼ 1 3 9 9 10 10 11 11 <- Take the stored generations and use the fourth slice to compute the loss
644
+ #
645
+ # 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss
646
+ # 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss
647
+ # ...
648
+ if dataset is None:
649
+ dataset = self.train_dataset
650
+ return RepeatSampler(
651
+ data_source=dataset,
652
+ mini_repeat_count=self.num_generations,
653
+ batch_size=self.args.generation_batch_size // self.num_generations,
654
+ repeat_count=self.num_iterations * self.args.steps_per_generation,
655
+ shuffle=self.shuffle_dataset,
656
+ seed=self.args.seed,
657
+ )
658
+
659
+ def _get_eval_sampler(self, eval_dataset) -> Sampler:
660
+ # See _get_train_sampler for an explanation of the sampler.
661
+ return RepeatSampler(
662
+ data_source=eval_dataset,
663
+ mini_repeat_count=self.num_generations_eval,
664
+ seed=self.args.seed,
665
+ )
666
+
667
+ @profiling_decorator
668
+ def _get_per_token_logps_and_entropies(
669
+ self,
670
+ model,
671
+ input_ids,
672
+ attention_mask,
673
+ logits_to_keep,
674
+ batch_size=None,
675
+ compute_entropy=False,
676
+ pixel_values=None,
677
+ image_grid_thw=None,
678
+ num_images=None,
679
+ pixel_attention_mask=None,
680
+ image_sizes=None,
681
+ token_type_ids=None,
682
+ mm_token_type_ids=None,
683
+ pixel_position_ids=None,
684
+ ) -> dict[str, torch.Tensor | None]:
685
+ """Compute log-probs and (optionally) entropies for each token."""
686
+ batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
687
+ all_logps = []
688
+ all_entropies = []
689
+ for start in range(0, input_ids.size(0), batch_size):
690
+ input_ids_batch = input_ids[start : start + batch_size]
691
+ attention_mask_batch = attention_mask[start : start + batch_size]
692
+
693
+ # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
694
+ model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
695
+ if image_grid_thw is not None and pixel_values is not None:
696
+ rows_per_image = image_grid_thw.prod(dim=-1)
697
+ rows_per_sample = torch.split(rows_per_image, num_images)
698
+ rows_per_sample = torch.stack([s.sum() for s in rows_per_sample])
699
+ cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)])
700
+ row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item()
701
+ model_inputs["pixel_values"] = pixel_values[row_start:row_end]
702
+ cum_imgs = torch.tensor([0] + num_images).cumsum(0)
703
+ img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size]
704
+ model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end]
705
+ elif pixel_values is not None:
706
+ model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
707
+ if pixel_attention_mask is not None:
708
+ model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size]
709
+ if image_sizes is not None:
710
+ model_inputs["image_sizes"] = image_sizes[start : start + batch_size]
711
+ if token_type_ids is not None:
712
+ model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size]
713
+ if mm_token_type_ids is not None:
714
+ model_inputs["mm_token_type_ids"] = mm_token_type_ids[start : start + batch_size]
715
+ if pixel_position_ids is not None:
716
+ model_inputs["pixel_position_ids"] = pixel_position_ids[start : start + batch_size]
717
+
718
+ # Only add logits_to_keep if the model supports it
719
+ if "logits_to_keep" in self.model_kwarg_keys:
720
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
721
+ model_inputs["logits_to_keep"] = logits_to_keep + 1
722
+
723
+ model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings
724
+
725
+ logits = model(**model_inputs).logits
726
+ # Exclude the last value: it corresponds to the next token pred
727
+ logits = logits[:, :-1, :] # (B, L-1, H)
728
+ # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op.
729
+ logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H)
730
+ # Divide logits by sampling temperature.
731
+ # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
732
+ logits.div_(self.temperature)
733
+ completion_ids = input_ids_batch[:, -logits_to_keep:]
734
+ logps = selective_log_softmax(logits, completion_ids) # compute logprobs
735
+ all_logps.append(logps)
736
+
737
+ if compute_entropy:
738
+ with torch.no_grad():
739
+ entropies = entropy_from_logits(logits)
740
+ all_entropies.append(entropies)
741
+
742
+ logps = torch.cat(all_logps, dim=0)
743
+ entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None
744
+ return logps, entropies
745
+
746
+ def training_step(self, model, inputs, num_items_in_batch):
747
+ time_before = time.perf_counter()
748
+ output = super().training_step(model, inputs, num_items_in_batch)
749
+ self._step += 1
750
+ time_after = time.perf_counter()
751
+ self._current_train_step_time += time_after - time_before
752
+ if self._step % self.current_gradient_accumulation_steps == 0:
753
+ self._metrics["train"]["step_time"].append(self._current_train_step_time)
754
+ self._current_train_step_time = 0.0
755
+ return output
756
+
757
+ @profiling_decorator
758
+ def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]:
759
+ # Prepares inputs for model training/evaluation by managing completion generation and batch handling.
760
+ # During training:
761
+ # - Receives the local generation batch (Per-GPU batch size × steps per generation)
762
+ # from the modified training dataloader instead of the standard local batch
763
+ # - Generates completions once for the entire generation batch and splits it into batches of size
764
+ # `per_device_train_batch_size`
765
+ # - Buffers these completions and returns the appropriate slice for the current accumulation step
766
+ # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations)
767
+ # During evaluation:
768
+ # - The input is treated as a standard local batch (no accumulation, no multiple iterations)
769
+ # - Completions are generated for each batch without buffering or reuse
770
+ # Returns a single local batch in both cases.
771
+
772
+ mode = "train" if self.model.training else "eval"
773
+ if mode == "train":
774
+ generate_every = self.args.steps_per_generation * self.num_iterations
775
+ if self._step % generate_every == 0 or self._buffered_inputs is None:
776
+ # self._buffered_inputs=None can occur when resuming from a checkpoint
777
+ generation_batch = self._generate_and_score_completions(generation_batch)
778
+ generation_batch = split_pixel_values_by_grid(generation_batch)
779
+ generation_batch = shuffle_sequence_dict(generation_batch)
780
+ generation_batches = split_tensor_dict(generation_batch, self.args.steps_per_generation)
781
+ self._buffered_inputs = [unsplit_pixel_values_by_grid(batch) for batch in generation_batches]
782
+ inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]
783
+ else:
784
+ # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence
785
+ # local generation batch == local eval batch
786
+ inputs = self._generate_and_score_completions(generation_batch)
787
+ return inputs
788
+
789
+ def _log_completion_extra(self, column: str, values: list):
790
+ """
791
+ Log extra columns to the completions table. Called from reward functions via the `log_extra` kwarg.
792
+
793
+ Args:
794
+ column (`str`):
795
+ Name of the column to add.
796
+ values (`list`):
797
+ Values for the column, one per sample in the batch.
798
+ """
799
+ self._pending_extra_logs[column].extend(values)
800
+
801
+ def _log_metric(self, name: str, value: float):
802
+ """
803
+ Log a scalar metric from a reward function. Called via the `log_metric` kwarg. Values are averaged over each
804
+ logging step and reported alongside built-in metrics like `kl` and `entropy`.
805
+
806
+ Args:
807
+ name (`str`):
808
+ Name of the metric.
809
+ value (`float`):
810
+ Scalar value for this batch.
811
+ """
812
+ self._pending_metrics[name].append(value)
813
+
814
+ @profiling_decorator
815
+ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list):
816
+ device = self.accelerator.device
817
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
818
+
819
+ # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations
820
+ keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]
821
+ reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
822
+
823
+ # This allows for dynamic reward shaping based on training progress.
824
+ reward_kwargs["trainer_state"] = self.state
825
+
826
+ # Allow reward functions to log extra columns to the completions table.
827
+ reward_kwargs["log_extra"] = self._log_completion_extra
828
+
829
+ # Allow reward functions to log additional scalar metrics.
830
+ reward_kwargs["log_metric"] = self._log_metric
831
+
832
+ async_funcs_info = [] # async custom functions for asyncio.gather
833
+
834
+ for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
835
+ zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names, strict=True)
836
+ ):
837
+ if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models
838
+ with profiling_context(self, reward_func_name):
839
+ if is_conversational(inputs[0]):
840
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)]
841
+ texts = [
842
+ apply_chat_template(x, reward_processing_class, **self.chat_template_kwargs)["text"]
843
+ for x in messages
844
+ ]
845
+ else:
846
+ texts = [p + c for p, c in zip(prompts, completions, strict=True)]
847
+ reward_inputs = reward_processing_class(
848
+ text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
849
+ )
850
+ reward_inputs = super()._prepare_inputs(reward_inputs)
851
+ with torch.inference_mode():
852
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
853
+ elif inspect.iscoroutinefunction(reward_func): # Separate async reward funcs to run them in parallel later
854
+ async_funcs_info.append((i, reward_func, reward_func_name))
855
+ else:
856
+ # Run synchronous reward function
857
+ with profiling_context(self, reward_func_name):
858
+ output_reward_func = reward_func(
859
+ prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
860
+ )
861
+ # Convert None values to NaN
862
+ output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
863
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
864
+
865
+ # Execute async custom functions in parallel using asyncio.gather
866
+ if async_funcs_info:
867
+
868
+ async def _invoke_async_reward(index, func, func_name):
869
+ with profiling_context(self, func_name):
870
+ output = await func(
871
+ prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
872
+ )
873
+ output = [r if r is not None else torch.nan for r in output]
874
+ return index, output
875
+
876
+ async def _run_async_funcs():
877
+ coros = [_invoke_async_reward(i, func, func_name) for (i, func, func_name) in async_funcs_info]
878
+ return await asyncio.gather(*coros)
879
+
880
+ async_results = asyncio.run_coroutine_threadsafe(_run_async_funcs(), self.async_reward_loop).result()
881
+ for idx, output_reward_func in async_results:
882
+ rewards_per_func[:, idx] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
883
+
884
+ # If all reward functions return None for a given row, issue a detailed warning
885
+ if torch.isnan(rewards_per_func).all(dim=1).any():
886
+ nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
887
+ row_reward_kwargs = {
888
+ key: value[nan_row_idx]
889
+ for key, value in reward_kwargs.items()
890
+ if key not in ("trainer_state", "log_extra", "log_metric")
891
+ }
892
+ row_reward_kwargs["prompt"] = prompts[nan_row_idx]
893
+ row_reward_kwargs["completion"] = completions[nan_row_idx]
894
+ logger.warning(
895
+ f"All reward functions returned None for the following kwargs:\n{row_reward_kwargs}\n"
896
+ "Please ensure that at least one reward function returns a valid reward."
897
+ )
898
+
899
+ # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
900
+ # completions may be distributed across processes
901
+ rewards_per_func = gather(rewards_per_func)
902
+ return rewards_per_func
903
+
904
+ def _tokenize_prompts(self, prompts: list):
905
+ """Tokenize prompts and extract images/multimodal fields for generation."""
906
+ if is_conversational({"prompt": prompts[0]}):
907
+ # Extract images from messages for VLM support
908
+ images = []
909
+ has_images = False
910
+ for prompt in prompts:
911
+ prompt_images = []
912
+ for message in prompt:
913
+ if isinstance(message["content"], list):
914
+ for part in message["content"]:
915
+ if part["type"] == "image":
916
+ prompt_images.append(part["image"])
917
+ has_images = True
918
+ images.append(prompt_images if prompt_images else None)
919
+ images = images if has_images else None
920
+
921
+ # We pass padding=True to work around a bug introduced in transformers 5.2.0 in some processors
922
+ # (e.g. Qwen2.5-VL) that crash on batched unpadded input. We then unpad input_ids using attention_mask.
923
+ # See: https://github.com/huggingface/transformers/issues/44514
924
+ tokenized = self.processing_class.apply_chat_template(
925
+ conversation=prompts,
926
+ add_generation_prompt=True,
927
+ tokenize=True,
928
+ return_dict=True,
929
+ padding=True,
930
+ **self.chat_template_kwargs,
931
+ )
932
+ # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists
933
+ prompt_ids = [
934
+ [tok for tok, m in zip(ids, mask, strict=True) if m]
935
+ for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True)
936
+ ]
937
+ # For VLMs, the processor returns extra multimodal fields (pixel_values, image_grid_thw, etc.)
938
+ multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")}
939
+ else:
940
+ prompt_ids = self.processing_class(text=prompts)["input_ids"]
941
+ images = None
942
+ multimodal_fields = {}
943
+ return prompt_ids, images, multimodal_fields
944
+
945
+ def _generate_single_turn(self, prompt_ids, images, multimodal_fields):
946
+ device = self.accelerator.device
947
+ mode = "train" if self.model.training else "eval"
948
+
949
+ # Generate completions using either vLLM or regular generation
950
+ if self.use_vllm:
951
+ # Sync weights if training step changed
952
+ if self.state.global_step != self._last_loaded_step:
953
+ with profiling_context(self, "sync_weights"):
954
+ self.vllm_generation.sync_weights()
955
+ self._last_loaded_step = self.state.global_step
956
+
957
+ # Generate using vLLM (note: RLOO doesn't use logprobs from generation, so we ignore them)
958
+ num_generations = self.num_generations if mode == "train" else self.num_generations_eval
959
+ _, completion_ids, _, _ = self.vllm_generation.generate(
960
+ prompts=prompt_ids,
961
+ images=images,
962
+ num_generations=num_generations,
963
+ profiler=profiling_context(self, "vLLM.generate"),
964
+ )
965
+
966
+ elif self.use_transformers_paged:
967
+ with (
968
+ profiling_context(self, "transformers.generate_batch"),
969
+ unwrap_model_for_generation(
970
+ self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
971
+ ) as unwrapped_model,
972
+ torch.no_grad(),
973
+ FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
974
+ ):
975
+ # Cast to the appropriate dtype based on training configuration
976
+ if self.args.bf16:
977
+ unwrapped_model.to(torch.bfloat16)
978
+ elif self.args.fp16:
979
+ unwrapped_model.to(torch.float16)
980
+ with torch.inference_mode():
981
+ # Continuous batching API expects 'inputs' arg only
982
+ all_outputs = unwrapped_model.generate_batch(
983
+ prompt_ids, generation_config=self.generation_config, progress_bar=False
984
+ )
985
+ unwrapped_model.train() # restore training mode, as generate_batch forces eval mode
986
+ completion_ids = [output.generated_tokens for output in all_outputs.values()]
987
+
988
+ else:
989
+ # Regular generation path: left-pad token IDs into tensors
990
+ prompt_tensors = [torch.tensor(ids) for ids in prompt_ids]
991
+ padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left")
992
+ attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left")
993
+ generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask}
994
+ # For VLMs, include multimodal fields as tensors (pixel_values, image_grid_thw, etc.)
995
+ for k, v in multimodal_fields.items():
996
+ if isinstance(v, torch.Tensor):
997
+ generate_inputs[k] = v
998
+ elif isinstance(v, list) and v and isinstance(v[0], list):
999
+ # Per-token field (e.g., token_type_ids): left-pad like input_ids
1000
+ generate_inputs[k] = pad([torch.tensor(x) for x in v], padding_value=0, padding_side="left")
1001
+ else:
1002
+ generate_inputs[k] = torch.tensor(np.array(v))
1003
+ generate_inputs = super()._prepare_inputs(generate_inputs)
1004
+
1005
+ with (
1006
+ profiling_context(self, "transformers.generate"),
1007
+ unwrap_model_for_generation(
1008
+ self.model_wrapped,
1009
+ self.accelerator,
1010
+ gather_deepspeed3_params=self.args.ds3_gather_for_generation,
1011
+ generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762
1012
+ ) as unwrapped_model,
1013
+ torch.no_grad(),
1014
+ FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
1015
+ ):
1016
+ prompt_completion_ids = unwrapped_model.generate(
1017
+ **generate_inputs, generation_config=self.generation_config
1018
+ )
1019
+ # Compute prompt length and extract completion ids
1020
+ prompt_length = generate_inputs["input_ids"].size(1)
1021
+ completion_ids = prompt_completion_ids[:, prompt_length:]
1022
+
1023
+ # Mask everything after the first EOS token
1024
+ is_eos = completion_ids == self.eos_token_id
1025
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
1026
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
1027
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
1028
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
1029
+ completion_ids = [
1030
+ c[m].tolist() for c, m in zip(completion_ids.cpu(), completion_mask.bool().cpu(), strict=True)
1031
+ ]
1032
+
1033
+ return completion_ids
1034
+
1035
+ def _generate(self, prompts: list):
1036
+ device = self.accelerator.device
1037
+ mode = "train" if self.model.training else "eval"
1038
+
1039
+ # Copy the prompts to avoid modifying the original list
1040
+ prompts = copy.deepcopy(prompts)
1041
+
1042
+ prompt_ids, images, multimodal_fields = self._tokenize_prompts(prompts)
1043
+ completion_ids = self._generate_single_turn(prompt_ids, images, multimodal_fields)
1044
+
1045
+ # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls.
1046
+ if is_conversational({"prompt": prompts[0]}):
1047
+ contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
1048
+ completions = [[{"role": "assistant", "content": content}] for content in contents]
1049
+ else:
1050
+ completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
1051
+
1052
+ # Get completion length per sequence, used for logging
1053
+ prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
1054
+ completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device)
1055
+ agg_prompt_lengths = self.accelerator.gather(prompt_lengths)
1056
+ agg_completion_lengths = self.accelerator.gather(completion_lengths)
1057
+ total_prompt_tokens = agg_prompt_lengths.sum()
1058
+ total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss
1059
+
1060
+ # Log the metrics
1061
+ if mode == "train":
1062
+ self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item()
1063
+ self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
1064
+
1065
+ # Log completion lengths, mean, min, max
1066
+ self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())
1067
+ self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())
1068
+ self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())
1069
+
1070
+ # Identify sequences that terminated with EOS and log their lengths
1071
+ eos_and_pad = [self.eos_token_id, self.pad_token_id]
1072
+ is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device)
1073
+ agg_is_truncated = self.accelerator.gather(is_truncated)
1074
+ self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item())
1075
+ term_completion_lengths = agg_completion_lengths[~agg_is_truncated]
1076
+ if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found
1077
+ term_completion_lengths = torch.zeros(1, device=device)
1078
+ self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
1079
+ self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
1080
+ self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
1081
+
1082
+ return prompt_ids, completion_ids, completions
1083
+
1084
+ def _generate_and_score_completions(
1085
+ self, inputs: list[dict[str, torch.Tensor | Any]]
1086
+ ) -> dict[str, torch.Tensor | Any]:
1087
+ device = self.accelerator.device
1088
+ mode = "train" if self.model.training else "eval"
1089
+
1090
+ prompts = [x["prompt"] for x in inputs]
1091
+
1092
+ if "images" in inputs[0]:
1093
+ images = [example.get("images") for example in inputs]
1094
+ elif "image" in inputs[0]:
1095
+ images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
1096
+ else:
1097
+ images = None
1098
+ # Transformers requires at least one image in the batch, otherwise it throws an error
1099
+ if images is not None and all(img_list == [] for img_list in images):
1100
+ images = None
1101
+
1102
+ # If the prompts are conversational and the inputs contain images, we need to convert the prompts from
1103
+ # [{"role": "user", "content": "What color is the sky?"}] to
1104
+ # [{"role": "user", "content": [{"type": "image", "image": <Image>}, {"type": "text", "text": "What color is the sky?"}]}]
1105
+ if images is not None:
1106
+ if not is_conversational(inputs[0]):
1107
+ raise ValueError(
1108
+ "Multimodal training requires conversational prompts. It looks like the dataset contains "
1109
+ "non-conversational inputs, likely because a chat template was applied before passing the dataset "
1110
+ "to the trainer. Please provide the raw conversational prompts and let the trainer apply the chat "
1111
+ "template internally."
1112
+ )
1113
+ prompts = [
1114
+ prepare_multimodal_messages(prompt, image_list)
1115
+ for prompt, image_list in zip(prompts, images, strict=True)
1116
+ ]
1117
+
1118
+ prompt_ids_list, completion_ids_list, completions = self._generate(prompts)
1119
+
1120
+ # Convert lists of token IDs to padded tensors
1121
+ prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list]
1122
+ prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
1123
+ prompt_ids = pad(
1124
+ prompt_ids,
1125
+ padding_value=self.pad_token_id,
1126
+ padding_side="left",
1127
+ pad_to_multiple_of=self.pad_to_multiple_of,
1128
+ ).to(device=device)
1129
+ prompt_mask = pad(
1130
+ prompt_mask, padding_value=0, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of
1131
+ ).to(device=device)
1132
+ completion_ids = [torch.tensor(ids) for ids in completion_ids_list]
1133
+ completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
1134
+ completion_ids = pad(
1135
+ completion_ids,
1136
+ padding_value=self.pad_token_id,
1137
+ padding_side="right",
1138
+ pad_to_multiple_of=self.pad_to_multiple_of,
1139
+ ).to(device=device)
1140
+ completion_mask = pad(
1141
+ completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
1142
+ ).to(device=device)
1143
+
1144
+ # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
1145
+ if self.mask_truncated_completions:
1146
+ eos_and_pad = [self.eos_token_id, self.pad_token_id]
1147
+ # Mask completion_mask for attention masking
1148
+ is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
1149
+ completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
1150
+
1151
+ # Concatenate prompt_mask with completion_mask for logit computation
1152
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
1153
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
1154
+
1155
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
1156
+ batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
1157
+
1158
+ num_images = [len(img_list) for img_list in images] if images is not None else None
1159
+
1160
+ # Get forward_kwargs for models with multimodal inputs
1161
+ if images is not None:
1162
+ prompts_text = [
1163
+ apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"]
1164
+ for prompt in prompts
1165
+ ]
1166
+ prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
1167
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
1168
+ forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
1169
+ else:
1170
+ forward_kwargs = {}
1171
+
1172
+ # If token_type_ids are used, extend them with zeros for the completion part
1173
+ if "token_type_ids" in forward_kwargs:
1174
+ token_type_ids = forward_kwargs["token_type_ids"]
1175
+ if self.pad_to_multiple_of is not None:
1176
+ # Needed only with pad_to_multiple_of: otherwise prompt_ids and token_type_ids must have equal len
1177
+ padding_size = prompt_ids.size(1) - token_type_ids.size(1)
1178
+ if padding_size > 0:
1179
+ token_type_ids = torch.cat(
1180
+ [token_type_ids.new_zeros((token_type_ids.size(0), padding_size)), token_type_ids], dim=1
1181
+ )
1182
+ forward_kwargs["token_type_ids"] = torch.cat(
1183
+ [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
1184
+ )
1185
+ # If mm_token_type_ids are used, extend them with zeros for the completion part
1186
+ if "mm_token_type_ids" in forward_kwargs:
1187
+ mm_token_type_ids = forward_kwargs["mm_token_type_ids"]
1188
+ if self.pad_to_multiple_of is not None:
1189
+ # Needed only with pad_to_multiple_of: otherwise prompt_ids and mm_token_type_ids must have equal len
1190
+ padding_size = prompt_ids.size(1) - mm_token_type_ids.size(1)
1191
+ if padding_size > 0:
1192
+ mm_token_type_ids = torch.cat(
1193
+ [mm_token_type_ids.new_zeros((mm_token_type_ids.size(0), padding_size)), mm_token_type_ids],
1194
+ dim=1,
1195
+ )
1196
+ forward_kwargs["mm_token_type_ids"] = torch.cat(
1197
+ [mm_token_type_ids, mm_token_type_ids.new_zeros(completion_ids.shape)], dim=1
1198
+ )
1199
+
1200
+ # When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a
1201
+ # torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True").
1202
+ # Temporarily disable checkpointing to avoid this warning during inference.
1203
+ with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs):
1204
+ # Compute the per-token log probabilities for the current model
1205
+ old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
1206
+ self.model,
1207
+ prompt_completion_ids,
1208
+ attention_mask,
1209
+ logits_to_keep,
1210
+ batch_size,
1211
+ num_images=num_images,
1212
+ **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids
1213
+ )
1214
+ old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS
1215
+
1216
+ # Compute the per-token log probabilities for the reference model
1217
+ if self.beta != 0.0:
1218
+ if self.ref_model is not None:
1219
+ ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
1220
+ self.ref_model,
1221
+ prompt_completion_ids,
1222
+ attention_mask,
1223
+ logits_to_keep,
1224
+ batch_size=batch_size,
1225
+ num_images=num_images,
1226
+ **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids
1227
+ )
1228
+ else:
1229
+ # When training a PEFT adapter, how we obtain the reference depends on the setup:
1230
+ # - New adapter: disabling adapters yields the base model.
1231
+ # - Re-training an existing adapter: an initial copy is loaded under the name "ref".
1232
+ model = self.accelerator.unwrap_model(self.model)
1233
+ with use_adapter(model, adapter_name="ref" if "ref" in model.peft_config else None):
1234
+ ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
1235
+ self.model,
1236
+ prompt_completion_ids,
1237
+ attention_mask,
1238
+ logits_to_keep,
1239
+ batch_size=batch_size,
1240
+ num_images=num_images,
1241
+ **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids
1242
+ )
1243
+ else:
1244
+ ref_per_token_logps = None
1245
+
1246
+ # Decode
1247
+ prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True)
1248
+ completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
1249
+
1250
+ # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
1251
+ # important because rewards will be normalized per group, and completions are distributed. We will later slice
1252
+ # rewards_per_func to extract each process's subset.
1253
+ rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list)
1254
+ num_generations = self.num_generations if mode == "train" else self.num_generations_eval
1255
+
1256
+ # Apply weights to each reward function's output and sum
1257
+ rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
1258
+
1259
+ # Apply reward clipping if specified
1260
+ if self.reward_clip_range:
1261
+ rewards = rewards.clamp(min=self.reward_clip_range[0], max=self.reward_clip_range[1])
1262
+
1263
+ # Include the KL penalty in the reward
1264
+ if self.beta != 0.0:
1265
+ per_token_kl = old_per_token_logps - ref_per_token_logps
1266
+ # Apply sequence-level KL penalty to rewards (sum KL across tokens first, then apply to each sequence)
1267
+ kl = (per_token_kl * completion_mask).sum(-1)
1268
+ kl = gather(kl) # rewards are gathered, so kl must be too
1269
+ rewards = rewards - self.beta * kl
1270
+
1271
+ grouped_rewards = rewards.view(-1, num_generations)
1272
+ mean_grouped_rewards = grouped_rewards.mean(dim=1)
1273
+ if num_generations > 1:
1274
+ std_rewards = grouped_rewards.std(dim=1)
1275
+ else: # doesn't occur during training, but could occur in eval when num_generations_eval=1
1276
+ std_rewards = torch.zeros_like(mean_grouped_rewards)
1277
+
1278
+ # RLOO advantages computation
1279
+ grouped_sum = grouped_rewards.sum(dim=1, keepdim=True) # (num_prompts, 1)
1280
+ if num_generations > 1:
1281
+ baselines = (grouped_sum - grouped_rewards) / (num_generations - 1) # (num_prompts, num_generations)
1282
+ baselines = baselines.view(-1) # Flatten back to match rewards shape
1283
+ advantages = rewards - baselines
1284
+ else: # this case doesn't occur during training, but could in eval when num_generations_eval=1
1285
+ advantages = torch.zeros_like(rewards)
1286
+
1287
+ # Normalize advantages
1288
+ if self.normalize_advantages:
1289
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-4)
1290
+
1291
+ is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) # for logging
1292
+
1293
+ # Slice to keep only the local part of the data
1294
+ process_slice = slice(
1295
+ self.accelerator.process_index * len(prompts),
1296
+ (self.accelerator.process_index + 1) * len(prompts),
1297
+ )
1298
+ all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
1299
+ advantages = advantages[process_slice]
1300
+
1301
+ # Calculate and log the mean KL divergence between current and reference model
1302
+ if self.beta != 0.0:
1303
+ mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
1304
+ self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item())
1305
+
1306
+ # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
1307
+ for i, reward_func_name in enumerate(self.reward_func_names):
1308
+ mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
1309
+ self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
1310
+ std_func_rewards = nanstd(rewards_per_func[:, i]).item()
1311
+ self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards)
1312
+ rewards = (rewards_per_func * self.reward_weights.to(rewards_per_func.device).unsqueeze(0)).nansum(dim=1)
1313
+ self._metrics[mode]["reward"].append(rewards.mean().item())
1314
+ self._metrics[mode]["reward_std"].append(rewards.std().item())
1315
+ self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())
1316
+
1317
+ # Log prompt and completion texts
1318
+ self._logs["prompt"].extend(gather_object(prompts_text))
1319
+ self._logs["completion"].extend(gather_object(completions_text))
1320
+ for i, name in enumerate(self.reward_func_names):
1321
+ self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
1322
+ self._logs["advantages"].extend(all_process_advantages.tolist())
1323
+
1324
+ # Flush user-logged extra columns (from log_extra), gathering across processes.
1325
+ # Keys must be sorted so that all ranks call gather_object in the same order, otherwise values
1326
+ # get mis-attributed across columns (dict insertion order may differ between processes).
1327
+ for column in sorted(self._pending_extra_logs):
1328
+ self._logs["extra"][column].extend(gather_object(self._pending_extra_logs[column]))
1329
+ self._pending_extra_logs.clear()
1330
+
1331
+ # Flush user-logged metrics (from log_metric), averaging across processes.
1332
+ # Keys must be sorted so that all ranks call accelerator.gather in the same order, otherwise values
1333
+ # get mis-attributed across metrics (dict insertion order may differ between processes).
1334
+ for name in sorted(self._pending_metrics):
1335
+ values = self._pending_metrics[name]
1336
+ local_mean = sum(values) / len(values)
1337
+ global_mean = self.accelerator.gather(torch.tensor(local_mean, device=device)).mean().item()
1338
+ self._metrics[mode][name].append(global_mean)
1339
+ self._pending_metrics.clear()
1340
+
1341
+ if images is not None:
1342
+ self._logs["images"].extend(gather_object(images))
1343
+
1344
+ output = {
1345
+ "prompt_ids": prompt_ids,
1346
+ "prompt_mask": prompt_mask,
1347
+ "completion_ids": completion_ids,
1348
+ "completion_mask": completion_mask,
1349
+ "old_logps": old_logps,
1350
+ "advantages": advantages,
1351
+ }
1352
+ if "pixel_values" in forward_kwargs:
1353
+ output["pixel_values"] = forward_kwargs["pixel_values"]
1354
+ if "image_grid_thw" in forward_kwargs:
1355
+ output["image_grid_thw"] = forward_kwargs["image_grid_thw"]
1356
+ if "pixel_attention_mask" in forward_kwargs:
1357
+ output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"]
1358
+ if "image_sizes" in forward_kwargs:
1359
+ output["image_sizes"] = forward_kwargs["image_sizes"]
1360
+ if "token_type_ids" in forward_kwargs:
1361
+ output["token_type_ids"] = forward_kwargs["token_type_ids"]
1362
+ if "mm_token_type_ids" in forward_kwargs:
1363
+ output["mm_token_type_ids"] = forward_kwargs["mm_token_type_ids"]
1364
+ if "pixel_position_ids" in forward_kwargs:
1365
+ output["pixel_position_ids"] = forward_kwargs["pixel_position_ids"]
1366
+ if images is not None:
1367
+ output["num_images"] = num_images
1368
+ return output
1369
+
1370
+ @profiling_decorator
1371
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
1372
+ if return_outputs:
1373
+ raise ValueError("The RLOOTrainer does not support returning outputs")
1374
+ return self._compute_loss(model, inputs)
1375
+
1376
+ def _compute_loss(self, model, inputs):
1377
+ # Compute the per-token log probabilities for the model
1378
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
1379
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
1380
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
1381
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
1382
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
1383
+
1384
+ # Compute the per_token_logps and the entropy at each position in the completion
1385
+ per_token_logps, entropies = self._get_per_token_logps_and_entropies(
1386
+ model,
1387
+ input_ids,
1388
+ attention_mask,
1389
+ logits_to_keep,
1390
+ compute_entropy=True,
1391
+ pixel_values=inputs.get("pixel_values"),
1392
+ image_grid_thw=inputs.get("image_grid_thw"),
1393
+ num_images=inputs.get("num_images"),
1394
+ pixel_attention_mask=inputs.get("pixel_attention_mask"),
1395
+ image_sizes=inputs.get("image_sizes"),
1396
+ token_type_ids=inputs.get("token_type_ids"),
1397
+ mm_token_type_ids=inputs.get("mm_token_type_ids"),
1398
+ pixel_position_ids=inputs.get("pixel_position_ids"),
1399
+ )
1400
+
1401
+ logps = (per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS
1402
+ old_logps = inputs["old_logps"]
1403
+ log_ratio = logps - old_logps
1404
+
1405
+ # Compute the loss
1406
+ advantages = inputs["advantages"]
1407
+ coef_1 = torch.exp(log_ratio)
1408
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
1409
+ per_sequence_loss1 = coef_1 * advantages
1410
+ per_sequence_loss2 = coef_2 * advantages
1411
+ per_sequence_loss = -torch.min(per_sequence_loss1, per_sequence_loss2)
1412
+ loss = per_sequence_loss.mean()
1413
+
1414
+ # Log the metrics
1415
+ mode = "train" if self.model.training else "eval"
1416
+
1417
+ # Entropy
1418
+ mean_entropy = (entropies * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
1419
+ self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())
1420
+
1421
+ # Compute the clipped probability ratios
1422
+ is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0)
1423
+ is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0)
1424
+ is_region_clipped = is_low_clipped | is_high_clipped
1425
+ gathered_low_clip = self.accelerator.gather(is_low_clipped.float().mean())
1426
+ self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
1427
+ self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
1428
+ gathered_high_clip = self.accelerator.gather(is_high_clipped.float().mean())
1429
+ self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
1430
+ self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
1431
+ gathered_clip_ratio = self.accelerator.gather(is_region_clipped.float().mean())
1432
+ self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
1433
+ return loss
1434
+
1435
+ # During eval, Trainer calls prediction_step. If no labels are present in the inputs, it only runs forward and
1436
+ # returns logits. We override prediction_step to force compute_loss, because this trainer doesn't involve labels.
1437
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None):
1438
+ inputs = self._prepare_inputs(inputs)
1439
+ with torch.no_grad():
1440
+ with self.compute_loss_context_manager():
1441
+ loss = self.compute_loss(model, inputs)
1442
+ loss = loss.mean().detach()
1443
+ return loss, None, None
1444
+
1445
+ def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
1446
+ mode = "train" if self.model.training else "eval"
1447
+ # Average the metrics
1448
+ metrics = {}
1449
+ for key, val in self._metrics[mode].items():
1450
+ # Filter out NaN values before averaging. A reward function that returns None for all samples
1451
+ # in a batch produces NaN for that batch's metric. With logging_steps > 1, a naive sum()/len()
1452
+ # would let a single NaN contaminate valid data from other batches. Only return None when no
1453
+ # valid values remain (e.g. JSON loggers crash on float NaN).
1454
+ valid = [v for v in val if not math.isnan(v)]
1455
+ metrics[key] = sum(valid) / len(valid) if valid else None
1456
+
1457
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
1458
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
1459
+ if mode == "eval":
1460
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
1461
+
1462
+ logs = {**logs, **metrics}
1463
+ super().log(logs, start_time)
1464
+ self._metrics[mode].clear()
1465
+
1466
+ if self.accelerator.is_main_process and self.log_completions:
1467
+ if is_rich_available():
1468
+ print_prompt_completions_sample(
1469
+ self._logs["prompt"],
1470
+ self._logs["completion"],
1471
+ self._logs["rewards"],
1472
+ self._logs["advantages"],
1473
+ self.state.global_step,
1474
+ self.num_completions_to_print,
1475
+ )
1476
+
1477
+ logging_backends = []
1478
+ if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
1479
+ logging_backends.append(wandb)
1480
+ if self.args.report_to and "trackio" in self.args.report_to:
1481
+ logging_backends.append(trackio)
1482
+
1483
+ table = {
1484
+ "step": [self.state.global_step] * len(self._logs["prompt"]),
1485
+ "prompt": self._logs["prompt"],
1486
+ "completion": self._logs["completion"],
1487
+ **self._logs["rewards"],
1488
+ **self._logs["extra"],
1489
+ "advantage": self._logs["advantages"],
1490
+ }
1491
+
1492
+ df_base = pd.DataFrame(table)
1493
+ images_raw = self._logs["images"] or []
1494
+
1495
+ for logging_backend in logging_backends:
1496
+ if images_raw:
1497
+ images = []
1498
+ for image_list in self._logs["images"]:
1499
+ images.append([logging_backend.Image(image) for image in image_list])
1500
+ df = pd.concat(
1501
+ [df_base, pd.Series(images, name="image")],
1502
+ axis=1,
1503
+ copy=False,
1504
+ )
1505
+ else:
1506
+ df = df_base
1507
+
1508
+ if self.log_unique_prompts:
1509
+ df = df.drop_duplicates(subset=["prompt"])
1510
+
1511
+ logging_backend.log({"completions": logging_backend.Table(dataframe=df)})
1512
+
1513
+ # Ensure the model card is saved along with the checkpoint
1514
+ def _save_checkpoint(self, model, trial):
1515
+ if self.args.hub_model_id is None:
1516
+ model_name = Path(self.args.output_dir).name
1517
+ else:
1518
+ model_name = self.args.hub_model_id.split("/")[-1]
1519
+ self.create_model_card(model_name=model_name)
1520
+ super()._save_checkpoint(model, trial)