diff --git a/ICL/DAPO/verl-recipe/char_count/README.md b/ICL/DAPO/verl-recipe/char_count/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f86fa195484597991d1a6db949599e857abe2c8a --- /dev/null +++ b/ICL/DAPO/verl-recipe/char_count/README.md @@ -0,0 +1,59 @@ +# Char Count +## Introduction +Char count is a simple NLP task. We create it for beginners to grasp the idea of RLVR. The task can be trained using a tiny model (e.g., https://huggingface.co/HuggingFaceTB/SmolLM2-135M) on a consumer GPU with only 8GB. + +## Problem formulation +The prompt is: "How many {char} are there in {word}?". In order for LLM to better answer this question, we create SFT dataset with intermediate steps. For example, + +```text +Question: How many n are there in n-i-n-e? +Answer: +n = n +i != n +n = n +e != n +\boxed{2} +``` + +Note that +- We add a dash between each individual char to make the task easier because each individual char will be tokenized to the same token by most tokenizer. +- In the SFT dataset, we create a CoT by listing all the individual chars and whether it equals to the target. In the end, it outputs the final answer inside the box. +- The task can be verified. +- The word is not always meaningful. Each char is sampled uniformly from a to z. We make the total length and the answer uniformly distributed within a range. + +## Scripts +Installation + +```bash +pip install verl==0.6.1 +``` + + +To create the dataset, run +```bash +python3 create_dataset.py +``` +We create a train set and a val set. Both of them are used of SFT and RL. You can specify the total number of data, min/max length and data path. + +To run the SFT +```bash +BACKEND=fsdp bash train_sft.sh # use fsdp +BACKEND=megatron bash train_sft.sh # use megatron +``` +We train SFT for 1 epoch. After 1 epoch, the validation score is around 0.435. + +Merge checkpoint trained from SFT +```bash +# sft +export CKPT_PATH=$HOME/experiments/char_count/models/sft/fsdp/global_step_140 +python3 -m verl.model_merger merge --backend fsdp --local_dir $CKPT_PATH --target_dir $CKPT_PATH/huggingface/ +# megatron +export CKPT_PATH=$HOME/experiments/char_count/models/sft/megatron/global_step_140 +python3 -m verl.model_merger merge --backend megatron --local_dir $CKPT_PATH --target_dir $CKPT_PATH/huggingface/ +``` + +To run GRPO +```bash +bash train_grpo.sh +``` +We train GRPO for 2 epochs. After 2 epochs, the validation score is around 0.6. diff --git a/ICL/DAPO/verl-recipe/char_count/create_dataset.py b/ICL/DAPO/verl-recipe/char_count/create_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..754dacde6034256ca415c61713841b5120bc9640 --- /dev/null +++ b/ICL/DAPO/verl-recipe/char_count/create_dataset.py @@ -0,0 +1,198 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Task description: +Given a random word and a random char, count the number of occurrence of char in the word. + +Create CoT dataset that split the word into separate char. Then list the char and count the occurrence. + +The word set comes from shakespeare +""" + +import os.path +import random + +prompt_template = "How many {} are there in word {}?" + + +def generate_random_char(): + return chr(97 + random.randint(0, 25)) + + +def create_prompt_response(min_length=3, max_length=5): + # randomly generate a length + word_length = random.randint(min_length, max_length) + # randomly generate a target count number. This makes the target number + target_count_number = random.randint(1, word_length) + + char_lst = [] + # generate the word + # step 1: generate the target word + target_char = generate_random_char() + + for _ in range(target_count_number): + char_lst.append(target_char) + + # step 2: generate other words + for _ in range(word_length - target_count_number): + while True: + char = generate_random_char() + if char != target_char: + char_lst.append(char) + break + + # step 3: random permute char_lst + random.shuffle(char_lst) + + word = "-".join(char_lst) + + prompt = prompt_template.format(target_char, word) + final_answer = [] + + # cot + number = 0 + for i, char in enumerate(char_lst): + cot = f"{char}" + if char != target_char: + cot += " != " + else: + cot += " = " + number += 1 + cot += f"{target_char}." + + final_answer.append(cot) + + conclusion = f"\\boxed{{{number}}} {target_char} in {word}." + + final_answer.append(conclusion) + + final_answer = "\n".join(final_answer) + + return prompt, final_answer + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--total_number", type=int, default=10000) + parser.add_argument("--min_length", type=int, default=5) + parser.add_argument("--max_length", type=int, default=20) + parser.add_argument("--data_path", type=str, default="~/data/char_count") + + args = vars(parser.parse_args()) + + total_number = args["total_number"] + min_length = args["min_length"] + max_length = args["max_length"] + data_path = args["data_path"] + data_path = os.path.expanduser(data_path) + + full_output = [] + for _ in range(total_number): + output = create_prompt_response(min_length=min_length, max_length=max_length) + full_output.append(output) + + # random reorder + random.shuffle(full_output) + + # split for train and test + train_split_len = int(0.9 * len(full_output)) + train_outputs = full_output[:train_split_len] + test_output = full_output[train_split_len:] + + sft_train_dataset = {"messages": []} + + for o in train_outputs: + messages = [ + {"role": "user", "content": o[0]}, + {"role": "assistant", "content": o[1]}, + ] + + sft_train_dataset["messages"].append(messages) + + sft_test_dataset = {"messages": []} + + for o in test_output: + messages = [ + {"role": "user", "content": o[0]}, + {"role": "assistant", "content": o[1]}, + ] + sft_test_dataset["messages"].append(messages) + + import pandas as pd + + sft_train_dataset = pd.DataFrame(data=sft_train_dataset) + sft_test_dataset = pd.DataFrame(data=sft_test_dataset) + + folder = os.path.join(data_path, "sft") + + os.makedirs(folder, exist_ok=True) + + sft_train_dataset.to_parquet(os.path.join(folder, "train.parquet")) + sft_test_dataset.to_parquet(os.path.join(folder, "test.parquet")) + + # build RL dataset + rl_train_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []} + + rl_test_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []} + + from verl.utils.reward_score.math_reward import last_boxed_only_string, remove_boxed + + for o in train_outputs: + prompt = o[0] + response = o[1] + prompt_with_template = [ + { + "role": "user", + "content": prompt, + } + ] + + rl_train_dataset["prompt"].append(prompt_with_template) + rl_train_dataset["data_source"].append("char_count") + rl_train_dataset["ability"].append("other") + rl_train_dataset["reward_model"].append( + {"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))} + ) + rl_train_dataset["extra_info"].append({"response": response}) + + for o in test_output: + prompt = o[0] + response = o[1] + prompt_with_template = [ + { + "role": "user", + "content": prompt, + } + ] + + rl_test_dataset["prompt"].append(prompt_with_template) + rl_test_dataset["data_source"].append("char_count") + rl_test_dataset["ability"].append("other") + rl_test_dataset["reward_model"].append( + {"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))} + ) + rl_test_dataset["extra_info"].append({"response": response}) + + rl_train_dataset = pd.DataFrame(data=rl_train_dataset) + rl_test_dataset = pd.DataFrame(data=rl_test_dataset) + + folder = os.path.join(data_path, "rl") + + os.makedirs(folder, exist_ok=True) + + rl_train_dataset.to_parquet(os.path.join(folder, "train.parquet")) + rl_test_dataset.to_parquet(os.path.join(folder, "test.parquet")) diff --git a/ICL/DAPO/verl-recipe/char_count/reward_function.py b/ICL/DAPO/verl-recipe/char_count/reward_function.py new file mode 100644 index 0000000000000000000000000000000000000000..7c87ea49a1b105a4e1035f5ee07b9eb19384f38a --- /dev/null +++ b/ICL/DAPO/verl-recipe/char_count/reward_function.py @@ -0,0 +1,34 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Reward function +""" + +from verl.utils.reward_score import math_reward + + +def char_count_reward_function(data_source, solution_str, ground_truth, extra_info=None): + try: + last_boxed_string = math_reward.last_boxed_only_string(solution_str) + if last_boxed_string is None: + return 0 + solution = math_reward.remove_boxed(last_boxed_string) + if solution == ground_truth: + return 1 + else: + return 0 + except Exception: + print(ground_truth, solution_str) + return 0 diff --git a/ICL/DAPO/verl-recipe/char_count/train_grpo.sh b/ICL/DAPO/verl-recipe/char_count/train_grpo.sh new file mode 100644 index 0000000000000000000000000000000000000000..e1aa43e25b276f24bebc420ed663327e5d1976ba --- /dev/null +++ b/ICL/DAPO/verl-recipe/char_count/train_grpo.sh @@ -0,0 +1,45 @@ +set -x + + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/char_count/rl/train.parquet \ + data.val_files=$HOME/data/char_count/rl/test.parquet \ + data.train_batch_size=128 \ + data.max_prompt_length=128 \ + data.max_response_length=128 \ + data.filter_overlong_prompts=False \ + data.truncation='error' \ + actor_rollout_ref.model.path=$HOME/experiments/char_count/models/sft/megatron/global_step_140/huggingface \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5000 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","tensorboard"]' \ + trainer.project_name='verl_example' \ + trainer.experiment_name='smol135m_grpo-1128a1' \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.use_legacy_worker_impl=disable \ + custom_reward_function.path=./reward_function.py \ + custom_reward_function.name=char_count_reward_function diff --git a/ICL/DAPO/verl-recipe/char_count/train_sft.sh b/ICL/DAPO/verl-recipe/char_count/train_sft.sh new file mode 100644 index 0000000000000000000000000000000000000000..4aacea828d4f6c6357374d49a1d68cf87c678024 --- /dev/null +++ b/ICL/DAPO/verl-recipe/char_count/train_sft.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"} + +TRAIN_FILES=${TRAIN_FILES:-$HOME/data/char_count/sft/train.parquet} +TEST_FILES=${TEST_FILES:-$HOME/data/char_count/sft/test.parquet} + +backend=${BACKEND:-fsdp} + +project_name=char_count-sft + +RESUME_MODE=auto +MODEL_ID=${MODEL_ID:-HuggingFaceTB/SmolLM2-135M-Instruct} + +SP_SIZE=${SP_SIZE:-1} +FSDP_SIZE=${FSDP_SIZE:-1} +FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp2"} + +TP_SIZE=${TP_SIZE:-1} +PP_SIZE=${PP_SIZE:-1} +VPP_SIZE=${VPP_SIZE:-null} +CP_SIZE=${CP_SIZE:-1} + +PAD_MODE=${PAD_MODE:-no_padding} + +USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True} + +FSDP_ENGINE_CONFIG="\ + engine=${backend} \ + optim=${backend} \ + optim.lr=2e-5 \ + optim.lr_warmup_steps_ratio=0.01 \ + optim.weight_decay=0.1 \ + optim.betas="[0.9,0.95]" \ + optim.clip_grad=1.0 \ + optim.min_lr_ratio=0.1 \ + optim.warmup_style=cosine \ + engine.ulysses_sequence_parallel_size=${SP_SIZE} \ + engine.strategy=${FSDP_STRATEGY} \ + engine.fsdp_size=${FSDP_SIZE}" + + +MEGATRON_ENGINE_CONFIG="\ + engine=${backend} \ + optim=${backend} \ + optim.lr=2e-5 \ + optim.lr_warmup_steps_ratio=0.01 \ + optim.weight_decay=0.1 \ + optim.betas="[0.9,0.95]" \ + optim.clip_grad=1.0 \ + optim.lr_warmup_init=0 \ + optim.lr_decay_style=cosine \ + optim.min_lr=2e-6 \ + engine.tensor_model_parallel_size=${TP_SIZE} \ + engine.pipeline_model_parallel_size=${PP_SIZE} \ + engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \ + engine.context_parallel_size=${CP_SIZE} \ + engine.use_mbridge=False" + +if [ "$backend" = "fsdp" ]; then + ENGINE_CONFIG="$FSDP_ENGINE_CONFIG" + echo "Using fsdp engine" + exp_name=char_count-sft-SmolLM2-135M-Instruct-fsdp +else + ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG" + echo "Using megatron engine" + exp_name=char_count-sft-SmolLM2-135M-Instruct-megatron +fi + +CKPT_HOME=${CKPT_HOME:-$HOME/experiments/char_count/models/sft/$backend} +mkdir -p "${CKPT_HOME}" + +torchrun --standalone --nnodes=1 --nproc-per-node=${NUM_TRAINERS:-1} \ + ${ENTRYPOINT} \ + data.train_files="${TRAIN_FILES}" \ + data.train_batch_size=64 \ + data.val_files="${TEST_FILES}" \ + data.max_length=256 \ + data.pad_mode=${PAD_MODE} \ + data.truncation=error \ + data.use_dynamic_bsz=True \ + data.max_token_len_per_gpu=1792 \ + data.messages_key=messages \ + model.path=$MODEL_ID \ + model.use_remove_padding=${USE_REMOVE_PADDING} \ + ${ENGINE_CONFIG} \ + trainer.test_freq=-1 \ + trainer.save_freq=70 \ + trainer.logger=['console'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPT_HOME}" \ + trainer.resume_mode=${RESUME_MODE} \ + trainer.max_ckpt_to_keep=5 \ + checkpoint.save_contents=[model,optimizer,extra] \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/collabllm/README.md b/ICL/DAPO/verl-recipe/collabllm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..953b08544cc0c4a83ef7a86e7520aab724fa0f92 --- /dev/null +++ b/ICL/DAPO/verl-recipe/collabllm/README.md @@ -0,0 +1,74 @@ +# CollabLLM + +This repository implements [CollabLLM](https://arxiv.org/pdf/2502.00640) (ICML 2025) using the verl framework. For the original implementation, see the [CollabLLM repository](https://github.com/Wuyxin/collabllm). + + +CollabLLM is a method for training language models to collaborate effectively in multi-turn conversations. This implementation adapts the original imlpementation to work with the Verl training framework. + +## Quick start + +### 0. Environment +Make sure the required packages for `verl` are installed. Additionally, install `litellm` and export the required API keys. The API model will be used for user simulators and, optionally, LLM Judges (see the Configuration section below). + +### 1. Prepare Your Dataset + +First, process your dataset using the provided script: + +```bash +python process_dataset.py --dataset <> ... --dataset_type +``` + + +**Requirements:** +- Input: A Hugging Face multiturn dataset. Existing datasets: `collabllm/collabllm-multiturn-$DATASET`, with `DATASET` in one of [`math-hard(-large)`, `medium(-large)`, `bigcodebench(-large)`] (*-large are the datasets used in the CollabLLM paper) +- Example format: See [collabllm-multiturn-math-hard](https://huggingface.co/datasets/collabllm/collabllm-multiturn-math-hard) +- To generate your own dataset: Use [build_dataset.py](https://github.com/Wuyxin/collabllm/blob/main/scripts/engine/build_dataset.py) from the original CollabLLM repository + +*Note: Check `process_dataset.py` for example commands and usage.* + +### 2. Train Your Model + +**(Optional) For Supervised Fine-Tuning (SFT):** +```bash +bash train_sft_collabllm.sh +``` + +**For Reinforcement Learning (RL):** + +```bash +bash train_rl_collabllm.sh +``` + +The RL script shows an example to train CollabLLM on `math-hard-large`. + +- The config to sample future conversations are in `recipe/collabllm/config/collabllm_interaction_config.yaml`. +- The Multiturn-aware Reward is aggregated from these three conversational-level rewards: + + ``` + +reward_model.reward_kwargs.metric_weights.accuracy=1 \ + +reward_model.reward_kwargs.metric_weights.interactivity=1 \ + +reward_model.reward_kwargs.metric_weights.token_amount=-0.0001 \ + ``` + + You can remove, add, or modify the weights depending on your task. A list of implemented metrics you can already add are under `recipe/collabllm/metrics`. For example, on `medium-large`, you can replace `accuracy` with `bleu_score` via + ``` + +reward_model.reward_kwargs.metric_weights.bleu_score=1 + ``` + which will instead apply bleu score on the sampled future conversations. + +## Configuration +Read [doc](https://verl.readthedocs.io/en/latest/) for detailed configurations. + +## Citation +If you find CollabLLM useful in your research, please cite the following: + +```bibtex +@inproceedings{collabllm2025, + title={CollabLLM: From Passive Responders to Active Collaborators}, + author={Shirley Wu and Michel Galley and Baolin Peng and Hao Cheng and + Gavin Li and Yao Dou and Weixin Cai and James Zou and + Jure Leskovec and Jianfeng Gao}, + booktitle={International Conference on Machine Learning (ICML)}, + year={2025} +} +``` diff --git a/ICL/DAPO/verl-recipe/collabllm/utils.py b/ICL/DAPO/verl-recipe/collabllm/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..588b2d006528dd117f98641cd18553b1f62b110e --- /dev/null +++ b/ICL/DAPO/verl-recipe/collabllm/utils.py @@ -0,0 +1,280 @@ +# Copyright 2025 CollabLLM team and/or its affiliates +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import re + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def parse_messages(messages, strip_sys_prompt=True): + """ + Args: + messages: List[dict] + List of dictionaries with keys 'role' and 'content' + Example: messages = [{'role': 'user', 'content': 'Hello!'}, + {'role': 'assistant', 'content': 'Hi!'}, ...] + """ + if messages is None: + return "" + + if strip_sys_prompt: + messages = strip_system_prompt(messages) + + chat = "\n".join(f"**{m.role.capitalize()}**: {m.content}" for m in messages) + + return chat + + +def strip_system_prompt(messages): + """ + Args: + messages: List[dict] + List of dictionaries with keys 'role' and 'content' + Example: messages = [{'role': 'user', 'content': 'Hello!'}, + {'role': 'assistant', 'content': 'Hi!'}, ...] + """ + return [msg for msg in messages if msg.role != "system"] + + +def extract_json(s): + def convert_value(value): + true_values = {"true": True, "false": False, "null": None} + value_lower = value.lower() + if value_lower in true_values: + return true_values[value_lower] + try: + if "." in value or "e" in value.lower(): + return float(value) + else: + return int(value) + except ValueError: + return value # Return as string if not a number + + def parse_number(s, pos): + start = pos + while pos < len(s) and s[pos] in "-+0123456789.eE": + pos += 1 + num_str = s[start:pos] + try: + if "." in num_str or "e" in num_str.lower(): + return float(num_str), pos + else: + return int(num_str), pos + except ValueError: + logger.error(f"Invalid number at position {start}: {num_str}") + raise + + def skip_whitespace(s, pos): + while pos < len(s) and s[pos] in " \t\n\r": + pos += 1 + return pos + + def parse_string(s, pos): + quote_char = s[pos] + assert quote_char in ('"', "'") + pos += 1 + result = "" + while pos < len(s): + c = s[pos] + if c == "\\": + pos += 1 + if pos >= len(s): + raise ValueError("Invalid escape sequence") + c = s[pos] + escape_sequences = {"n": "\n", "t": "\t", "r": "\r", "\\": "\\", quote_char: quote_char} + result += escape_sequences.get(c, c) + elif c == quote_char: + pos += 1 + # Attempt to convert to a number if possible + converted_value = convert_value(result) + return converted_value, pos + else: + result += c + pos += 1 + raise ValueError("Unterminated string") + + def parse_key(s, pos): + pos = skip_whitespace(s, pos) + if s[pos] in ('"', "'"): + key, pos = parse_string(s, pos) + return key, pos + else: + raise ValueError(f"Expected string for key at position {pos}") + + def parse_object(s, pos): + obj = {} + assert s[pos] == "{" + pos += 1 + pos = skip_whitespace(s, pos) + while pos < len(s) and s[pos] != "}": + pos = skip_whitespace(s, pos) + key, pos = parse_key(s, pos) + pos = skip_whitespace(s, pos) + if pos >= len(s) or s[pos] != ":": + raise ValueError(f'Expected ":" at position {pos}') + pos += 1 + pos = skip_whitespace(s, pos) + value, pos = parse_value(s, pos) + obj[key] = value + pos = skip_whitespace(s, pos) + if pos < len(s) and s[pos] == ",": + pos += 1 + pos = skip_whitespace(s, pos) + elif pos < len(s) and s[pos] == "}": + break + elif pos < len(s) and s[pos] != "}": + raise ValueError(f'Expected "," or "}}" at position {pos}') + if pos >= len(s) or s[pos] != "}": + raise ValueError(f'Expected "}}" at position {pos}') + pos += 1 + return obj, pos + + def parse_array(s, pos): + lst = [] + assert s[pos] == "[" + pos += 1 + pos = skip_whitespace(s, pos) + while pos < len(s) and s[pos] != "]": + value, pos = parse_value(s, pos) + lst.append(value) + pos = skip_whitespace(s, pos) + if pos < len(s) and s[pos] == ",": + pos += 1 + pos = skip_whitespace(s, pos) + elif pos < len(s) and s[pos] == "]": + break + elif pos < len(s) and s[pos] != "]": + raise ValueError(f'Expected "," or "]" at position {pos}') + if pos >= len(s) or s[pos] != "]": + raise ValueError(f'Expected "]" at position {pos}') + pos += 1 + return lst, pos + + def parse_triple_quoted_string(s, pos): + if s[pos : pos + 3] == "'''": + quote_str = "'''" + elif s[pos : pos + 3] == '"""': + quote_str = '"""' + else: + raise ValueError(f"Expected triple quotes at position {pos}") + pos += 3 + result = "" + while pos < len(s): + if s[pos : pos + 3] == quote_str: + pos += 3 + # Attempt to convert to a number if possible + converted_value = convert_value(result) + return converted_value, pos + else: + result += s[pos] + pos += 1 + raise ValueError("Unterminated triple-quoted string") + + def parse_value(s, pos): + pos = skip_whitespace(s, pos) + if pos >= len(s): + raise ValueError("Unexpected end of input") + if s[pos] == "{": + return parse_object(s, pos) + elif s[pos] == "[": + return parse_array(s, pos) + elif s[pos : pos + 3] in ("'''", '"""'): + return parse_triple_quoted_string(s, pos) + elif s[pos] in ('"', "'"): + return parse_string(s, pos) + elif s[pos : pos + 4].lower() == "true": + return True, pos + 4 + elif s[pos : pos + 5].lower() == "false": + return False, pos + 5 + elif s[pos : pos + 4].lower() == "null": + return None, pos + 4 + elif s[pos] in "-+0123456789.": + return parse_number(s, pos) + else: + raise ValueError(f"Unexpected character at position {pos}: {s[pos]}") + + json_start = s.index("{") + json_end = s.rfind("}") + s = s[json_start : json_end + 1] + + s = s.strip() + result, pos = parse_value(s, 0) + pos = skip_whitespace(s, pos) + if pos != len(s): + raise ValueError(f"Unexpected content at position {pos}") + return result + + +def remove_think_block(msg: dict): + """ + remove .*? from content + """ + if "content" in msg and isinstance(msg["content"], str): + msg["content"] = re.sub(r".*?", "", msg["content"], flags=re.DOTALL).strip() + return msg + + +def is_valid_messages(msg: dict) -> bool: + """ + check if is valid messages, including: + 1. is paried with + 2. is not empty inside and outside + 3. is not nested, and at most one block is allowed. + 4. can not be empty if remove ending "<|im_end|>" + """ + content = msg.get("content") + if not isinstance(content, str): + return True + + # Base case: empty or whitespace-only content is invalid. + if not content.strip(): + return False + + num_think_open = content.count("") + num_think_close = content.count("") + + # Rule 1: Check for paired tags. + if num_think_open != num_think_close: + return False + + # Rule 3: Allow at most one think block. + if num_think_open > 1: + return False + + # Case 1: No blocks. + if num_think_open == 0: + visible_content = content + # Case 2: Exactly one block. + else: + # Rule 2: Check for empty content inside the think block. + match = re.search(r"(.*?)", content, re.DOTALL) + if not match or not match.group(1).strip(): + return False + + # The "visible" content is what's outside the think block. + visible_content = re.sub(r".*?", "", content, flags=re.DOTALL) + + visible_content = visible_content.strip() + + # Rule 4 & 2 (outside): Check if visible content is empty after handling <|im_end|>. + if visible_content.endswith("<|im_end|>"): + visible_content = visible_content[: -len("<|im_end|>")] + + if not visible_content.strip(): + return False + + return True diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_8b_base_npu.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_8b_base_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..9c34fa7e06ca12c7d1b355b8bfd8608435a9c254 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_8b_base_npu.sh @@ -0,0 +1,138 @@ +#!/bin/bash +project_name='DAPO' +exp_name='DAPO-Qwen3-8B-Base' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=16 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=1 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-1} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-8B-Base"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Performance Related Parameter +sp_size=2 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +offload=True +gen_tp=2 + +ray job submit --runtime-env="${RUNTIME_ENV}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.90 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=20 \ + trainer.total_epochs=1 \ + trainer.total_training_steps=100 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + data.shuffle=False \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.entropy_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True diff --git a/ICL/DAPO/verl-recipe/deepeyes/deepeyes.py b/ICL/DAPO/verl-recipe/deepeyes/deepeyes.py new file mode 100644 index 0000000000000000000000000000000000000000..bedcd79471985ee8f0e0c562821043501fa93f10 --- /dev/null +++ b/ICL/DAPO/verl-recipe/deepeyes/deepeyes.py @@ -0,0 +1,408 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import logging +import os +import random +import re + +import requests +from openai import OpenAI +from PIL import Image + +import verl.utils.torch_functional as verl_F +from verl.utils.dataset.rl_dataset import RLHFDataset +from verl.utils.model import compute_position_id_with_mask + +logger = logging.getLogger(__name__) + +openai_api_key = "EMPTY" +openai_api_base = os.environ.get("LLM_AS_A_JUDGE_BASE", "http://10.1.100.71:18901/v1") + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +model_name = "" +if openai_api_base: + try: + response = requests.get(f"{openai_api_base}/models") + response.raise_for_status() + models = response.json() + if models.get("data"): + model_name = models["data"][0]["id"] + else: + logger.warning("No models found at the specified API base for reward scoring.") + except (requests.exceptions.RequestException, KeyError, IndexError) as e: + logger.warning(f"Failed to get model from {openai_api_base}: {e}. Reward scoring will be disabled.") + + +class CustomRLHFDataset(RLHFDataset): + def __getitem__(self, item): + """ + Note that we also return the raw_input_ids so that it can be combined with other chat template + """ + row_dict: dict = self.dataframe[item] + row_dict[self.prompt_key] = [ + { + "role": "system", + # We don't need tool description, because custom_chat_template will add it. + "content": ( + "You are a helpful assistant. You can call functions to assist with the user query. " + "Important: You must call only one function at a time. After each function call, " + "wait for the execution result before making the next function call if needed." + ), + }, + { + "role": "user", + "content": row_dict[self.prompt_key][1]["content"], + }, + ] + + images = [] + row_dict_images = row_dict.get(self.image_key, None) + if row_dict_images: + images = [Image.open(io.BytesIO(image["bytes"])) for image in row_dict_images] + messages = self._build_messages(row_dict) + + if self.processor is not None: + raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + model_inputs = self.processor(text=[raw_prompt], images=images, return_tensors="pt") + + input_ids = model_inputs.pop("input_ids") + attention_mask = model_inputs.pop("attention_mask") + + if "second_per_grid_ts" in model_inputs: + model_inputs.pop("second_per_grid_ts") + + else: + raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False) + input_ids = model_inputs.pop("input_ids") + attention_mask = model_inputs.pop("attention_mask") + + input_ids, attention_mask = verl_F.postprocess_data( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=self.max_prompt_length, + pad_token_id=self.tokenizer.pad_token_id, + left_pad=True, + truncation=self.truncation, + ) + + if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__: + from verl.models.transformers.qwen2_vl import get_rope_index + + position_ids = [ + get_rope_index( + self.processor, + input_ids=input_ids[0], + image_grid_thw=model_inputs.get("image_grid_thw"), + video_grid_thw=model_inputs.get("video_grid_thw"), + second_per_grid_ts=model_inputs.get("second_per_grid_ts"), + attention_mask=attention_mask[0], + ) + ] # (1, 3, seq_len) + + else: + position_ids = compute_position_id_with_mask(attention_mask) + + row_dict["input_ids"] = input_ids[0] + row_dict["attention_mask"] = attention_mask[0] + row_dict["position_ids"] = position_ids[0] + + raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False) + if len(raw_prompt_ids) > self.max_prompt_length: + if self.truncation == "left": + raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :] + elif self.truncation == "right": + raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length] + elif self.truncation == "middle": + left_half = self.max_prompt_length // 2 + right_half = self.max_prompt_length - left_half + raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:] + elif self.truncation == "error": + raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.") + + row_dict["raw_prompt_ids"] = raw_prompt_ids + # encode prompts without chat template + if self.return_raw_chat: + row_dict["raw_prompt"] = messages + + # get prompts with chat template + if self.return_full_prompt: + row_dict["full_prompts"] = raw_prompt # array of strings + + # add index for each prompt + index = row_dict.get("extra_info", {}).get("index", 0) + tools_kwargs = { + "image_zoom_in_tool": { + "create_kwargs": {"image": images[0]}, + # "execute_kwargs": {}, + # "calc_reward_kwargs": {}, + # "release_kwargs": {}, + } + } + row_dict["index"] = index + row_dict["tools_kwargs"] = tools_kwargs + row_dict["agent_name"] = "tool_agent" + return row_dict + + +def compute_score(data_source: str, solution_str: str, ground_truth: str, extra_info=None) -> float: + """ + Compute reward score for model solutions with robust handling of various formats. + + Returns a weighted combination of: + - Accuracy reward (0.8 weight): Whether the answer is semantically correct + - Format reward (0.2 weight): Whether the output follows expected format + - Tool reward (1.2 weight): Whether tools were used when answer is correct + """ + + # Initialize tracking variables + is_format_error = False + + # 1. Check tag format + count_think_1 = solution_str.count("") + count_think_2 = solution_str.count("") + if count_think_1 != count_think_2: + is_format_error = True + + # 2. Check vision tokens (skip this since tokenizer removes special tokens) + # We'll use and instead to detect tool usage + + # 3. Extract answer text with multiple fallback strategies + answer_text = "" + + # Strategy 1: Try to extract from tags first + predict_no_think = ( + solution_str.split("")[-1].strip() if "" in solution_str else solution_str.strip() + ) + + # Check tag format + count_answer_1 = predict_no_think.count("") + count_answer_2 = predict_no_think.count("") + if count_answer_1 != count_answer_2: + is_format_error = True + + # Try to extract from tags + answer_match = re.search(r"(.*?)", predict_no_think, re.DOTALL) + if answer_match: + answer_text = answer_match.group(1).strip() + else: + # No proper tags found - this is a format error + is_format_error = True + + # Strategy 2: If no tags, extract content after tool responses + # Look for pattern: ...assistant\n[actual_answer] + tool_response_match = re.search( + r"\s*assistant\s*\n(.*?)$", predict_no_think, re.DOTALL | re.MULTILINE + ) + if tool_response_match: + answer_text = tool_response_match.group(1).strip() + else: + # Strategy 3: If no tool responses, look for content after + if "" in solution_str: + # Remove any remaining tool-related tags and extract meaningful content + remaining_content = predict_no_think + # Remove tool calls and responses + remaining_content = re.sub(r".*?", "", remaining_content, flags=re.DOTALL) + remaining_content = re.sub( + r".*?", "", remaining_content, flags=re.DOTALL + ) + # Remove user/assistant markers + remaining_content = re.sub(r"\b(user|assistant)\b", "", remaining_content) + answer_text = remaining_content.strip() + else: + # Strategy 4: Use the entire solution_str as fallback + answer_text = solution_str.strip() + + # Clean up answer text + answer_text = answer_text.strip() + + # If answer is still empty after all strategies, mark as format error + if not answer_text: + is_format_error = True + answer_text = solution_str.strip() # Use full text as last resort + + # 4. Evaluate correctness using LLM judge + question_text = extra_info.get("question", "") if extra_info else "" + + if not client or not model_name: + logger.warning("Reward function client not initialized or model name not found.") + return 0.0 + + system_prompt = ( + "You are an expert evaluator. Your task is to determine if a model's answer is semantically equivalent to a " + "provided standard answer, given a specific question.\n" + "Your evaluation must be strict. The model's answer is only correct if it fully matches the meaning of the " + "standard answer.\n" + 'You must provide your final judgement as a single word: either "CORRECT" or "INCORRECT". Do not provide ' + "any explanation or other text." + ) + + user_prompt = ( + f"I will provide a question, a standard answer, and a model's answer. You must evaluate if the model's " + f"answer is correct.\n\n" + f"---\n" + f"**Example 1:**\n" + f"[Question]: Is the countertop tan or blue?\n" + f"[Standard Answer]: The countertop is tan.\n" + f"[Model's Answer]: tan\n" + f"[Your Judgement]: CORRECT\n" + f"---\n" + f"**Example 2:**\n" + f"[Question]: Is the man phone both blue and closed?\n" + f"[Standard Answer]: Yes, the man phone is both blue and closed.\n" + f"[Model's Answer]: No.\n" + f"[Your Judgement]: INCORRECT\n" + f"---\n" + f"**Task:**\n" + f"[Question]: {question_text}\n" + f"[Standard Answer]: {ground_truth}\n" + f"[Model's Answer]: {answer_text}\n" + f"[Your Judgement]:" + ) + + try: + chat_response = client.chat.completions.create( + model=model_name, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + seed=random.randint(0, 1000000), + temperature=0.1, # Lower temperature for more deterministic judgement + extra_body={ + "chat_template_kwargs": {"enable_thinking": False}, + }, + ) + response = chat_response.choices[0].message.content.strip() + except Exception as e: + logger.warning(f" [WARNING] Chat completion request failed: {e}") + return 0.0 + + # Parse LLM judge response + if re.search(r"\bCORRECT\b", response, re.IGNORECASE): + acc_reward = 1.0 + elif re.search(r"\bINCORRECT\b", response, re.IGNORECASE): + acc_reward = 0.0 + else: + logger.warning( + f" [WARNING] Judgement format error. Expected 'CORRECT' or 'INCORRECT'.\n" + f"Response: '{response}'\n" + f"Model Answer: '{answer_text}'\n" + f"Ground Truth: '{ground_truth}'" + ) + acc_reward = 0.0 + + # Penalize excessively long answers (potential judge hacking) + if len(answer_text) >= 1000: + acc_reward = 0.0 + is_format_error = True + + # 5. Check tool usage - look for tool_call/tool_response patterns instead of vision tokens + has_tool_usage = bool( + re.search(r".*?", solution_str, re.DOTALL) + or re.search(r".*?", solution_str, re.DOTALL) + ) + + # Tool reward: only give if tools were used AND answer is correct + tool_reward = 1.0 if has_tool_usage and acc_reward > 0.5 else 0.0 + + # Format reward: penalty for format errors + format_reward = -1.0 if is_format_error else 0.0 + + # Log debug information for problematic cases + if is_format_error or not answer_text: + logger.debug( + f"Format issue detected:\n" + f"Solution: {solution_str[:200]}...\n" + f"Extracted answer: '{answer_text}'\n" + f"Format error: {is_format_error}\n" + f"Tool usage: {has_tool_usage}" + ) + + # Final weighted score + final_score = 0.8 * acc_reward + 0.2 * format_reward + 1.2 * tool_reward + + return final_score + + +if __name__ == "__main__": + # Test case 1: Original test case + predict_str = "The answer is 2 + 2 = 4 right left " + ground_truth = "left" + extra_info = { + "answer": "The woman is to the left of the man who is holding the camera.", + "id": 0, + "image": "/cpfs/user/honglingyi/DATA/LLM/Vstar/gqa/images/713270.jpg", + "pred_ans": "The woman is to the right of the man who is holding the camera.", + "question": "Is the woman to the left or to the right of the man who is holding the camera?", + } + print("=== Test Case 1: Original test ===") + import time + + time_start = time.time() + score = compute_score("common_reasoning", predict_str, ground_truth, extra_info) + print(f"Score: {score}") + time_end = time.time() + print(f"Time: {time_end - time_start}") + + # Test case 2: Problematic case mentioned by user + problematic_solution = """ +{"name": "image_zoom_in_tool", "arguments": {"bbox_2d": [226, 399, 265, 464], "label": "white van"}} +user + +Zoomed in on the image to the region [226, 399, 265, 464] with label white van. + +assistant +The white van is visible in the lower section of the image, near the diagonal road.""" + + problematic_ground_truth = "Yes, the white van is indeed situated in the bottom part of the picture." + problematic_extra_info = { + "question": "Is the white van in the bottom part of the picture?", + } + + print("\n=== Test Case 2: Problematic case (no answer tags) ===") + print(f"Solution: {problematic_solution}") + print(f"Ground truth: {problematic_ground_truth}") + + time_start = time.time() + score2 = compute_score("common_reasoning", problematic_solution, problematic_ground_truth, problematic_extra_info) + print(f"Score: {score2}") + time_end = time.time() + print(f"Time: {time_end - time_start}") + + # Test case 3: Well-formatted case with tools + well_formatted_solution = """ +I need to use the image zoom tool to get a better look at the specific area. + + +{"name": "image_zoom_in_tool", "arguments": {"bbox_2d": [226, 399, 265, 464], "label": "white van"}} + + +Zoomed in on the image to the region [226, 399, 265, 464] with label white van. + +Yes, the white van is indeed situated in the bottom part of the picture.""" + + print("\n=== Test Case 3: Well-formatted case ===") + time_start = time.time() + score3 = compute_score( + "common_reasoning", well_formatted_solution, problematic_ground_truth, problematic_extra_info + ) + print(f"Score: {score3}") + time_end = time.time() + print(f"Time: {time_end - time_start}") diff --git a/ICL/DAPO/verl-recipe/fault_recover/async_llm.py b/ICL/DAPO/verl-recipe/fault_recover/async_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..13a1de9fec98ed753fe99c418a36d569b5beade0 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fault_recover/async_llm.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio + +import numpy as np +from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE +from vllm.utils import cdiv +from vllm.v1.engine.async_llm import AsyncLLM, logger +from vllm.v1.metrics.stats import IterationStats + + +class AsyncFaultRecoverLLM(AsyncLLM): + def _run_output_handler(self): + """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" + + if self.output_handler is not None: + return + + # Ensure that the task doesn't have a circular ref back to the AsyncLLM + # object, or else it won't be garbage collected and cleaned up properly. + engine_core = self.engine_core + output_processor = self.output_processor + log_stats = self.log_stats + logger_manager = self.logger_manager + + async def output_handler(q): + try: + while True: + # 1) Pull EngineCoreOutputs from the EngineCore. + outputs = await engine_core.get_output_async() + + if q is not None: + req_info = {} + for output in outputs.outputs: + req_info[output.request_id] = {} + req_info[output.request_id]["new_token_ids"] = output.new_token_ids + req_info[output.request_id]["finished"] = output.finished + await q.put.remote(req_info) + + num_outputs = len(outputs.outputs) + + iteration_stats = IterationStats() if (log_stats and num_outputs) else None + + # Split outputs into chunks of at most + # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the + # event loop for too long. + if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: + slices = (outputs.outputs,) + else: + slices = np.array_split(outputs.outputs, cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)) + + for i, outputs_slice in enumerate(slices): + # 2) Process EngineCoreOutputs. + processed_outputs = output_processor.process_outputs( + outputs_slice, outputs.timestamp, iteration_stats + ) + # NOTE: RequestOutputs are pushed to their queues. + assert not processed_outputs.request_outputs + + # Allow other asyncio tasks to run between chunks + if i + 1 < len(slices): + await asyncio.sleep(0) + + # 3) Abort any reqs that finished due to stop strings. + await engine_core.abort_requests_async(processed_outputs.reqs_to_abort) + + # 4) Logging. + # TODO(rob): make into a coroutine and launch it in + # background thread once Prometheus overhead is non-trivial. + if logger_manager: + logger_manager.record( + engine_idx=outputs.engine_index, + scheduler_stats=outputs.scheduler_stats, + iteration_stats=iteration_stats, + ) + except Exception as e: + logger.exception("AsyncLLM output_handler failed.") + output_processor.propagate_error(e) + + from recipe.fault_recover.fault_manager import get_tokens_queue + + tokens_queue = get_tokens_queue() + + self.output_handler = asyncio.create_task(output_handler(tokens_queue)) diff --git a/ICL/DAPO/verl-recipe/flash_rl_ascend/README.md b/ICL/DAPO/verl-recipe/flash_rl_ascend/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f76db8f69f4b13ca2fd54d913bf72f39570ff1e6 --- /dev/null +++ b/ICL/DAPO/verl-recipe/flash_rl_ascend/README.md @@ -0,0 +1,121 @@ +## 在线量化权重: + +介绍在昇腾设备上,使用 [Flash-RL](https://github.com/yaof20/Flash-RL) 工具,修改推理后端,通过比较 INT8 模型和 BF16 模型,对权重和激活值执行在线量化。下文以 Qwen3-30B int8 为例,在 NPU 上跑通端到端功能。 + +### 环境依赖 + +## +| PyTorch版本 | torch_npu版本 | CANN版本 | Python版本 | +| ------------ |-----------| ---------- | ---------- | +| 2.7.1 | 2.7.1 | 8.5.0 | Python3.10 | + + +#### 1、安装 vllm 和 vllm-ascend +```bash +# vllm==0.10.1 +git clone https://github.com/vllm-project/vllm.git +cd vllm +git checkout e03940762b43812fccd3c214bda60201cff9d16a +pip install -r requirements/build.txt +VLLM_TARGET_DEVICE=empty pip install -v . +cd .. + +# vllm-ascend==0.10.1 +git clone https://github.com/vllm-project/vllm-ascend.git +cd vllm-ascend +git checkout 7e16b4a7cdb15723c63c1c0efe58672a056eace8 +pip install -r requirements.txt +export COMPILE_CUSTOM_KERNELS=1 +python setup.py install +cd .. + +# 源码安装transformers +git clone -b v4.57.6 https://github.com/huggingface/transformers.git +cd transformers +pip install -e . +``` + +#### 2、安装 MindSpeed 与 Megatron +```bash +# MindSpeed +git clone https://gitcode.com/Ascend/MindSpeed.git +cd MindSpeed +git checkout 1cdd0abd75e40936ad31721c092f57c695dd72c4 +pip install -e . +cd .. + +# Megatron +pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 +``` + +#### 3、安装 verl +```bash +# verl +git clone https://github.com/volcengine/verl.git +cd verl +pip install -e . +cd .. +``` + +### 使用步骤: + +#### 1、安装包: + +``` +pip install flash-llm-rl # need to be installed in all nodes in multi-node training +``` + +#### 2、打patch + +安装 FlashRL 后,默认采用自动 patch,推荐改用手动方式,减少过程中的错误: + +1. 在 `verl/verl/__init__.py` 文件中添加 `import flash_rl`; +2. 在 shell 脚本中添加 `flashrl cleanup`,这将禁用自动 patch; + +#### 3、生成性能分析文件 + +具体来说,profile 文件会比较 bf16 模型和 int8 模型,以确定如何对更新后的模型执行在线量化: + +``` +flashrl profile -m Qwen3-30B-A3B -q Qwen3-30B-A3B-w8a8 -o ${PROFILE_PATH:-"$HOME/profile.30b.pt"} --fn int8 +``` + +`-m` 参数后是 bf16 模型路径,`-q` 参数后是 int8 模型路径,`-o` 参数后是生成文件路径; +[RedHatAI](https://huggingface.co/RedHatAI/collections) 提供了各种量化模型; + +#### 4、生成配置文件 + +通过以下命令生成 yaml 配置文件,供 patch 程序使用: + +``` +flashrl setup -m Qwen3-30B-A3B-w8a8 -p $HOME/profile.30b.pt --fn int8 -o ${CONFIG_PATH:-"$HOME/.flashrl_config.30b.yaml"} +``` + +`-m` 参数后是 int8 模型路径,`-p` 参数后是 profile 文件路径,`-o` 参数后是生成文件路径; + +(可选)为了缩小 rollout 生成和梯度计算之间的差距,FlashRL 提供了在 DP 工作线程间以混合方式进行 16 位和 8 位 rollout 生成的功能。具体来说,运行以下命令会将第二个配置附加到现有的 yaml 配置文件中。 + +``` +flashrl setup -a --fn bf16 -o ${CONFIG_PATH:-"$HOME/.flashrl_config.30b.yaml"} +``` + +#### 5、开始训练 + +脚本中添加以下环境变量: + +``` +# 打印详细日志,查看是否 patch 成功: +export FLASHRL_LOGGING_LEVEL=DEBUG +# 指定配置文件: +export FLASHRL_CONFIG=$HOME/.flashrl_config.30b.yaml +# 强制 lm-head 使用 bf16,减小精度损失: +export FLASHRL_LMHEAD_FP32=1 +``` + +以上步骤已在 `test_qwen3-30b_int8_npu.sh` 提供实例,修改脚本中的模型路径即可自动执行,有具体问题可根据上述步骤排查; + +在 `run.sh` 文件中补充机器 IP、网络接口,运行以下命令启动训练: + +``` +bash ./run.sh +``` diff --git a/ICL/DAPO/verl-recipe/flowrl/README.md b/ICL/DAPO/verl-recipe/flowrl/README.md new file mode 100644 index 0000000000000000000000000000000000000000..515cc0b795c5e096b2f98cdcc0368b8f63618feb --- /dev/null +++ b/ICL/DAPO/verl-recipe/flowrl/README.md @@ -0,0 +1,182 @@ +

+ FlowRL +

+ +

+ Matching Reward Distributions via Flow Balance +

+

+ 📄 arXiv Paper | + 🤗 #1 Paper of the Day +

+

+ 𝕏 Post 1 | + 𝕏 Post 2 | + 𝕏 Post 3 | + 𝕏 Post 4 +

+ +

+ FlowRL Overview +

+ +## Table of Contents + +- [FlowRL Objective](#flowrl-objective) +- [Trained Models & Experiment Logs](#trained-models--experiment-logs) +- [Quick Start](#quick-start) + - [Option 1: Original Paper Reproduction (verl 0.4.0)](#option-1-original-paper-reproduction-verl-040--recommended) + - [Step 1: Installation](#step-1-installation) + - [Step 2: Data Preparation](#step-2-data-preparation) + - [Step 3: Model Preparation](#step-3-model-preparation) + - [Step 4: Training Scripts](#step-4-training-scripts) + - [Option 2: Latest verl Recipe FlowRL](#option-3-latest-verl-recipe-flowrl) + - [Step 1: Prepare Data and Model](#step-1-prepare-data-and-model) + - [Step 2: Run Training](#step-2-run-training) + - [Option 3: Implement FlowRL Yourself](#option-4-implement-flowrl-yourself) +- [Testing](#testing) +- [Citation](#citation) + +## FlowRL Objective + +$$ +\mathcal{L}_{\text{FlowRL}} = w \cdot \left( \log Z_{\phi}(x) + \frac{1}{|y|} \log \pi_{\theta}(y \mid x) - \beta \hat{r}(x, y) - \frac{1}{|y|} \log \pi_{\text{ref}}(y \mid x) \right)^2 +$$ + +FlowRL is a flow-balanced reinforcement learning method that matches full reward distributions instead of maximizing rewards, promoting diverse exploration and generalizable reasoning trajectories in LLMs. + +## Trained Models & Experiment Logs + +| Base Model | Domain | WandB Logs | Hugging Face Model | +|-------|--------|------------|-------------------| +| Qwen2.5-7B | Math | [🔗 View Run](https://wandb.ai/xuekaizhu0/FlowRL/runs/pa62rs4x?nw=nwuserxuekaizhu0) | [🤗 Model](https://huggingface.co/xuekai/FlowRL-Qwen2.5-7B-math) | +| DeepSeek-7B | Code | [🔗 View Run](https://wandb.ai/xuekaizhu0/FlowRL/runs/wbw72gdv?nw=nwuserxuekaizhu0) | [🤗 Model](https://huggingface.co/xuekai/FlowRL-DeepSeek-7B-code) | +| Qwen2.5-32B | Math | - | [🤗 Model](https://huggingface.co/xuekai/FlowRL-Qwen2.5-32B-math) | + +## Quick Start + +There are three ways to use FlowRL: + +--- + +**⭐ We recommend using Option 1 as the default choice.** Since verl updates frequently, the newest versions may have unstable factors such as training and inference mismatches. Option 1 uses verl 0.4.0, which is stable and has been thoroughly tested with our paper results. + +--- + +### Option 1: Original Paper Reproduction (verl 0.4.0) ⭐ Recommended + +For exact reproduction of results from the paper, use the original repository with verl 0.4.0: + +👉 **Original Code:** [https://github.com/Xuekai-Zhu/FlowRL](https://github.com/Xuekai-Zhu/FlowRL) + +#### Step 1: Installation + +Install [verl](https://github.com/volcengine/verl) first before using FlowRL. + +#### Step 2: Data Preparation + +```bash +# Option A: Download our pre-processed datasets directly +bash preprocess/down_load_dataset.sh +# Move data to default directory +mv data/xuekai/flowrl-data-collection/math_data data/math_data +mv data/xuekai/flowrl-data-collection/code_data data/code_data +``` + +```bash +# Option B: Process data from original sources +# For detailed processing instructions, see data/README.md +``` + +#### Step 3: Model Preparation + +For Math Tasks: `Qwen/Qwen2.5-7B` (default in script) ; `Qwen/Qwen2.5-32B` + +For Code Tasks: `deepseek-ai/DeepSeek-R1-Distill-Qwen-7B` + +```bash +# Download default model (Qwen2.5-7B for math) +bash preprocess/down_load_model.sh + +# For other models, modify MODEL_NAME in the script before running +``` + +#### Step 4: Training Scripts + +```bash +cd verl_FlowRL + +# For 7B math training +bash command/training/math/flowrl_7B_math.sh + +# For 32B math training +bash command/training/math/flowrl_32B_math.sh + +# For 7B code training +bash command/training/code/flowrl_7B_code.sh +``` +---- +### Option 2: Latest verl Recipe FlowRL + +For running FlowRL using the latest verl framework: + +**Latest verl:** + +- verl recipe: [https://github.com/volcengine/verl/tree/main/recipe/flowrl](https://github.com/volcengine/verl/tree/main/recipe/flowrl) + +#### Step 1: Prepare Data and Model + +```bash +# Prepare dataset +bash recipe/flowrl/prepare/prepare_data.sh + +# Prepare model +bash recipe/flowrl/prepare/prepare_model.sh +``` + +#### Step 2: Run Training + +```bash +# Train FlowRL with Qwen2.5-7B +bash recipe/flowrl/run_flowrl_qwen2.5_7b.sh +``` +---- +### Option 3: Implement FlowRL Yourself + +If you want to implement FlowRL in your own codebase, we provide a detailed implementation guide: + +📖 **[FlowRL Implementation Guide](FLOWRL_SIMPLE_GUIDE.md)** + +This guide walks you through the key components and steps needed to integrate FlowRL into your existing training pipeline. + +## Testing + +After training your FlowRL models, you can evaluate them using the following commands: + +```bash +cd verl_Test + +# First merge the model +bash command/eval/merge_model.sh + +# For math testing +bash command/eval/math/flowrl_math_test.sh + +# For code testing +bash command/eval/code/flowrl_code_test.sh +``` + +**Reference:** For verl v0.5.0.dev merge model script, see [merge_model.sh](https://github.com/Xuekai-Zhu/verl_FlowRL/blob/flowrl-v0.5.0.dev/recipe/flowrl/eval/merge_model.sh) + +## Citation + +If you think this repo helps you, please kindly consider citing our paper: + +```bibtex +@article{zhu2025flowrl, + title={FlowRL: Matching Reward Distributions for LLM Reasoning}, + author={Zhu, Xuekai and Cheng, Daixuan and Zhang, Dinghuai and Li, Hengli and Zhang, Kaiyan and Jiang, Che and Sun, Youbang and Hua, Ermo and Zuo, Yuxin and Lv, Xingtai and others}, + journal={arXiv preprint arXiv:2509.15207}, + year={2025} +} +``` diff --git a/ICL/DAPO/verl-recipe/flowrl/__init__.py b/ICL/DAPO/verl-recipe/flowrl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3fcb391c6680083dfa1749fc9ca311063a4e750 --- /dev/null +++ b/ICL/DAPO/verl-recipe/flowrl/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FlowRL recipe package.""" + +__all__ = [] diff --git a/ICL/DAPO/verl-recipe/flowrl/flowrl_fsdp_worker.py b/ICL/DAPO/verl-recipe/flowrl/flowrl_fsdp_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..aed3063a2a69e45905764e692c2a4733c9843421 --- /dev/null +++ b/ICL/DAPO/verl-recipe/flowrl/flowrl_fsdp_worker.py @@ -0,0 +1,495 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FlowRL FSDP Worker that uses FlowRLActor instead of standard DPActor.""" + +import logging +import os +import warnings + +import torch +import torch.distributed +from peft import LoraConfig, TaskType, get_peft_model +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +try: + # for torch 2.5+ + from torch.distributed.tensor import DTensor +except ImportError: + from torch.distributed._tensor import DTensor + +from recipe.flowrl.flowrl_actor import FlowRLActor, ProjZModule + +from verl.models.transformers.monkey_patch import apply_monkey_patch +from verl.single_controller.base.decorator import Dispatch, register +from verl.utils import hf_processor, hf_tokenizer +from verl.utils.activation_offload import enable_activation_offloading +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.device import ( + get_device_id, + get_torch_device, + set_expandable_segments, +) +from verl.utils.fsdp_utils import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + apply_fsdp2, + collect_lora_params, + fsdp2_load_full_state_dict, + get_fsdp_wrap_policy, + get_init_weight_context_manager, + get_shard_placement_fn, + init_fn, + load_fsdp_model_to_gpu, + offload_fsdp_model_to_cpu, + replace_lora_wrapper, +) +from verl.utils.memory_utils import aggressive_empty_cache +from verl.utils.model import convert_weight_keys +from verl.utils.profiler import log_gpu_memory_usage +from verl.utils.py_functional import convert_to_regular_types +from verl.workers.config import FSDPEngineConfig +from verl.workers.fsdp_workers import ActorRolloutRefWorker, get_sharding_strategy, get_vl_model_vision_tower + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class FlowRLActorRolloutRefWorker(ActorRolloutRefWorker): + """ + FlowRL version of ActorRolloutRefWorker that uses FlowRLActor. + + This worker adds FlowRL-specific modifications: + - ProjZModule for log Z estimation (added in _build_model_optimizer) + - FlowRLActor with trajectory balance loss (replaces standard DPActor) + """ + + def _build_model_optimizer( + self, + model_path, + fsdp_config: FSDPEngineConfig, + optim_config, + override_model_config, + use_remove_padding=False, + use_fused_kernels=False, + enable_gradient_checkpointing=False, + trust_remote_code=False, + use_liger=False, + role="actor", + enable_activation_offload=False, + ): + from torch import optim + from torch.distributed.fsdp import CPUOffload, MixedPrecision + from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForImageTextToText, + AutoModelForVision2Seq, + ) + + from verl.utils.model import get_generation_config, print_model_size, update_model_config + from verl.utils.torch_dtypes import PrecisionType + + assert role in ["actor", "ref"] + + log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger) + local_path = model_path + + # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect + # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code) + + if self.config.model.get("custom_chat_template", None) is not None: + if self.processor is not None: + self.processor.chat_template = self.config.model.custom_chat_template + else: + self.tokenizer.chat_template = self.config.model.custom_chat_template + + vllm_dtype = PrecisionType.to_dtype(self.config.rollout.dtype) + torch_dtype = fsdp_config.get("model_dtype", None) + if torch_dtype is None: + torch_dtype = torch.float32 if self._is_actor else vllm_dtype + else: + torch_dtype = PrecisionType.to_dtype(torch_dtype) + + # override model kwargs + actor_model_config = AutoConfig.from_pretrained( + local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2" + ) + # TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53 + # which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids + # Maybe support Ulysses in VisionAttention in the future and remove this patch + if self.ulysses_sequence_parallel_size > 1 and hasattr(actor_model_config, "vision_config"): + actor_model_config.vision_config._attn_implementation = "eager" + + # patch for kimi-vl + if getattr(actor_model_config, "model_type", None) == "kimi_vl": + actor_model_config.text_config.topk_method = "greedy" + + self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code) + + override_config_kwargs = { + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + } + override_config_kwargs.update(override_model_config) + update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs) + if self.rank == 0: + print(f"Model config after override: {actor_model_config}") + + # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang + init_context = get_init_weight_context_manager( + use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh + ) + + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + has_remote_code = hasattr(actor_model_config, "auto_map") and any( + actor_model_config.architectures[0] in val for val in actor_model_config.auto_map.values() + ) + if has_remote_code: + auto_class = next( + k for k, v in actor_model_config.auto_map.items() if actor_model_config.architectures[0] in v + ) + match auto_class: + case "AutoModelForVision2Seq": + actor_module_class = AutoModelForVision2Seq + case "AutoModelForCausalLM": + actor_module_class = AutoModelForCausalLM + case "AutoModelForImageTextToText": + actor_module_class = AutoModelForImageTextToText + case _: + actor_module_class = AutoModel + else: + if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys(): + actor_module_class = AutoModelForVision2Seq + elif type(actor_model_config) in AutoModelForCausalLM._model_mapping.keys(): + actor_module_class = AutoModelForCausalLM + elif type(actor_model_config) in AutoModelForImageTextToText._model_mapping.keys(): + actor_module_class = AutoModelForImageTextToText + else: + actor_module_class = AutoModel + + actor_module = actor_module_class.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=actor_model_config, + trust_remote_code=trust_remote_code, + ) + + # ==== FlowRL: inject ProjZ BEFORE FSDP wrap ==== + if role == "actor" and self._is_actor: + n_dim = actor_module.config.hidden_size + proj_layers = getattr(self.config.actor, "proj_layer", 3) + actor_module.add_module("proj_z", ProjZModule(n_dim, num_layers=proj_layers)) + + if self.rank == 0: + print(f"[FlowRL] Added proj_z (layers={proj_layers}, hidden={n_dim}) BEFORE FSDP wrap") + # =============================================== + + # Apply Liger kernel to the model if use_liger is set to True + if use_liger: + from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance + + _apply_liger_kernel_to_instance(model=actor_module) + + fused_kernel_options = self.config.model.get("fused_kernel_options", None) + fused_kernels_backend = ( + fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None + ) + + apply_monkey_patch( + model=actor_module, + use_remove_padding=use_remove_padding, + ulysses_sp_size=self.ulysses_sequence_parallel_size, + use_fused_kernels=use_fused_kernels, + fused_kernels_backend=fused_kernels_backend, + ) + + # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 + actor_module.to(torch_dtype) + + if enable_gradient_checkpointing: + actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + if self._is_lora: + print("Applying LoRA to actor module") + actor_module.enable_input_require_grads() + # Convert config to regular Python types before creating PEFT model + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "exclude_modules": convert_to_regular_types(self.config.model.exclude_modules), + "bias": "none", + } + actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) + + self.use_orig_params = fsdp_config.get("use_orig_params", False) + if self.config.actor.get("freeze_vision_tower", False): + vision_tower = get_vl_model_vision_tower(actor_module) + if vision_tower is not None: + vision_tower.requires_grad_(False) + self.use_orig_params = True + if self.rank == 0: + print("[actor model] Vision tower is set to not trainable.") + else: + if self.rank == 0: + print("[actor model] No vision tower found.") + + torch.distributed.barrier() + + if self.rank == 0: + print_model_size(actor_module) + + log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger) + + # We wrap FSDP for rollout as well + mixed_precision_config = fsdp_config.get("mixed_precision", None) + if mixed_precision_config is not None: + param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) + reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) + buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) + else: + param_dtype = PrecisionType.to_dtype(self.config.actor.get("dtype", "bfloat16")) + reduce_dtype = torch.float32 + buffer_dtype = torch.float32 + + mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) + + auto_wrap_policy = get_fsdp_wrap_policy( + module=actor_module, + config=fsdp_config.get("wrap_policy", None), + is_lora=self.config.model.get("lora_rank", 0) > 0, + ) + + if self._is_rollout and self.config.rollout.name == "hf": + # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma + auto_wrap_policy = None + + if self.rank == 0: + print(f"wrap_policy: {auto_wrap_policy}") + + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) + + # TODO: add transformer policy + # We force reference policy to use CPUOffload to save memory. + # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation + cpu_offload = None if role == "actor" else CPUOffload(offload_params=True) + fsdp_strategy = self.config.actor.strategy + if fsdp_strategy == "fsdp": + actor_module_fsdp = FSDP( + actor_module, + cpu_offload=cpu_offload, + param_init_fn=init_fn, + auto_wrap_policy=auto_wrap_policy, + device_id=get_device_id(), + sharding_strategy=sharding_strategy, # zero3 + mixed_precision=mixed_precision, + sync_module_states=True, + device_mesh=self.device_mesh, + use_orig_params=self.use_orig_params, + forward_prefetch=fsdp_config.get("forward_prefetch", False), + ) + elif fsdp_strategy == "fsdp2": + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True + ) + if role == "actor" and fsdp_config.offload_policy: + cpu_offload = CPUOffloadPolicy(pin_memory=True) + self._is_offload_param = False + self._is_offload_optimizer = False + else: + cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True) + + fsdp_kwargs = { + "mesh": fsdp_mesh, + "mp_policy": mp_policy, + "offload_policy": cpu_offload, + "reshard_after_forward": fsdp_config.reshard_after_forward, + "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), + } + full_state = actor_module.state_dict() + apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config) + fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload) + actor_module_fsdp = actor_module + else: + raise NotImplementedError(f"not implement {fsdp_strategy}") + + if enable_activation_offload: + enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing) + + log_gpu_memory_usage(f"After {role} FSDP init", logger=logger) + + # TODO: add more optimizer args into config + if role == "actor" and optim_config is not None: + from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup + + actor_optimizer = optim.AdamW( + actor_module_fsdp.parameters(), + lr=optim_config.lr, + betas=optim_config.get("betas", (0.9, 0.999)), + weight_decay=optim_config.get("weight_decay", 1e-2), + ) + + total_steps = optim_config.get("total_training_steps", 0) + num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1)) + warmup_style = optim_config.get("warmup_style", "constant") + min_lr_ratio = optim_config.get("min_lr_ratio", 0.0) + num_cycles = optim_config.get("num_cycles", 0.5) + if num_warmup_steps < 0: + num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0) + num_warmup_steps = int(num_warmup_steps_ratio * total_steps) + + if self.rank == 0: + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + + if warmup_style == "constant": + actor_lr_scheduler = get_constant_schedule_with_warmup( + optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps + ) + elif warmup_style == "cosine": + actor_lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=actor_optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, + min_lr_ratio=min_lr_ratio, + num_cycles=num_cycles, + ) + else: + raise NotImplementedError(f"Warmup style {warmup_style} is not supported") + + log_gpu_memory_usage(f"After {role} optimizer init", logger=logger) + else: + actor_optimizer = None + actor_lr_scheduler = None + + return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + """Override init_model to use FlowRLActor instead of DataParallelPPOActor.""" + # Call parent's init_model to set up the FSDP model (with proj_z already added) + super().init_model() + + # Replace the actor with FlowRLActor if this worker is an actor + if self._is_actor: + if self.rank == 0: + print("[FlowRL] Replacing DataParallelPPOActor with FlowRLActor") + + # Convert actor config to dataclass + actor_cfg = omega_conf_to_dataclass(self.config.actor) + + # Create FlowRLActor with trajectory balance loss + self.actor = FlowRLActor( + config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer + ) + + async def rollout_mode(self): + """ + Override rollout_mode to filter out proj_z parameters before syncing to vLLM. + + FlowRL's proj_z module is only needed during training for estimating log Z. + It should not be loaded into vLLM since vLLM is only used for rollout generation. + """ + aggressive_empty_cache(force_sync=True) + + log_gpu_memory_usage("Before load_fsdp_model_to_gpu", logger=logger) + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + log_gpu_memory_usage("After load_fsdp_model_to_gpu", logger=logger) + + peft_config = None + peft_model = getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) + if hasattr(peft_model, "peft_config"): # LoRA + peft_config = peft_model.peft_config.get("default", None) + params = collect_lora_params( + module=self.actor_module_fsdp, + layered_summon=self.config.rollout.get("layered_summon", False), + base_sync_done=self.base_sync_done, + ) + if not self.base_sync_done: + params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()} + else: + params = self.actor_module_fsdp.state_dict() + + # ==== FlowRL: Filter out proj_z parameters ==== + params = {k: v for k, v in params.items() if not k.startswith("proj_z")} + num_proj_z_filtered = len([k for k in self.actor_module_fsdp.state_dict().keys() if k.startswith("proj_z")]) + if num_proj_z_filtered > 0 and self.rank == 0: + print(f"[FlowRL] Filtered {num_proj_z_filtered} proj_z parameters before syncing to vLLM") + # =============================================== + + params = convert_weight_keys( + params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) + ) + + # Special handling for LoRA with sleep_level=2: + if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2: + base_model_params = collect_lora_params( + module=self.actor_module_fsdp, + layered_summon=self.layered_summon, + base_sync_done=False, + ) + base_model_params = {replace_lora_wrapper(k, peft_config): v for k, v in base_model_params.items()} + # Filter proj_z from base model params as well + base_model_params = {k: v for k, v in base_model_params.items() if not k.startswith("proj_z")} + base_model_params = convert_weight_keys( + base_model_params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) + ) + + log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger) + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload_fsdp_model_to_cpu", logger=logger) + + set_expandable_segments(False) + + if peft_config is not None and self.base_sync_done: + per_tensor_param = params + else: + device = get_device_id() + per_tensor_param = ( + (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) + for name, param in params.items() + ) + + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["weights"]) + log_gpu_memory_usage("After resume weights", logger=logger) + + if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2: + per_tensor_base_params = ( + (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) + for name, param in base_model_params.items() + ) + await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False) + del base_model_params, per_tensor_base_params + + await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done) + log_gpu_memory_usage("After update_weights", logger=logger) + del params, per_tensor_param + aggressive_empty_cache(force_sync=True) + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["kv_cache"]) + log_gpu_memory_usage("After resume kv_cache", logger=logger) + + self.base_sync_done = True + self.torch_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.gen_random_states) diff --git a/ICL/DAPO/verl-recipe/flowrl/main_flowrl.py b/ICL/DAPO/verl-recipe/flowrl/main_flowrl.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9912c05c3a850b74e89248ec5dea0e81802a78 --- /dev/null +++ b/ICL/DAPO/verl-recipe/flowrl/main_flowrl.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main training script for FlowRL algorithm.""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.trainer.ppo.reward import load_reward_manager +from verl.utils.device import is_cuda_available + + +@hydra.main(config_path="config", config_name="flowrl_trainer", version_base=None) +def main(config): + run_flowrl(config) + + +def run_flowrl(config) -> None: + if not ray.is_initialized(): + # this is for local ray cluster + default_runtime_env = { + "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} + } + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + try: + if ( + is_cuda_available + and config.global_profiler.tool == "nsys" + and OmegaConf.select(config.global_profiler, "steps") is not None + and len(OmegaConf.select(config.global_profiler, "steps")) > 0 + ): + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + finally: + if ray.is_initialized(): + ray.shutdown() + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + def run(self, config): + # print initial config + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_to_local(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + from verl.utils import hf_processor, hf_tokenizer + + tokenizer = hf_tokenizer(local_path) + processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none + + from verl.single_controller.ray import RayWorkerGroup + + # define worker classes + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} + + # Use FlowRL custom worker instead of standard worker + from recipe.flowrl.flowrl_fsdp_worker import FlowRLActorRolloutRefWorker + + from verl.workers.fsdp_workers import CriticWorker # , ActorRolloutRefWorker + + ActorRolloutRefWorker = FlowRLActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.Critic: ray.remote(CriticWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + } + + # we should adopt a multi-source reward function here + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # - finally, we combine all the rewards together + # - The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # reference model + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + reward_fn = load_reward_manager( + config, + tokenizer, + 0, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer, + ) + + # Note that we always use function-based RM for validation + val_reward_fn = load_reward_manager( + config, + tokenizer, + 1, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer, + ) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + from recipe.flowrl.flowrl_ray_trainer import RayFlowRLTrainer + + trainer = RayFlowRLTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + ) + trainer.init_workers() + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/flowrl/run_flowrl_qwen2.5_7b.sh b/ICL/DAPO/verl-recipe/flowrl/run_flowrl_qwen2.5_7b.sh new file mode 100644 index 0000000000000000000000000000000000000000..406f4d10730dd43469cda496bc10c4b153870445 --- /dev/null +++ b/ICL/DAPO/verl-recipe/flowrl/run_flowrl_qwen2.5_7b.sh @@ -0,0 +1,134 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='FlowRL' +exp_name='FlowRL-Qwen2.5-7B' + +# Algorithm settings +adv_estimator=grpo + +# KL settings (ref policy needed for FlowRL, but KL penalty disabled) +use_kl_in_reward=False # Enable ref policy for ref_log_prob (needed for FlowRL loss) +kl_coef=0.0 +use_kl_loss=True +kl_loss_coef=0.0 + +# Clip parameters +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Sequence lengths +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) + +# Overlong buffer for very long responses +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# Batch sizes +train_prompt_bsz=512 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=8 +train_prompt_mini_bsz=32 + +# Checkpoint saving frequency (-1 to disable periodic saves) +save_freq=-1 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-1} + +# Paths +MODEL_PATH=${MODEL_PATH:-"${WORKING_DIR}/downloads/models/Qwen/Qwen2.5-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${WORKING_DIR}/outputs/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${WORKING_DIR}/downloads/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${WORKING_DIR}/downloads/data/aime-2024.parquet"} + +# Sampling +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +n_gpus=8 +sp_size=1 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=False +gen_tp=1 + + +python3 -m recipe.flowrl.main_flowrl \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.warmup_style='constant' \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=${n_gpus} \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=${save_freq} \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto diff --git a/ICL/DAPO/verl-recipe/infigui-g1/README.md b/ICL/DAPO/verl-recipe/infigui-g1/README.md new file mode 100644 index 0000000000000000000000000000000000000000..01ec072aa8cc001b26b4479b80baf98be67fe141 --- /dev/null +++ b/ICL/DAPO/verl-recipe/infigui-g1/README.md @@ -0,0 +1,56 @@ +# Recipe for InfiGUI-G1 + +This directory contains the official implementation for the paper [InfiGUI-G1: Advancing GUI Grounding with Adaptive Exploration Policy Optimization](https://arxiv.org/abs/2508.05731). + +This work introduces Adaptive Exploration Policy Optimization (AEPO), a policy optimization framework designed to enhance GUI grounding in Multimodal Large Language Models (MLLMs). AEPO improves exploration efficiency by employing a multi-answer generation strategy and a theoretically grounded Adaptive Exploration Reward (AER) function. This approach effectively addresses the challenge of semantic alignment in complex GUI grounding tasks. + +We provide training scripts for both 3B and 7B models, configured for a single machine with 8 GPUs by default. + +## Environment Setup + +Please follow the main environment setup guide for `verl`. + +The provided scripts use the following Docker image: `verlai/verl:app-verl0.5-transformers4.55.4-sglang0.4.10.post2-mcore0.13.0-te2.2` + +## Data Preparation + +Before starting the training, you need to download the example dataset. This dataset is a filtered version of [omniact](https://huggingface.co/datasets/Writer/omniact), containing only grounding tasks and excluding easy samples. + +The data is hosted on the Hugging Face. You can download it using the `huggingface-cli`: + +```bash +huggingface-cli download --repo-type dataset --resume-download InfiX-ai/omniact_grounding_filtered --local-dir data/omniact_grounding_filtered +``` + +This command will download the training and validation parquet files into the `data/omniact_grounding_filtered` directory, which is the default path used by the scripts. + +## Training + +We provide scripts to train the 3B and 7B models. Please run them from the root directory of `verl`. + +- **Train the 3B model:** + + ```bash + bash recipe/infigui-g1/run_3b.sh + ``` + +- **Train the 7B model:** + + ```bash + bash recipe/infigui-g1/run_7b.sh + ``` + +## Using Custom Data + +If you wish to train on your own dataset, please format your data to match the structure of the example files located in `data/omniact_grounding_filtered`. + +Once your data is ready, you need to update the data path arguments in the training script. + +In `run_3b.sh` or `run_7b.sh`, modify the following lines: + +```bash + data.train_files=./path/to/your/train_data.parquet \ + data.val_files=./path/to/your/val_data.parquet \ +``` + +Replace the paths with the location of your custom data files. diff --git a/ICL/DAPO/verl-recipe/langgraph_agent/__init__.py b/ICL/DAPO/verl-recipe/langgraph_agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/ICL/DAPO/verl-recipe/langgraph_agent/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ICL/DAPO/verl-recipe/langgraph_agent/chat_model.py b/ICL/DAPO/verl-recipe/langgraph_agent/chat_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0365a56a999bdc7139aabd8abacbf28ac49bf89e --- /dev/null +++ b/ICL/DAPO/verl-recipe/langgraph_agent/chat_model.py @@ -0,0 +1,393 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Ref: https://python.langchain.com/docs/how_to/custom_chat_model/ +""" + +import asyncio +import json +import logging +import os +import uuid +from typing import Any, Optional + +from langchain_core.language_models import BaseChatModel +from langchain_core.language_models.base import LanguageModelInput +from langchain_core.messages import ( + AIMessage, + BaseMessage, + convert_to_openai_messages, +) +from langchain_core.messages.tool import InvalidToolCall, ToolCall +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.runnables import Runnable, RunnableConfig +from langchain_core.tools import StructuredTool +from langchain_core.utils.function_calling import convert_to_openai_tool +from pydantic import Field + +from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, AsyncLLMServerManager +from verl.experimental.agent_loop.tool_parser import ToolParser +from verl.experimental.agent_loop.utils import add_generation_prompt_for_gpt_oss, format_gpt_oss_tool_response_manually + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class MaxTokenExceededError(Exception): + """Indicate that history chat messages + tool message exceeds LLM max_tokens.""" + + pass + + +class ChatModel(BaseChatModel): + model_name: str = Field(alias="model") + """The name of the model""" + + client: AsyncLLMServerManager + """AsyncLLM server manager""" + + tokenizer: Any + """Tokenizer for the model""" + + max_tokens: int + """Max tokens to generate""" + + tool_parser: str = "hermes" + """Tool parser for the model""" + + max_parallel_calls: int = 1 + """Max parallel tool calls""" + + temperature: float = 1.0 + """Temperature for sampling""" + + top_p: float = 1.0 + """Top p for sampling""" + + repetition_penalty: float = 1.0 + """Repetition penalty for sampling""" + + def bind_tools(self, tools, **kwargs) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tools to the model. + + Args: + tools: Sequence of tools to bind to the model. + + Returns: + A Runnable that returns a message. + """ + formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools] + + # used to remove system prompt prefix when encoding tool response + system_prompt = self.tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True) + kwargs["system_prompt"] = system_prompt + + return self.bind(tools=formatted_tools, **kwargs) + + def with_structured_output( + self, + schema: dict | type, + *, + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, dict | BaseChatModel]: + """Ref: https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/""" + raise NotImplementedError + + def _generate( + self, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> ChatResult: + raise NotImplementedError + + async def _agenerate( + self, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> ChatResult: + """Asynchronously generate chat completion message. + + Args: + messages (list[BaseMessage]): List of list of messages. + stop (Optional[list[str]], optional): Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. Defaults to None. + + Returns: + ChatResult: Chat result. + """ + request_id, prompt_ids, response_mask = await self._preprocess(messages, **kwargs) + + sampling_params = { + "temperature": self.temperature, + "top_p": self.top_p, + "repetition_penalty": self.repetition_penalty, + } + if "sampling_params" in kwargs: + sampling_params.update(kwargs["sampling_params"]) + + output = await self.client.generate( + request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params + ) + + message = await self._postprocess(request_id, prompt_ids, response_mask, output.token_ids, **kwargs) + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + @property + def _llm_type(self) -> str: + """Get the type of language model used by this chat model.""" + return self.model_name + + async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple[str, list[int], list[int]]: + """Preprocess messages for chat completion. + + To ensure strong consistency with policy model, AsyncLLM server generate response with token in token out + instead of messages list. + + But all agent frameworks use messages list to represent chat history. To mitigate the gap, we store trajectory + (prompt_ids, response_mask) in lastest AIMessage.response_metadata. + + 1. Encode ToolMessage to token ids. + 2. Retrieve trajectory (prompt_ids, response_mask) from lastest AIMessage.response_metadata. + 3. Append ToolMessage token ids to prompt_ids, and append 0 to response_mask. + + Ref: https://python.langchain.com/docs/concepts/chat_history/ + + Args: + messages (list[BaseMessage]): List of messages. + + Returns: + tuple[str, list[int], list[int]]: Request id, prompt ids, response mask. + """ + # messages: [system], human, ai, human|tool, ai, human|tool, ... + assert messages[-1].type in ["human", "tool"], ( + f"Last message must be human or tool, but got {messages[-1].type}" + ) + loop = asyncio.get_running_loop() + + # Case 1: initial chat completion: [system], human + if messages[-1].type == "human" and (len(messages) == 1 or messages[-2].type != "ai"): + prompt_ids = await loop.run_in_executor( + None, + lambda: self.tokenizer.apply_chat_template( + convert_to_openai_messages(messages), + tools=kwargs.get("tools"), + add_generation_prompt=True, + tokenize=True, + ), + ) + return str(uuid.uuid4()), prompt_ids, [] + + # Case 2: follow up chat completion with tool/human response: [system], human, ai, human|tool, ... + for i in range(len(messages) - 1, -1, -1): + if messages[i].type == "ai": + break + assert "prompt_ids" in messages[i].response_metadata, "Last message must have prompt_ids in response_metadata" + assert "response_mask" in messages[i].response_metadata, ( + "Last message must have response_mask in response_metadata" + ) + + # encode tool response + tool_responses = convert_to_openai_messages(messages[i + 1 :]) + if self.tool_parser == "hermes": + tool_response_ids = await loop.run_in_executor( + None, + lambda messages=tool_responses: self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True + ), + ) + tool_response_ids = tool_response_ids[len(kwargs["system_prompt"]) :] + elif self.tool_parser == "gpt-oss": + # Format tool responses manually + # since gpt-oss chat template requires tool call messages to parse tool response messages + # we need to format the tool response messages manually + tool_response_texts = [] + for tool_msg in tool_responses: + if tool_msg["role"] == "tool": + # Use tool message's name if available (for multiple tool calls) + actual_tool_name = tool_msg.get("name", "unknown") + if actual_tool_name == "unknown": + logger.error(f"actual_tool_name: {actual_tool_name}") + formatted = format_gpt_oss_tool_response_manually(tool_msg["content"], actual_tool_name) + tool_response_texts.append(formatted) + + # Tokenize the manually formatted tool responses + tool_response_text = "".join(tool_response_texts) + # need to add generation tokens for gpt-oss manually since add_generation_prompt is True + tool_response_text = add_generation_prompt_for_gpt_oss(tool_response_text) + logger.debug(f"tool_response_text: {tool_response_text}") + + tool_response_ids = await loop.run_in_executor( + None, lambda: self.tokenizer.encode(tool_response_text, add_special_tokens=False) + ) + else: + raise ValueError(f"Unsupported tool parser: {self.tool_parser}") + + # stop generation if response length exceeds max response length + if len(messages[i].response_metadata["response_mask"]) + len(tool_response_ids) >= self.max_tokens: + raise MaxTokenExceededError(f"Max response length {self.max_tokens} exceeded") + + # append tool response to prompt + request_id = messages[i].response_metadata.pop("request_id") + prompt_ids = messages[i].response_metadata.pop("prompt_ids") + response_mask = messages[i].response_metadata.pop("response_mask") + prompt_ids += tool_response_ids + response_mask += [0] * len(tool_response_ids) + + return request_id, prompt_ids, response_mask + + async def _postprocess( + self, request_id: str, prompt_ids: list[int], response_mask: list[int], response_ids: list[int], **kwargs: Any + ) -> AIMessage: + """Postprocess response_ids when chat completion is done. + + 1. Decode response_ids, parse tool calls to AIMessage. + 2. Append response_ids to prompt_ids, and append 1 to response_mask. + 3. Store trajectory (prompt_ids, response_mask) in AIMessage.response_metadata. + + Args: + request_id (str): Unique request id. + prompt_ids (list[int]): Input prompt token ids in this chat completion. + response_mask (list[int]): Response mask before this chat completion. + response_ids (list[int]): LLM generated token ids in this chat completion. + + Returns: + AIMessage: Postprocessed message. + """ + prompt_ids += response_ids + response_mask += [1] * len(response_ids) + + tool_parser = ToolParser.get_tool_parser(self.tool_parser, self.tokenizer) + content, function_calls = await tool_parser.extract_tool_calls(response_ids) + + tool_calls, invalid_tool_calls = [], [] + + for function_call in function_calls: + error = None + try: + args = json.loads(function_call.arguments) + if not isinstance(args, dict): + error = f"Tool arguments must be a JSON object, got {type(args).__name__}" + except json.JSONDecodeError as e: + error = f"Invalid JSON tool arguments: {e}" + + if error: + logger.warning(error) + invalid_tool_calls.append( + InvalidToolCall( + name=function_call.name, + args=function_call.arguments, + id=str(uuid.uuid4()), + error=error, + ) + ) + else: + tool_calls.append( + ToolCall( + name=function_call.name, + args=args, + id=str(uuid.uuid4()), + ) + ) + + message = AIMessage( + content=content, + tool_calls=tool_calls[: self.max_parallel_calls], + invalid_tool_calls=invalid_tool_calls[: self.max_parallel_calls], + response_metadata={ + "request_id": request_id, + "prompt_ids": prompt_ids, + "response_mask": response_mask, + }, + ) + return message + + +class TruncateStructuredTool(StructuredTool): + """Structured tool with response truncation.""" + + tool_response_truncate_side: str + """truncate side of tool response: left, middle, right""" + + max_tool_response_length: int + """max length of tool response""" + + async def _arun( + self, + *args: Any, + config: RunnableConfig, + **kwargs: Any, + ) -> Any: + tool_response = await super()._arun(*args, config=config, **kwargs) + tool_response = str(tool_response) + + if len(tool_response) > self.max_tool_response_length: + if self.tool_response_truncate_side == "left": + tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)" + elif self.tool_response_truncate_side == "right": + tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :] + else: + length = self.max_tool_response_length // 2 + tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:] + + return tool_response + + +def convert_to_agent_output(messages: list[BaseMessage], response_length: int) -> AgentLoopOutput: + """Convert messages to AgentLoopOutput. + + Args: + messages (List[BaseMessage]): List of messages, last message must be assistant + with response_metadata containing `prompt_ids` and `response_mask`. + response_length (int): Max length of response. + + Returns: + AgentLoopOutput: agent loop output trajectory used for training. + """ + # skip last tool calls + for i in range(len(messages) - 1, -1, -1): + if messages[i].type != "tool": + break + last_message = messages[i] + assert last_message.type == "ai", f"Last message must be assistant, but got {last_message.type}" + assert "prompt_ids" in last_message.response_metadata, "Last message must have prompt_ids in response_metadata" + assert "response_mask" in last_message.response_metadata, ( + "Last message must have response_mask in response_metadata" + ) + + num_turns = 0 + for i in range(len(messages)): + if messages[i].type == "system": + continue + # parallel tool calls are in single turn + if i == 0 or messages[i].type != messages[i - 1].type: + num_turns += 1 + + prompt_ids = last_message.response_metadata["prompt_ids"] + response_mask = last_message.response_metadata["response_mask"] + + response_ids = prompt_ids[-len(response_mask) :] + prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)] + + output = AgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids[:response_length], + response_mask=response_mask[:response_length], + num_turns=num_turns, + metrics={}, + ) + return output diff --git a/ICL/DAPO/verl-recipe/langgraph_agent/react_agent_loop.py b/ICL/DAPO/verl-recipe/langgraph_agent/react_agent_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..7c705b8da6b8ef3de5ed3bec20ae090d2cd375a3 --- /dev/null +++ b/ICL/DAPO/verl-recipe/langgraph_agent/react_agent_loop.py @@ -0,0 +1,188 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +LangGraph React Agent Loop. + +This implementation is exact same as `ToolAgentLoop`. + +Ref: https://langchain-ai.github.io/langgraph/tutorials/workflows/ +""" + +import logging +from typing import Any, Literal + +from langchain_core.messages import AIMessage +from langchain_core.runnables import RunnableConfig +from langgraph.graph import END, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode +from recipe.langgraph_agent.chat_model import ( + ChatModel, + MaxTokenExceededError, + convert_to_agent_output, +) + +from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput + +logger = logging.getLogger(__name__) + + +async def call_model(state: MessagesState, config: RunnableConfig): + model = config["configurable"]["model"] + sampling_params = config["configurable"]["sampling_params"] + try: + message = await model.ainvoke(state["messages"], sampling_params=sampling_params) + return {"messages": [message]} + except MaxTokenExceededError: + # last message is ToolMessage + return {"messages": []} + + +def should_continue(state: MessagesState, config: RunnableConfig) -> Literal["tools", END]: + # Safely extract max_assistant_turns from config + max_assistant_turns = None + try: + if config and "configurable" in config: + max_assistant_turns = config["configurable"].get("max_assistant_turns") + except Exception as e: + logger.warning(f"Failed to extract max_assistant_turns from config: {e}") + + num_assistant_turns = 0 + for message in state["messages"]: + if message.type == "ai": + num_assistant_turns += 1 + + last_message = state["messages"][-1] + + # LLM call failed, e.g: max response length exceeded + if last_message.type == "tool": + return END + + # max assistant turns exceeded + # Use a reasonable default limit (25) if max_assistant_turns is not set + # This prevents infinite loops + effective_max_turns = max_assistant_turns if max_assistant_turns is not None else 25 + if num_assistant_turns >= effective_max_turns: + return END + + # no tool calls + if not getattr(last_message, "tool_calls", None): + return END + + return "tools" + + +class ReactAgentLoop(AgentLoopBase): + # Recursion limit calculation constants + DEFAULT_MAX_ASSISTANT_TURNS = 25 + MIN_RECURSION_LIMIT = 50 + NODES_PER_TURN = 2 # Each AI turn involves agent + tools nodes + RECURSION_LIMIT_SAFETY_FACTOR = 1.5 # 50% buffer for edge cases + + @classmethod + def init_class(cls, config, tokenizer, **kwargs): + if cls._class_initialized: + return + cls._class_initialized = True + print("Performing class-level ReactAgentLoop initialization") + + # build graph + cls.graph = cls.build_graph() + + @classmethod + def build_graph(cls) -> StateGraph: + workflow = StateGraph(MessagesState) + + workflow.add_node("agent", call_model) + workflow.add_node("tools", ToolNode(cls.tools)) + workflow.set_entry_point("agent") + workflow.add_conditional_edges( + "agent", + should_continue, + { + "tools": "tools", + END: END, + }, + ) + + workflow.add_edge("tools", "agent") + graph = workflow.compile() + return graph + + async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: + messages = list(kwargs["raw_prompt"]) + + model_path = self.config.actor_rollout_ref.model.path + model_name = "/".join(model_path.split("/")[-2:]) + + rollout = self.config.actor_rollout_ref.rollout + model = ChatModel( + model=model_name, + client=self.server_manager, + tokenizer=self.tokenizer, + max_tokens=rollout.response_length, + max_parallel_calls=rollout.multi_turn.max_parallel_calls, + tool_parser=rollout.multi_turn.format, + ) + + model = model.bind_tools(self.tools, tool_choice="any") + + # Calculate recursion_limit dynamically based on max_assistant_turns + max_assistant_turns = ( + rollout.multi_turn.max_assistant_turns + if rollout.multi_turn.max_assistant_turns + else self.DEFAULT_MAX_ASSISTANT_TURNS + ) + + # Formula: nodes_per_turn * max_turns * safety_buffer, with minimum threshold + recursion_limit = max( + self.MIN_RECURSION_LIMIT, + int(max_assistant_turns * self.NODES_PER_TURN * self.RECURSION_LIMIT_SAFETY_FACTOR), + ) + logger.info(f"Configured recursion_limit={recursion_limit} (max_assistant_turns={max_assistant_turns})") + + config = { + "configurable": { + "model": model, + "sampling_params": sampling_params, + "max_user_turns": rollout.multi_turn.max_user_turns, + "max_assistant_turns": rollout.multi_turn.max_assistant_turns, + }, + "recursion_limit": recursion_limit, + } + + # TODO: how to handle multiple trajectories in an graph invocation? + # Each graph node may has its own LLM calls and state, e.g: + # https://github.com/google-gemini/gemini-fullstack-langgraph-quickstart + try: + state = await self.graph.ainvoke(input={"messages": messages}, config=config) + except Exception as e: + logger.error(f"Agent loop execution failed: {type(e).__name__}: {e}") + logger.error("Falling back to a minimal dummy trajectory.") + + # Fallback to a minimal assistant message so that + # convert_to_agent_output and downstream padding logic + # can still run without crashing. + dummy_id = 0 + fallback_message = AIMessage( + content="[Agent execution failed - no valid trajectory]", + response_metadata={ + "request_id": "fallback", + "prompt_ids": [dummy_id, dummy_id], + "response_mask": [1], + }, + ) + state = {"messages": [fallback_message]} + + output = convert_to_agent_output(state["messages"], rollout.response_length) + return output diff --git a/ICL/DAPO/verl-recipe/langgraph_agent/test_react_agent_loop.py b/ICL/DAPO/verl-recipe/langgraph_agent/test_react_agent_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..43b288dcc66fe8a797c8dceeb2f5e08efc1b2d94 --- /dev/null +++ b/ICL/DAPO/verl-recipe/langgraph_agent/test_react_agent_loop.py @@ -0,0 +1,202 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os + +import numpy as np +import pytest +import ray +from langchain_core.tools import tool +from omegaconf import DictConfig +from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop +from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager + +from verl.protocol import DataProto +from verl.utils import hf_tokenizer + + +@pytest.fixture +def init_config() -> DictConfig: + from hydra import compose, initialize_config_dir + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose(config_name="ppo_trainer") + model_path = "Qwen/Qwen2.5-1.5B-Instruct" + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.prompt_length = 4096 + config.actor_rollout_ref.rollout.response_length = 4096 + config.actor_rollout_ref.rollout.n = 4 + config.actor_rollout_ref.rollout.agent.num_workers = 2 + + config.actor_rollout_ref.actor.use_dynamic_bsz = True + # test sleep/wake_up with fsdp offload + config.actor_rollout_ref.actor.fsdp_config.param_offload = True + config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True + + return config + + +@tool(parse_docstring=True) +def get_current_temperature(location: str, unit: str = "celsius"): + """Get current temperature at a location. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, and the unit in a dict + """ + print(f"[DEBUG] get_current_temperature: {location}, {unit}") + return { + "temperature": 26.1, + "location": location, + "unit": unit, + } + + +@tool(parse_docstring=True) +def get_temperature_date(location: str, date: str, unit: str = "celsius"): + """Get temperature at a location and date. + + Args: + location: The location to get the temperature for, in the format "City, State, Country". + date: The date to get the temperature for, in the format "Year-Month-Day". + unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) + + Returns: + the temperature, the location, the date and the unit in a dict + """ + print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}") + return { + "temperature": 25.9, + "location": location, + "date": date, + "unit": unit, + } + + +class TestReactAgentLoop(ReactAgentLoop): + @classmethod + def init_class(cls, config, tokenizer, **kwargs): + # TODO: find better way to configure tools + cls.tools = [get_current_temperature, get_temperature_date] + super().init_class(config, tokenizer, **kwargs) + + +def test_react_agent(init_config): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + # =========================== 1. Init rollout manager =========================== + agent_loop_config = [ + { + "_target_": "recipe.langgraph_agent.test_react_agent_loop.TestReactAgentLoop", + "name": "react_agent", + }, + ] + agent_loop_config_path = "/tmp/agent_loop_config.json" + with open(agent_loop_config_path, "w") as f: + json.dump(agent_loop_config, f) + + n = 2 + init_config.actor_rollout_ref.rollout.n = n + # init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path + init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2 + init_config.actor_rollout_ref.rollout.agent.agent_loop_config_path = agent_loop_config_path + agent_loop_manager = init_agent_loop_manager(init_config) + + # =========================== 2. Generate sequences =========================== + raw_prompts = [ + [ + {"role": "user", "content": "How are you?"}, + ], + [ + {"role": "user", "content": "What's the temperature in Los Angeles now?"}, + ], + [ + {"role": "user", "content": "What's the temperature in New York now?"}, + ], + [ + { + "role": "system", + "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n" + "Current Date: 2024-09-30", + }, + {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"}, + ], + ] + batch = DataProto( + non_tensor_batch={ + "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), + "agent_name": np.array(["react_agent"] * len(raw_prompts)), + "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)), + "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)), + }, + ) + batch = batch.repeat(n) + result = agent_loop_manager.generate_sequences(prompts=batch) + assert len(result) == len(raw_prompts) * n + + # Check turns + num_turns = result.non_tensor_batch["__num_turns__"] + print(f"num_turns: {num_turns}") + for i in range(len(num_turns)): + if i // n == 0: + # [user, assistant] + assert num_turns[i] == 2 + else: + # [user, assistant, tool, assistant] + assert num_turns[i] == 4 + + # Check response_mask + tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) + responses = result.batch["responses"] + response_mask = result.batch["response_mask"] + attention_mask = result.batch["attention_mask"] + assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" + response_length = response_mask.size(1) + + for i in range(len(responses)): + # response with tool response + valid_tokens = responses[i][attention_mask[i][-response_length:].bool()] + response_with_obs = tokenizer.decode(valid_tokens) + + # response without tool response + valid_tokens = responses[i][response_mask[i].bool()] + response_without_obs = tokenizer.decode(valid_tokens) + + assert "" not in response_without_obs, ( + f"found in response: {response_without_obs}" + ) + assert "" not in response_without_obs, ( + f"found in response: {response_without_obs}" + ) + print("=========================") + print(response_with_obs) + print("---") + print(response_without_obs) + + print("Test passed!") + ray.shutdown() diff --git a/ICL/DAPO/verl-recipe/minicpmo/rl_dataset.py b/ICL/DAPO/verl-recipe/minicpmo/rl_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5f6eb1b3f0cb893ca173902f76fb8aa51c1b7714 --- /dev/null +++ b/ICL/DAPO/verl-recipe/minicpmo/rl_dataset.py @@ -0,0 +1,571 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import os +import re +from typing import Optional + +import datasets +import torch +from omegaconf import DictConfig, ListConfig +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from transformers import PreTrainedTokenizer, ProcessorMixin + +import verl.utils.torch_functional as verl_F +from verl.utils.dataset.vision_utils import process_image +from verl.utils.model import compute_position_id_with_mask + +logger = logging.getLogger(__name__) + + +def build_transform(): + IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN + IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD + return transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + ] + ) + + +def build_image_bound(input_ids, tokenizer, new_schema=True, logger=None): + if new_schema: + start_cond = (input_ids == tokenizer.im_start_id) | (input_ids == tokenizer.slice_start_id) + end_cond = (input_ids == tokenizer.im_end_id) | (input_ids == tokenizer.slice_end_id) + else: + start_cond = input_ids == tokenizer.im_start_id + end_cond = input_ids == tokenizer.im_end_id + image_start_tokens = torch.where(start_cond)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(end_cond)[0] + if len(image_start_tokens) != len(image_end_tokens): + logger.error("image start token != image end tokens") + raise Exception("image start token != image end tokens") + if len(image_start_tokens) > 0: + image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]) + else: + image_bound = [] + return image_bound + + +def preprocess( + images_dict, + conversations, + tokenizer, + transform, + query_nums=64, + slice_config=None, + llm_type=None, + patch_size=14, + batch_vision=False, + max_length=2048, + truncation="error", + apply_chat_template_kwargs=None, + logger=None, +): + """ + single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation + """ + conversations = copy.deepcopy(conversations) + assert conversations[0]["role"] == "user", "the first role must be user" + + if slice_config is not None: + assert isinstance(slice_config, dict) + assert "patch_size" in slice_config + assert "max_slice_nums" in slice_config + assert "scale_resolution" in slice_config + default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end + new_schema = False + use_image_id = False + if llm_type == "qwen": + new_schema = True + use_image_id = True + image_placeholder_dict = {} + images = [] + image_id_cnt = 0 + for img_name, image in images_dict.items(): + if slice_config: + source_image, patches, best_grid = slice_image( + image, + slice_config["max_slice_nums"], + slice_config["scale_resolution"], + slice_config["patch_size"], + ) + images.append(source_image) + image_placeholder = default_image_placeholder + if len(patches) > 0: + for i in range(len(patches)): + for j in range(len(patches[0])): + images.append(patches[i][j]) + if use_image_id: + image_placeholder = ( + f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder + ) + image_id_cnt += 1 + image_placeholder += get_grid_placeholder(tokenizer, best_grid, query_nums, new_schema=new_schema) + image_placeholder_dict[img_name] = image_placeholder + else: + images.append(image) + if use_image_id: + image_placeholder = f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder + image_id_cnt += 1 + else: + image_placeholder = default_image_placeholder + image_placeholder_dict[img_name] = image_placeholder + + images = [transform(i) for i in images] + + if len(images_dict) == 1 and "" in images_dict: + if "" in conversations[0]["content"]: + conversations[0]["content"] = conversations[0]["content"].replace("", image_placeholder) + else: + conversations[0]["content"] = image_placeholder + "\n" + conversations[0]["content"] + else: + pattern = r"" + new_conversations = [] + for conversation in conversations: + content = conversation["content"] + parts = re.split(f"({pattern})", content) + for i, part in enumerate(parts): + if not part.strip(): + continue + if re.match(pattern, part): + if part in image_placeholder_dict: + parts[i] = image_placeholder_dict[part] + else: + raise Exception(f"not found {part} in image dict") + conversation["content"] = "\n".join(parts) + new_conversations.append(conversation) + conversations = new_conversations + + # TODO change role in conversation for different llm + prompt_with_chat_template = tokenizer.apply_chat_template( + conversations, add_generation_prompt=True, tokenize=False, **(apply_chat_template_kwargs or {}) + ) + + input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( + prompt=prompt_with_chat_template, + tokenizer=tokenizer, + max_length=max_length, + pad_token_id=tokenizer.pad_token_id, + left_pad=True, + truncation=truncation, + ) + position_ids = compute_position_id_with_mask(attention_mask) + image_bound = build_image_bound(input_ids[0], tokenizer, new_schema, logger) + + input_dict = { + "input_ids": input_ids[0], + "attention_mask": attention_mask[0], + "position_ids": position_ids[0], + "image_bound": image_bound, + } + + if batch_vision: + tgt_sizes = [] + reshape_images = [] + for image in images: + H, W = image.shape[1:] + reshape_image = reshape_by_patch(image, patch_size) + reshape_images.append(reshape_image) + tgt_sizes.append([H // patch_size, W // patch_size]) + if tgt_sizes: + tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32) + + input_dict["pixel_values"] = reshape_images + input_dict["tgt_sizes"] = tgt_sizes + + else: + input_dict["pixel_values"] = images + input_dict["tgt_sizes"] = [] + + return input_dict + + +def slice_image(image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False): + original_size = image.size + original_width, original_height = original_size + log_ratio = math.log(original_width / original_height) + ratio = original_width * original_height / (scale_resolution * scale_resolution) + multiple = min(math.ceil(ratio), max_slice_nums) + + source_image = None + best_grid = None + patches = [] + + if multiple <= 1 or never_split: + # dont need to slice, upsample + best_size = find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True) + source_image = image.resize(best_size, Image.Resampling.BICUBIC) + else: + candidate_split_grids_nums = [] + for i in [multiple - 1, multiple, multiple + 1]: + if i == 1 or i > max_slice_nums: + continue + candidate_split_grids_nums.append(i) + + # source image, down-sampling and ensure divided by patch_size + best_resize = find_best_resize(original_size, scale_resolution, patch_size) + source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) + candidate_grids = [] + + # find best grid + for split_grids_nums in candidate_split_grids_nums: + m = 1 + while m <= split_grids_nums: + if split_grids_nums % m == 0: + candidate_grids.append([m, split_grids_nums // m]) + m += 1 + + best_grid = [1, 1] + min_error = float("inf") + for grid in candidate_grids: + error = abs(log_ratio - math.log(grid[0] / grid[1])) + if error < min_error: + best_grid = grid + min_error = error + + refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, allow_upscale=True) + + refine_image = image.resize(refine_size, Image.Resampling.BICUBIC) + patches = split_to_patches(refine_image, best_grid) + + return source_image, patches, best_grid + + +def ensure_divide(length, patch_size): + return max(round(length / patch_size) * patch_size, patch_size) + + +def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False): + width, height = original_size + if (width * height > scale_resolution * scale_resolution) or allow_upscale: + r = width / height + height = int(scale_resolution / math.sqrt(r)) + width = int(height * r) + best_width = ensure_divide(width, patch_size) + best_height = ensure_divide(height, patch_size) + return (best_width, best_height) + + +def get_refine_size(original_size, grid, scale_resolution, patch_size, allow_upscale=False): + width, height = original_size + grid_x, grid_y = grid + + refine_width = ensure_divide(width, grid_x) + refine_height = ensure_divide(height, grid_y) + + grid_width = refine_width / grid_x + grid_height = refine_height / grid_y + + best_grid_size = find_best_resize( + (grid_width, grid_height), + scale_resolution, + patch_size, + allow_upscale=allow_upscale, + ) + + refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) + + return refine_size + + +def split_to_patches(image, grid): + patches = [] + width, height = image.size + grid_x = int(width / grid[0]) + grid_y = int(height / grid[1]) + + for i in range(0, height, grid_y): + images = [] + for j in range(0, width, grid_x): + box = (j, i, j + grid_x, i + grid_y) + patch = image.crop(box) + images.append(patch) + patches.append(images) + + return patches + + +def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False): + if new_schema: + image_placeholder = tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end + else: + image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end + + cols = grid[0] + rows = grid[1] + slices = [] + for i in range(rows): + lines = [] + for j in range(cols): + lines.append(image_placeholder) + slices.append("".join(lines)) + if new_schema: + slice_placeholder = "\n".join(slices) + else: + slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end + return slice_placeholder + + +def reshape_by_patch(image_tensor, patch_size): + """ + :param image_tensor: shape [3, H, W] + :param patch_size: + :return: [3, patch_size, HW/patch_size] + """ + patches = torch.nn.functional.unfold(image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size)) + + patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1) + patches = patches.permute(0, 1, 3, 2).reshape(image_tensor.size(0), patch_size, -1) + return patches + + +def init_minicpmo_config(processor, config): + """Initialize MiniCPM-o specific configuration""" + minicpmo_config = { + "transform": build_transform(), + "patch_size": config.get("patch_size", 14), + "query_nums": config.get("query_nums", 64), + "slice_config": config.get( + "slice_config", {"max_slice_nums": 9, "patch_size": config.get("patch_size", 14), "scale_resolution": 448} + ), + "llm_type": config.get("llm_type", "qwen"), + "batch_vision": config.get("batch_vision", True), + } + return minicpmo_config + + +def process_minicpmo_data( + row_dict, + messages, + tokenizer, + minicpmo_config, + image_key, + max_prompt_length, + truncation, + apply_chat_template_kwargs, + logger, +): + """Process data for MiniCPM-o model""" + if len(row_dict[image_key]) == 1: + multi_modal_data = {} + image = process_image(row_dict.pop(image_key)[0]) + multi_modal_data["image"] = [image] + images_dict = {"": image} + else: + raise NotImplementedError + + model_inputs = preprocess( + images_dict, + messages, + tokenizer, + minicpmo_config["transform"], + query_nums=minicpmo_config["query_nums"], + slice_config=minicpmo_config["slice_config"], + llm_type=minicpmo_config["llm_type"], + patch_size=minicpmo_config["patch_size"], + batch_vision=minicpmo_config["batch_vision"], + max_length=max_prompt_length, + truncation=truncation, + apply_chat_template_kwargs=apply_chat_template_kwargs, + logger=logger, + ) + + raw_prompt = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False, **(apply_chat_template_kwargs or {}) + ) + raw_prompt = raw_prompt.replace("", "(./)") + + return model_inputs, multi_modal_data, raw_prompt + + +class RLHFDataset(Dataset): + """ + Load and preprocess RLHF data from Parquet files. + + - Caches files locally. + - Reads into a HuggingFace Dataset and tokenizes prompts. + - Optionally handles images/videos via a ProcessorMixin. + - Filters prompts over a max length. + - Supports resuming from checkpoints. + + Args: + data_files (str or list): Path(s) to Parquet file(s). + tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs. + config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc. + processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos. + """ + + def __init__( + self, + data_files: str | list[str], + tokenizer: PreTrainedTokenizer, + config: DictConfig, + processor: Optional[ProcessorMixin] = None, + ): + if not isinstance(data_files, list | ListConfig): + data_files = [data_files] + + self.data_files = copy.deepcopy(data_files) + self.original_data_files = copy.deepcopy(data_files) # use for resume + self.tokenizer = tokenizer + self.processor = processor + self.config = config + + self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) + self.prompt_key = config.get("prompt_key", "prompt") + self.image_key = config.get("image_key", "images") + self.video_key = config.get("video_key", "videos") + self.max_prompt_length = config.get("max_prompt_length", 1024) + self.return_raw_chat = config.get("return_raw_chat", False) + self.return_full_prompt = config.get("return_full_prompt", False) + self.truncation = config.get("truncation", "error") + self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) + self.apply_chat_template_kwargs = config.get("apply_chat_template_kwargs", {}) + + self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) + self.num_workers = min(self.num_workers, os.cpu_count()) + self.use_shm = config.get("use_shm", False) + self.chat_template_func = config.get("chat_template_func", None) + self.need_tools_kwargs = config.get("need_tools_kwargs", False) + self.filter_prompts = config.get("filter_prompts", True) + self.serialize_dataset = False + self.minicpmo_config = init_minicpmo_config(self.processor, config) + self._download() + self._read_files_and_tokenize() + + def _download(self, use_origin_parquet=False): + from verl.utils.fs import copy_to_local + + data_files = self.data_files if not use_origin_parquet else self.original_data_files + for i, parquet_file in enumerate(data_files): + self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm) + + def _read_files_and_tokenize(self): + dataframes = [] + for parquet_file in self.data_files: + # read parquet files and cache + dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] + dataframes.append(dataframe) + self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) + + print(f"dataset len: {len(self.dataframe)}") + + def resume_dataset_state(self): + self.serialize_dataset = not hasattr(self, "original_data_files") + # resume dataframe if not it's serialized in data.pt + if not self.serialize_dataset: + self._download(use_origin_parquet=True) # download and resume from original parquet files + self._read_files_and_tokenize() + else: + print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance") + + def __len__(self): + return len(self.dataframe) + + def _build_messages(self, example: dict): + return example.pop(self.prompt_key) + + def __getitem__(self, item): + """ + Note that we also return the raw_input_ids so that it can be combined with other chat template + """ + row_dict: dict = self.dataframe[item] + messages = self._build_messages(row_dict) + model_inputs = {} + + if self.processor is not None: + model_inputs, multi_modal_data, raw_prompt = process_minicpmo_data( + row_dict, + messages, + self.tokenizer, + self.minicpmo_config, + self.image_key, + self.max_prompt_length, + self.truncation, + self.apply_chat_template_kwargs, + logger, + ) + input_ids = model_inputs.pop("input_ids") + attention_mask = model_inputs.pop("attention_mask") + position_ids = model_inputs.pop("position_ids") + + # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature + row_dict["multi_modal_data"] = multi_modal_data + row_dict["multi_modal_inputs"] = dict(model_inputs) + else: + raw_prompt = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs + ) + model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False) + input_ids = model_inputs.pop("input_ids") + attention_mask = model_inputs.pop("attention_mask") + position_ids = compute_position_id_with_mask(attention_mask) + + row_dict["input_ids"] = input_ids + row_dict["attention_mask"] = attention_mask + row_dict["position_ids"] = position_ids + + raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False) + if len(raw_prompt_ids) > self.max_prompt_length: + if self.truncation == "left": + raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :] + elif self.truncation == "right": + raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length] + elif self.truncation == "middle": + left_half = self.max_prompt_length // 2 + right_half = self.max_prompt_length - left_half + raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:] + elif self.truncation == "error": + raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.") + + row_dict["raw_prompt_ids"] = raw_prompt_ids + # encode prompts without chat template + if self.return_raw_chat: + row_dict["raw_prompt"] = messages + + # get prompts with chat template + if self.return_full_prompt: + row_dict["full_prompts"] = raw_prompt # array of strings + + # add index for each prompt + index = row_dict.get("extra_info", {}).get("index", 0) + tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {}) + interaction_kwargs = row_dict.get("extra_info", {}).get("interaction_kwargs", {}) + need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs) + if need_tools_kwargs and not tools_kwargs: + logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"]) + row_dict["index"] = index + row_dict["tools_kwargs"] = tools_kwargs + row_dict["interaction_kwargs"] = interaction_kwargs + return row_dict + + def __getstate__(self): + if not self.serialize_dataset: + state = self.__dict__.copy() + + if "dataframe" in state: + del state["dataframe"] + return state + + return self.__dict__.copy() diff --git a/ICL/DAPO/verl-recipe/prime/__init__.py b/ICL/DAPO/verl-recipe/prime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b76ea65c919de7f0b6544338c10026251d17100 --- /dev/null +++ b/ICL/DAPO/verl-recipe/prime/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ICL/DAPO/verl-recipe/prime/prime_core_algos.py b/ICL/DAPO/verl-recipe/prime/prime_core_algos.py new file mode 100644 index 0000000000000000000000000000000000000000..825671216ee12874d5eedf5900ae90de3298d968 --- /dev/null +++ b/ICL/DAPO/verl-recipe/prime/prime_core_algos.py @@ -0,0 +1,147 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +import verl +import verl.utils.torch_functional as verl_F + + +def compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Tensor, n_samples, config): + # calculate rloo reward on different reward sources, and sum again + def masked_rloo(reward_tensor_original, mask_tensor): + reward_tensor = reward_tensor_original.clone() + reward_tensor[~mask_tensor] = 0 + for start_pos in range(0, reward_tensor.shape[0], n_samples): + cur_rewards_mean = torch.cat( + [ + reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean(dim=0, keepdim=True) + for pos in range(start_pos, start_pos + n_samples) + ], + dim=0, + ) + cur_rewards_sum = cur_rewards_mean.sum() + cur_reward_baseline = cur_rewards_sum / (n_samples - 1) + reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] = ( + reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] + * (n_samples / (n_samples - 1)) + - cur_reward_baseline + ) + + return reward_tensor + + reward_tensors = [] + + with torch.no_grad(): + if "rm_scores" in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.0: + reward_tensor = data.batch["rm_scores"] + reward_mask = response_mask.bool() + + reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef) + + if "acc" in data.batch.keys() and config.algorithm.reward_gt_coef != 0.0: + reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32) + reward_mask = torch.zeros_like(response_mask, dtype=torch.bool) + + prompt_ids = data.batch["prompts"] + prompt_length = prompt_ids.shape[-1] + valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(-1) + + reward_mask[ + torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device), + valid_response_length - 1, + ] = True + reward_tensor[ + torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device), + valid_response_length - 1, + ] = data.batch["acc"] + + reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef) + + final_reward_tensor = sum(reward_tensors) + + returns = (final_reward_tensor * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + + advantages = returns.clone() + advantages = verl_F.masked_whiten(advantages, response_mask) + + return advantages, returns + + +def compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta): + cur_scores = ((token_level_scores * response_mask).sum(dim=1) * beta).sigmoid() + cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc) + return cur_dpo_loss + + +def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode="none"): + # we always assume that the BoN size equals n_samples + # mode1: use acc as rm + # mode2: use Q as rm + cur_Q = (token_level_scores * response_mask).sum(dim=1) * beta + other_Q = torch.zeros_like(cur_Q) + for i in range(token_level_scores.shape[0]): + Q_chosen = Q_bc[i][acc_bc[i] < acc[i]] if acc[i] > 0 else Q_bc[i][acc_bc[i] > acc[i]] + if len(Q_chosen) > 0: + other_Q[i] = Q_chosen.mean() * beta + else: + other_Q[i] = 0 + dpo_loss = -torch.log(torch.sigmoid((cur_Q - other_Q) * ((acc > 0).float() * 2 - 1))) + if bon_mode == "none": + dpo_loss = dpo_loss.mean() + else: + weight = torch.zeros_like(dpo_loss) + n_samples = acc_bc.shape[1] + if bon_mode == "bon_rm": + for i in range(token_level_scores.shape[0]): + weight[i] = n_samples * torch.pow((Q_bc[i] * beta <= cur_Q[i]).float().mean(), n_samples - 1) + elif bon_mode == "bon_acc": + for i in range(token_level_scores.shape[0]): + weight[i] = n_samples * torch.pow((acc_bc[i] <= acc[i]).float().mean(), n_samples - 1) + else: + raise NotImplementedError + dpo_loss = (dpo_loss * weight).sum() + + return dpo_loss + + +def compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples): + dpo_acc = [] + for start_id in range(0, token_level_scores.shape[0], n_samples): + cur_scores = ( + token_level_scores[start_id : start_id + n_samples] * response_mask[start_id : start_id + n_samples] + ).sum(dim=1) + + def get_upper_triangle(tensor_x): + diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0) + upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1) + return diff_matrix[upper_tri_indices] + + cur_acc_diff = get_upper_triangle(acc[start_id : start_id + n_samples]) # in range [-1,1] + cur_score_diff = get_upper_triangle(cur_scores) # in R + cur_score_prediction = (cur_score_diff > 0).float() # in [0,1] + if cur_acc_diff.abs().sum() == 0: + cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5 + else: + cur_acc = ( + ((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * cur_acc_diff.abs() + ).sum() / cur_acc_diff.abs().sum() + + dpo_acc.append(cur_acc.unsqueeze(0)) + + return torch.cat(dpo_acc, dim=0).mean() + + +def compute_dpo_abs_accuracy(token_level_scores, acc, response_mask, n_samples): + return (torch.sign((token_level_scores * response_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean() diff --git a/ICL/DAPO/verl-recipe/prime/run_prime_qwen_code.sh b/ICL/DAPO/verl-recipe/prime/run_prime_qwen_code.sh new file mode 100644 index 0000000000000000000000000000000000000000..e179c0858ab0f4819a4f9cd7ebf58cf5b7acd194 --- /dev/null +++ b/ICL/DAPO/verl-recipe/prime/run_prime_qwen_code.sh @@ -0,0 +1,61 @@ +set -x + + +# download from https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data +code_train_path=$HOME/data/code/train.parquet +code_test_path=$HOME/data/code/test.parquet + +train_files="['$code_train_path']" +test_files="['$code_test_path']" + +model_path=PRIME-RL/Eurus-2-7B-SFT +# model_path=Qwen/Qwen2.5-0.5B-Instruct + +python3 -m recipe.prime.main_prime \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=64 \ + data.val_batch_size=6312 \ + data.max_prompt_length=1024 \ + data.max_response_length=3072 \ + data.filter_overlong_prompts=True \ + data.filter_accuracy=True \ + data.accuracy_lower_bound=0.2 \ + data.accuracy_upper_bound=0.8 \ + data.oversample_factor=4 \ + actor_rollout_ref.model.path=$model_path \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + algorithm.adv_estimator=rloo \ + algorithm.use_kl_in_reward=True \ + algorithm.kl_penalty=kl \ + algorithm.kl_ctrl.kl_coef=0.001 \ + reward_model.model.path=$model_path \ + reward_model.micro_batch_size_per_gpu=1 \ + reward_model.model.update=before \ + reward_model.model.beta_train=0.05 \ + reward_model.model.optim.lr=1e-6 \ + reward_model.model.optim.grad_clip=10.0 \ + reward_model.model.input_tokenizer=null \ + reward_model.mini_batch_size=64 \ + trainer.val_before_train=False \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='prime_example' \ + trainer.experiment_name='Eurus-2-7B-SFT-code' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=64 \ + trainer.test_freq=64 \ + trainer.total_epochs=15 $@ diff --git a/ICL/DAPO/verl-recipe/r1/run_r1_distill_qwen.sh b/ICL/DAPO/verl-recipe/r1/run_r1_distill_qwen.sh new file mode 100644 index 0000000000000000000000000000000000000000..a1aa9edccc43ddd0b63d382cda59c92f143cd8f7 --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1/run_r1_distill_qwen.sh @@ -0,0 +1,33 @@ +MODEL_PATH=Qwen/DeepSeek-R1-Distill-Qwen-1.5B +DATA_PATH=/workspace/datasets/r1_bench + +# Eval Data Process +python3 -m recipe.r1.data_process \ + --local_dir $DATA_PATH \ + --tasks all + +# Generation +python3 -m verl.trainer.main_generation \ + trainer.nnodes=1 \ + trainer.n_gpus_per_node=8 \ + data.path=$DATA_PATH/test.parquet \ + data.prompt_key=prompt \ + data.batch_size=1024 \ + data.n_samples=8 \ + data.output_path=$DATA_PATH/test-output-8.parquet \ + model.path=$MODEL_PATH \ + rollout.temperature=0.6 \ + rollout.top_p=0.95 \ + rollout.prompt_length=1024 \ + rollout.response_length=32768 \ + rollout.tensor_model_parallel_size=1 \ + rollout.gpu_memory_utilization=0.9 \ + rollout.max_num_batched_tokens=65536 + +# Evaluation +python3 -m recipe.r1.main_eval \ + data.path=$DATA_PATH/test-output-8.parquet \ + data.prompt_key=prompt \ + data.response_key=responses \ + custom_reward_function.path=recipe/r1/reward_score.py \ + custom_reward_function.name=reward_func diff --git a/ICL/DAPO/verl-recipe/r1_ascend/Dockerfile.vllm_ascend.mindspeed.deepseekV3 b/ICL/DAPO/verl-recipe/r1_ascend/Dockerfile.vllm_ascend.mindspeed.deepseekV3 new file mode 100644 index 0000000000000000000000000000000000000000..97f731326349be7f580855b132f33639b9facdeb --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1_ascend/Dockerfile.vllm_ascend.mindspeed.deepseekV3 @@ -0,0 +1,82 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM quay.io/ascend/cann:8.2.rc1-a3-openeuler22.03-py3.11 + +ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" +ARG COMPILE_CUSTOM_KERNELS=1 + +# Define environments +ENV DEBIAN_FRONTED=noninteractive +ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS} + +RUN yum install -y patch + +WORKDIR /workspace + +RUN pip config set global.index-url ${PIP_INDEX_URL} + +# Install torch and torch-npu +RUN python3 -m pip install torch==2.5.1 torch-npu==2.5.1.post1 + +# Compile/Install apex +RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + source /usr/local/Ascend/nnal/asdsip/set_env.sh && \ + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + git clone -b master https://gitcode.com/ascend/apex.git && \ + cd apex/ && bash scripts/build.sh --python=3.11 && \ + cd apex/dist/ && \ + python3 -m pip install --upgrade apex-0.1+ascend-*.whl + +# verl +RUN git clone https://github.com/volcengine/verl.git + +# MindSpeed +RUN git clone https://gitcode.com/Ascend/MindSpeed.git && \ + cd MindSpeed && \ + git checkout f6688 && \ + pip install -r requirements.txt && \ + cp -r mindspeed ../verl + +# Install vLLM +RUN git clone https://github.com/vllm-project/vllm.git && \ + cd vllm && \ + git checkout v0.9.1 && \ + cp -r vllm ../verl +# In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it. +RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -e /workspace/vllm/ --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip uninstall -y triton && \ + python3 -m pip cache purge + +# Install vllm-ascend +RUN git clone https://github.com/vllm-project/vllm-ascend.git && \ + cd vllm-ascend && \ + git checkout 8c7bc45 && \ + cp -r vllm_ascend ../verl + +# Append `libascebd_hal.so` path (devlib) to LD_LIBRARY_PATH +RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh && \ + source /usr/local/Ascend/nnal/atb/set_env.sh && \ + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ + export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/usr/include/c++/12:/usr/include/c++/12/`uname -i`-openEuler-linux && \ + python3 -m pip install -v -e /workspace/vllm-ascend/ --exists-action=i --extra-index https://download.pytorch.org/whl/cpu/ && \ + python3 -m pip cache purge + +# Install modelscope (for fast download) and ray (for multinode) and Megatron-LM and others +RUN python3 -m pip install modelscope ray cache purge "transformers<4.54.0" mathruler cbor2 && \ + pip install pybase64 fastapi zmq uvicorn openai msgspec blake3 py-cpuinfo gguf openai-harmony && \ + pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/r1_ascend/README.md b/ICL/DAPO/verl-recipe/r1_ascend/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f19f80372af8e9018d1c4aee8040bb073be1729d --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1_ascend/README.md @@ -0,0 +1,119 @@ +# DeepSeek-R1-Zero on Ascend NPU +This recipe provides a sample for fine-tuning the Deepseek-V3-Base model using Reinforcement Learning from Human Feedback (RLHF) on Ascend NPUs, specifically utilizing the GRPO algorithm with rule-based rewards on the deepscaler dataset. + +## Implementation Details +To implement RL training for the DeepSeek model on Ascend NPUs, this example includes the following key code additions and modifications: +- We implemented a simple rule-based reward function in `deepscaler.py`, referencing `verl/utils/reward_score/gsm8k.py`. +- We provided a dataset file conversion script `json_to_parquet.py`, which adds a template to the prompts to stimulate model thinking during the data file format conversion. +- Due to potential incomplete memory offloading during sleep operations for vLLM on NPUs, we added patches to manually handle the offloading and onloading of the rollout model and KVcache on NPUs. The related code is in `vllm_rollout_spmd.py` and `megatron_workers.py`. +- To enable vLLM to utilize all ranks for expert parallelism, support for vLLM's data parallelism was necessary. For this purpose, we added patches to construct the correct data parallel communication group. The related code is in `vllm_parallel_state.py` and `vllm_rollout_spmd.py`. Additionally, the `VLLM_DP_SIZE` environment variable must be correctly set to `world_size / vllm_tp_size`. +- The MindSpeed training framework for NPUs invalidates torch.compile to avoid compilation failures during training, but this prevents its use for accelerating inference. To resolve this, we added patches that allow compilation during inference but not during training. The related code is in `megatron_workers.py`. +- During RL training, multiple KV cache scheduling operations in vLLM on NPUs could lead to inconsistent memory allocation causing memory trampling. The fix for this issue is patched in `engine_core.py`. + +By searching globally for `# NPU-ADAPTATION`, you can see the actual changes made by the patch code. + +For more technical details, please refer to [the Technical Report (in Chinese)](https://gitcode.com/cann/cann-recipes-train/blob/master/docs/deepseek/deepseek_rl_train_optimization.md). + +## Training Details +### Hyperparameters +This example fine-tunes the DeepSeek-671B Base model on the deepscaler dataset using a combination of simple format rewards and answer accuracy rewards. The key hyperparameters are as follows: + +| iteration | learning rate | global batchsize | n_samples | temperature | kl-coef | prompt_max_len | response_max_len | rule reward | reward model | +|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:| +| 70 | 1e-6 (constant) | 512 | 16 | 1.0 | 0.001 | 1024 | 2048 | format + acc | - | + +### Resource Allocation and Performance +This recipe was trained on an Ascend Atlas 800T A3 hyper-node server, utilizing 128 A3 NPUs, which is equivalent to 256 accelerator ranks. The specific deployment strategy is as follows: + +| Rollout Deployment | Actor Deployment | Reference Deployment | Offload Strategy | +|:----:|:----:|:----:|:----:| +| TP2 EP256 | EP32 PP8 | Same as Actor | Full offload, optimizer utilizes the [Mindspeed Swap Optimizer feature](https://gitee.com/ascend/MindSpeed/blob/master/docs/features/swap-optimizer.md) | + +The performance metrics for one training step are shown below (throughput varies with the model's response length during training): + +| step | prompt_len_mean | response_len_mean | timing_step (s) | throughput (tps/A3) | timing_gen (s) | timing_reward (s) | timing_old_prob (s) | timing_ref_prob (s) | timing_update (s) | +|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:| +| 2 | 175.1 | 1385.0 | 1044.8 | 95.5 | 482.2 | 20.4 | 105.5 | 92.7 | 342.9 | + +### Training Metrics +
+ + + +
+ +## Quick Start + +### Environment Setup +For setting up the Ascend NPU environment for verl, please refer to [ascend_quick_start.rst (in Chinese)](../../docs/ascend_tutorial/ascend_quick_start.rst). + +Alternatively, you can use the provided Dockerfile to build the project's runtime environment locally: `docker build -f Dockerfile.vllm_ascend.mindspeed.deepseekV3 -t REPOSITORY:TAG ./` + +Prepare the source code with the following steps: +```bash +# Clone verl +git clone https://github.com/volcengine/verl.git + +# Clone and setup vLLM (v0.9.1) +git clone https://github.com/vllm-project/vllm.git +cd vllm +git checkout v0.9.1 +cp -r vllm ../verl +cd .. + +# Clone and setup vLLM-Ascend (commit 8c7bc45) +git clone https://github.com/vllm-project/vllm-ascend.git +cd vllm-ascend +git checkout 8c7bc45 +cp -r vllm_ascend ../verl +cd .. + +# Clone and setup MindSpeed (commit f6688) +git clone https://gitcode.com/Ascend/MindSpeed.git +cd MindSpeed +git checkout f6688 +cp -r mindspeed ../verl +cd .. + +# Install Megatron-LM.core and other dependencies +pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 +pip install mathruler +``` + +### Prepare the Training Dataset +This example uses the deepscaler dataset. Prepare it as follows: +- Download the dataset [JSON file](https://huggingface.co/datasets/agentica-org/DeepScaleR-Preview-Dataset/blob/main/deepscaler.json). +- Generate the `train.parquet` and `test.parquet` files and place them in the `./data/deepscaler` directory: + ```bash + # Execute from the verl project directory + python recipe/r1_ascend/json_to_parquet.py --output_dir ./data/deepscaler --json_path path/to/deepscaler.json --train_data_ratio 0.9 + ``` + + The processed prompts used during training will include a specific template, for example: `A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Put your final answer within \boxed{}. <|User|>{problem}<|Assistant|>` + +### Prepare Model Weights +Prepare the DeepSeek-V3-Base model weights as follows: +- Place the model configuration files (excluding the weights) into the `./DeepSeek-V3-hf` directory. The `config.json` file needs to be replaced to remove quantization and MTP configurations. Refer to [this link (in Chinese)](https://gitcode.com/cann/cann-recipes-train/blob/master/rl_train/deepseek/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87) for details. +- Download the FP8 model weights from [HuggingFace](https://huggingface.co/deepseek-ai/DeepSeek-V3-Base) or [ModelScope](https://www.modelscope.cn/models/deepseek-ai/DeepSeek-V3-Base). Ensure the target disk has over 650GB of free space. +- Convert the FP8 weights to BF16 weights. Refer to [this link (in Chinese)](https://gitcode.com/cann/cann-recipes-train/blob/master/rl_train/deepseek/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87) for instructions. This step requires over 1300GB of free space on the target disk. + +This example uses pre-sharded distributed weights. Therefore, the following weight sharding step is also required: +- The distributed weights will be stored in `ckpts/DeepseekV3-dist-ckpts`. +- Use the script `verl/scripts/converter_hf_to_mcore.py` to shard the original BF16 weights into distributed weights. In practice, we found that 2TB of CPU RAM was insufficient for sharding the 671B model. Therefore, we adapted this script for expert parallelism and performed the weight sharding using a distributed strategy of EP8 PP8 across 64 NPUs. + +### Other Code Modifications +In practice, to achieve the above results for on-policy RL training, we need to replace the `old_log_prob = data["old_log_probs"]` code in `verl/workers/actor/megatron_actor.py` with: + +```python +on_policy = self.config.ppo_epochs == 1 +if on_policy: + old_log_prob = log_prob.detach() # guarantee exact numerical equality +else: + old_log_prob = data["old_log_probs"] +``` + +### Execute RL Fine-tuning +```bash +# Start the RL fine-tuning for DeepSeekV3 from the verl directory +bash ./recipe/r1_ascend/ray_start_grpo_npu.sh +``` diff --git a/ICL/DAPO/verl-recipe/r1_ascend/README_zh.md b/ICL/DAPO/verl-recipe/r1_ascend/README_zh.md new file mode 100644 index 0000000000000000000000000000000000000000..0216eef14005797f60623b628e20238bae654107 --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1_ascend/README_zh.md @@ -0,0 +1,119 @@ +# DeepSeek-R1-Zero on Ascend NPU +本recipe是基于Deepseek-V3-Base模型在NPU上进行RLHF后训练的样例,基于GRPO与规则奖励,使用deepscaler数据集。 + +## 实现细节 +为了在Ascend NPU上实现DeepSeek模型的RL训练,本样例中补充了一些代码,如下所示: +- 我们参考`verl/utils/reward_score/gsm8k.py`,在`deepscaler.py`中实现了一个简单的规则奖励函数。 +- 我们提供了数据集文件转换脚本`json_to_parquet.py`,在数据文件格式转换的同时给prompt增加了激发模型思考的模板。 +- NPU上vLLM的sleep可能存在内存卸载不干净的问题,因此添加了一些patch,手动实现NPU上Rollout模型与KVcache的卸载与加载。相关代码在`vllm_rollout_spmd.py`以及 `megatron_workers.py`中。 +- 为了实现vLLM利用所有卡进行专家并行,需要支持vLLM的数据并行。为此添加了一些patch构建正确的DP通信域。相关代码在`vllm_parallel_state.py`以及`vllm_rollout_spmd.py`中。此外还需要正确配置`VLLM_DP_SIZE`环境变量为`world_size / vllm_tp_size`。 +- NPU的MindSpeed训练框架会将torch.compile无效化来规避训练侧的compile失败,但这会使推理侧无法利用torch.compile加速。为了解决该问题,本样例添加了一些patch,使推理时可以compile,训练时不compile。相关代码`megatron_workers.py`中。 +- RL训练过程中,NPU上vLLM多次KVcache调度可能引发申请内存不一致导致内存踩踏问题,修复patch在`engine_core.py`中。 + +通过全局搜索`# NPU-ADAPTATION`,可以看到patch代码所做的实际改动。 + +更多技术细节可参考[技术报告](https://gitcode.com/cann/cann-recipes-train/blob/master/docs/deepseek/deepseek_rl_train_optimization.md)。 + +## 训练细节 +### 训练超参 + +本样例基于DeepSeek-671B Base模型在deepscaler数据集上训练,使用简单的格式奖励和结果准确率奖励,训练超参如下: + +| 迭代 | 学习率 | gbs | 采样数 | 温度 | kl-coef | 输入长度 | 输出长度 | 规则奖励 | 奖励模型 | +|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:| +| 70 | 1e-6 (constant) | 512 | 16 | 1.0 | 0.001 | 1024 | 2048 | format + acc | - | + +### 训练资源与性能 +本样例在昇腾Atlas 800T A3超节点服务器上进行训练,使用了128张A3 NPU,等效于256张加速卡。具体的部署方式如下: + +| Rollout部署 | Actor部署 | Reference部署 | Offload策略 | +|:----:|:----:|:----:|:----:| +| TP2 EP256 | EP32 PP8 | 同Actor | 全offload,优化器使用[Mindspeed卸载特性](https://gitee.com/ascend/MindSpeed/blob/master/docs/features/swap-optimizer.md) | + +得到一步的训练性能如下(吞吐会随着训练中模型输出长度变化而改变): +| step | 平均问题长度 | 平均回复长度 | 单步总耗时(s) | 吞吐(tps/A3) | gen耗时(s) | reward耗时(s) | old_prob耗时(s) | ref_prob耗时(s) | update耗时(s) | +|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:| +| 2 | 175.1 | 1385.0 | 1044.8 | 95.5 | 482.2 | 20.4 | 105.5 | 92.7 | 342.9 | + +### 训练过程记录 +
+ + + +
+ +## 快速开始 + +### 环境准备 +verl上的NPU环境准备,可参考[ascend_quick_start.rst](../../docs/ascend_tutorial/ascend_quick_start.rst)进行配置。 + +此外,也可使用我们提供的Dockerfile在本地构建项目运行环境:`docker build -f Dockerfile.vllm_ascend.mindspeed.deepseekV3 -t REPOSITORY:TAG ./` + +本样准备源码的步骤如下: +```bash +# verl +git clone https://github.com/volcengine/verl.git + +# vLLM (v0.9.1) +git clone https://github.com/vllm-project/vllm.git +cd vllm +git checkout v0.9.1 +cp -r vllm ../verl +cd .. + +# vLLM-Ascend (v0.9.1-dev) +git clone https://github.com/vllm-project/vllm-ascend.git +cd vllm-ascend +git checkout 8c7bc45 +cp -r vllm_ascend ../verl +cd .. + +# MindSpeed (commit-id: f6688) +git clone https://gitcode.com/Ascend/MindSpeed.git +cd MindSpeed +git checkout f6688 +cp -r mindspeed ../verl +cd .. + +# Megatron-LM.core and others +pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 +pip install mathruler +``` + +### 准备训练数据集 +本样例使用deepscaler数据集。准备方式如下: +- 下载数据集[json文件](https://huggingface.co/datasets/agentica-org/DeepScaleR-Preview-Dataset/blob/main/deepscaler.json)。 +- 获取`train.parquet`与`test.parquet`文件并放入`./data/deepscaler`路径: + + ```bash + # 在verl项目目录执行 + python recipe/r1_ascend/json_to_parquet.py --output_dir ./data/deepscaler --json_path path/to/deepscaler.json --train_data_ratio 0.9 + ``` + + 训练中经过处理的prompt将包含模板,例如:`A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Put your final answer within \boxed{}. <|User|>{problem}<|Assistant|>` + +### 准备模型权重 +DeepSeek-V3-Base模型权重准备步骤如下: +- 需要将模型配置相关文件(不含权重)放入`./DeepSeek-V3-hf`目录,并且`config.json`需要进行替换以去除量化和MTP。该步骤可参考[此链接](https://gitcode.com/cann/cann-recipes-train/blob/master/rl_train/deepseek/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87)。 +- 模型FP8权重下载:[HuggingFace地址](https://huggingface.co/deepseek-ai/DeepSeek-V3-Base),[ModelScope地址](https://www.modelscope.cn/models/deepseek-ai/DeepSeek-V3-Base)。此步骤需要目录所在磁盘有650GB以上空间。 +- 将FP8权重转为BF16权重,可参考[此链接](https://gitcode.com/cann/cann-recipes-train/blob/master/rl_train/deepseek/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87)。此步骤需要目录所在磁盘有1300GB以上空间。 + +本样例使用了预先切分的分布式权重,因此还要执行以下的切分权重操作: +- 分布式权重需存储至`ckpts/DeepseekV3-dist-ckpts`。 +- 使用`verl/scripts/converter_hf_to_mcore.py`对原始的BF16权重切分得到分布式权重。实践中我们发现2T的CPU内存不足以完成671B模型的权重切分处理,为此我们对该脚本进行了专家并行的适配,并在64块NPU上用EP8 PP8分布式策略对权重进行了切分。 + +### 其他代码修改 +实践中为了得到以上on-policy训练的结果,我们将 `verl/workers/actor/megatron_actor.py` 中的代码段 `old_log_prob = data["old_log_probs"]` 替换为如下代码: +```python +on_policy = self.config.ppo_epochs == 1 +if on_policy: + old_log_prob = log_prob.detach() # 确保二者数值完全相等 +else: + old_log_prob = data["old_log_probs"] +``` + +### 执行RL后训练 +```bash +# verl目录下启动DeepSeekV3的RL后训练 +bash ./recipe/r1_ascend/ray_start_grpo_npu.sh +``` \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/r1_ascend/ray_start_grpo_npu.sh b/ICL/DAPO/verl-recipe/r1_ascend/ray_start_grpo_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..1647226a18a2f24066eeacd637b855f9f419fc80 --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1_ascend/ray_start_grpo_npu.sh @@ -0,0 +1,82 @@ +ray stop --force + +export RAY_DEDUP_LOGS=0 # 0: disable ray's log folding 1: enable ray's log folding +export HYDRA_FULL_ERROR=1 # display the accurate error stack + +ulimit -n 32768 +mkdir logs + +NNODES=16 # number of nodes +NPUS_PER_NODE=16 # the number of npus for each node +export WORLD_SIZE=$(($NNODES*$NPUS_PER_NODE)) + +RAY_START_PORT=6766 +RAY_DASHBOARD_PORT=8260 + +MASTER_ADDR="IP FOR MASTER NODE" # modify it to correspond to the IP of the master node +SOCKET_IFNAME="SOCKET IFNAME FOR CURRENT NODE" # modify it to the communication network card of the current node +# obtain the current node IP +CURRENT_IP=$(ifconfig $SOCKET_IFNAME | grep -Eo 'inet (addr:)?([0-9]{1,3}\.){3}[0-9]{1,3}' | awk '{print $NF}') +export MASTER_PORT=29444 +export HCCL_IF_BASE_PORT=64247 +export TP_SOCKET_IFNAME=$SOCKET_IFNAME +export HCCL_SOCKET_IFNAME=$SOCKET_IFNAME +export GLOO_SOCKET_IFNAME=$SOCKET_IFNAME + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export PYTORCH_NPU_ALLOC_CONF="expandable_segments:True" +export TASK_QUEUE_ENABLE=2 # enable level2 optimization of the sent queue of the ascend operator +export HCCL_BUFFSIZE=300 # the buffer size of HCCL + +export HCCL_CONNECT_TIMEOUT=600 +export HCCL_EXEC_TIMEOUT=600 + +export ASCEND_LAUNCH_BLOCKING=0 # debug usage, which seriously affects performance after use, but the error stack is accurate + +export VLLM_USE_V1=1 # use the V1 engine of vLLM +export VLLM_ENABLE_GRAPH_MODE=1 # enable vLLM graph mode +export HCCL_OP_EXPANSION_MODE=AIV # enable the communication mode of AIV +export VLLM_ENABLE_MC2=1 # enable MC2 communication +export VLLM_DP_SIZE=128 # configure the DP size of vLLM, this is related to the vLLM instance num + +# under the configuration of the vLLM log level of INFO, enable this configuration, print the time of prefill and decode +export VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE=0 + +if [ "$MASTER_ADDR" = "$CURRENT_IP" ]; then + # the master node starts + ray start --head --port=$RAY_START_PORT --dashboard-host=0.0.0.0 --node-ip-address=$CURRENT_IP --dashboard-port=$RAY_DASHBOARD_PORT --resources='{"NPU": '$NPUS_PER_NODE'}' + + while true; do + ray_status_output=$(ray status) + npu_count=$(echo "$ray_status_output" | grep -oP '(?<=/)\d+\.\d+(?=\s*NPU)' | head -n 1) + npu_count_int=$(echo "$npu_count" | awk '{print int($1)}') + device_count=$((npu_count_int / $NPUS_PER_NODE)) + + # determine whether device_count is equal to NNODES + if [ "$device_count" -eq "$NNODES" ]; then + echo "Ray cluster is ready with $device_count devices (from $npu_count NPU resources), starting Python script." + ray status + bash ./recipe/r1_ascend/run_deepseekv3_671b_grpo_megatron_npu.sh + break + else + echo "Waiting for Ray to allocate $NNODES devices. Current device count: $device_count" + sleep 5 + fi + done +else + # the child node attempts to register ray with the master node until successful + while true; do + # try to connect to the Ray cluster + ray start --address="$MASTER_ADDR:$RAY_START_PORT" --resources='{"NPU": '$NPUS_PER_NODE'}' --node-ip-address=$CURRENT_IP + + # check if the connection is successful + ray status + if [ $? -eq 0 ]; then + echo "Successfully connected to the Ray cluster!" + break + else + echo "Failed to connect to the Ray cluster. Retrying in 5 seconds..." + sleep 5 + fi + done +fi \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/r1_ascend/vllm_rollout_spmd.py b/ICL/DAPO/verl-recipe/r1_ascend/vllm_rollout_spmd.py new file mode 100644 index 0000000000000000000000000000000000000000..980fe94c7dac309ea9ed29ec3f1c2f7870087092 --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1_ascend/vllm_rollout_spmd.py @@ -0,0 +1,347 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted from +# https://github.com/volcengine/verl/blob/main/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py + +import logging +import os +from typing import Generator + +import torch +import torch.distributed +from omegaconf import ListConfig +from torch.distributed.device_mesh import DeviceMesh +from vllm import LLM, SamplingParams +from vllm.config import CompilationConfig, CompilationLevel + +from verl.third_party.vllm import VLLM_SLEEP_LEVEL +from verl.utils.device import get_device_name +from verl.utils.memory_utils import aggressive_empty_cache +from verl.workers.config import HFModelConfig, RolloutConfig +from verl.workers.rollout.vllm_rollout import vLLMRollout as vLLMRolloutBase + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class vLLMRollout(vLLMRolloutBase): + def __init__( + self, + config: RolloutConfig, + model_config: HFModelConfig, + device_mesh: DeviceMesh, + ): + self.config = config + self.model_config = model_config + self.device_mesh = device_mesh + # NPU-ADAPTATION: import vLLM-Ascend patch + from recipe.r1_ascend import engine_core # noqa: F401 + from vllm_ascend.patch import ( + platform, # noqa: F401 + worker, # noqa: F401 + ) + # NPU-ADAPTATION END + + if config.layered_summon: + self.sleep_level = 1 + else: + self.sleep_level = VLLM_SLEEP_LEVEL + + model_path = model_config.local_path + tokenizer = model_config.tokenizer + model_hf_config = model_config.hf_config + trust_remote_code = model_config.trust_remote_code + self.lora_kwargs = ( + {"enable_lora": True, "max_loras": 1, "max_lora_rank": model_config.lora_rank} + if model_config.lora_rank > 0 + else {} + ) + + tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) + assert tensor_parallel_size <= torch.distributed.get_world_size(), ( + "tensor parallel size should be less than or equal to the world size" + ) + max_num_batched_tokens = self.config.get("max_num_batched_tokens", 8192) + + # NPU-ADAPTATION: VLLM_DP_SIZE is configured, the DP communication domain needs to be explicitly initialized + if int(os.environ.get("VLLM_DP_SIZE", "1")) > 1: + from recipe.r1_ascend.vllm_parallel_state import init_parallel_state + + init_parallel_state(tensor_parallel_size) + # NPU-ADAPTATION END + + rope_scaling_config = getattr(model_hf_config, "rope_scaling", None) + if not rope_scaling_config: + max_position_embeddings = None + if hasattr(model_hf_config, "max_position_embeddings"): + max_position_embeddings = model_hf_config.max_position_embeddings + elif hasattr(model_hf_config, "llm_config") and hasattr( + model_hf_config.llm_config, "max_position_embeddings" + ): + max_position_embeddings = model_hf_config.llm_config.max_position_embeddings + elif hasattr(model_hf_config, "text_config") and hasattr( + model_hf_config.text_config, "max_position_embeddings" + ): + max_position_embeddings = model_hf_config.text_config.max_position_embeddings + if max_position_embeddings is None: + raise ValueError("max_position_embeddings not found in model_hf_config") + assert max_position_embeddings >= config.prompt_length + config.response_length, ( + "model context length should be greater than total sequence length" + ) + else: + # handle type where there's a length extend factor + # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support + # for using yarn as an example + rope_scaling_factor = rope_scaling_config.get("factor", 1.0) + + assert ( + model_hf_config.max_position_embeddings * rope_scaling_factor + >= config.prompt_length + config.response_length + ), ( + "model context length should be greater than total sequence length, " + + f"got rope_scaling_factor={rope_scaling_factor} and " + + f"max_position_embeddings={model_hf_config.max_position_embeddings}" + ) + + max_model_len = int(config.max_model_len or config.prompt_length + config.response_length) + + load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format + + # copy it to avoid secretly modifying the engine config + engine_kwargs = config.get("engine_kwargs", {}).get("vllm", {}) or {} + + # For each vLLM engine parameter, + # - `None` means not setting it, so we pop it, and leave it to vLLM default value + # (which can vary across different vLLM versions); + # - Otherwise it's the desired value we want to explicitly set. + engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None} + if config.get("limit_images", None): # support for multi-image data + engine_kwargs["limit_mm_per_prompt"] = {"image": config.get("limit_images")} + + compilation_config = {} + + cudagraph_capture_sizes = config.get("cudagraph_capture_sizes") + # enforce_eager must be False to use cudagraph + if not config.enforce_eager and cudagraph_capture_sizes: + if isinstance(cudagraph_capture_sizes, ListConfig): + compilation_config["compilation_config"] = CompilationConfig( + level=CompilationLevel.PIECEWISE, cudagraph_capture_sizes=cudagraph_capture_sizes + ) + else: + logger.warning(f"cudagraph_capture_sizes must be a list, but got {cudagraph_capture_sizes}") + + VLLM_ENABLE_GRAPGH_MODE = int(os.environ.get("VLLM_ENABLE_GRAPH_MODE", "0")) + self.inference_engine = LLM( + model=model_path, + # NPU-ADAPTATION: Enable inference EP and disable sleep mode. + enable_sleep_mode=False, + enable_expert_parallel=True, + # NPU-ADAPTATION END + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend="external_launcher", + dtype=config.dtype, + enforce_eager=config.enforce_eager, + gpu_memory_utilization=config.gpu_memory_utilization, + disable_custom_all_reduce=True, + skip_tokenizer_init=False, + max_model_len=max_model_len, + max_num_seqs=config.max_num_seqs, + load_format=load_format, + disable_log_stats=config.disable_log_stats, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=config.enable_chunked_prefill, + enable_prefix_caching=False, + trust_remote_code=trust_remote_code, + seed=config.get("seed", 0), + # NPU-ADAPTATION: Enable graph mode and configure the parameters. + additional_config={ + "torchair_graph_config": { + "enabled": VLLM_ENABLE_GRAPGH_MODE, + "use_cached_graph": False, + "graph_batch_sizes_init": False, + "graph_batch_sizes": [config.max_num_seqs], + "enable_multistream_mla": False, + "enable_multistream_moe": False, + "enable_view_optimize": False, + "enable_kv_nz": False, + "enable_frozen_parameter": False, + }, + "ascend_scheduler_config": { + "enabled": True, + }, + "refresh": True, + }, + # NPU-ADAPTATION END + **compilation_config, + **self.lora_kwargs, + **engine_kwargs, + ) + # NPU-ADAPTATION: Weight onload and offload, and initialization configurations such as kv_cache. + self.model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.get_model() + self.kv_cache_configs = None + self.cpu_model = {} + self.gpu_buffers = None + for name, params in self.model.named_parameters(): + self.cpu_model[name] = torch.empty_like(params, device="cpu") + # NPU-ADAPTATION END + + kwargs = dict( + n=1, + logprobs=0, # can be set to 0 and let actor to recompute + max_tokens=config.response_length, + repetition_penalty=config.get("repetition_penalty", 1.0), + ) + + kwargs["detokenize"] = False + + # supporting adding any sampling params from the config file + for k in config.keys(): + if hasattr(SamplingParams(), str(k)) and k != "seed": + kwargs[k] = config.get(k) + kwargs["n"] = 1 # already repeat in ray_trainer + logger.info(f"vllm sampling kwargs: {kwargs}") + self.sampling_params = SamplingParams(**kwargs) + + self.pad_token_id = tokenizer.pad_token_id + + # NPU-ADAPTATION: Weight onload and offload, kv_cache init and free function + # NOTE: Due to potential incomplete memory offloading during sleep operations for vLLM on NPUs, we add + # patches to manually handle the off/on loading of the rollout model and KVcache on NPUs. + def init_cache_engine(self): + if os.environ["VLLM_USE_V1"] == "1": + worker = self.inference_engine.llm_engine.model_executor.driver_worker.worker + if not worker.model_runner.kv_caches: + # v1 use explicit initialization method + self.inference_engine.llm_engine.engine_core.engine_core.model_executor.initialize_from_config( + self.inference_engine.llm_engine.engine_core.engine_core.kv_cache_configs + ) + self.inference_engine.llm_engine.reset_prefix_cache() + else: + if self.inference_engine.llm_engine.model_executor.driver_worker.worker.cache_engine is None: + self.inference_engine.llm_engine.model_executor.driver_worker.worker._init_cache_engine() + + def onload_model_weights(self): + self.gpu_buffers = {} + for name, param in self.model.named_parameters(): + self.gpu_buffers[name] = torch.empty_like(param, device=get_device_name()) + for name, param in self.model.named_parameters(): + param.data = self.gpu_buffers[name] + + def offload_model_weights(self): + for name, params in self.model.named_parameters(): + params.data = self.cpu_model[name] + if hasattr(self.model.model.layers[0].self_attn, "mla_attn"): + for i in range(self.model.model.start_layer, self.model.model.end_layer): + mla = self.model.model.layers[i].self_attn.mla_attn.impl + if hasattr(mla, "w_kc"): + mla.w_kc = None + mla.w_vc = None + if hasattr(mla, "W_UV"): + mla.W_UV = None + mla.W_UK_T = None + + self.gpu_buffers = None + aggressive_empty_cache() + + def free_cache_engine(self): + if os.environ["VLLM_USE_V1"] == "1": + worker = self.inference_engine.llm_engine.model_executor.driver_worker.worker + ctx = worker.model_runner.vllm_config.compilation_config.static_forward_context + else: + compilation_config = self.inference_engine.llm_engine.model_executor.driver_worker.worker.compilation_config + ctx = compilation_config.static_forward_context + from vllm.attention import AttentionType + + layer_need_kv_cache = [] + for layer_name in ctx: + if hasattr(ctx[layer_name], "attn_type") and ctx[layer_name].attn_type in ( + AttentionType.DECODER, + AttentionType.ENCODER_DECODER, + ): + layer_need_kv_cache.append(layer_name) + + pipeline_parallel_size = self.inference_engine.llm_engine.vllm_config.parallel_config.pipeline_parallel_size + for layer_name in layer_need_kv_cache: + kv_cache = [] + for _ in range(pipeline_parallel_size): + kv_cache.append(torch.tensor([])) + ctx[layer_name].kv_cache = kv_cache + if os.environ["VLLM_USE_V1"] == "1": + worker = self.inference_engine.llm_engine.model_executor.driver_worker.worker + + worker.model_runner.kv_caches = [] + else: + self.inference_engine.llm_engine.model_executor.driver_worker.worker.cache_engine = None + self.inference_engine.llm_engine.model_executor.driver_worker.worker.gpu_cache = None + + if hasattr(self.model.model.layers[0].self_attn, "attn"): + for i in range(self.model.model.start_layer, self.model.model.end_layer): + attn_impl = self.model.model.layers[i].self_attn.attn.impl + if hasattr(attn_impl, "key_cache"): + attn_impl.key_cache = None + attn_impl.value_cache = None + + aggressive_empty_cache() + + def _process_mla(self, load_weight=False): + for i in range(self.model.model.start_layer, self.model.model.end_layer): + mla = self.model.model.layers[i].self_attn.mla_attn.impl + if hasattr(mla, "w_kc"): + mla.w_kc = None + mla.w_vc = None + if hasattr(mla, "W_UV"): + mla.W_UV = None + mla.W_UK_T = None + if load_weight: + mla.process_weights_after_loading(None) + + async def resume(self, tags: list[str]): + """Resume rollout weights or kv cache in NPU memory. + + Args: + tags: weights or kv_cache. + """ + if not self.config.free_cache_engine: + return + + if "weights" in tags: + self.onload_model_weights() + elif "kv_cache" in tags: + self.init_cache_engine() + + async def release(self): + """Release weights and kv cache in NPU memory.""" + if not self.config.free_cache_engine: + return + + self.free_cache_engine() + self.offload_model_weights() + + if hasattr(self.model.model.layers[0].self_attn, "mla_attn"): + self._process_mla() + + async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs): + """Update the weights of the rollout model. + + Args: + weights: A generator that yields the name of the weight tensor and the tensor itself. + """ + await super().update_weights(weights, **kwargs) + + if hasattr(self.model.model.layers[0].self_attn, "mla_attn"): + self._process_mla(load_weight=True) + + # NPU-ADAPTATION END diff --git a/ICL/DAPO/verl-recipe/rep_exp/README.md b/ICL/DAPO/verl-recipe/rep_exp/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ea8fefcf04c21a8e5cfad6ba9a64bfbfefb24631 --- /dev/null +++ b/ICL/DAPO/verl-recipe/rep_exp/README.md @@ -0,0 +1,71 @@ +
+ +# Representation-Based Exploration for Language Models:
From Test-Time to Post-Training + +[📄 arXiv](https://arxiv.org/abs/2510.11686)     [🌐 Website](https://rep-exp.github.io)     [🐦 Twitter / X ](https://x.com/JensTuyls/status/1978244454617128993) + +
+ +## Installation 🔌 + +Install the following commit of verl: +``` +pip install verl@git+https://github.com/volcengine/verl.git@b9bd00efba253ea90072555c45692054cf703de2 +``` + +The only other package to install is scikit-learn, which we'll use for applying a sparse projection. +```bash +pip install scikit-learn +``` + +## Running the Experiments 🚀 + +You can reproduce or extend our experiments by running the following commands: + +```bash +# General format +sh rep_exp/train_elliptical.sh $TASK $SPARSE_DIM $BETA $SEED + +# MATH +sh rep_exp/train_elliptical.sh math 32 0.01 42 + +# GSM8K +sh rep_exp/train_elliptical.sh gsm8k 32 0.01 42 + +# DAPO-WITH-AIME +sh rep_exp/train_elliptical.sh dapo-with-aime24 128 0.01 42 +``` +where `$TASK` is the task name, `$SPARSE_DIM` is the sparse dimension, `$BETA` is the beta parameter, and `$SEED` is the seed. + +## Evaluation 📊 +Once done training, you can evaluate the model on the test set by following two steps. +1. Merge the model checkpoint. + +This is necessary because the model checkpoint is saved in multiple shards (depending on the nubmer of GPUs), and we need to merge them into a single checkpoint. + +```bash +sh rep_exp/model_merge.sh /path/to/global_step_X/actor # where X is the global step of the checkpoint with the best pass@1 on dev +``` + +2. Evaluate the merged model. + +```bash +sh rep_exp/eval.sh $TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev +``` + +The results should be in a folder named `eval` and saved as a JSON file. + +## Citation 📝 + +```bibtex +@article{tuyls2025representation, + title={Representation-Based Exploration for Language Models: From Test-Time to Post-Training}, + author={Tuyls, Jens and Foster, Dylan J and Krishnamurthy, Akshay and Ash, Jordan T}, + journal={arXiv preprint arXiv:2510.11686}, + year={2025} +} +``` + +## Contact 📬 + +If you have any questions or suggestions, feel free to reach out at [jtuyls@princeton.edu](mailto:jtuyls@princeton.edu). \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/rep_exp/eval.sh b/ICL/DAPO/verl-recipe/rep_exp/eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..282aa7e1cea430d2f8d191350b8664d3aabaf68e --- /dev/null +++ b/ICL/DAPO/verl-recipe/rep_exp/eval.sh @@ -0,0 +1,83 @@ +TASK=${1} # math, gsm8k, dapo-with-aime24 + +# Custom model path for evaluation after training +MODEL_PATH=${2} # /path/to/global_step_X/actor/hf, where X is the global step of the checkpoint with the best pass@1 on dev + +# If you want to evaluate the base model before training +# MODEL_PATH=Qwen/Qwen2.5-7B-Instruct + +train_path=$HOME/data/${TASK}/train.parquet +train_files="['$train_path']" +CHECKPOINT_SAVE_CONTENTS='["model"]' + +if [ ${TASK} == "dapo-with-aime24" ]; then + MAX_PROMPT_LENGTH=$((1024 * 2)) + MAX_RESPONSE_LENGTH=$((1024 * 8)) + MAX_NUM_BATCHED_TOKENS=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH)) + test_path=$HOME/data/${TASK}/dev.parquet +else + MAX_PROMPT_LENGTH=1024 + MAX_RESPONSE_LENGTH=1024 + MAX_NUM_BATCHED_TOKENS=8192 + test_path=$HOME/data/${TASK}/test.parquet +fi + +test_files="['$test_path']" + +# If you're on a cluster with no internet access, set to OFFLINE=True +OFFLINE=False + +PYTHONUNBUFFERED=1 WANDB_MODE=disabled TRANSFORMERS_OFFLINE=${OFFLINE} python3 -u -m rep_exp.main_rep_exp \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=$MAX_PROMPT_LENGTH \ + data.max_response_length=$MAX_RESPONSE_LENGTH \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.val_batch_size=128 \ + actor_rollout_ref.model.path="$MODEL_PATH" \ + actor_rollout_ref.actor.checkpoint.save_contents=$CHECKPOINT_SAVE_CONTENTS \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ppo_epochs=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.mode=sync \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_BATCHED_TOKENS \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.45 \ + actor_rollout_ref.rollout.val_kwargs.n=256 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + reward_model.model.path="$MODEL_PATH" \ + reward_model.model.use_remove_padding=False \ + reward_model.model.fsdp_config.param_offload=True \ + reward_model.micro_batch_size_per_gpu=32 \ + reward_model.model.input_tokenizer=null \ + actor_rollout_ref.actor.use_kl_loss=False \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","json_eval"]' \ + trainer.project_name='rep-exp' \ + trainer.experiment_name="${TASK}_eval" \ + trainer.n_gpus_per_node=1 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=1 \ + trainer.total_epochs=100 \ + trainer.val_only=True \ + trainer.resume_mode=disable \ + trainer.resume_from_path='' + +exit 0 diff --git a/ICL/DAPO/verl-recipe/rep_exp/main_rep_exp.py b/ICL/DAPO/verl-recipe/rep_exp/main_rep_exp.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7068f12c116747a7c5be2cf4154988a7b27e65 --- /dev/null +++ b/ICL/DAPO/verl-recipe/rep_exp/main_rep_exp.py @@ -0,0 +1,483 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import os +import socket +import warnings + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.experimental.dataset.sampler import AbstractSampler +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.ppo.reward import load_reward_manager +from verl.trainer.ppo.utils import need_critic, need_reference_policy +from verl.utils.config import validate_config +from verl.utils.device import is_cuda_available +from verl.utils.import_utils import load_extern_type + +from .rep_exp_trainer import RayRepExpTrainer + + +@hydra.main(config_path="config", config_name="rep_exp_trainer", version_base=None) +def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config_dict: Hydra configuration dictionary containing training parameters. + """ + run_ppo(config) + + +# Define a function to run the PPO-like training process +def run_ppo(config, task_runner_class=None) -> None: + """Initialize Ray cluster and run distributed PPO training process. + + Args: + config: Training configuration object containing all necessary parameters + for distributed PPO training including Ray initialization settings, + model paths, and training hyperparameters. + task_runner_class: For recipe to change TaskRunner. + """ + # Check if Ray is not initialized + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + default_runtime_env = get_ppo_ray_runtime_env() + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + + if config.transfer_queue.enable: + # Add runtime environment variables for transfer queue + runtime_env_vars = runtime_env_kwargs.get("env_vars", {}) + runtime_env_vars["TRANSFER_QUEUE_ENABLE"] = "1" + runtime_env_kwargs["env_vars"] = runtime_env_vars + + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + if task_runner_class is None: + task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if ( + is_cuda_available + and config.global_profiler.tool == "nsys" + and config.global_profiler.get("steps") is not None + and len(config.global_profiler.get("steps", [])) > 0 + ): + from verl.utils.import_utils import is_nvtx_available + + assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = task_runner_class.remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +class TaskRunner: + """Ray remote class for executing distributed PPO training tasks. + + This class encapsulates the main training logic and runs as a Ray remote actor + to enable distributed execution across multiple nodes and GPUs. + + Attributes: + role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes + mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation + """ + + def __init__(self): + self.role_worker_mapping = {} + self.mapping = {} + + def add_actor_rollout_worker(self, config): + """Add actor rollout worker based on the actor strategy.""" + from verl.single_controller.ray import RayWorkerGroup + from verl.trainer.ppo.ray_trainer import Role + + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + + # use new model engine implementation + if use_legacy_worker_impl == "disable": + from verl.workers.engine_workers import ActorRolloutRefWorker + + actor_rollout_cls = ActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + # NOTE: In new model engine, ref policy and actor rollout are in same ActorRolloutRefWorker, + # while in legacy model engine, ref policy is in a separate ActorRolloutRefWorker. + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role = Role.ActorRolloutRef + else: + role = Role.ActorRollout + self.role_worker_mapping[role] = ray.remote(actor_rollout_cls) + self.mapping[role] = "global_pool" + return actor_rollout_cls, ray_worker_group_cls + + if config.actor_rollout_ref.rollout.mode == "sync": + warnings.warn("spmd rollout mode is deprecated and will be removed in v0.6.2", stacklevel=2) + + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls) + self.mapping[Role.ActorRollout] = "global_pool" + return actor_rollout_cls, ray_worker_group_cls + + def add_critic_worker(self, config): + """Add critic worker to role mapping.""" + if config.critic.strategy in {"fsdp", "fsdp2"}: + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable"]: + from verl.workers.fsdp_workers import CriticWorker + elif use_legacy_worker_impl == "disable": + from verl.workers.roles import CriticWorker + + print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + elif config.critic.strategy == "megatron": + from verl.workers.megatron_workers import CriticWorker + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import Role + + self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker) + self.mapping[Role.Critic] = "global_pool" + + def init_resource_pool_mgr(self, config): + """Initialize resource pool manager.""" + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + # TODO Here you can use the new registration method to support dynamic registration of roles + if config.reward_model.enable_resource_pool: + if config.reward_model.n_gpus_per_node <= 0: + raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0") + if config.reward_model.nnodes <= 0: + raise ValueError("config.reward_model.nnodes must be greater than 0") + + reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes + resource_pool_spec["reward_pool"] = reward_pool + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping) + return resource_pool_manager + + def add_reward_model_worker(self, config): + """Add reward model worker if enabled.""" + from verl.trainer.ppo.ray_trainer import Role + + if config.reward_model.enable: + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable"]: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + if config.reward_model.elliptical: + from .workers.elliptical_reward_model_worker import ( + EllipticalRewardModelWorker as RewardModelWorker, + ) + else: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + elif use_legacy_worker_impl == "disable": + from verl.workers.roles import RewardModelWorker + + print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + if config.reward_model.enable_resource_pool: + self.mapping[Role.RewardModel] = "reward_pool" + else: + self.mapping[Role.RewardModel] = "global_pool" + + def add_ref_policy_worker(self, config, ref_policy_cls): + """Add reference policy worker if KL loss or KL reward is used.""" + from verl.trainer.ppo.ray_trainer import Role + + # Ref policy has been fused into ActorRolloutRefWorker in new model engine, + # we don't need to add a separate ref policy worker goup. + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl == "disable": + return + + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls) + self.mapping[Role.RefPolicy] = "global_pool" + + def run(self, config): + """Execute the main PPO training workflow. + + This method sets up the distributed training environment, initializes + workers, datasets, and reward functions, then starts the training process. + + Args: + config: Training configuration object containing all parameters needed + for setting up and running the PPO training process. + """ + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) + self.add_critic_worker(config) + + # We should adopt a multi-source reward function here: + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # finally, we combine all the rewards together + # The reward type depends on the tag of the data + self.add_reward_model_worker(config) + + # Add a reference policy worker if KL loss or KL reward is used. + self.add_ref_policy_worker(config, actor_rollout_cls) + + # validate config + validate_config( + config=config, + use_reference_policy=need_reference_policy(self.role_worker_mapping), + use_critic=need_critic(config), + ) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + # Make sure the elliptical reward manager is registered + from .reward_manager.elliptical_reward_manager import EllipticalRewardManager # noqa: F401 + + # Load the reward manager for training and validation. + reward_manager_name = config.reward_model.get("reward_manager", "naive") + reward_fn = load_reward_manager( + config, + tokenizer, + num_examine=0, + **config.reward_model.get("reward_kwargs", {}).get(reward_manager_name, {}), + ) + val_reward_fn = load_reward_manager( + config, + tokenizer, + num_examine=1, + **config.reward_model.get("reward_kwargs", {}).get(reward_manager_name, {}), + ) + + resource_pool_manager = self.init_resource_pool_mgr(config) + + from verl.utils.dataset.rl_dataset import collate_fn + + # Create training and validation datasets. + train_dataset = create_rl_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + is_train=True, + max_samples=config.data.get("train_max_samples", -1), + ) + val_dataset = create_rl_dataset( + config.data.val_files, + config.data, + tokenizer, + processor, + is_train=False, + max_samples=config.data.get("val_max_samples", -1), + ) + train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the PPO trainer. + trainer = RayRepExpTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=self.role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + # Initialize the workers of the trainer. + trainer.init_workers() + + # Start the training process. + trainer.fit() + + +def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True, max_samples: int = -1): + """Create a dataset. + + Arguments: + data_paths: List of paths to data files. + data_config: The data config. + tokenizer (Tokenizer): The tokenizer. + processor (Processor): The processor. + + Returns: + dataset (Dataset): The dataset. + """ + from torch.utils.data import Dataset + + from verl.utils.dataset.rl_dataset import RLHFDataset + + # Check if a custom dataset class is specified in the data configuration + # and if the path to the custom class is provided + if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: + # Dynamically load the custom dataset class + dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + # Verify that the custom dataset class inherits from torch.utils.data.Dataset + if not issubclass(dataset_cls, Dataset): + raise TypeError( + f"The custom dataset class '{data_config.custom_cls.name}' from " + f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset" + ) + elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train: + # If a data generation strategy is specified, use the DynamicGenDataset class + from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset + + dataset_cls = DynamicGenDataset + print("Using DynamicGenDataset for data generation.") + else: + # Use the default RLHFDataset class if no custom class is specified + dataset_cls = RLHFDataset + print(f"Using dataset class: {dataset_cls.__name__}") + + # Instantiate the dataset using the determined dataset class + dataset = dataset_cls( + data_files=data_paths, + tokenizer=tokenizer, + processor=processor, + config=data_config, + max_samples=max_samples, + ) + + return dataset + + +def create_rl_sampler(data_config, dataset): + """Create a sampler for the dataset. + + Arguments: + data_config: The data config. + dataset (Dataset): The dataset. + + Returns: + sampler (Sampler): The sampler. + """ + import torch + from torch.utils.data import SequentialSampler + + # torch.utils.data.RandomSampler could not recover properly + from torchdata.stateful_dataloader.sampler import RandomSampler + + if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: + curriculum_class = load_extern_type( + data_config.sampler.class_path, + data_config.sampler.class_name, + ) + sampler = curriculum_class( + data_source=dataset, + data_config=data_config, + ) + assert isinstance(sampler, AbstractSampler) + assert data_config.get("dataloader_num_workers", 8) == 0, ( + "If using curriculum, num_workers must be 0 to prevent data caching. " + "If the dataloader caches data before the batch is done the " + "curriculum sampler won't have the opportunity to reorder it. " + ) + + # Use a sampler to facilitate checkpoint resumption. + # If shuffling is enabled in the data configuration, create a random sampler. + elif data_config.shuffle: + train_dataloader_generator = torch.Generator() + seed = data_config.get("seed") + if seed is not None: + train_dataloader_generator.manual_seed(seed) + sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + else: + # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. + sampler = SequentialSampler(data_source=dataset) + + return sampler + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/rep_exp/metric_utils.py b/ICL/DAPO/verl-recipe/rep_exp/metric_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..519a5b6c0f8a36396f18ba01f59e847ac4003ce0 --- /dev/null +++ b/ICL/DAPO/verl-recipe/rep_exp/metric_utils.py @@ -0,0 +1,382 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Metrics related to the RepExp trainer. +""" + +from collections import defaultdict +from functools import partial +from typing import Any + +import numpy as np +import torch + +from verl import DataProto +from verl.trainer.ppo.metric_utils import _compute_response_info, bootstrap_metric, calc_maj_val + + +def _compute_three_case_stats(data: DataProto, extrinsic_reward_tensor: torch.Tensor) -> dict: + """ + Compute the fraction of samples that have no rollouts correct, some rollouts correct, and all rollouts correct. + + Args: + data (DataProto): The data proto containing the batch data. + extrinsic_reward_tensor (torch.Tensor): The extrinsic reward tensor. + + Returns: + dict[str, float]: A dictionary containing the fraction of samples that have no rollouts correct, + some rollouts correct, and all rollouts correct. + """ + no_rollouts_correct = 0 + some_rollouts_correct = 0 + all_rollouts_correct = 0 + + visited_uids = set() + for uid in data.non_tensor_batch["uid"]: + if uid in visited_uids: + continue + + visited_uids.add(uid) + mask = torch.from_numpy(data.non_tensor_batch["uid"] == uid) + + # Split into three cases + if extrinsic_reward_tensor[mask].sum() == 0: + no_rollouts_correct += 1 + elif extrinsic_reward_tensor[mask].sum() == mask.sum(): + all_rollouts_correct += 1 + elif extrinsic_reward_tensor[mask].sum() > 0 and extrinsic_reward_tensor[mask].sum() < mask.sum(): + some_rollouts_correct += 1 + else: + raise ValueError(f"Invalid extrinsic reward tensor: {extrinsic_reward_tensor[mask].sum()}") + + # Sanity checks + assert len(visited_uids) == no_rollouts_correct + some_rollouts_correct + all_rollouts_correct + + return { + "no_rollouts_correct_frac": no_rollouts_correct / len(visited_uids), + "some_rollouts_correct_frac": some_rollouts_correct / len(visited_uids), + "all_rollouts_correct_frac": all_rollouts_correct / len(visited_uids), + } + + +def compute_data_metrics(batch: DataProto, use_critic: bool = True, elliptical: bool = False) -> dict[str, Any]: + """ + Computes various metrics from a batch of data for PPO training. + + This function calculates metrics related to scores, rewards, advantages, returns, values, + and sequence lengths from a batch of data. It provides statistical information (mean, max, min) + for each metric category. + + Args: + batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc. + use_critic: Whether to include critic-specific metrics. Defaults to True. + elliptical: Whether to include elliptical-specific metrics. Defaults to False. + + Returns: + A dictionary of metrics including: + - critic/score/mean, max, min: Statistics about sequence scores + - critic/rewards/mean, max, min: Statistics about sequence rewards + - critic/advantages/mean, max, min: Statistics about advantages + - critic/returns/mean, max, min: Statistics about returns + - critic/values/mean, max, min: Statistics about critic values (if use_critic=True) + - critic/vf_explained_var: Explained variance of the value function (if use_critic=True) + - response_length/mean, max, min, clip_ratio: Statistics about response lengths + - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths + - num_turns/mean, max, min: Statistics about the number of multi-turn conversations + """ + sequence_score = batch.batch["token_level_scores"].sum(-1) + sequence_reward = batch.batch["token_level_rewards"].sum(-1) + + if elliptical: + sequence_intrinsic_reward = batch.non_tensor_batch["intrinsic_reward"].sum(-1) + sequence_beta_scaled_intrinsic_reward = batch.non_tensor_batch["beta_scaled_intrinsic_reward"].sum(-1) + sequence_extrinsic_reward = batch.non_tensor_batch["extrinsic_reward"].sum(-1) + sequence_total_reward = batch.non_tensor_batch["total_reward"].sum(-1) + sequence_raw_bonuses = batch.non_tensor_batch["raw_bonuses"].sum(-1) + + three_case_stats = _compute_three_case_stats(batch, batch.non_tensor_batch["extrinsic_reward"]) + + advantages = batch.batch["advantages"] + returns = batch.batch["returns"] + + max_response_length = batch.batch["responses"].shape[-1] + + prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() + response_mask = batch.batch["response_mask"].bool() + + max_prompt_length = prompt_mask.size(-1) + + response_info = _compute_response_info(batch) + prompt_length = response_info["prompt_length"] + response_length = response_info["response_length"] + + aborted_mask = (response_length == 0).bool() + non_aborted_mask = ~aborted_mask + + non_aborted_sequence_score = sequence_score[non_aborted_mask] + non_aborted_sequence_reward = sequence_reward[non_aborted_mask] + + score_mean = torch.mean(non_aborted_sequence_score).detach().item() + score_max = torch.max(non_aborted_sequence_score).detach().item() + score_min = torch.min(non_aborted_sequence_score).detach().item() + + reward_mean = torch.mean(non_aborted_sequence_reward).detach().item() + reward_max = torch.max(non_aborted_sequence_reward).detach().item() + reward_min = torch.min(non_aborted_sequence_reward).detach().item() + + valid_adv = torch.masked_select(advantages, response_mask) + valid_returns = torch.masked_select(returns, response_mask) + + if use_critic: + values = batch.batch["values"] + valid_values = torch.masked_select(values, response_mask) + return_diff_var = torch.var(valid_returns - valid_values) + return_var = torch.var(valid_returns) + + # Aborted samples and non-aborted response length statistics + # response_length_non_aborted/*: statistics computed on non-aborted samples only + aborted_ratio = torch.mean(aborted_mask.float()).detach().item() + + non_aborted_response_length = response_length[non_aborted_mask] + if non_aborted_response_length.numel() > 0: + non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item() + non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item() + non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item() + non_aborted_response_length_clip_ratio = ( + torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item() + ) + else: + raise ValueError("All samples are aborted, this should not happen.") + + metrics = { + # score + "critic/score/mean": score_mean, + "critic/score/max": score_max, + "critic/score/min": score_min, + # reward + "critic/rewards/mean": reward_mean, + "critic/rewards/max": reward_max, + "critic/rewards/min": reward_min, + # adv + "critic/advantages/mean": torch.mean(valid_adv).detach().item(), + "critic/advantages/max": torch.max(valid_adv).detach().item(), + "critic/advantages/min": torch.min(valid_adv).detach().item(), + # returns + "critic/returns/mean": torch.mean(valid_returns).detach().item(), + "critic/returns/max": torch.max(valid_returns).detach().item(), + "critic/returns/min": torch.min(valid_returns).detach().item(), + **( + { + # values + "critic/values/mean": torch.mean(valid_values).detach().item(), + "critic/values/max": torch.max(valid_values).detach().item(), + "critic/values/min": torch.min(valid_values).detach().item(), + # vf explained var + "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } + if use_critic + else {} + ), + **( + { + # raw bonuses + "critic/raw_bonuses/mean": np.mean(sequence_raw_bonuses).item(), + "critic/raw_bonuses/max": np.max(sequence_raw_bonuses).item(), + "critic/raw_bonuses/min": np.min(sequence_raw_bonuses).item(), + "critic/raw_bonuses/std": np.std(sequence_raw_bonuses).item(), + # intrinsic_reward + "critic/intrinsic_reward/mean": np.mean(sequence_intrinsic_reward).item(), + "critic/intrinsic_reward/max": np.max(sequence_intrinsic_reward).item(), + "critic/intrinsic_reward/min": np.min(sequence_intrinsic_reward).item(), + "critic/intrinsic_reward/std": np.std(sequence_intrinsic_reward).item(), + # beta_scaled_intrinsic_reward + "critic/beta_scaled_intrinsic_reward/mean": np.mean(sequence_beta_scaled_intrinsic_reward).item(), + "critic/beta_scaled_intrinsic_reward/max": np.max(sequence_beta_scaled_intrinsic_reward).item(), + "critic/beta_scaled_intrinsic_reward/min": np.min(sequence_beta_scaled_intrinsic_reward).item(), + "critic/beta_scaled_intrinsic_reward/std": np.std(sequence_beta_scaled_intrinsic_reward).item(), + # extrinsic_reward + "critic/extrinsic_reward/mean": np.mean(sequence_extrinsic_reward).item(), + "critic/extrinsic_reward/max": np.max(sequence_extrinsic_reward).item(), + "critic/extrinsic_reward/min": np.min(sequence_extrinsic_reward).item(), + "critic/extrinsic_reward/std": np.std(sequence_extrinsic_reward).item(), + # three_case_stats + "critic/extrinsic_reward/no_rollouts_correct_frac": three_case_stats["no_rollouts_correct_frac"], + "critic/extrinsic_reward/some_rollouts_correct_frac": three_case_stats["some_rollouts_correct_frac"], + "critic/extrinsic_reward/all_rollouts_correct_frac": three_case_stats["all_rollouts_correct_frac"], + # total_reward + "critic/total_reward/mean": np.mean(sequence_total_reward).item(), + "critic/total_reward/max": np.max(sequence_total_reward).item(), + "critic/total_reward/min": np.min(sequence_total_reward).item(), + "critic/total_reward/std": np.std(sequence_total_reward).item(), + } + if elliptical + else {} + ), + # response length + "response_length/mean": torch.mean(response_length).detach().item(), + "response_length/max": torch.max(response_length).detach().item(), + "response_length/min": torch.min(response_length).detach().item(), + "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) + .detach() + .item(), + # response length (non-aborted only) + # These statistics exclude aborted samples to avoid skew from zeros + "response_length_non_aborted/mean": non_aborted_response_length_mean, + "response_length_non_aborted/max": non_aborted_response_length_max, + "response_length_non_aborted/min": non_aborted_response_length_min, + "response_length_non_aborted/clip_ratio": non_aborted_response_length_clip_ratio, + # aborted ratio + # Fraction of samples whose response length is zero + "response/aborted_ratio": aborted_ratio, + # prompt length + "prompt_length/mean": torch.mean(prompt_length).detach().item(), + "prompt_length/max": torch.max(prompt_length).detach().item(), + "prompt_length/min": torch.min(prompt_length).detach().item(), + "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + } + + # multi-turn conversation + if "__num_turns__" in batch.non_tensor_batch: + num_turns = batch.non_tensor_batch["__num_turns__"] + metrics["num_turns/min"] = num_turns.min() + metrics["num_turns/max"] = num_turns.max() + metrics["num_turns/mean"] = num_turns.mean() + + if "tool_call_counts" in batch.non_tensor_batch: + tool_call_counts = batch.non_tensor_batch["tool_call_counts"] + metrics["tool_call_counts/min"] = tool_call_counts.min() + metrics["tool_call_counts/max"] = tool_call_counts.max() + metrics["tool_call_counts/mean"] = tool_call_counts.mean() + + return metrics + + +def comb_estimator(n: int, c: int, k: int) -> float: + """Calculates 1 - comb(n - c, k) / comb(n, k).""" + if n - c < k: + return 1.0 + return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) + + +def process_validation_metrics( + data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42 +) -> dict[str, dict[str, dict[str, float]]]: + """ + Process validation metrics into a structured format with statistical analysis. + + This function organizes validation metrics by data source and prompt, then computes + various statistical measures including means, standard deviations, best/worst values, + and majority voting results. It also performs bootstrap sampling to estimate statistics + for different sample sizes. + + Args: + data_sources: List of data source identifiers for each sample. + sample_uids: List of sample uids corresponding to each sample. + infos_dict: Dictionary mapping variable names to lists of values for each sample. + seed: Random seed for bootstrap sampling. Defaults to 42. + + Returns: + A nested dictionary with the structure: + { + data_source: { + variable_name: { + metric_name: value + } + } + } + + Where metric_name includes: + - "mean@N": Mean value across N samples + - "std@N": Standard deviation across N samples + - "best@N/mean": Mean of the best values in bootstrap samples of size N + - "best@N/std": Standard deviation of the best values in bootstrap samples + - "worst@N/mean": Mean of the worst values in bootstrap samples + - "worst@N/std": Standard deviation of the worst values in bootstrap samples + - "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists) + - "maj@N/std": Standard deviation of majority voting results (if "pred" exists) + + Example: + >>> data_sources = ["source1", "source1", "source2"] + >>> sample_uids = ["uid1", "uid1", "uid2"] + >>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]} + >>> result = process_validation_metrics(data_sources, sample_uids, infos_dict) + >>> # result will contain statistics for each data source and variable + """ + # Group metrics by data source, prompt and variable + data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for sample_idx, data_source in enumerate(data_sources): + uid = sample_uids[sample_idx] + var2vals = data_src2uid2var2vals[data_source][uid] + for var_name, var_vals in infos_dict.items(): + var2vals[var_name].append(var_vals[sample_idx]) + + # Calculate metrics for each group + data_src2uid2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + for data_source, uid2var2vals in data_src2uid2var2vals.items(): + for uid, var2vals in uid2var2vals.items(): + for var_name, var_vals in var2vals.items(): + if isinstance(var_vals[0], str): + continue + + metric = {} + n_resps = len(var_vals) + metric[f"mean@{n_resps}"] = np.mean(var_vals) + metric["pass@1/mean"] = comb_estimator(n_resps, np.sum(var_vals), 1) + + if n_resps > 1: + metric[f"std@{n_resps}"] = np.std(var_vals) + + ns = [] + n = 2 + while n < n_resps: + ns.append(n) + n *= 2 + ns.append(n_resps) + + for n in ns: + # [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric( + # data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed + # ) + # metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std + # metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std + metric[f"pass@{n}/mean"] = comb_estimator(n_resps, np.sum(var_vals), n) + if var2vals.get("pred", None) is not None: + vote_data = [ + {"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True) + ] + [(maj_n_mean, maj_n_std)] = bootstrap_metric( + data=vote_data, + subset_size=n, + reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")], + seed=seed, + ) + metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std + + data_src2uid2var2metric[data_source][uid][var_name] = metric + + # Aggregate metrics across uids + data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for data_source, uid2var2metric in data_src2uid2var2metric.items(): + for uid, var2metric in uid2var2metric.items(): + for var_name, metric in var2metric.items(): + for metric_name, metric_val in metric.items(): + data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val) + + data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float))) + for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items(): + for var_name, metric2uid_vals in var2metric2uid_vals.items(): + for metric_name, uid_vals in metric2uid_vals.items(): + data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals) + + return data_src2var2metric2val diff --git a/ICL/DAPO/verl-recipe/rep_exp/model_merge.sh b/ICL/DAPO/verl-recipe/rep_exp/model_merge.sh new file mode 100644 index 0000000000000000000000000000000000000000..b7c673ebe8ce95b306f375a8dc04307d6223b7a1 --- /dev/null +++ b/ICL/DAPO/verl-recipe/rep_exp/model_merge.sh @@ -0,0 +1,6 @@ +CHECKPOINT_PATH=${1} # /path/to/global_step_X/actor, where X is the global step of the checkpoint with the best pass@1 on dev + +python3 -m verl.model_merger merge \ + --backend fsdp \ + --local_dir $CHECKPOINT_PATH \ + --target_dir $CHECKPOINT_PATH/hf \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/rep_exp/plot_pass_at_k.py b/ICL/DAPO/verl-recipe/rep_exp/plot_pass_at_k.py new file mode 100644 index 0000000000000000000000000000000000000000..eea2011bab05c39d3637f52eb79dbf39e3057fef --- /dev/null +++ b/ICL/DAPO/verl-recipe/rep_exp/plot_pass_at_k.py @@ -0,0 +1,241 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Code to plot the pass@k results for the RepExp RL training results. +""" + +import json +import os +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np +import scipy.stats as stats +import seaborn as sns +from matplotlib.lines import Line2D + +# Content configuration +EVAL_FOLDER = "./eval" +TASKS = ["math"] # ["math", "gsm8k", "dapo-with-aime24"] +SEEDS = [41, 42, 43] +ALGORITHMS = ["elliptical"] # ["grpo", "elliptical", "untrained", "unlikely"] +LOG_AXES = True + +# Plot configuration +FACE_COLOR = "#F7F7FF" +MARKER = "o" +LINEWIDTH = 1.275 +MARKERSIZE = 6 +MARKEREDGEWIDTH = 0.9 +LABEL_FONT_SIZE = 10 +TITLE_FONT_SIZE = 11 +TICK_LABEL_FONT_SIZE = 8 +LEGEND_FONT_SIZE = 8 + +TASK_TO_NICE_NAME = { + "math": "MATH", + "gsm8k": "GSM8K", + "dapo-with-aime24": "AIME 2024", + "countdown-4": "Countdown", +} + +ALGO_TO_COLOR = { + "grpo": sns.color_palette("deep")[-1], + "untrained": sns.color_palette("deep")[7], + "elliptical": sns.color_palette("colorblind")[2], + "unlikely": sns.color_palette("deep")[1], +} + +ALGO_TO_NICE_NAME = { + "grpo": "GRPO", + "untrained": "Base Model", + "elliptical": r"RepExp (ours)", + "unlikely": "Unlikeliness", +} + + +def process_data(data: list[dict[str, float]], algorithm: str) -> tuple[dict[int, float], dict[int, float]]: + """ + Process the pass@k data generated by a given algorithm. + + Args: + data (List[Dict]): The data to process. + algorithm (str): Algorithm that generated the data. + + Returns: + Tuple[Dict[int, float], Dict[int, float]]: + pass_at_k - The mean pass@k values. + pass_at_k_sem - The standard error of the pass@k values. + """ + pass_at_k = defaultdict(list) + for d in data: + for key, v in d.items(): + for k in [1, 2, 4, 8, 16, 32, 64, 128, 256]: + if key.endswith(f"reward/pass@{k}/mean"): + pass_at_k[k].append(v) + + # NOTE: we only use a single seed for untrained since there is only one checkpoint for it + if algorithm != "untrained": + for k in pass_at_k.keys(): + assert len(pass_at_k[k]) == len(SEEDS) + + pass_at_k_sem = {k: stats.sem(v) for k, v in pass_at_k.items()} if algorithm != "untrained" else None + pass_at_k = {k: np.mean(v) for k, v in pass_at_k.items()} + + return pass_at_k, pass_at_k_sem + + +def main(): + # Get all top-level folders in EVAL_FOLDER + eval_folders = os.listdir(EVAL_FOLDER) + + # Figure setup + sns.set_style("whitegrid") + fig, axs = plt.subplots(1, len(TASKS), figsize=(3 * len(TASKS), 3)) + + for i, task in enumerate(TASKS): + ax = axs[i] if len(TASKS) > 1 else axs + algo_to_xs = {} + algo_to_ys = {} + + for algorithm in ALGORITHMS: + # Get all eval folders for the current task and algorithm + folders = [f for f in eval_folders if f.startswith(f"{task}_{algorithm}")] + if len(folders) == 0: + continue + + data = [] + for folder in folders: + if algorithm == "untrained": + with open(os.path.join(EVAL_FOLDER, folder, "eval.json")) as f: + data.append(json.load(f)) + else: + # walk all files recursively in folder + for root, dirs, files in os.walk(os.path.join(EVAL_FOLDER, folder)): + for file in files: + if file.endswith("eval.json"): + with open(os.path.join(root, file)) as f: + data.append(json.load(f)) + break + + pass_at_k, pass_at_k_sem = process_data(data, algorithm) + + xs = np.array(list(pass_at_k.keys())) + ys = np.array([pass_at_k[k] for k in xs]) + algo_to_xs[algorithm] = xs + algo_to_ys[algorithm] = ys + + # Plot the current task - algorithm data + ax.plot( + xs, + ys, + color=ALGO_TO_COLOR[algorithm], + label=algorithm, + markeredgecolor=FACE_COLOR, + marker=MARKER, + linewidth=LINEWIDTH, + markersize=MARKERSIZE, + markeredgewidth=MARKEREDGEWIDTH, + alpha=1.0 if algorithm != "untrained" else 0.8, + ) + + # Plot the standard error in shaded bands + if algorithm != "untrained": + sems = np.array([pass_at_k_sem[k] for k in xs]) + ax.fill_between(xs, ys - sems, ys + sems, alpha=0.2, color=ALGO_TO_COLOR[algorithm]) + + # Set y-axis limits + if task == "math": + y_min = 0.7 + ax.set_ylim(top=0.95, bottom=y_min) + elif task == "gsm8k": + y_min = 0.925 + ax.set_ylim(top=0.995, bottom=y_min) + elif task == "dapo-with-aime24": + y_min = 0.1 + ax.set_ylim(bottom=y_min, top=0.63) + + # Set x-axis limits + if LOG_AXES: + ax.set_xlim(left=2 ** (-0.2), right=2 ** (8.2)) + else: + ax.set_xlim(left=-10, right=266) + + # Set x-axis scale and ticks + if LOG_AXES: + ax.set_xscale("log", base=2) + x_ticks = [2**i for i in range(int(np.log2(max(xs))) + 1)] + x_tick_labels = [f"$2^{{{i}}}$" for i in range(int(np.log2(max(xs))) + 1)] + else: + # set every 64 + x_ticks = [1, 32, 64, 96, 128, 160, 192, 224, 256] + x_tick_labels = ["1", "32", "64", "96", "128", "160", "192", "224", "256"] + ax.set_xticks(x_ticks, x_tick_labels) + + # Set axes labels + ax.set_xlabel("k", fontsize=LABEL_FONT_SIZE) + if i == 0: + ax.set_ylabel("Pass@k", fontsize=LABEL_FONT_SIZE) + + # Set title + ax.set_title(f"{TASK_TO_NICE_NAME[task]}", fontsize=TITLE_FONT_SIZE) + + # Set font size for tick labels + for _label in ax.get_xticklabels(): + _label.set_fontsize(TICK_LABEL_FONT_SIZE) + for _label in ax.get_yticklabels(): + _label.set_fontsize(TICK_LABEL_FONT_SIZE) + + # Create legend handles + legend_handles = [ + Line2D( + [0], + [0], + color=ALGO_TO_COLOR[algo], + marker=MARKER, + linestyle="-", + linewidth=LINEWIDTH, + markersize=MARKERSIZE, + markeredgewidth=MARKEREDGEWIDTH, + markeredgecolor=FACE_COLOR, + label=ALGO_TO_NICE_NAME[algo], + ) + for algo in ALGORITHMS + ] + + # Create legend + legend = fig.legend( + handles=legend_handles, + loc="lower center", + ncol=len(ALGORITHMS), + bbox_to_anchor=(0.5, -0.07), + fontsize=LEGEND_FONT_SIZE, + ) + + plt.tight_layout() + + os.makedirs("figures", exist_ok=True) + # Save figure + plt.savefig( + os.path.join("figures", f"rl_pass_at_k_{TASKS}_{'' if LOG_AXES else '_linear_axes'}.pdf"), + bbox_extra_artists=(legend,), + bbox_inches="tight", + ) + + # Close figure + plt.close() + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/rep_exp/rep_exp_trainer.py b/ICL/DAPO/verl-recipe/rep_exp/rep_exp_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c23b848d782ac6bb4b0511372c12aeaa73e1a1 --- /dev/null +++ b/ICL/DAPO/verl-recipe/rep_exp/rep_exp_trainer.py @@ -0,0 +1,739 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import json +import os +import uuid +from collections import defaultdict +from copy import deepcopy +from pprint import pprint + +import numpy as np +import ray +import torch +from omegaconf import OmegaConf +from tqdm import tqdm + +from verl import DataProto +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.ray import RayClassWithInitArgs +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_throughout_metrics, + compute_timing_metrics, +) +from verl.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask +from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.trainer.ppo.utils import Role +from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.debug import marked_timer +from verl.utils.metric import reduce_metrics +from verl.utils.rollout_skip import RolloutSkip + +from .metric_utils import compute_data_metrics, process_validation_metrics + + +class RayRepExpTrainer(RayPPOTrainer): + """Distributed RepExp trainer using Ray for scalable reinforcement learning. + + See RayPPOTrainer parent class for more details. + """ + + def _save_checkpoint(self): + super()._save_checkpoint() + + # Write best metric to global steps + local_best_metric_to_global_step = os.path.join( + self.config.trainer.default_local_dir, "best_metric_to_global_step.json" + ) + with open(local_best_metric_to_global_step, "w") as f: + json.dump(self.best_dev_pass_at_k_to_global_step, f) + + def _update_best_pass_at(self, val_metrics: dict[str, float], pass_at_k: int) -> bool: + """ + Save checkpoint if the validation metrics are the best. + + Args: + val_metrics: The validation metrics. + pass_at_k: The pass@k to use for determining whether to save the checkpoint. + """ + for k in val_metrics.keys(): + if k.endswith(f"reward/pass@{pass_at_k}/mean"): + if val_metrics[k] > self.best_dev_pass_at_k[pass_at_k]: + self.best_dev_pass_at_k[pass_at_k] = val_metrics[k] + self.best_dev_pass_at_k_to_global_step[pass_at_k] = self.global_steps + return True + + return False + + def _validate(self): + data_source_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_gts = [] + sample_scores = [] + sample_turns = [] + sample_uids = [] + + for test_data in tqdm(self.val_dataloader, desc="Validating ..."): + test_batch = DataProto.from_single_dict(test_data) + + if "uid" not in test_batch.non_tensor_batch: + test_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object + ) + + # repeat test batch + test_batch = test_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + ) + + # we only do validation on rule-based rm + if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": + return {} + + # Store original inputs + input_ids = test_batch.batch["input_ids"] + # TODO: Can we keep special tokens except for padding tokens? + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + sample_uids.extend(test_batch.non_tensor_batch["uid"]) + + ground_truths = [ + item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch + ] + sample_gts.extend(ground_truths) + + test_gen_batch = self._get_gen_batch(test_batch) + test_gen_batch.meta_info = { + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + "recompute_log_prob": False, + "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, + "validate": True, + "global_steps": self.global_steps, + } + print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") + + # pad to be divisible by dp_size + size_divisor = ( + self.actor_rollout_wg.world_size + if not self.async_rollout_mode + else self.config.actor_rollout_ref.rollout.agent.num_workers + ) + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) + if not self.async_rollout_mode: + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + else: + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + + # unpad + test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + + print("validation generation end") + + # Store generated outputs + output_ids = test_output_gen_batch.batch["responses"] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + sample_outputs.extend(output_texts) + + test_batch = test_batch.union(test_output_gen_batch) + test_batch.meta_info["validate"] = True + + # evaluate using reward_function + if self.val_reward_fn is None: + raise ValueError("val_reward_fn must be provided for validation.") + result = self.val_reward_fn(test_batch, return_dict=True) + reward_tensor = result["reward_tensor"] + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + reward_extra_infos_dict["reward"].extend(scores) + if "reward_extra_info" in result: + for key, lst in result["reward_extra_info"].items(): + reward_extra_infos_dict[key].extend(lst) + + # collect num_turns of each prompt + if "__num_turns__" in test_batch.non_tensor_batch: + sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) + + data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", None) + if val_data_dir: + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + gts=sample_gts, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + ) + + for key_info, lst in reward_extra_infos_dict.items(): + assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + + data_sources = np.concatenate(data_source_lst, axis=0) + + data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict) + metric_dict = {} + for data_source, var2metric2val in data_src2var2metric2val.items(): + core_var = "acc" if "acc" in var2metric2val else "reward" + for var_name, metric2val in var2metric2val.items(): + n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + for metric_name, metric_val in metric2val.items(): + if ( + (var_name == core_var) + and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) + and (f"@{n_max}" in metric_name) + ): + metric_sec = "val-core" + else: + metric_sec = "val-aux" + pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + metric_dict[pfx] = metric_val + + if len(sample_turns) > 0: + sample_turns = np.concatenate(sample_turns) + metric_dict["val-aux/num_turns/min"] = sample_turns.min() + metric_dict["val-aux/num_turns/max"] = sample_turns.max() + metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() + + return metric_dict + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + val_only = self.config.trainer.get("val_only", False) + + # create actor and rollout + actor_role = Role.ActorRolloutRef if Role.ActorRolloutRef in self.role_worker_mapping else Role.ActorRollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(actor_role) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[actor_role], + config=self.config.actor_rollout_ref, + role=str(actor_role), + ) + self.resource_pool_to_cls[resource_pool][str(actor_role)] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic and not val_only: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cfg = omega_conf_to_dataclass(self.config.critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) + self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls + + # create reference policy if needed + if self.use_reference_policy and not val_only: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role=str(Role.RefPolicy), + ) + self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm and not val_only: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.global_profiler, "steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") + # Only require nsight worker options when tool is nsys + if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": + assert ( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + ) + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg[str(Role.Critic)] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + if str(Role.RefPolicy) in all_wg: + self.ref_policy_wg = all_wg[str(Role.RefPolicy)] + self.ref_policy_wg.init_model() + else: + # Model engine: ActorRolloutRefWorker + assert str(Role.ActorRolloutRef) in all_wg, f"{all_wg.keys()=}" + self.ref_policy_wg = all_wg[str(Role.ActorRolloutRef)] + + self.rm_wg = None + # initalization of rm_wg will be deprecated in the future + if self.use_rm: + self.rm_wg = all_wg[str(Role.RewardModel)] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg[str(actor_role)] + self.actor_rollout_wg.init_model() + + # create async rollout manager and request scheduler + self.async_rollout_mode = False + if self.config.actor_rollout_ref.rollout.mode == "async": + from verl.experimental.agent_loop import AgentLoopManager + + self.async_rollout_mode = True + self.async_rollout_manager = AgentLoopManager( + config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg + ) + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from .utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + # global vars to track during training + self.global_steps = 0 + + self.best_dev_pass_at_k = { + 1: 0, + } + self.best_dev_pass_at_k_to_global_step = { + 1: 0, + } + + # load checkpoint before doing anything + self._load_checkpoint() + + current_epoch = self.global_steps // len(self.train_dataloader) + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + + # Initialize the best validation metrics for pass@k before training + self._update_best_pass_at(val_metrics, 1) + val_metrics["best/pass@1"] = self.best_dev_pass_at_k[1] + + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + + if self.config.trainer.get("val_only", False): + return + + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip.wrap_generate_sequences() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + self.max_steps_duration = 0 + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + for epoch in range(current_epoch, self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + batch: DataProto = DataProto.from_single_dict(batch_dict) + batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature + + # add uid to batch + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + + gen_batch = self._get_gen_batch(batch) + + # pass global_steps to trace + gen_batch.meta_info["global_steps"] = self.global_steps + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + is_last_step = self.global_steps >= self.total_training_steps + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, color="red"): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output) + else: + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + if self.reward_fn is None: + raise ValueError("A reward_fn is required for REMAX advantage estimation.") + + with marked_timer("gen_max", timing_raw, color="purple"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + if not self.async_rollout_mode: + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + else: + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + batch = batch.union(gen_baseline_output) + # compute reward model score on batch + rm_scores = None + if self.use_rm and "rm_scores" not in batch.batch.keys(): + rm_scores = self.rm_wg.compute_rm_score(batch) + batch = batch.union(rm_scores) + reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + keys_to_pop = set(gen_baseline_output.batch.keys()) + if rm_scores is not None: + keys_to_pop.update(rm_scores.batch.keys()) + batch.pop(batch_keys=list(keys_to_pop)) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del rm_scores, gen_baseline_batch, gen_baseline_output + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + if "response_mask" not in batch.batch.keys(): + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm and "rm_scores" not in batch.batch.keys(): + if self.config.reward_model.elliptical.enable: + hidden_states = self.rm_wg.compute_hidden_states(batch) + batch = batch.union(hidden_states) + reward_tensor = self.rm_wg.compute_rm_score(batch) + else: + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + if self.config.reward_model.launch_reward_fn_async: + future_reward = compute_reward_async.remote( + data=batch, config=self.config, tokenizer=self.tokenizer + ) + else: + reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + + # Operating Mode Selection: + # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ) + # - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ) + # Note: π_old computed once per data batch, serves as stable reference during mini-batch updates + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False) + if bypass_recomputing_logprobs: # Use `rollout_log_probs` + from verl.trainer.ppo.rollout_corr_helper import apply_rollout_correction + + apply_rollout_correction( + batch=batch, + rollout_corr_config=rollout_corr_config, + policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss, + ) + else: # Recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss( + loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode + ) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) + + assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}' + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + if self.config.reward_model.launch_reward_fn_async: + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # Compute rollout correction: IS weights, rejection sampling, and metrics + # Only runs in decoupled mode (computes once per batch using stable π_old) + # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout + if ( + rollout_corr_config is not None + and "rollout_log_probs" in batch.batch + and not bypass_recomputing_logprobs # Only in decoupled mode + ): + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + # Compute IS weights, apply rejection sampling, compute metrics + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) + + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + rollout_config = self.config.actor_rollout_ref.rollout + batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable + # TODO: Make "temperature" single source of truth from generation. + batch.meta_info["temperature"] = rollout_config.temperature + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + + # Initialize the best validation metrics for pass@k before training + self._update_best_pass_at(val_metrics, 1) + val_metrics["best/pass@1"] = self.best_dev_pass_at_k[1] + + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update( + compute_data_metrics( + batch=batch, + use_critic=self.use_critic, + elliptical=self.config.reward_model.elliptical.enable, + ) + ) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + # Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation + + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if ( + hasattr(self.config.actor_rollout_ref.actor, "profiler") + and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory" + ): + self.actor_rollout_wg.dump_memory_snapshot( + tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}" + ) + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + # this is experimental and may be changed/removed in the future + # in favor of a general-purpose data buffer pool + if hasattr(self.train_dataset, "on_batch_end"): + # The dataset may be changed after each training batch + self.train_dataset.on_batch_end(batch=batch) diff --git a/ICL/DAPO/verl-recipe/spin/core_algos.py b/ICL/DAPO/verl-recipe/spin/core_algos.py new file mode 100644 index 0000000000000000000000000000000000000000..c48027e54106ab496c09ddb80107fb7df210f2b6 --- /dev/null +++ b/ICL/DAPO/verl-recipe/spin/core_algos.py @@ -0,0 +1,206 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch + + +class AdaptiveKLController: + """ + Adaptive KL controller described in the paper: + https://arxiv.org/pdf/1909.08593.pdf + """ + + def __init__(self, init_kl_coef, target_kl, horizon): + self.value = init_kl_coef + self.target = target_kl + self.horizon = horizon + + def update(self, current_kl, n_steps): + target = self.target + proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult + + +class FixedKLController: + """Fixed KL controller.""" + + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current_kl, n_steps): + pass + + +def get_kl_controller(kl_ctrl): + if kl_ctrl.type == "fixed": + return FixedKLController(kl_coef=kl_ctrl.kl_coef) + elif kl_ctrl.type == "adaptive": + assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}" + return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon) + else: + raise NotImplementedError + + +def compute_onlinedpo_pref( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, +) -> torch.Tensor: + """ + Computes preferences between pairs of sequences based on summed rewards + and returns a mask aligned with the interleaved batch. + + Assumes inputs are interleaved: [Resp1_Prompt0, Resp2_Prompt0, Resp1_Prompt1, Resp2_Prompt1, ...] + + Args: + token_level_rewards: Tensor of shape [batch_size * 2, seq_len] + response_mask: Tensor of shape [batch_size * 2, seq_len] + + Returns: + torch.Tensor: A boolean mask of shape [batch_size * 2], where True indicates + the corresponding entry is the chosen response for its pair. + Example: [True, False, False, True, ...] means for prompt 0, + response 1 was chosen; for prompt 1, response 2 was chosen. + """ + # print(f"---- [DEBUG] Inside compute_onlinedpo_pref ----") + if token_level_rewards.shape[0] % 2 != 0 or response_mask.shape[0] % 2 != 0: + raise ValueError( + f"Input tensor batch dimension must be even for pair comparison, got shapes: " + f"{token_level_rewards.shape}, {response_mask.shape}" + ) + if token_level_rewards.shape != response_mask.shape: + raise ValueError(f"Shape mismatch between rewards {token_level_rewards.shape} and mask {response_mask.shape}") + + # 1. Calculate Sequence Scores + scores = (token_level_rewards * response_mask).sum(dim=-1) + # print(f" Calculated sequence scores shape: {scores.shape}") # [batch_size * 2] + + # 2. Reshape scores to group pairs: [batch_size, 2] + try: + score_pairs = scores.view(-1, 2) + except RuntimeError as e: + print(f"ERROR reshaping scores (shape {scores.shape}) into pairs: {e}") + raise e + print(f" Reshaped score pairs shape: {score_pairs.shape}") # [batch_size, 2] + + # 3. Compare scores to find which index (0 or 1) is the winner within each pair + # winner_indices[i] = 0 if score_pairs[i, 0] >= score_pairs[i, 1] else 1 + winner_indices = torch.argmax(score_pairs, dim=1) # 0 if first is max, 1 if second is max + # Handle ties explicitly if argmax behavior isn't guaranteed (usually picks first max) + # Alternatively: winner_mask_original = score_pairs[:, 0] >= score_pairs[:, 1] + # print(f" Winner indices shape: {winner_indices.shape}") # [batch_size] + # print(f" Number where Response 2 (index 1) is preferred: {winner_indices.sum().item()}") # Counts number of 1s + + # 4. Create the final [batch_size * 2] mask + num_pairs = score_pairs.shape[0] + full_batch_size = num_pairs * 2 + # Create indices for the full batch [0, 1, 2, 3, ..., N*2-1] + # full_indices = torch.arange(full_batch_size, device=scores.device) + # Create indices corresponding to the winner within each pair's original index + # E.g., if winner_indices is [0, 1, 0], pair_indices is [0, 1, 2] + # winner_global_indices = (pair_indices * 2) + winner_indices -> [ (0*2)+0, (1*2)+1, (2*2)+0 ] -> [0, 3, 4] + pair_indices = torch.arange(num_pairs, device=scores.device) + winner_global_indices = (pair_indices * 2) + winner_indices + + # Create boolean mask - True at the winner's position + output_preference_mask = torch.zeros(full_batch_size, dtype=torch.bool, device=scores.device) + output_preference_mask[winner_global_indices] = True + + # print(f" Output preference mask shape: {output_preference_mask.shape}") # Should be [batch_size * 2] + # print(f" Output mask True count (Chosen): {output_preference_mask.sum().item()}") # Should be batch_size + # print(f" Output mask False count (Rejected): {(~output_preference_mask).sum().item()}") # Should be batch_size + # print(f"---- [DEBUG] Exiting compute_onlinedpo_pref ----") + + return output_preference_mask + + +def compute_online_dpo_loss( + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + beta: float, + label_smoothing: float = 0.0, + loss_type: str = "sigmoid", + reference_free: bool = False, +) -> torch.Tensor: + import torch.nn.functional as F + + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + if reference_free: + ref_logratios = torch.zeros_like(pi_logratios) + + logits = pi_logratios - ref_logratios + + if loss_type == "sigmoid": + losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing + elif loss_type == "ipo": + losses = (logits - 1 / (2 * beta)) ** 2 + else: + raise ValueError(f"Unsupported loss_type: {loss_type}. Choose 'sigmoid', 'ipo', or 'hinge'.") + + return losses.mean() + + +def get_batch_logps( + logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False +) -> torch.FloatTensor: + """ + Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (e.g., huggingface CausalLMOutputs `logits`). + Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for computing the sequence log probabilities. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per sequence. Otherwise, return the sum. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given sequences. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits and labels must have the same shape[:-1]") + + # Ensure labels are contiguous and on the same device as logits + labels = labels.contiguous().to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Calculate per token log probability + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none") + per_token_logps = -loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + per_token_logps = per_token_logps.view( + shift_logits.size(0), shift_logits.size(1) + ) # Reshape back to (batch_size, seq_len-1) + + # Create a mask for the labels that are not -100 + loss_mask = shift_labels != -100 + + # Apply the mask to the per token log probabilities + masked_logps = per_token_logps * loss_mask + + # Calculate the sum or average log probability per sequence + sequence_logps = masked_logps.sum(dim=-1) + + if average_log_prob: + # Avoid division by zero for sequences with no valid tokens + num_valid_tokens = loss_mask.sum(dim=-1) + return sequence_logps / torch.clamp(num_valid_tokens, min=1) + else: + return sequence_logps diff --git a/ICL/DAPO/verl-recipe/spin/main_spin.py b/ICL/DAPO/verl-recipe/spin/main_spin.py new file mode 100644 index 0000000000000000000000000000000000000000..465742360d3be4c796c94f657d70bc6c2fe4953f --- /dev/null +++ b/ICL/DAPO/verl-recipe/spin/main_spin.py @@ -0,0 +1,168 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import hydra +import ray +from recipe.spin.spin_trainer import RaySPINTrainer +from recipe.spin.utils import validate_config + +from verl.trainer.ppo.reward import get_custom_reward_fn +from verl.trainer.ppo.utils import need_reference_policy + + +@hydra.main(config_path="config", config_name="spin_trainer", version_base=None) +def main(config): + run_ppo(config) + + +def run_ppo(config) -> None: + # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices + # isolation, will solve in the future + os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if not ray.is_initialized(): + # this is for local ray cluster + ray.init( + runtime_env={ + "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} + } + ) + + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + def run(self, config): + # print initial config + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # define worker classes + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} + # from recipe.spin.fsdp_workers import ActorRolloutRefWorker + from recipe.spin.fsdp_workers import SPINRolloutRefWorker + + from verl.single_controller.ray import RayWorkerGroup + + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray import RayWorkerGroup + + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + from recipe.spin.spin_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + # Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.ActorRollout: ray.remote(SPINRolloutRefWorker), + # Role.Critic: ray.remote(CriticWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + # Role.Critic: global_pool_id, + } + + if config.reward_model.enable: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from recipe.spin.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # use reference model + # if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + # role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) + role_worker_mapping[Role.RefPolicy] = ray.remote(SPINRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + # validate config + validate_config( + config=config, + use_reference_policy=need_reference_policy(role_worker_mapping), + use_critic=False, + ) + + # download the checkpoint from hdfs + local_path = copy_to_local(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none + + from verl.workers.reward_manager import get_reward_manager_cls + + # Note(haibin.lin): please make sure custom reward managers are imported and + # registered via `verl.workers.reward_manager.register` + reward_manager_name = config.reward_model.get("reward_manager", "naive") + reward_manager_cls = get_reward_manager_cls(reward_manager_name) + + compute_score = get_custom_reward_fn(config) + reward_kwargs = dict(config.reward_model.get("reward_kwargs", {})) + reward_fn = reward_manager_cls( + tokenizer=tokenizer, + num_examine=0, + compute_score=compute_score, + reward_fn_key=config.data.reward_fn_key, + **reward_kwargs, + ) + + # Note that we always use function-based RM for validation + val_reward_fn = reward_manager_cls( + tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key + ) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RaySPINTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + ) + trainer.init_workers() + trainer.fit_dpo() + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/spin/spin_trainer.py b/ICL/DAPO/verl-recipe/spin/spin_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..41ec06cc49bc4479cd7dd56cec93f17893035140 --- /dev/null +++ b/ICL/DAPO/verl-recipe/spin/spin_trainer.py @@ -0,0 +1,1312 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import traceback +import uuid +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass, field +from pprint import pprint +from typing import Any, Optional + +import numpy as np +import ray +import torch +from codetiming import Timer +from omegaconf import OmegaConf, open_dict +from recipe.spin import core_algos +from torch.utils.data import Dataset, Sampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from verl import DataProto +from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo.metric_utils import compute_throughout_metrics, compute_timing_metrics, process_validation_metrics +from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path +from verl.utils.metric import reduce_metrics +from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.torch_functional import masked_mean +from verl.utils.tracking import ValidationGenerationsLogger + + +@dataclass +class ResourcePoolManager: + """ + Define a resource pool specification. Resource pool will be initialized first. + Mapping + """ + + resource_pool_spec: dict[str, list[int]] + mapping: dict[Role, str] + resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) + + def create_resource_pool(self): + for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): + # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool + # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. + # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different + # WorkerGroup for different models + resource_pool = RayResourcePool( + process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name + ) + self.resource_pool_dict[resource_pool_name] = resource_pool + + self._check_resource_available() + + def get_resource_pool(self, role: Role) -> RayResourcePool: + """Get the resource pool of the worker_cls""" + return self.resource_pool_dict[self.mapping[role]] + + def get_n_gpus(self) -> int: + """Get the number of gpus in this cluster.""" + return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) + + def _check_resource_available(self): + """Check if the resource pool can be satisfied in this ray cluster.""" + node_available_resources = ray._private.state.available_resources_per_node() + node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()} + + # check total required gpus can be satisfied + total_available_gpus = sum(node_available_gpus.values()) + total_required_gpus = sum( + [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] + ) + if total_available_gpus < total_required_gpus: + raise ValueError( + f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" + ) + + # check each resource pool can be satisfied, O(#resource_pools * #nodes) + for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): + num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes) + for node, available_gpus in node_available_gpus.items(): + if available_gpus >= num_gpus: + node_available_gpus[node] -= num_gpus + num_nodes -= 1 + if num_nodes == 0: + break + if num_nodes > 0: + raise ValueError( + f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this " + f"ray cluster" + ) + + +def _compute_response_info(batch: DataProto) -> dict[str, Any]: + """Placeholder: Computes prompt and response lengths.""" + try: + # Assuming 'prompts' and 'responses' keys exist after generation/union + prompt_len = batch.batch["prompts"].shape[1] + resp_len = batch.batch["responses"].shape[1] + # This is simplified - real implementation might use attention masks + # to get actual lengths per sample. + batch_size = batch.batch.batch_size[0] + prompt_lengths_tensor = torch.full((batch_size,), prompt_len, dtype=torch.float32, device=batch.batch.device) + response_lengths_tensor = torch.full((batch_size,), resp_len, dtype=torch.float32, device=batch.batch.device) + + # Try getting actual lengths from attention mask if possible (more accurate) + if "response_mask" in batch.batch: + response_lengths_tensor = batch.batch["response_mask"].sum(dim=1).float() + # if "attention_mask" in batch.batch and "response_mask" in batch.batch: + # full_mask = batch.batch["attention_mask"] + # resp_mask = batch.batch["response_mask"] + # Infer prompt mask length based on where response mask starts or total length + # This logic depends heavily on how your masks are constructed. + # Example: prompt_lengths_tensor = full_mask.sum(dim=1).float() - response_lengths_tensor + # Fallback to using prompt shape if mask logic is complex: + prompt_lengths_tensor = torch.tensor( + [batch.batch["prompts"].shape[1]] * batch_size, dtype=torch.float32, device=batch.batch.device + ) + + return { + "prompt_length": prompt_lengths_tensor, + "response_length": response_lengths_tensor, + "max_response_length": resp_len, + "max_prompt_length": prompt_len, # Or from config if fixed padding + } + except KeyError as e: + print(f"Warning: Missing key in _compute_response_info: {e}. Returning defaults.") + # Return default/dummy values if keys are missing + b_size = batch.batch.batch_size[0] if batch.batch.batch_size else 1 + max_resp = batch.batch.get("responses").shape[1] if batch.batch.get("responses") is not None else 0 + max_prompt = batch.batch.get("prompts").shape[1] if batch.batch.get("prompts") is not None else 0 + return { + "prompt_length": torch.zeros(b_size), + "response_length": torch.zeros(b_size), + "max_response_length": max_resp, + "max_prompt_length": max_prompt, + } + + +# --- Modified Metric Function --- +def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]: + """ + Computes and returns metrics relevant for the DPO-like process. + Assumes 'batch' contains results after generation and preference marking, + potentially including 'dpo_logits', 'preferences', 'chosen_logps', etc. + Removes PPO-specific advantage/return/critic metrics. + """ + print("---- [DEBUG] Computing DPO Data Metrics ----") + metrics = {} + try: + # --- Scores and Rewards (from reward_fn) --- + if "token_level_scores" in batch.batch and batch.batch["token_level_scores"] is not None: + sequence_score = batch.batch["token_level_scores"].sum(-1) + metrics.update( + { + "reward/score/mean": torch.mean(sequence_score).item(), + "reward/score/max": torch.max(sequence_score).item(), + "reward/score/min": torch.min(sequence_score).item(), + } + ) + else: + print("DEBUG compute_dpo_data_metrics: 'token_level_scores' not found.") + + if "token_level_rewards" in batch.batch and batch.batch["token_level_rewards"] is not None: + sequence_reward = batch.batch["token_level_rewards"].sum(-1) + metrics.update( + { + "reward/rewards/mean": torch.mean(sequence_reward).item(), + "reward/rewards/max": torch.max(sequence_reward).item(), + "reward/rewards/min": torch.min(sequence_reward).item(), + } + ) + else: + print("DEBUG compute_dpo_data_metrics: 'token_level_rewards' not found.") + + # --- DPO Specific Metrics (if stored previously) --- + if "dpo_logits" in batch.batch and batch.batch["dpo_logits"] is not None: + metrics["actor/dpo_logits"] = batch.batch["dpo_logits"].mean().item() + else: + print("DEBUG compute_dpo_data_metrics: 'dpo_logits' not found.") + + if "chosen_logps" in batch.batch and batch.batch["chosen_logps"] is not None: + metrics["actor/chosen_logps"] = batch.batch["chosen_logps"].mean().item() + else: + print("DEBUG compute_dpo_data_metrics: 'chosen_logps' not found.") + + if "rejected_logps" in batch.batch and batch.batch["rejected_logps"] is not None: + metrics["actor/rejected_logps"] = batch.batch["rejected_logps"].mean().item() + else: + print("DEBUG compute_dpo_data_metrics: 'rejected_logps' not found.") + + # Add metrics based on the 'preferences' mask if available + # if "preferences" in batch.batch and batch.batch["preferences"] is not None: + # prefs_mask = batch.batch["preferences"] # Shape [batch_size * n] + # Calculate accuracy based on RM scores (assuming higher score -> True in mask) + # Requires chosen/rejected scores to be available or recalculated + # This is complex here, better calculated in the main loop or update function + + # --- Length Metrics --- + response_info = _compute_response_info(batch) + prompt_length = response_info["prompt_length"] + response_length = response_info["response_length"] + max_response_length = response_info["max_response_length"] + max_prompt_length = response_info["max_prompt_length"] # Use calculated or from config + + metrics.update( + { + "response_length/mean": torch.mean(response_length).item(), + "response_length/max": torch.max(response_length).item(), + "response_length/min": torch.min(response_length).item(), + "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()).item(), + "prompt_length/mean": torch.mean(prompt_length).item(), + "prompt_length/max": torch.max(prompt_length).item(), + "prompt_length/min": torch.min(prompt_length).item(), + # Prompt clip ratio might need adjustment based on how max_prompt_length is defined + "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).item(), + } + ) + + except KeyError as e: + print(f"ERROR in compute_dpo_data_metrics: Missing key {e}") + except Exception as e: + print(f"ERROR in compute_dpo_data_metrics: {e}") + traceback.print_exc() + + print(f"---- [DEBUG] Calculated DPO Data Metrics: {list(metrics.keys())} ----") + return metrics + + +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): + responses = data.batch["responses"] + response_length = responses.size(1) + token_level_scores = data.batch["token_level_scores"] + batch_size = data.batch.batch_size[0] + attention_mask = data.batch["attention_mask"] + response_mask = attention_mask[:, -response_length:] + + # compute kl between ref_policy and current policy + # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. + kld = core_algos.kl_penalty( + data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty + ) # (batch_size, response_length) + kld = kld * response_mask + beta = kl_ctrl.value + + token_level_rewards = token_level_scores - beta * kld + + current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence + current_kl = torch.mean(current_kl, dim=0).item() + + # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 + kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) + data.batch["token_level_rewards"] = token_level_rewards + + metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} + + return data, metrics + + +def compute_response_mask(data: DataProto): + responses = data.batch["responses"] + response_length = responses.size(1) + attention_mask = data.batch["attention_mask"] + return attention_mask[:, -response_length:] + + +def compute_onlineDPO_pref(data: DataProto): + """ + Wrapper to compute DPO preference and add it to the DataProto batch. + Includes debugging prints. + """ + # print(f"\n---- [DEBUG] Entering compute_onlineDPO_pref ----") + # print(f" Input batch keys: {list(data.batch.keys())}") + + # Check inputs + rewards_tensor = data.batch.get("token_level_rewards") + mask_tensor = data.batch.get("response_mask") + + if rewards_tensor is None or mask_tensor is None: + print(" ERROR: Missing 'token_level_rewards' or 'response_mask' in input data!") + # Handle error case - maybe return original data or raise? + # Returning original data for now to potentially allow skipping + return data + + try: + preferences = core_algos.compute_onlinedpo_pref(token_level_rewards=rewards_tensor, response_mask=mask_tensor) + # Store the result + data.batch["preferences"] = preferences + + except AttributeError: + print("ERROR: Function 'compute_online_dpo_preference' not found in core_algos.py!") + # Assign dummy value or raise error + data.batch["preferences"] = None # Indicate failure + except Exception as e_pref: + print(f"ERROR during core_algos.compute_online_dpo_preference: {e_pref}") + import traceback + + traceback.print_exc() + data.batch["preferences"] = None # Indicate failure + + # print(f"---- [DEBUG] Exiting compute_onlineDPO_pref ----") + return data + + +@contextmanager +def _timer(name: str, timing_raw: dict[str, float]): + with Timer(name=name, logger=None) as timer: + yield + timing_raw[name] = timer.last + + +class RaySPINTrainer: + """ + Note that this trainer runs on the driver process on a single CPU/GPU node. + """ + + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + collate_fn=None, + train_sampler: Optional[Sampler] = None, + device_name=None, + ): + # assert get_torch_device().is_available(), 'cuda must be available on driver' + + self.tokenizer = tokenizer + self.processor = processor + self.config = config + self.reward_fn = reward_fn + self.val_reward_fn = val_reward_fn + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert self.hybrid_engine, "Currently, only support hybrid engine" + + if self.hybrid_engine: + assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = need_reference_policy(role_worker_mapping) + self.use_rm = need_reward_model(role_worker_mapping) + self.use_critic = False + self.ray_worker_group_cls = ray_worker_group_cls + self.validation_generations_logger = ValidationGenerationsLogger() + self.async_rollout_mode = False + self.device_name = device_name if device_name else self.config.trainer.device + + # define in-reward KL control + # kl loss control currently not suppoorted + if config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) + + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + + if train_dataset is None: + train_dataset = create_rl_dataset( + self.config.data.train_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("train_max_samples", -1), + ) + if val_dataset is None: + val_dataset = create_rl_dataset( + self.config.data.val_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("val_max_samples", -1), + ) + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=self.config.data.get("dataloader_num_workers", 8), + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=self.config.data.get("dataloader_num_workers", 8), + shuffle=False, + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print( + f"Size of train dataloader: {len(self.train_dataloader)}, " + f"Size of val dataloader: {len(self.val_dataloader)}" + ) + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + + def _maybe_log_val_generations(self, inputs, outputs, scores): + """Log a table of validation samples to the configured logger (wandb or swanlab)""" + + generations_to_log = self.config.trainer.log_val_generations + + if generations_to_log == 0: + return + + import numpy as np + + # Create tuples of (input, output, score) and sort by input text + samples = list(zip(inputs, outputs, scores, strict=True)) + samples.sort(key=lambda x: x[0]) # Sort by input text + + # Use fixed random seed for deterministic shuffling + rng = np.random.RandomState(42) + rng.shuffle(samples) + + # Take first N samples after shuffling + samples = samples[:generations_to_log] + + # Log to each configured logger + self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) + + def _validate(self): + data_source_lst = [] + reward_extra_infos_dict: dict[str, list] = defaultdict(list) + + # Lists to collect samples for the table + sample_inputs = [] + sample_outputs = [] + sample_scores = [] + + for test_data in self.val_dataloader: + test_batch = DataProto.from_single_dict(test_data) + + # repeat test batch + test_batch = test_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True + ) + + # we only do validation on rule-based rm + if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": + return {} + + # Store original inputs + input_ids = test_batch.batch["input_ids"] + # TODO: Can we keep special tokens except for padding tokens? + input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] + sample_inputs.extend(input_texts) + + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_inputs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"]) + if "raw_prompt" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + test_gen_batch = test_batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + + test_gen_batch.meta_info = { + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + "recompute_log_prob": False, + "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, + "validate": True, + } + print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") + + # pad to be divisible by dp_size + test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) + if not self.async_rollout_mode: + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + else: + test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) + + # unpad + test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) + print("validation generation end") + + # Store generated outputs + output_ids = test_output_gen_batch.batch["responses"] + output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] + sample_outputs.extend(output_texts) + + test_batch = test_batch.union(test_output_gen_batch) + + # evaluate using reward_function + result = self.val_reward_fn(test_batch, return_dict=True) + reward_tensor = result["reward_tensor"] + scores = reward_tensor.sum(-1).cpu().tolist() + sample_scores.extend(scores) + + reward_extra_infos_dict["reward"].extend(scores) + if "reward_extra_info" in result: + for key, lst in result["reward_extra_info"].items(): + reward_extra_infos_dict[key].extend(lst) + + data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) + + self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) + + # dump generations + val_data_dir = self.config.trainer.get("validation_data_dir", None) + if val_data_dir: + sample_gts = [ + item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch + ] + self._dump_generations( + inputs=sample_inputs, + outputs=sample_outputs, + gts=sample_gts, + scores=sample_scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=val_data_dir, + ) + + for key_info, lst in reward_extra_infos_dict.items(): + assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" + + data_sources = np.concatenate(data_source_lst, axis=0) + print(f"DEBUG: Data sources shape: {data_sources.shape}") # Added Print + print(f"DEBUG: reward_extra_infos_dict keys before processing: {reward_extra_infos_dict.keys()}") # Added Print + + data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) + print( + f"DEBUG: Output of process_validation_metrics (data_src2var2metric2val): {data_src2var2metric2val}" + ) # Added Print + metric_dict = {} + for data_source, var2metric2val in data_src2var2metric2val.items(): + core_var = "acc" if "acc" in var2metric2val else "reward" + for var_name, metric2val in var2metric2val.items(): + n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) + for metric_name, metric_val in metric2val.items(): + if ( + (var_name == core_var) + and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) + and (f"@{n_max}" in metric_name) + ): + metric_sec = "val-core" + else: + metric_sec = "val-aux" + pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" + metric_dict[pfx] = metric_val + + return metric_dict + + def init_workers(self): + """Init resource pool and worker group""" + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role="actor_rollout", + ) + self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) + self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref" + ) + self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different + # parallel size, + # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to + # different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + self.wg_dicts = [] + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 + self.wg_dicts.append(wg_dict) + + if self.use_critic: + self.critic_wg = all_wg["critic"] + self.critic_wg.init_model() + + if self.use_reference_policy: + self.ref_policy_wg = all_wg["ref"] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = all_wg["rm"] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg["actor_rollout"] + self.actor_rollout_wg.init_model() + + def _save_checkpoint(self): + # path: given_path + `/global_step_{global_steps}` + `/actor` + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) + + print(f"local_global_step_folder: {local_global_step_folder}") + actor_local_path = os.path.join(local_global_step_folder, "actor") + + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + ) + + remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) + if remove_previous_ckpt_in_save: + print( + "Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and " + "max_critic_ckpt_to_keep=1 instead" + ) + max_actor_ckpt_to_keep = ( + self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + max_critic_ckpt_to_keep = ( + self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + + self.actor_rollout_wg.save_checkpoint( + actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep + ) + + if self.use_critic: + critic_local_path = os.path.join(local_global_step_folder, "critic") + critic_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") + ) + self.critic_wg.save_checkpoint( + critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep + ) + + # save dataloader + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") + dataloader_state_dict = self.train_dataloader.state_dict() + torch.save(dataloader_state_dict, dataloader_local_path) + + # latest checkpointed iteration tracker (for atomic usage) + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) + with open(local_latest_checkpointed_iteration, "w") as f: + f.write(str(self.global_steps)) + + def _load_checkpoint(self): + if self.config.trainer.resume_mode == "disable": + return 0 + + # load from hdfs + if self.config.trainer.default_hdfs_dir is not None: + raise NotImplementedError("load from hdfs is not implemented yet") + else: + checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + + # find global_step_folder + if self.config.trainer.resume_mode == "auto": + if global_step_folder is None: + print("Training from scratch") + return 0 + else: + if self.config.trainer.resume_mode == "resume_path": + assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) + global_step_folder = self.config.trainer.resume_from_path + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + print(f"Load from checkpoint folder: {global_step_folder}") + # set global step + self.global_steps = int(global_step_folder.split("global_step_")[-1]) + + print(f"Setting global step to {self.global_steps}") + print(f"Resuming from {global_step_folder}") + + actor_path = os.path.join(global_step_folder, "actor") + critic_path = os.path.join(global_step_folder, "critic") + # load actor + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + # load critic + if self.use_critic: + self.critic_wg.load_checkpoint( + critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + + # load dataloader, + # TODO: from remote not implemented yet + dataloader_local_path = os.path.join(global_step_folder, "data.pt") + if os.path.exists(dataloader_local_path): + dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) + self.train_dataloader.load_state_dict(dataloader_state_dict) + else: + print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") + + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): + """Reorder the data on single controller such that each dp rank gets similar total tokens""" + attention_mask = batch.batch["attention_mask"] + batch_size = attention_mask.shape[0] + global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + world_size = self.actor_rollout_wg.world_size + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=world_size, equal_size=True + ) + # reorder based on index. The data will be automatically equally partitioned by dispatch function + global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) + batch.reorder(global_idx) + global_balance_stats = log_seqlen_unbalance( + seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix + ) + metrics.update(global_balance_stats) + + def fit_dpo(self): # Renamed for clarity as standard PPO loop + """ + The training loop of Online DPO using a periodically updated reference model. + The driver process calls worker groups for computation. + Advantage computation is replaced by DPO logic. + """ + import traceback # Ensure traceback is imported + + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + # Initialize logger + logger = None + try: + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True, throw_on_missing=False), + ) + except Exception as e: + print(f"Warning: Failed to initialize logger: {e}") + + self.global_steps = 0 + # Load checkpoint before doing anything + loaded_step = self._load_checkpoint() + self.global_steps = loaded_step + 1 if loaded_step is not None and loaded_step > 0 else 1 + print( + f"Starting Online DPO training from global step {self.global_steps}. " + f"Total steps: {self.total_training_steps}" + ) + print(f"Reference model update frequency: {self.config.trainer.get('ref_update_freq', 'Not Set')}") + + # Check if reference policy is configured correctly for this mode + if not self.use_reference_policy: + print( + "WARNING: 'use_reference_policy' is False. Periodic reference model update requires a " + "reference policy worker. DPO updates might fail or use incorrect logic." + ) + # Consider raising an error if strict adherence is required: + # raise ValueError("Periodic reference model update requires 'use_reference_policy' to be True " + # "and a configured reference worker.") + + # Perform validation before training + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + print("Running validation before Online DPO training...") + val_metrics = self._validate() + pprint(f"Initial validation metrics: {val_metrics}") + if logger and val_metrics: + logger.log(data=val_metrics, step=max(0, self.global_steps - 1)) + if self.config.trainer.get("val_only", False): + print("Validation only mode enabled. Exiting training.") + if logger and hasattr(logger, "finish"): + logger.finish() + return + + # Add tqdm progress bar + progress_bar = tqdm( + total=self.total_training_steps, + initial=self.global_steps, + desc="Online DPO Training Progress", + position=0, + leave=True, + ) + + last_val_metrics = None + should_stop = False + + for epoch in range(self.config.trainer.total_epochs): + if should_stop: + break + print(f"--- Starting Online DPO Epoch {epoch} ---") + try: + train_iterator = iter(self.train_dataloader) + except TypeError: + print("Warning: Dataloader is not iterable.") + train_iterator = self.train_dataloader # Fallback attempt + + for batch_idx, batch_dict in enumerate(train_iterator): + if self.global_steps > self.total_training_steps: + should_stop = True + break + + metrics = {} + timing_raw = {} + step_timer = Timer(logger=None) + ref_log_prob_computed = False # Flag to track if ref log probs were computed + + try: # Outer try-except for the whole step + step_timer.start() + with _timer("step", timing_raw): + batch: DataProto = DataProto.from_single_dict(batch_dict) + current_batch_size = batch.batch.batch_size[0] + print( + f"\n[Step {self.global_steps}, Batch {batch_idx}] Processing batch size: " + f"{current_batch_size}" + ) + + # --- Reference Model Update --- + ref_update_freq = self.config.trainer.get("ref_update_freq", -1) + if ( + self.use_reference_policy + and ref_update_freq > 0 + and self.global_steps % ref_update_freq == 0 + ): + print(f"\n[Step {self.global_steps}] Updating Reference Model Weights from Actor...") + try: + # --- This requires careful implementation with FSDP --- + # 1. Save actor state dict (potentially to CPU memory or disk) + # This needs to be done collectively across actor worker ranks. + # The checkpoint_manager might be adaptable, or use FSDP APIs directly. + # Example placeholder using a conceptual save/load mechanism: + actor_state_path = "/tmp/actor_state_mid" # Temporary path + self.actor_rollout_wg.save_checkpoint(actor_state_path) # Adapt save logic + + # 2. Load the state dict onto the reference model worker group + # This also needs collective loading on the ref worker ranks. + self.ref_policy_wg.load_checkpoint(actor_state_path, None, True) # Adapt load logic + + print(f"[Step {self.global_steps}] Reference Model Weights Updated.") + # Optionally remove the temporary state file + # os.remove(actor_state_path) # Needs rank-aware removal or shared storage + + except Exception as sync_e: + print(f"ERROR during reference model sync at step {self.global_steps}: {sync_e}") + traceback.print_exc() + + # Pop keys for generation + pop_batch_keys = ["input_ids", "attention_mask"] + if "position_ids" in batch.batch: + pop_batch_keys.append("position_ids") + pop_non_tensor_keys = ["raw_prompt_ids"] if "raw_prompt_ids" in batch.non_tensor_batch else [] + if "multi_modal_inputs" in batch.non_tensor_batch.keys(): + pop_non_tensor_keys.extend(["multi_modal_data", "multi_modal_inputs"]) + original_non_tensor_data = batch.non_tensor_batch + gen_batch = batch.pop( + batch_keys=pop_batch_keys, + non_tensor_batch_keys=pop_non_tensor_keys, + ) + gen_batch = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + # (Add Debug prints for gen_batch if needed) + + # Generate sequences (chosen/rejected pairs) + with _timer("gen", timing_raw): + try: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + # (Add Debug prints for gen_batch_output if needed) + except Exception as gen_e: + print(f"\n!!!!!!!! ERROR DURING GENERATION (Step {self.global_steps}) !!!!!!!!") + print(gen_e) + traceback.print_exc() + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + step_timer.stop() + continue + + # Combine original prompts with generated sequences + batch.non_tensor_batch = original_non_tensor_data # Restore non-tensor data + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(current_batch_size)], dtype=object + ) + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + # (Add Debug prints after union if needed) + + # Compute response mask (needed for ref logprob calc and DPO prep) + batch.batch["response_mask"] = compute_response_mask(batch) + + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + # --- Compute Log Probs for the CURRENT policy (used for KL if enabled, or ActorAsRef + # fallback) --- + # Note: For pure DPO with external ref, this 'old_log_probs' might not be strictly needed + # unless used for other metrics or a fallback. Keep it for now. + with _timer("policy_log_prob", timing_raw): + policy_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch) + batch = batch.union(policy_log_prob_output) # Adds 'old_log_probs' + # (Debug prints for old_log_probs) + + # --- Compute Log Probs using the EXTERNAL Reference Model --- + if self.use_reference_policy: + with _timer("ref_log_prob_dpo", timing_raw): + # print(f"---- [Step {self.global_steps}] DEBUG DPO: Calling compute_ref_log_prob ----") + try: + # 'batch' contains interleaved chosen/rejected sequences + ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob( + batch + ) # Returns DataProto with 'ref_log_prob' + batch = batch.union( + ref_log_prob_output + ) # Adds 'ref_log_prob' key [batch_size * n, seq_len] + ref_log_prob_computed = True # Mark success + # print(f"---- [Step {self.global_steps}] DEBUG DPO: ref_log_prob tensor shape: " + # f"{batch.batch['ref_log_prob'].shape} ----") + except Exception as ref_e: + print(f"ERROR computing reference log probs at step {self.global_steps}: {ref_e}") + traceback.print_exc() + batch.batch["ref_log_prob"] = None # Mark as failed + ref_log_prob_computed = False + else: + print( + "Warning: Skipping external reference log prob calculation as use_reference_policy " + "is False." + ) + # DPO update will likely fail unless ActorAsRef logic is re-enabled in dp_actor + + # --- Compute Rewards/Scores (used to determine preference) --- + with _timer("reward_calc", timing_raw): + # (Reward calculation logic using RM or reward_fn as before) + # ... Ensure this calculates 'token_level_rewards' or similar ... + if self.use_rm: + reward_tensor_rm = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor_rm) # Adds 'rm_scores' + + reward_extra_infos_dict = {} + try: + if self.reward_fn is None: + # print(f"---- [DEBUG Step {self.global_steps}] ERROR: self.reward_fn is None! " + # f"Using dummy rewards. ----") + # Use rm_scores if available, otherwise zeros + reward_tensor = batch.batch.get( + "rm_scores", torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32) + ) + else: + reward_result = self.reward_fn(batch, return_dict=True) + reward_tensor = reward_result["reward_tensor"] # Final combined reward + reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) + + except Exception: + # print(f'---- [DEBUG Step {self.global_steps}] Error in reward_fn call: {e}. ' + # f'Using dummy rewards. ----') + traceback.print_exc() + reward_tensor = torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32) + reward_extra_infos_dict = {} + + # Use 'token_level_rewards' as the key for preference calculation + batch.batch["token_level_rewards"] = reward_tensor + if reward_extra_infos_dict: + batch.non_tensor_batch.update( + {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + ) + + # --- Determine Preferences --- + # Uses 'token_level_rewards' to determine chosen/rejected based on score + batch = compute_onlineDPO_pref(batch) # Adds 'preferences' key + + # --- Prepare DPO Batch --- + dpo_update_batch_proto = None # Initialize + with _timer("prepare_dpo_batch", timing_raw): + try: + if "preferences" not in batch.batch or batch.batch["preferences"] is None: + raise ValueError("'preferences' key missing or None after compute_onlineDPO_pref.") + + # Check if reference log probs were computed successfully (if needed) + if self.use_reference_policy and not ref_log_prob_computed: + raise ValueError("Reference log probs required but failed to compute.") + + # Check required base keys + required_keys = ["input_ids", "attention_mask", "response_mask"] + for rk in required_keys: + if rk not in batch.batch or batch.batch[rk] is None: + raise KeyError(f"Required key '{rk}' missing from batch for DPO prep.") + + preferences_mask = batch.batch["preferences"] # Shape [batch_size * n] + not_preferences_mask = ~preferences_mask + + # Gather Chosen/Rejected Base Tensors + chosen_input_ids = batch.batch["input_ids"][preferences_mask] + chosen_attention_mask = batch.batch["attention_mask"][preferences_mask] + rejected_input_ids = batch.batch["input_ids"][not_preferences_mask] + rejected_attention_mask = batch.batch["attention_mask"][not_preferences_mask] + chosen_position_ids = ( + batch.batch.get("position_ids")[preferences_mask] + if "position_ids" in batch.batch + else None + ) + rejected_position_ids = ( + batch.batch.get("position_ids")[not_preferences_mask] + if "position_ids" in batch.batch + else None + ) + + # Create Labels + print("WARNING: Creating DPO labels using configured max_prompt_length...") + prompt_len = self.config.data.max_prompt_length + chosen_labels = chosen_input_ids.clone() + chosen_labels[:, :prompt_len] = -100 + rejected_labels = rejected_input_ids.clone() + rejected_labels[:, :prompt_len] = -100 + + # Calculate and Gather Reference Log Probs (Sequence Level) + if self.use_reference_policy: + ref_log_prob_tensor = batch.batch["ref_log_prob"] # Token level [bsz * n, seq_len] + response_mask_full = batch.batch[ + "response_mask" + ] # Response mask [bsz * n, seq_len] + ref_sequence_logps = (ref_log_prob_tensor * response_mask_full).sum( + dim=-1 + ) # Sequence level [bsz * n] + reference_chosen_logps = ref_sequence_logps[preferences_mask] + reference_rejected_logps = ref_sequence_logps[not_preferences_mask] + else: + # If not using external ref, DPO needs ActorAsRef logic in dp_actor + # We won't add the keys here, dp_actor will handle it (or fail if not modified) + print( + "Info: Not adding explicit reference logps to DPO batch " + "(use_reference_policy=False)." + ) + reference_chosen_logps = None # Explicitly None + reference_rejected_logps = None + + # Package Tensors + dpo_tensors = { + "chosen_input_ids": chosen_input_ids, + "chosen_attention_mask": chosen_attention_mask, + "chosen_labels": chosen_labels, + "rejected_input_ids": rejected_input_ids, + "rejected_attention_mask": rejected_attention_mask, + "rejected_labels": rejected_labels, + } + # Conditionally add reference logps if computed + if reference_chosen_logps is not None: + dpo_tensors["reference_chosen_logps"] = reference_chosen_logps + if reference_rejected_logps is not None: + dpo_tensors["reference_rejected_logps"] = reference_rejected_logps + # Add position ids if they exist + if chosen_position_ids is not None: + dpo_tensors["chosen_position_ids"] = chosen_position_ids + if rejected_position_ids is not None: + dpo_tensors["rejected_position_ids"] = rejected_position_ids + + # Prepare Meta Info + dpo_meta = { + "dpo_beta": OmegaConf.select(self.config.algorithm, "dpo_beta", default=0.1), + "dpo_loss_type": OmegaConf.select( + self.config.algorithm, "dpo_loss_type", default="sigmoid" + ), + "dpo_label_smoothing": OmegaConf.select( + self.config.algorithm, "dpo_label_smoothing", default=0.0 + ), + "use_reference_policy": self.use_reference_policy, + "reference_free": not self.use_reference_policy, # False if using external ref + "global_step": self.global_steps, + } + + dpo_update_batch_proto = DataProto.from_dict(tensors=dpo_tensors, meta_info=dpo_meta) + # print(f"---- [Step {self.global_steps}] DEBUG DPO: Prepared DPO Update Batch ----") + # print(f" Keys: {list(dpo_update_batch_proto.batch.keys())}") + # print(f" Meta Info: {dpo_meta}") + + except Exception as e_prep: + print(f"ERROR preparing DPO batch at step {self.global_steps}: {e_prep}") + traceback.print_exc() + dpo_update_batch_proto = None # Skip update on error + + # --- Actor Update Step --- + actor_output = None + if self.config.trainer.critic_warmup <= self.global_steps and dpo_update_batch_proto: + with _timer("update_actor", timing_raw): + # Pass the batch containing reference log probs (if computed) + # The modified update_actor_dpo expects them if reference_free=False + actor_output = self.actor_rollout_wg.update_actor_dpo(dpo_update_batch_proto) + if actor_output and "metrics" in actor_output.meta_info: + metrics.update(reduce_metrics(actor_output.meta_info["metrics"])) + elif dpo_update_batch_proto is None: + print( + f"Skipping actor update at step {self.global_steps} due to DPO batch preparation error." + ) + + # --- Validation and Saving --- + test_freq = OmegaConf.select(self.config.trainer, "test_freq", default=-1) + is_last_step = self.global_steps >= self.total_training_steps + if ( + self.val_reward_fn is not None + and test_freq > 0 + and (is_last_step or self.global_steps % test_freq == 0) + ): + print(f"\nRunning DPO validation at step {self.global_steps}...") + val_timing_raw = {} + with _timer("testing", val_timing_raw): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + if val_metrics: + metrics["time/validation_run"] = val_timing_raw.get("testing", 0) + metrics.update(val_metrics) + else: + print("Validation skipped or returned no metrics.") + + save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1) + if save_freq > 0 and (is_last_step or self.global_steps % save_freq == 0): + print(f"\nSaving DPO checkpoint at step {self.global_steps}...") + with _timer("save_checkpoint", timing_raw): + self._save_checkpoint() # Saves actor (and potentially critic if used elsewhere) + metrics["time/save_checkpoint"] = timing_raw.get("save_checkpoint", 0) + + # --- End main step timer context --- + + # --- Metrics calculation AFTER the 'step' timer block --- + metrics.update(compute_dpo_data_metrics(batch=batch)) # Use DPO-specific metrics + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + n_gpus = self.resource_pool_manager.get_n_gpus() + if "step" in timing_raw: + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + else: + print( + f"Warning: 'step' key missing from timing_raw at step {self.global_steps}. " + f"Skipping throughput." + ) + + step_timer.stop() + metrics["time/step"] = step_timer.last + + # Log metrics + log_freq = OmegaConf.select(self.config.trainer, "log_freq", default=1) + if logger and self.global_steps % log_freq == 0: + log_payload = metrics.copy() + # Add learning rate to log payload + if actor_output and "actor/lr" in metrics: + log_payload["actor/lr"] = metrics["actor/lr"] + + print(f"[Step {self.global_steps} DPO] Logging Step Payload Keys: {list(log_payload.keys())}") + try: + logger.log(data=log_payload, step=self.global_steps) + except Exception as e: + print(f"Logging failed at step {self.global_steps}: {e}") + + # Update progress bar + postfix_metrics = { + k: f"{v:.3f}" if isinstance(v, float) else v + for k, v in metrics.items() + if isinstance(v, int | float) + } + progress_bar.set_postfix(postfix_metrics) + + except Exception as step_e: + print(f"\n!!!!!!!! ERROR DURING DPO Step {self.global_steps} !!!!!!!!") + print(f"Caught Exception: {step_e}") + traceback.print_exc() + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + step_timer.stop() + should_stop = True + break + + if is_last_step or should_stop: + print(f"Stopping DPO training at step {self.global_steps}.") + break + + self.global_steps += 1 + progress_bar.update(1) + + # End of epoch handling + if hasattr(self.train_dataloader, "reset"): + try: + self.train_dataloader.reset() + except Exception as e: + print(f"Warning: Failed to reset train dataloader state: {e}") + if should_stop: + break + + # --- Final cleanup and logging --- + progress_bar.close() + final_step = max(0, self.global_steps - 1) + print(f"Online DPO Training finished at step {final_step}.") + # Save final checkpoint + save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1) + if not self.config.trainer.get("val_only", False) and (save_freq <= 0 or final_step % save_freq != 0): + print(f"Saving final DPO checkpoint at step {final_step}...") + self._save_checkpoint() + + # Final validation run + if self.val_reward_fn and last_val_metrics is None and not self.config.trainer.get("val_only", False): + print("Running final validation...") + last_val_metrics = self._validate() + if last_val_metrics and logger: + last_val_metrics["final_validation"] = True + try: + logger.log(data=last_val_metrics, step=final_step) + except Exception as e: + print(f"[Final Val Metrics Log Error]: {e}") + + pprint(f"Final validation metrics: {last_val_metrics}") + if logger and hasattr(logger, "finish"): + logger.finish() + print("Online DPO Training Run Complete.") diff --git a/ICL/LV/code/README.md b/ICL/LV/code/README.md new file mode 100644 index 0000000000000000000000000000000000000000..702952597cfdaaf0ce4dbf478149e01ea535c038 --- /dev/null +++ b/ICL/LV/code/README.md @@ -0,0 +1,66 @@ +Unified Multi-Model VQA Codebase + +目的 +- 这一套代码是“模型无关”的通用评测/数据/提示构造层;所有模型仅通过“适配器”接入。 +- 通用输入是 OpenAI 扁平内容序列(image→text;示例用 [REQUEST]/[RESPONSE];查询的 [RESPONSE] 为空)。 + +目录 +- core/ + - prompting/openai_segments.py 扁平序列构造与落盘适配 + - datasets/m3it_reader.py M3IT 统一读取 & base64 图片缓存 + - metrics/metrics.py Token‑F1、BERTScore‑F1 等 + - eval/ 与模型无关的评测脚本(调用 adapters) + - zero_shot_vqa.py / random_k_shot_vqa.py + - eval_textual_retriever_vqa.py / eval_visual_retriever_vqa.py / eval_multimodal_retriever_vqa.py + - order 评测(统一缓存 + 独立指标脚本): + - order_eval_core.py(内部调用) / _modal_order.py(内部调用) + - eval_order_caption_bertscore.py / eval_order_caption_cider.py + - eval_order_classification_accuracy.py / eval_order_classification_f1.py + - eval_order_reasoning_accuracy.py / eval_order_reasoning_ras.py + - eval_order_vqa_bertscore.py / eval_order_vqa_tokenf1.py +- adapters/ + - idefics2_adapter.py + - qwen_vl_adapter.py + - qwen3vl_adapter.py + - gemma3_adapter.py + +使用 +- 例:零样本(Idefics2) + python3 -m core.eval.zero_shot_vqa \ + --adapter idefics2 \ + --model-path /path/to/idefics2-8b \ + --dataset-root /path/to/M3IT \ + --split test --total-samples 500 \ + --instruction-image "C:\\Users\\you\\instruction.png" --dump-first 2 + +- 例:随机 few‑shot(Qwen‑VL) + python3 -m core.eval.random_k_shot_vqa \ + --adapter qwen-vl \ + --model-path /path/to/Qwen-VL \ + --dataset-root /path/to/M3IT \ + --split test --k-shots 3 --total-samples 500 \ + --use-paper-instruction --instruction-image "C:\\Users\\you\\instruction.png" + +- 例:模态顺序评测(以 VQA Token-F1 为例) + python3 -m core.eval.eval_order_vqa_tokenf1 \ + --adapter idefics2 \ + --model-path /path/to/idefics2-8b \ + --dataset-root /path/to/M3IT \ + --retriever-model-path /path/to/BridgeTower-or-CLIP \ + --orders image-text,text-image,text-image-text \ + --k-shots 3 --total-samples 500 --split val + --adapter qwen-vl \ + --model-path /path/to/Qwen-VL \ + --dataset-root /path/to/M3IT \ + --split test --k-shots 3 --total-samples 500 \ + --use-paper-instruction --instruction-image "C:\\Users\\you\\instruction.png" + +约定 +- 适配器接口见 adapters/*.py: + - create(model_path: str) -> Adapter + - Adapter.generate_from_segments(segs: List[dict], temperature: float, top_p: float, max_new_tokens: int) -> str + - 可选:Adapter.generate_single(image_path: str, prompt: str, ...) + +说明 +- 适配器与通用源码彻底分离;你可以只替换 adapters/xxx_adapter.py 即可对接新模型。 +- Windows 路径/BASE64/data:URL 的图片由 prompting/openai_segments.py 自动兼容。 diff --git a/ICL/LV/code/SFT/__pycache__/dataset.cpython-310.pyc b/ICL/LV/code/SFT/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6efdbc852380baab960f6c5b0f3d12890c24e5d4 Binary files /dev/null and b/ICL/LV/code/SFT/__pycache__/dataset.cpython-310.pyc differ diff --git a/ICL/LV/code/SFT/build_icl_eval_sharegpt.py b/ICL/LV/code/SFT/build_icl_eval_sharegpt.py new file mode 100644 index 0000000000000000000000000000000000000000..53f1e331a00dd832e09a1dced212f926b4a078f6 --- /dev/null +++ b/ICL/LV/code/SFT/build_icl_eval_sharegpt.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python3 +""" +Build a prompt-only ShareGPT-style eval set for deciding vs . + +Prompt format is aligned with build_icl_dataset.py: + instruction + + "Question: ...\\nAction:" + +But for evaluation we keep ONLY the initial human turn in `conversations` to avoid leaking labels. +Gold labels are stored outside the prompt: + - expected_first_tag: "" or "" (NOT included in conversations) + - answer: used for offline checking (NOT included in conversations) + - shots: for RET samples only, used for the follow-up step after model outputs + +Important: + - Never use train split for eval; recommend val/test/dev. + - Optionally excludes any uid already present in an existing training jsonl to avoid overlap. +""" + +import argparse +import json +import random +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Set + + +# Add code root to PYTHONPATH for core/ imports +CODE_ROOT = Path(__file__).resolve().parents[1] +if str(CODE_ROOT) not in sys.path: + sys.path.insert(0, str(CODE_ROOT)) + +from core.datasets.m3it_reader import iter_m3it_samples, load_instructions # noqa: E402 + + +@dataclass(frozen=True) +class PoolItem: + image_path: str + description: str + subdir: str + + +@dataclass(frozen=True) +class QueryItem: + image_path: str + question: str + answer: str + subdir: str + uid: str + + +def _extract_uid(raw: Dict, fallback: str) -> str: + if isinstance(raw, dict): + for k in ("id", "image_id"): + v = raw.get(k) + if isinstance(v, (str, int)): + return str(v) + meta = raw.get("meta") if isinstance(raw.get("meta"), dict) else {} + for k in ("img_id", "id", "image_id"): + v = meta.get(k) + if isinstance(v, (str, int)): + return str(v) + return fallback + + +def discover_subdirs(dataset_root: Path, category: str) -> List[str]: + base = dataset_root / "data" / category + if not base.exists(): + return [] + out: List[str] = [] + for p in sorted(base.iterdir()): + if p.is_dir(): + out.append(f"{category}/{p.name}") + return out + + +def pick_instruction(insts: List[str], rng: random.Random) -> str: + if insts: + s = rng.choice(insts) + if isinstance(s, str) and s.strip(): + return s.strip() + return "Please answer the question based on the image." + + +def load_exclude_uids(path: Optional[str]) -> Set[str]: + if not path: + return set() + p = Path(path) + if not p.exists(): + return set() + + out: Set[str] = set() + with p.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + except Exception: + continue + if not isinstance(obj, dict): + continue + uid = obj.get("uid") + if isinstance(uid, (str, int)): + out.add(str(uid)) + continue + sid = obj.get("id") + if isinstance(sid, (str, int)): + out.add(str(sid)) + return out + + +def to_rel(path: str, root: Path) -> str: + try: + return str(Path(path).relative_to(root)) + except Exception: + return path + + +def build_pool_for_subdir( + *, + dataset_root: Path, + subdir: str, + split: str, + cache_dir: Path, + target_n: int, + max_samples_scan: int, +) -> List[PoolItem]: + items: List[PoolItem] = [] + try: + iterable = iter_m3it_samples( + str(dataset_root), + subdir, + split=split, + cache_dir=str(cache_dir), + max_samples=None, + ) + except FileNotFoundError: + return [] + + for idx, smp in enumerate(iterable): + if max_samples_scan > 0 and idx >= max_samples_scan: + break + if not smp.answers: + continue + desc = (smp.answers[0] or "").strip() + if not desc: + continue + items.append(PoolItem(smp.image_path, desc, subdir)) + if target_n > 0 and len(items) >= target_n: + break + return items + + +def collect_query_pool( + *, + dataset_root: Path, + subdirs: List[str], + split: str, + cache_dir: Path, + exclude_uids: Set[str], + target_n: int, + seed: int, + max_samples_per_subdir: int, +) -> List[QueryItem]: + rng = random.Random(seed) + subdirs = list(subdirs) + rng.shuffle(subdirs) + + seen: Set[str] = set() + out: List[QueryItem] = [] + + for subdir in subdirs: + taken = 0 + try: + iterable = iter_m3it_samples( + str(dataset_root), + subdir, + split=split, + cache_dir=str(cache_dir), + max_samples=None, + ) + except FileNotFoundError: + continue + + for i, smp in enumerate(iterable): + if max_samples_per_subdir > 0 and taken >= max_samples_per_subdir: + break + q = (smp.text or "").strip() + if not q: + continue + if not smp.answers: + continue + ans = (smp.answers[0] or "").strip() + if not ans: + continue + uid = _extract_uid(smp.raw, f"{subdir}:{i:08d}") + if uid in exclude_uids: + continue + if uid in seen: + continue + seen.add(uid) + out.append(QueryItem(smp.image_path, q, ans, subdir, uid)) + taken += 1 + if target_n > 0 and len(out) >= target_n: + return out + return out + + +def select_shots( + pool: List[PoolItem], + k: int, + rng: random.Random, + exclude_image: Optional[str] = None, +) -> List[PoolItem]: + if not pool or k <= 0: + return [] + cand = [p for p in pool if p.image_path != exclude_image] + if not cand: + cand = pool + if len(cand) >= k: + return rng.sample(cand, k=k) + return [rng.choice(cand) for _ in range(k)] + + +def write_jsonl(path: Path, records: List[Dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + for r in records: + f.write(json.dumps(r, ensure_ascii=False) + "\n") + + +def build_prompt_only_record( + *, + uid: str, + instruction: str, + image_rel: str, + question: str, + expected_first_tag: str, + answer: str, + category: str, + subdir: str, + shots: List[Dict], + k_shot: int, +) -> Dict: + human = [] + if instruction: + human.append(instruction.strip()) + human.append("") + human.append(f"Question: {question}\nAction:") + human_value = "\n".join([x for x in human if x]).strip() + + return { + "id": uid, + "images": [image_rel], + "conversations": [ + {"from": "human", "value": human_value}, + ], + "expected_first_tag": expected_first_tag, + "answer": answer, + "k_shot": k_shot, + "shots": shots, + "category": category, + "subdir": subdir, + "instruction": instruction, + "query": {"image": image_rel, "question": question}, + } + + +def main() -> int: + ap = argparse.ArgumentParser(description="Build prompt-only eval set (ShareGPT jsonl) for / decision.") + ap.add_argument("--dataset-root", default="/workspace/M3IT") + ap.add_argument("--output-dir", default="/workspace/M3IT_new/ICL_eval") + ap.add_argument("--category", default="vqa") + ap.add_argument("--split", default="val", help="Never use train; recommend val/test/dev.") + ap.add_argument("--pool-split", default="val", help="Never use train; recommend val/test/dev.") + ap.add_argument("--seed", type=int, default=42) + + ap.add_argument("--total", type=int, default=100) + ap.add_argument("--ret-ratio", type=float, default=0.5) + ap.add_argument("--query-pool-size", type=int, default=1000, help="How many queries to collect before sampling.") + ap.add_argument("--max-samples-per-subdir", type=int, default=2000, help="Scan cap per subdir when collecting queries.") + ap.add_argument("--pool-size-per-subdir", type=int, default=2000, help="Max pool size to build per subdir (for shots).") + ap.add_argument("--pool-scan-per-subdir", type=int, default=4000, help="Scan cap per subdir when building pools.") + ap.add_argument("--shot-k-min", type=int, default=1) + ap.add_argument("--shot-k-max", type=int, default=3) + + ap.add_argument( + "--exclude-uids-from", + default="/workspace/M3IT_new/ICL/vqa/merged_shuffled_sharegpt.jsonl", + help="Optional jsonl to exclude uids/ids (to avoid overlap with training).", + ) + ap.add_argument("--overwrite", action="store_true") + ap.add_argument("--output", default=None, help="Default: {output_dir}/{category}/eval_sharegpt_{total}.jsonl") + args = ap.parse_args() + + if args.split.strip().lower() == "train" or args.pool_split.strip().lower() == "train": + raise ValueError("split/pool-split=train is not allowed for eval set") + if args.total <= 0: + raise ValueError("total must be > 0") + if not (0.0 <= args.ret_ratio <= 1.0): + raise ValueError("ret-ratio must be in [0, 1]") + if args.shot_k_min <= 0 or args.shot_k_max < args.shot_k_min: + raise ValueError("invalid shot-k range") + + dataset_root = Path(args.dataset_root) + output_dir = Path(args.output_dir) + cache_dir = output_dir / "_image_cache" + cache_dir.mkdir(parents=True, exist_ok=True) + + out_path = Path( + args.output + if args.output + else str(output_dir / args.category / f"eval_sharegpt_{args.total}.jsonl") + ) + if out_path.exists() and not args.overwrite: + raise FileExistsError(f"Output exists: {out_path} (use --overwrite to replace)") + + subdirs = discover_subdirs(dataset_root, args.category) + if not subdirs: + raise FileNotFoundError(f"No subdirs found under {dataset_root}/data/{args.category}") + + exclude_uids = load_exclude_uids(args.exclude_uids_from) + + # Load instructions once per subdir. + inst_map: Dict[str, List[str]] = {sd: load_instructions(dataset_root, sd) for sd in subdirs} + + query_pool_target = max(args.total, args.query_pool_size) + queries = collect_query_pool( + dataset_root=dataset_root, + subdirs=subdirs, + split=args.split, + cache_dir=cache_dir, + exclude_uids=exclude_uids, + target_n=query_pool_target, + seed=args.seed, + max_samples_per_subdir=args.max_samples_per_subdir, + ) + if len(queries) < args.total: + raise RuntimeError(f"Not enough queries collected: got {len(queries)}/{args.total}. " + f"Try increasing --max-samples-per-subdir or changing --split.") + + rng = random.Random(args.seed) + ret_n = int(round(args.total * args.ret_ratio)) + ret_n = max(0, min(args.total, ret_n)) + ans_n = args.total - ret_n + + labels = ["RET"] * ret_n + ["ANS"] * ans_n + rng.shuffle(labels) + + # Pools are built lazily per subdir for RET samples only. + pool_map: Dict[str, List[PoolItem]] = {} + + records: List[Dict] = [] + used_uids: Set[str] = set() + + for label in labels: + for _try in range(2000): + q = rng.choice(queries) + if q.uid in used_uids: + continue + + inst = pick_instruction(inst_map.get(q.subdir, []), rng) + image_rel = to_rel(q.image_path, output_dir) + + if label == "ANS": + used_uids.add(q.uid) + records.append( + build_prompt_only_record( + uid=q.uid, + instruction=inst, + image_rel=image_rel, + question=q.question, + expected_first_tag="", + answer=q.answer, + category=args.category, + subdir=q.subdir, + shots=[], + k_shot=0, + ) + ) + break + + # RET case: attach hidden shots for the follow-up step. + if q.subdir not in pool_map: + pool_map[q.subdir] = build_pool_for_subdir( + dataset_root=dataset_root, + subdir=q.subdir, + split=args.pool_split, + cache_dir=cache_dir, + target_n=args.pool_size_per_subdir, + max_samples_scan=args.pool_scan_per_subdir, + ) + pool = pool_map.get(q.subdir, []) + if not pool: + continue + + k = rng.randint(args.shot_k_min, args.shot_k_max) + shots_items = select_shots(pool, k, rng, exclude_image=q.image_path) + shots = [ + {"image": to_rel(s.image_path, output_dir), "description": s.description} + for s in shots_items + ] + used_uids.add(q.uid) + records.append( + build_prompt_only_record( + uid=q.uid, + instruction=inst, + image_rel=image_rel, + question=q.question, + expected_first_tag="", + answer=q.answer, + category=args.category, + subdir=q.subdir, + shots=shots, + k_shot=k, + ) + ) + break + else: + raise RuntimeError(f"Failed to sample enough records for label={label}. " + f"Try increasing --query-pool-size or relaxing --exclude-uids-from.") + + rng.shuffle(records) + write_jsonl(out_path, records) + + # Lightweight summary to stdout. + ret_cnt = sum(1 for r in records if r.get("expected_first_tag") == "") + ans_cnt = len(records) - ret_cnt + print(f"[OK] wrote={len(records)} ret={ret_cnt} ans={ans_cnt} -> {out_path}") + print(f"[INFO] image_root (for eval): {output_dir}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ICL/LV/code/SFT/check_kshot_ret_ans.py b/ICL/LV/code/SFT/check_kshot_ret_ans.py new file mode 100644 index 0000000000000000000000000000000000000000..8567d189a38ad14f32996a6a29308022a8ae0f87 --- /dev/null +++ b/ICL/LV/code/SFT/check_kshot_ret_ans.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Check whether the model outputs or under different shot settings. + +0-shot: only query image + question. +K-shot (K>=1): after the model outputs , append 1 shot (image+description) and + ask again. If it still outputs , append another shot, and so on. + We record whether it outputs / at each step. +""" + +import argparse +import json +import os +import random +import re +from typing import Any, Dict, List, Optional, Tuple + +import torch +from PIL import Image +from transformers import AutoProcessor, Qwen3VLForConditionalGeneration + + +TAG_RE = re.compile(r"(|)") + + +def _extract_tag(text: str) -> Optional[str]: + match = TAG_RE.search(text) + return match.group(1) if match else None + + +def _resolve_path(root: str, maybe_rel: str) -> str: + if os.path.isabs(maybe_rel): + return maybe_rel + return os.path.normpath(os.path.join(root, maybe_rel)) + + +def _split_user_text_with_images( + text: str, image_paths: List[str] +) -> Tuple[List[Dict[str, str]], List[str]]: + parts = text.split("") + content: List[Dict[str, str]] = [] + used: List[str] = [] + for i, part in enumerate(parts): + part = part.strip() + if part: + content.append({"type": "text", "text": part}) + if i < len(parts) - 1: + if not image_paths: + raise ValueError("用户文本里 数量 > images 列表长度") + img_path = image_paths.pop(0) + used.append(img_path) + content.append({"type": "image", "image": img_path}) + return content, used + + +def _append_human_turn( + *, + messages: List[Dict[str, Any]], + pil_images: List[Image.Image], + image_root: str, + text: str, + images_all: List[str], + image_cursor: int, +) -> int: + n_placeholders = text.count("") + img_paths = images_all[image_cursor : image_cursor + n_placeholders] + if len(img_paths) != n_placeholders: + raise ValueError("images 列表长度 < 占位符数量") + + user_content, used_paths = _split_user_text_with_images(text, img_paths.copy()) + for p in used_paths: + p = _resolve_path(image_root, p) if not os.path.isabs(p) else p + if not os.path.exists(p): + raise FileNotFoundError(p) + with Image.open(p) as img: + pil_images.append(img.convert("RGB")) + + messages.append({"role": "user", "content": user_content}) + return image_cursor + n_placeholders + + +def _append_shot_turn( + *, + messages: List[Dict[str, Any]], + pil_images: List[Image.Image], + image_root: str, + image_path: str, + description: str, +) -> None: + img_path = _resolve_path(image_root, image_path) + if not os.path.exists(img_path): + raise FileNotFoundError(img_path) + with Image.open(img_path) as im: + pil_images.append(im.convert("RGB")) + messages.append( + { + "role": "user", + "content": [ + {"type": "image", "image": img_path}, + {"type": "text", "text": f"Description: {description}"}, + ], + } + ) + + +def _build_base_messages(obj: Dict[str, Any], image_root: str) -> Tuple[List[Dict[str, Any]], List[Image.Image]]: + conversations = obj.get("conversations") + if not isinstance(conversations, list) or not conversations: + raise ValueError("样本缺少 conversations") + images_rel = obj.get("images") or [] + if not isinstance(images_rel, list): + raise ValueError("images 字段不是 list") + + messages: List[Dict[str, Any]] = [] + pil_images: List[Image.Image] = [] + image_cursor = 0 + + # Use the first human turn as query prompt. + human = None + for t in conversations: + if t.get("from") == "human": + human = t + break + if human is None: + raise ValueError("没有 human turn") + + image_cursor = _append_human_turn( + messages=messages, + pil_images=pil_images, + image_root=image_root, + text=str(human.get("value", "")), + images_all=images_rel, + image_cursor=image_cursor, + ) + return messages, pil_images + + +def _pick_shots_from_pool( + pool: List[Dict[str, str]], + k: int, + rng: random.Random, + exclude_image: Optional[str], +) -> List[Dict[str, str]]: + if k <= 0 or not pool: + return [] + cand = [p for p in pool if p.get("image") != exclude_image] + if not cand: + cand = pool + if len(cand) >= k: + return rng.sample(cand, k=k) + return [rng.choice(cand) for _ in range(k)] + + +def main() -> int: + ap = argparse.ArgumentParser(description="Check / outputs under 0/1/2/3-shot settings.") + ap.add_argument("--model", required=True, help="HF model dir") + ap.add_argument( + "--data", + default="/workspace/M3IT_new/ICL_eval/vqa/eval_sharegpt_100.jsonl", + help="Prompt-only eval jsonl (has conversations/images/shots).", + ) + ap.add_argument("--image-root", default="/workspace/M3IT_new/ICL_eval") + ap.add_argument("--num-samples", type=int, default=20) + ap.add_argument("--seed", type=int, default=42) + ap.add_argument("--k-list", default="0,1,2,3", help="Comma-separated shot counts to report.") + ap.add_argument("--max-new-tokens", type=int, default=128) + ap.add_argument("--device", default="cuda:0") + ap.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16") + ap.add_argument("--print-samples", action="store_true", help="Print each sample input/output.") + args = ap.parse_args() + + rng = random.Random(args.seed) + k_list = [int(x.strip()) for x in args.k_list.split(",") if x.strip()] + if not k_list: + raise ValueError("k-list is empty") + max_k = max(k_list) + + # Load dataset + data: List[Dict[str, Any]] = [] + with open(args.data, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + data.append(json.loads(line)) + if not data: + raise ValueError("empty data") + + # Build global shot pool + pool: List[Dict[str, str]] = [] + for obj in data: + shots = obj.get("shots") or [] + if isinstance(shots, list): + for s in shots: + if not isinstance(s, dict): + continue + img = s.get("image") + desc = s.get("description") + if isinstance(img, str) and isinstance(desc, str) and img and desc: + pool.append({"image": img, "description": desc}) + if not pool: + print("[WARN] shot pool is empty, k-shot tests may be skipped") + + # Sample records + samples = rng.sample(data, k=min(args.num_samples, len(data))) + + dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 + device = torch.device(args.device) + processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True) + model = Qwen3VLForConditionalGeneration.from_pretrained( + args.model, dtype=dtype, trust_remote_code=True + ).to(device) + model.eval() + + summary: Dict[int, Dict[str, int]] = { + k: {"RET": 0, "ANS": 0, "NONE": 0, "REACHED": 0} for k in k_list + } + + for obj in samples: + uid = obj.get("id") or obj.get("uid") or "unknown" + query = obj.get("query") or {} + query_image = query.get("image") + + messages, pil_images = _build_base_messages(obj, args.image_root) + + # Pre-sample shots for this sample (use first N as we go). + shots_all = _pick_shots_from_pool(pool, max_k, rng, exclude_image=query_image) + if max_k > 0 and len(shots_all) < max_k: + continue + + step = 0 # step 0 = query only + while True: + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + inputs = processor(text=prompt, images=pil_images, padding=True, return_tensors="pt") + inputs = {k2: v.to(device) for k2, v in inputs.items()} + + with torch.inference_mode(): + out_ids = model.generate( + **inputs, do_sample=False, max_new_tokens=args.max_new_tokens + ) + in_len = int(inputs["input_ids"].shape[1]) + pred = processor.batch_decode(out_ids[:, in_len:], skip_special_tokens=True)[0].strip() + tag = _extract_tag(pred) + + # Append model output to the conversation + messages.append({"role": "assistant", "content": [{"type": "text", "text": pred}]}) + + if step in summary: + summary[step]["REACHED"] += 1 + if tag == "": + summary[step]["RET"] += 1 + elif tag == "": + summary[step]["ANS"] += 1 + else: + summary[step]["NONE"] += 1 + + if args.print_samples: + print("=" * 80) + print(f"uid={uid} | step={step} | pred_tag={tag}") + for m in messages: + role = m.get("role") + if role == "user": + parts = [] + for c in m.get("content", []): + if c.get("type") == "text": + parts.append(c.get("text", "")) + elif c.get("type") == "image": + parts.append(f" {c.get('image','')}") + print("[输入]") + print("\n".join([p for p in parts if p]).strip()) + elif role == "assistant": + parts = [] + for c in m.get("content", []): + if c.get("type") == "text": + parts.append(c.get("text", "")) + print("[输入-助手]") + print("\n".join([p for p in parts if p]).strip()) + print("[输出]") + print(pred) + + # Stop if not or reached max_k shots + if tag != "": + break + if step >= max_k: + break + + # Append next shot and ask again. + shot = shots_all[step] if step < len(shots_all) else None + if not shot: + break + _append_shot_turn( + messages=messages, + pil_images=pil_images, + image_root=args.image_root, + image_path=shot["image"], + description=shot["description"], + ) + # Ask for decision again after each shot. + messages.append({"role": "user", "content": [{"type": "text", "text": "Action:"}]}) + step += 1 + + print("=== summary ===") + for k in k_list: + s = summary[k] + reached = s["REACHED"] + if reached == 0: + print(f"k={k}: no samples") + continue + print( + f"k={k} | RET={s['RET']} | ANS={s['ANS']} | NONE={s['NONE']} | reached={reached}" + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ICL/LV/code/SFT/cuda-keyring_1.1-1_all.deb b/ICL/LV/code/SFT/cuda-keyring_1.1-1_all.deb new file mode 100644 index 0000000000000000000000000000000000000000..d02294184b32b91eb1a7d348406ebaa0f3218bdd Binary files /dev/null and b/ICL/LV/code/SFT/cuda-keyring_1.1-1_all.deb differ diff --git a/ICL/LV/code/SFT/prepare_dataset.py b/ICL/LV/code/SFT/prepare_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..15142c235b571b22ad01aad89139a66f3c8c7793 --- /dev/null +++ b/ICL/LV/code/SFT/prepare_dataset.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +""" +预先生成数据集缓存 +运行一次后,训练时直接加载缓存,避免超时 +""" + +import os +import sys +import pickle +from pathlib import Path + +# 添加当前目录到路径 +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from config import get_config +from dataset import SFTDataset + + +def main(): + config = get_config() + + cache_path = Path(config.training.output_dir) / ".dataset_cache.pkl" + ready_flag = Path(config.training.output_dir) / ".dataset_ready" + + # 创建输出目录 + os.makedirs(config.training.output_dir, exist_ok=True) + + print("=" * 60) + print("预生成 SFT 数据集缓存") + print("=" * 60) + print(f"数据目录: {config.data.sft_data_dir}") + print(f"缓存路径: {cache_path}") + print("=" * 60) + + # 加载数据集 + print("\n开始加载数据集...") + dataset = SFTDataset(config.data, split="train") + + print(f"\n数据集加载完成!共 {len(dataset)} 个样本") + + # 保存缓存 + print(f"\n保存缓存到: {cache_path}") + with open(cache_path, "wb") as f: + pickle.dump(dataset.samples, f) + + # 创建就绪标记 + ready_flag.touch() + + print(f"就绪标记: {ready_flag}") + print("\n" + "=" * 60) + print("缓存生成完成!现在可以运行 bash run_train.sh") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/ICL/LV/code/adapters/gemma3_adapter.py b/ICL/LV/code/adapters/gemma3_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..520c57004d2a07656ac8ca3c2a2ce90dbe2c04e6 --- /dev/null +++ b/ICL/LV/code/adapters/gemma3_adapter.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import List, Dict + +try: + from adapters._runners.gemma3_infer import Gemma3Runner +except Exception: + Gemma3Runner = None # type: ignore + + +class Adapter: + def __init__(self, model_path: str): + if Gemma3Runner is None: + raise RuntimeError('Gemma3Runner unavailable. Ensure gemma3-code is on PYTHONPATH or install its runner.') + self.runner = Gemma3Runner(model_path) + + def generate_from_segments(self, segs: List[Dict[str, str]], *, + temperature: float, top_p: float, max_new_tokens: int) -> str: + gen = getattr(self.runner, 'generate_from_qwen_segs', None) + if gen is None: + raise RuntimeError('Gemma3Runner missing generate_from_qwen_segs') + return gen(segs, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens) + + +def create(model_path: str) -> Adapter: + return Adapter(model_path) + diff --git a/ICL/LV/code/adapters/qwen3vl_adapter.py b/ICL/LV/code/adapters/qwen3vl_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..e913ae1de77dd8ea8aab6f39c3deac493af5be5b --- /dev/null +++ b/ICL/LV/code/adapters/qwen3vl_adapter.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import List, Dict + +try: + from adapters._runners.qwen3_vl_infer import Qwen3VLRunner +except Exception: + Qwen3VLRunner = None # type: ignore + + +class Adapter: + def __init__(self, model_path: str): + if Qwen3VLRunner is None: + raise RuntimeError('Qwen3VLRunner unavailable. Ensure QWEN3VL-code is on PYTHONPATH or install its runner.') + self.runner = Qwen3VLRunner(model_path) + + def generate_from_segments(self, segs: List[Dict[str, str]], *, + temperature: float, top_p: float, max_new_tokens: int) -> str: + gen = getattr(self.runner, 'generate_from_segments', None) + if gen is None: + raise RuntimeError('Qwen3VLRunner missing generate_from_segments') + return gen(segs, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens) + + +def create(model_path: str) -> Adapter: + return Adapter(model_path) + diff --git a/ICL/LV/code/attn map/attn map/attn map/__pycache__/token_attention_utils.cpython-313.pyc b/ICL/LV/code/attn map/attn map/attn map/__pycache__/token_attention_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53543caabea31b4e706599fe7358a4dfb58ee8a3 Binary files /dev/null and b/ICL/LV/code/attn map/attn map/attn map/__pycache__/token_attention_utils.cpython-313.pyc differ diff --git a/ICL/LV/code/attn map/attn map/attn map/qwen3vl_step_attention_map.py b/ICL/LV/code/attn map/attn map/attn map/qwen3vl_step_attention_map.py new file mode 100644 index 0000000000000000000000000000000000000000..55437c65756e15ff968b8e2a2b026dfe13a61d35 --- /dev/null +++ b/ICL/LV/code/attn map/attn map/attn map/qwen3vl_step_attention_map.py @@ -0,0 +1,199 @@ +""" +Step-wise attention visualization for Qwen3-VL. + +Greedily generates tokens, collects each generated token's attention to visual tokens +across all layers/heads, averages them, and overlays a heatmap on the image. + +Defaults match your setup: +- model: /workspace/Qwen3-VL-8B-Instruct +- image: /z_data/syxin/code/runs/shot_sweep_allmetrics_qwen3-vl/shot0/captioning_bertscore/_image_cache/captioning_coco/COCO_val2014_000000000661.jpg.jpg +- prompt: 解释一下这张图 +""" +from __future__ import annotations + +import argparse +import os +from typing import Optional, Tuple, List + +import cv2 +import numpy as np +import torch +from PIL import Image +from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer + + +def find_vision_token_range(tokenizer: AutoTokenizer, input_ids: torch.Tensor) -> Tuple[Optional[int], Optional[int]]: + ids = input_ids[0].tolist() + start_id = tokenizer.convert_tokens_to_ids("<|vision_start|>") + end_id = tokenizer.convert_tokens_to_ids("<|vision_end|>") + try: + start = ids.index(start_id) + 1 + end = ids.index(end_id) + return start, end + except ValueError: + return None, None + + +def prepare_inputs(processor: AutoProcessor, image: Image.Image, prompt: str, device: torch.device) -> dict: + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": prompt}, + ], + } + ] + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt") + return inputs.to(device) + + +def decode_generation(tokenizer: AutoTokenizer, generated_ids: torch.Tensor, prompt_len: int) -> str: + return tokenizer.decode(generated_ids[0, prompt_len:], skip_special_tokens=True, clean_up_tokenization_spaces=False) + + +def reshape_attention_map(attn_vec: np.ndarray, image_grid_thw: Optional[torch.Tensor], spatial_merge_size: int = 2) -> np.ndarray: + num_tokens = len(attn_vec) + if image_grid_thw is not None: + t, h, w = image_grid_thw[0].tolist() + merged_h = h // spatial_merge_size + merged_w = w // spatial_merge_size + expected = t * merged_h * merged_w + if num_tokens == expected: + attn_map = attn_vec.reshape(t, merged_h, merged_w) + attn_map = attn_map.mean(axis=0) if t > 1 else attn_map[0] + return attn_map + side = int(np.sqrt(num_tokens)) + if side * side == num_tokens: + return attn_vec.reshape(side, side) + for h in range(side, 0, -1): + if num_tokens % h == 0: + w = num_tokens // h + return attn_vec.reshape(h, w) + return attn_vec.reshape(1, -1) + + +def overlay_heatmap(image: Image.Image, attn_map: np.ndarray, save_path: str) -> str: + base = np.array(image) + attn_norm = attn_map - attn_map.min() + attn_norm = attn_norm / attn_norm.max() if attn_norm.max() > 0 else attn_norm + attn_resized = cv2.resize(attn_norm, (base.shape[1], base.shape[0]), interpolation=cv2.INTER_CUBIC) + heatmap = cv2.applyColorMap(np.uint8(255 * attn_resized), cv2.COLORMAP_JET) + overlay = cv2.addWeighted(base[:, :, ::-1], 0.5, heatmap, 0.5, 0) # base to BGR for cv2 + cv2.imwrite(save_path, overlay) + return save_path + + +def main(): + parser = argparse.ArgumentParser(description="Qwen3-VL step-wise attention to visual tokens") + parser.add_argument( + "--model", + default="/workspace/Qwen3-VL-8B-Instruct", + help="Model path or HF id", + ) + parser.add_argument( + "--image", + default="/z_data/syxin/code/runs/shot_sweep_allmetrics_qwen3-vl/shot0/captioning_bertscore/_image_cache/captioning_coco/COCO_val2014_000000000661.jpg.jpg", + help="Image path", + ) + parser.add_argument("--prompt", default="解释一下这张图", help="Prompt text") + parser.add_argument("--max-new-tokens", type=int, default=128, help="Max new tokens to generate") + parser.add_argument("--save", default="step_attention_overlay.png", help="Output overlay path") + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + model = AutoModelForVision2Seq.from_pretrained( + args.model, + device_map="cuda" if device.type == "cuda" else None, + torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32, + trust_remote_code=True, + ).eval() + if device.type != "cuda": + model = model.to(device) + if hasattr(model, "set_attn_implementation"): + try: + model.set_attn_implementation("eager") + except Exception as exc: + print(f"Warning: failed to set attn implementation to eager: {exc}") + elif hasattr(model.config, "attn_implementation"): + model.config.attn_implementation = "eager" + + image = Image.open(args.image).convert("RGB") + inputs = prepare_inputs(processor, image, args.prompt, device) + + input_ids = inputs["input_ids"] + attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)) + pixel_values = inputs.get("pixel_values") + image_grid_thw = inputs.get("image_grid_thw") + prompt_len = input_ids.shape[1] + + vision_start, vision_end = find_vision_token_range(tokenizer, input_ids) + if vision_start is None or vision_end is None: + raise ValueError("Could not locate vision token range in input_ids.") + num_vision_tokens = vision_end - vision_start + + generated_ids = input_ids.clone() + attn_per_token: List[torch.Tensor] = [] + generated_tokens: List[str] = [] + + eos_token_id = tokenizer.eos_token_id + if isinstance(eos_token_id, list): + eos_set = set(eos_token_id) + else: + eos_set = {eos_token_id} + + with torch.no_grad(): + for step in range(args.max_new_tokens): + outputs = model( + input_ids=generated_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_attentions=True, + return_dict=True, + ) + + attns = outputs.attentions + if attns is None: + raise ValueError("Model did not return attentions. Ensure eager attention is enabled.") + + last_token_attn = [] + for layer_attn in attns: + # layer_attn: (batch, heads, seq, seq) + attn_slice = layer_attn[0, :, -1, vision_start:vision_end] + last_token_attn.append(attn_slice.float().cpu()) + stacked = torch.stack(last_token_attn, dim=0) # (layers, heads, num_vis_tokens) + avg_attn = stacked.mean(dim=(0, 1)) # (num_vis_tokens,) + attn_per_token.append(avg_attn) + + next_token_logits = outputs.logits[:, -1, :] + next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True) + generated_tokens.append(tokenizer.decode(next_token_id[0])) + + if next_token_id.item() in eos_set: + generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) + break + + generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) + attention_mask = torch.cat( + [attention_mask, torch.ones((1, 1), device=attention_mask.device, dtype=attention_mask.dtype)], dim=-1 + ) + + generated_text = decode_generation(tokenizer, generated_ids, prompt_len) + print(f"Generated text:\n{generated_text}\n") + if not attn_per_token: + raise ValueError("No attentions collected; generation ended immediately.") + + all_attn = torch.stack(attn_per_token, dim=0).mean(dim=0).numpy() + attn_map = reshape_attention_map(all_attn, image_grid_thw) + + save_path = os.path.abspath(args.save) + final_path = overlay_heatmap(image, attn_map, save_path) + print(f"Saved attention overlay to: {final_path}") + + +if __name__ == "__main__": + main() diff --git a/ICL/LV/code/attn map/attn map/attn map/token_attention_utils.py b/ICL/LV/code/attn map/attn map/attn map/token_attention_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d29d067c06458658167aaeb7dee603eb82cb7823 --- /dev/null +++ b/ICL/LV/code/attn map/attn map/attn map/token_attention_utils.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from pathlib import Path +from typing import List + + +def save_token_attention_artifacts(attn_sums: List[float], tokens: List[str], out_dir: Path) -> None: + """Write per-token vision attention sums (tsv + plot).""" + out_dir.mkdir(parents=True, exist_ok=True) + table_lines = ["idx\ttoken\tsum_attn_vision"] + for idx, (tok, w) in enumerate(zip(tokens, attn_sums)): + tok_clean = tok.replace("\n", "\\n").strip() or "" + table_lines.append(f"{idx}\t{tok_clean}\t{w:.6f}") + (out_dir / "token_attention.tsv").write_text("\n".join(table_lines) + "\n", encoding="utf-8") + + try: + import matplotlib.pyplot as plt + except Exception as exc: + print(f"[WARN] matplotlib not available, skip token attention plot: {exc}") + return + + fig_w = max(6.0, len(attn_sums) * 0.6) + plt.figure(figsize=(fig_w, 4)) + plt.plot(attn_sums, marker="o") + plt.xticks( + range(len(tokens)), + [t.replace("\n", "\\n").strip() or "" for t in tokens], + rotation=60, + ha="right", + ) + plt.ylabel("sum(attn to vision tokens)") + plt.xlabel("generated token index") + plt.tight_layout() + plot_path = out_dir / "token_attention.png" + plt.savefig(plot_path, dpi=200) + plt.close() + print(f"[ok] Saved token attention plot -> {plot_path}") diff --git a/ICL/LV/code/attn map/modeling_qwen3vl.py b/ICL/LV/code/attn map/modeling_qwen3vl.py new file mode 100644 index 0000000000000000000000000000000000000000..cba7462558bbe3498d901aa7b569b61e82b37ac7 --- /dev/null +++ b/ICL/LV/code/attn map/modeling_qwen3vl.py @@ -0,0 +1,1578 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_vl/modular_qwen3_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations import use_kernel_forward_from_hub +from transformers.masking_utils import create_causal_mask +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import check_model_inputs +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig + + +class Qwen3VLVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True) + self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + +class Qwen3VLVisionPatchEmbed(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen3VLVisionRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen3VLVisionPatchMerger(nn.Module): + def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen3VLVisionAttention(nn.Module): + def __init__(self, config: Qwen3VLVisionConfig) -> None: + super().__init__() + self.dim = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.dim // self.num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) + self.proj = nn.Linear(self.dim, self.dim) + self.scaling = self.head_dim**-0.5 + self.config = config + self.attention_dropout = 0.0 + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + ) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if self.config._attn_implementation == "flash_attention_2": + # Flash Attention 2: Use cu_seqlens for variable length attention + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + attn_output, _ = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + cu_seq_lens_q=cu_seqlens, + cu_seq_lens_k=cu_seqlens, + max_length_q=max_seqlen, + max_length_k=max_seqlen, + is_causal=False, + **kwargs, + ) + else: + # Other implementations: Process each chunk separately + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + attention_interface( + self, + q, + k, + v, + attention_mask=None, + scaling=self.scaling, + dropout=0.0 if not self.training else self.attention_dropout, + is_causal=False, + **kwargs, + )[0] + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + + attn_output = attn_output.reshape(seq_length, -1).contiguous() + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen3VLVisionBlock(GradientCheckpointingLayer): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = Qwen3VLVisionAttention(config=config) + self.mlp = Qwen3VLVisionMLP(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen3VLTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3VLTextConfig, device=None): + super().__init__() + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", "default") + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20]) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen3VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen3VLTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Qwen3VLTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3VLTextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3VLTextRMSNorm( + self.head_dim, eps=config.rms_norm_eps + ) # thus post q_norm does not need reshape + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3VLTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3VLTextDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3VLTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3VLTextAttention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen3VLTextMLP(config) + self.input_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, attention_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + if output_attentions: + return hidden_states, attention_weights + return hidden_states + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Llava outputs, with hidden states and attentions. + """ +) +class Qwen3VLModelOutputWithPast(ModelOutput): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +@auto_docstring +class Qwen3VLPreTrainedModel(PreTrainedModel): + config: Qwen3VLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Qwen3VLTextDecoderLayer, + "attentions": Qwen3VLTextAttention, + } + + +class Qwen3VLVisionModel(Qwen3VLPreTrainedModel): + config: Qwen3VLVisionConfig + _no_split_modules = ["Qwen3VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen3VLVisionPatchEmbed( + config=config, + ) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=False, + ) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3VLVisionPatchMerger( + config=config, + use_postshuffle_norm=True, + ) + for _ in range(len(config.deepstack_visual_indexes)) + ] + ) + + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = self.merger(hidden_states) + + return hidden_states, deepstack_feature_lists + + +@auto_docstring( + custom_intro=( + "Text part of Qwen3VL, " + "not a pure text-only model, as DeepStack integrates visual features into the early hidden states." + ) +) +class Qwen3VLTextModel(Qwen3VLPreTrainedModel): + config: Qwen3VLTextConfig + _no_split_modules = ["Qwen3VLTextDecoderLayer"] + + def __init__(self, config: Qwen3VLTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen3VLTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs() + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # args for deepstack + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + output_attentions: Optional[bool] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + r""" + visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*): + The mask of the visual positions. + deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): + The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). + The feature is extracted from the different visual encoder layers, and fed to the decoder + hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334). + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + attention_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=text_position_ids, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_layer_attention = [] + # decoder layers + for layer_idx, decoder_layer in enumerate(self.layers): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + **kwargs, + ) + if output_attentions: + hidden_states, layer_attention = layer_outputs + all_layer_attention.append(layer_attention) + else: + hidden_states = layer_outputs + + # add visual features to the hidden states of first several layers + if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + attentions=all_layer_attention + ) + + def _deepstack_process( + self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor + ): + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + +@auto_docstring +class Qwen3VLModel(Qwen3VLPreTrainedModel): + base_model_prefix = "" + _checkpoint_conversion_mapping = {} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3VLConfig + _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen3VLVisionModel._from_config(config.vision_config) + self.language_model = Qwen3VLTextModel._from_config(config.text_config) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.""" + + # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + # Same implementation as for images + return self.get_image_features(pixel_values_videos, video_grid_thw) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds, deepstack_image_embeds + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + + @auto_docstring + @check_model_inputs() + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLModelOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + image_mask = None + video_mask = None + + if pixel_values is not None: + image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + visual_pos_masks = None + deepstack_visual_embeds = None + if image_mask is not None and video_mask is not None: + # aggregate visual_pos_masks and deepstack_visual_embeds + image_mask = image_mask[..., 0] + video_mask = video_mask[..., 0] + visual_pos_masks = image_mask | video_mask + deepstack_visual_embeds = [] + image_mask_joint = image_mask[visual_pos_masks] + video_mask_joint = video_mask[visual_pos_masks] + for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds): + embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device) + embed_joint[image_mask_joint, :] = img_embed + embed_joint[video_mask_joint, :] = vid_embed + deepstack_visual_embeds.append(embed_joint) + elif image_mask is not None: + image_mask = image_mask[..., 0] + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_image_embeds + elif video_mask is not None: + video_mask = video_mask[..., 0] + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_video_embeds + + if position_ids is None: + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + # Only apply conversion for floating point tensors (inverted masks) + if attention_mask_tensor.dtype.is_floating_point: + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + prefill_compiled_stage = is_torchdynamo_compiling() and ( + (input_ids is not None and input_ids.shape[1] != 1) + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) + ) + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( + (cache_position is not None and cache_position[0] == 0) + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ) + if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + output_attentions=output_attentions, + **kwargs, + ) + + return Qwen3VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + rope_deltas=self.rope_deltas, + attentions=outputs.attentions, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Qwen3VL causal language model (or autoregressive) outputs. + """ +) +class Qwen3VLCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = {} + _tied_weights_keys = ["lm_head.weight"] + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + config: Qwen3VLConfig + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3VLModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + return self.model.get_video_features(pixel_values_videos, video_grid_thw) + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + return self.model.get_image_features(pixel_values, image_grid_thw) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def visual(self): + return self.model.visual + + @check_model_inputs() + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + output_attentions: bool = False, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + + Example: + TODO: Add example + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + output_attentions=output_attentions, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size) + + return Qwen3VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + rope_deltas=outputs.rope_deltas, + attentions=outputs.attentions + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_cache=use_cache, + **kwargs, + ) + + # Qwen3VL position_ids are prepareed with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + if inputs_embeds is not None: + vision_start_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + image_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + video_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + )[..., 0] + else: + vision_start_mask = input_ids == vision_start_token_id + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> tuple[torch.LongTensor, dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums( + input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None) + ) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size + ) + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = [ + "Qwen3VLVisionModel", + "Qwen3VLForConditionalGeneration", + "Qwen3VLModel", + "Qwen3VLPreTrainedModel", + "Qwen3VLTextModel", +] \ No newline at end of file diff --git a/ICL/LV/code/attn map/vlm_attention_viz.py b/ICL/LV/code/attn map/vlm_attention_viz.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd6f7f38a63a9f4ab24fa7941446a341b8e76e8 --- /dev/null +++ b/ICL/LV/code/attn map/vlm_attention_viz.py @@ -0,0 +1,326 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image +from transformers import AutoProcessor +import argparse +import os +from modeling_qwen3vl import Qwen3VLForConditionalGeneration + + +def load_model(model_path="Qwen/Qwen3-VL-4B-Instruct", device="cuda"): + print(f"Loading model from {model_path}...") + + model = Qwen3VLForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + attn_implementation="eager", + device_map="auto", + ) + + processor = AutoProcessor.from_pretrained(model_path) + + print("Model loaded successfully!") + return model, processor + + +def get_image_token_positions(input_ids, processor): + vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") + vision_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_end|>") + + input_ids_list = input_ids[0].tolist() + + vision_start_pos = None + vision_end_pos = None + + for i, token_id in enumerate(input_ids_list): + if token_id == vision_start_token_id: + vision_start_pos = i + 1 + elif token_id == vision_end_token_id: + vision_end_pos = i + break + + if vision_start_pos is None or vision_end_pos is None: + image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") + if image_token_id is not None: + positions = [i for i, tid in enumerate(input_ids_list) if tid == image_token_id] + if positions: + vision_start_pos = positions[0] + vision_end_pos = positions[-1] + 1 + + return vision_start_pos, vision_end_pos + + +def generate_with_attention(model, processor, image_path, prompt="Describe this image in detail.", max_new_tokens=512): + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": prompt}, + ], + } + ] + + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + ) + inputs = inputs.to(model.device) + + input_ids = inputs["input_ids"] + input_length = input_ids.shape[1] + + vision_start, vision_end = get_image_token_positions(input_ids, processor) + print(f"Vision token positions: {vision_start} to {vision_end}") + print(f"Number of vision tokens: {vision_end - vision_start if vision_start and vision_end else 'N/A'}") + + generated_ids = input_ids.clone() + attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)) + all_attentions = [] + generated_tokens = [] + + pixel_values = inputs.get("pixel_values") + image_grid_thw = inputs.get("image_grid_thw") + + model.eval() + with torch.no_grad(): + for step in range(max_new_tokens): + outputs = model( + input_ids=generated_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_attentions=True, + return_dict=True, + ) + + next_token_logits = outputs.logits[:, -1, :] + next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True) + + eos_token_id = processor.tokenizer.eos_token_id + if isinstance(eos_token_id, list): + if next_token_id.item() in eos_token_id: + break + elif next_token_id.item() == eos_token_id: + break + + if outputs.attentions is not None and vision_start is not None and vision_end is not None: + layer_attentions = [] + for layer_attn in outputs.attentions: + attn_to_vision = layer_attn[0, :, -1, vision_start:vision_end] + layer_attentions.append(attn_to_vision.float().cpu()) + + stacked_attn = torch.stack(layer_attentions, dim=0) + avg_attn = stacked_attn.mean(dim=(0, 1)) + all_attentions.append(avg_attn) + + token_text = processor.tokenizer.decode([next_token_id.item()]) + generated_tokens.append(token_text) + + generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) + attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=attention_mask.device, dtype=attention_mask.dtype)], dim=-1) + + if (step + 1) % 10 == 0: + print(f"Generated {step + 1} tokens...") + + generated_text = processor.tokenizer.decode( + generated_ids[0, input_length:], + skip_special_tokens=True + ) + + num_vision_tokens = vision_end - vision_start if vision_start and vision_end else 0 + + return { + "generated_text": generated_text, + "generated_tokens": generated_tokens, + "attentions": all_attentions, + "num_vision_tokens": num_vision_tokens, + "image_grid_thw": image_grid_thw, + } + + +def compute_attention_heatmap(attentions, image_grid_thw, spatial_merge_size=2): + if not attentions: + return None + + stacked = torch.stack(attentions, dim=0) + avg_attention = stacked.mean(dim=0).numpy() + + if image_grid_thw is not None: + t, h, w = image_grid_thw[0].tolist() + merged_h = h // spatial_merge_size + merged_w = w // spatial_merge_size + expected_tokens = t * merged_h * merged_w + + print(f"Image grid (t, h, w): ({t}, {h}, {w})") + print(f"Merged grid (t, h, w): ({t}, {merged_h}, {merged_w})") + print(f"Expected vision tokens: {expected_tokens}, Actual: {len(avg_attention)}") + + if len(avg_attention) == expected_tokens: + attention_map = avg_attention.reshape(t, merged_h, merged_w) + if t > 1: + attention_map = attention_map.mean(axis=0) + else: + attention_map = attention_map[0] + return attention_map + + num_tokens = len(avg_attention) + side = int(np.sqrt(num_tokens)) + if side * side == num_tokens: + return avg_attention.reshape(side, side) + + for h in range(side, 0, -1): + if num_tokens % h == 0: + w = num_tokens // h + return avg_attention.reshape(h, w) + + return avg_attention.reshape(1, -1) + + +def visualize_attention(image_path, attention_map, generated_text, output_path="attention_heatmap.png"): + original_image = Image.open(image_path).convert("RGB") + img_width, img_height = original_image.size + + fig, axes = plt.subplots(1, 3, figsize=(18, 6)) + + axes[0].imshow(original_image) + axes[0].set_title("Original Image", fontsize=12) + axes[0].axis("off") + + im = axes[1].imshow(attention_map, cmap="hot", interpolation="nearest") + axes[1].set_title("Attention Heatmap (Raw)", fontsize=12) + axes[1].axis("off") + plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04) + + attention_resized = Image.fromarray( + (attention_map * 255 / attention_map.max()).astype(np.uint8) + ).resize((img_width, img_height), Image.BILINEAR) + attention_resized = np.array(attention_resized) / 255.0 + + axes[2].imshow(original_image) + axes[2].imshow(attention_resized, cmap="jet", alpha=0.5) + axes[2].set_title("Attention Overlay", fontsize=12) + axes[2].axis("off") + + display_text = generated_text[:200] + "..." if len(generated_text) > 200 else generated_text + fig.suptitle(f"Generated Caption:\n{display_text}", fontsize=10, wrap=True) + + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + + print(f"Attention visualization saved to {output_path}") + + +def visualize_token_attention(image_path, attentions, generated_tokens, image_grid_thw, + output_path="token_attention.png", num_tokens_to_show=10, spatial_merge_size=2): + if not attentions or not generated_tokens: + print("No attention data to visualize") + return + + num_tokens = min(num_tokens_to_show, len(attentions)) + + cols = 5 + rows = (num_tokens + cols - 1) // cols + fig, axes = plt.subplots(rows, cols, figsize=(15, 3 * rows)) + axes = axes.flatten() if num_tokens > 1 else [axes] + + original_image = Image.open(image_path).convert("RGB") + + for i in range(num_tokens): + attn = attentions[i].numpy() + token = generated_tokens[i] + + if image_grid_thw is not None: + t, h, w = image_grid_thw[0].tolist() + merged_h = h // spatial_merge_size + merged_w = w // spatial_merge_size + if len(attn) == t * merged_h * merged_w: + attn_map = attn.reshape(t, merged_h, merged_w) + if t > 1: + attn_map = attn_map.mean(axis=0) + else: + attn_map = attn_map[0] + else: + side = int(np.sqrt(len(attn))) + attn_map = attn.reshape(side, -1) + else: + side = int(np.sqrt(len(attn))) + attn_map = attn.reshape(side, -1) + + axes[i].imshow(attn_map, cmap="hot") + axes[i].set_title(f"Token {i+1}: '{token}'", fontsize=8) + axes[i].axis("off") + + for i in range(num_tokens, len(axes)): + axes[i].axis("off") + + plt.suptitle("Per-Token Attention to Image", fontsize=12) + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + + print(f"Token attention visualization saved to {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Qwen3-VL Image Caption with Attention Visualization") + parser.add_argument("--image", type=str, required=True, help="Path to input image") + parser.add_argument("--model", type=str, default="/workspace/Qwen3-VL-8B-Instruct", + help="Model path or name") + parser.add_argument("--prompt", type=str, default="Describe this image in detail.", + help="Prompt for image captioning") + parser.add_argument("--output", type=str, default="attention_heatmap.png", + help="Output path for attention heatmap") + parser.add_argument("--output-tokens", type=str, default="token_attention.png", + help="Output path for per-token attention") + + args = parser.parse_args() + + if not os.path.exists(args.image): + print(f"Error: Image not found at {args.image}") + return + + model, processor = load_model(args.model) + + print(f"\nProcessing image: {args.image}") + print(f"Prompt: {args.prompt}\n") + + result = generate_with_attention(model, processor, args.image, args.prompt) + + print(f"\nGenerated Caption:\n{result['generated_text']}\n") + print(f"Number of generated tokens: {len(result['generated_tokens'])}") + print(f"Number of vision tokens: {result['num_vision_tokens']}") + + attention_map = compute_attention_heatmap( + result["attentions"], + result["image_grid_thw"] + ) + + if attention_map is not None: + print(f"Attention map shape: {attention_map.shape}") + + visualize_attention( + args.image, + attention_map, + result["generated_text"], + args.output + ) + + visualize_token_attention( + args.image, + result["attentions"], + result["generated_tokens"], + result["image_grid_thw"], + args.output_tokens + ) + else: + print("Could not compute attention heatmap") + + +if __name__ == "__main__": + main() diff --git a/ICL/LV/code/core/__init__.py b/ICL/LV/code/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ICL/LV/code/core/eval/eval_order_caption_bertscore.py b/ICL/LV/code/core/eval/eval_order_caption_bertscore.py new file mode 100644 index 0000000000000000000000000000000000000000..435cfb56cdff7baff047bfac07e288714e8ec178 --- /dev/null +++ b/ICL/LV/code/core/eval/eval_order_caption_bertscore.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +"""Captioning BERTScore-F1 per modal order.""" + +import argparse +import json +from pathlib import Path + +from core.eval.order_eval_core import run_predictions +from core.metrics.metrics import bertscore_f1 + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--adapter', required=True) + ap.add_argument('--model-path', required=True) + ap.add_argument('--dataset-root', required=True) + ap.add_argument('--retriever-model-path', required=True) + ap.add_argument('--output-dir', default='runs/order_metrics') + ap.add_argument('--orders', type=str, default='image-text,text-image,text-image-text') + ap.add_argument('--total-samples', type=int, default=500) + ap.add_argument('--k-shots', type=int, default=3) + ap.add_argument('--temperature', type=float, default=0.6) + ap.add_argument('--top-p', type=float, default=1.0) + ap.add_argument('--max-new-tokens', type=int, default=128) + ap.add_argument('--split', type=str, default='val') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--reuse-cache', action='store_true') + ap.add_argument('--bertscore-model', type=str, default='roberta-large') + ap.add_argument('--no-bertscore-baseline', action='store_true') + ap.add_argument('--bertscore-batch-size', type=int, default=32) + ap.add_argument('--bertscore-num-layers', type=int, default=-1) + ap.add_argument('--bertscore-lang', type=str, default='', help="Language code for BERTScore baseline rescaling (e.g., 'en', 'zh')") + ap.add_argument('--strict-bertscore', action='store_true', help='Do not fallback silently; raise when BERTScore fails/unavailable') + args = ap.parse_args() + + orders = [o.strip().lower() for o in args.orders.split(',') if o.strip()] + base_out = Path(args.output_dir) / 'captioning_bertscore' + + preds = run_predictions( + adapter=args.adapter, + model_path=args.model_path, + dataset_root=args.dataset_root, + retriever_model_path=args.retriever_model_path, + output_dir=str(base_out), + orders=orders, + categories=['captioning'], + total_samples=args.total_samples, + k_shots=args.k_shots, + split=args.split, + seed=args.seed, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + auto_detect=True, + reuse_cache=args.reuse_cache, + ) + + summary = {} + for order in orders: + p, r = preds[order]['captioning'] + keep_idx = [] + for i, ri in enumerate(r): + first = (ri[0] if (isinstance(ri, (list, tuple)) and ri) else ri) + first = first if isinstance(first, str) else '' + if first.strip(): + keep_idx.append(i) + if keep_idx: + p = [p[i] for i in keep_idx] + r = [[(r[i][0] if r[i] else '')] for i in keep_idx] + else: + p, r = [], [] + + nl = None if args.bertscore_num_layers < 0 else int(args.bertscore_num_layers) + # Enable strict behavior via env for the metrics helper + import os + if args.strict_bertscore: + os.environ['BERTSCORE_STRICT'] = '1' + score = bertscore_f1( + p, + r, + model_type=args.bertscore_model, + rescale_with_baseline=not args.no_bertscore_baseline, + batch_size=args.bertscore_batch_size, + num_layers=nl, + lang=(args.bertscore_lang or None), + ) if p else None + if args.strict_bertscore and p and score is None: + raise RuntimeError('BERTScore returned None and --strict-bertscore forbids fallback') + out_dir = base_out / order; out_dir.mkdir(parents=True, exist_ok=True) + out = {'order': order, 'metric': 'captioning_bertscore_f1', 'score': (None if score is None else float(score))} + (out_dir / 'result.json').write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f'[Caption BERTScore-F1] order={order} score=' + (f'{score:.2f}' if score is not None else 'NA')) + summary[order] = out['score'] + + (base_out / 'summary.json').write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding='utf-8') + print('[Caption BERTScore-F1] all orders done:', summary) + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/metrics/metrics.py b/ICL/LV/code/core/metrics/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..d4dc3eac79c05e91de0bae69ba6d718d0e88d546 --- /dev/null +++ b/ICL/LV/code/core/metrics/metrics.py @@ -0,0 +1,210 @@ +import re +from typing import Iterable, List, Optional, Tuple + +_ARTICLES = {'a', 'an', 'the'} + + +def normalize_text(s: str) -> str: + s = s.lower().strip() + # Remove common ASCII punctuation and collapse spaces + s = re.sub(r"[\.,!?;:\-\(\)\[\]\{\}\'\"/\\]", " ", s) + s = re.sub(r"\s+", " ", s).strip() + # Remove articles + toks = [t for t in s.split() if t not in _ARTICLES] + return " ".join(toks) + + +def exact_match(pred: str, refs: List[str]) -> bool: + p = normalize_text(pred) + return any(p == normalize_text(r) for r in refs) + + +def _tokenize(s: str) -> List[str]: + return normalize_text(s).split() + + +def _f1_score(pred: str, refs: List[str]) -> float: + # SQuAD-style token F1: best over references + p_tokens = _tokenize(pred) + if not refs: + return 0.0 + best = 0.0 + for r in refs: + g_tokens = _tokenize(r) + if not p_tokens and not g_tokens: + best = max(best, 1.0) + continue + if not p_tokens or not g_tokens: + best = max(best, 0.0) + continue + # Count overlap + from collections import Counter + pc = Counter(p_tokens) + gc = Counter(g_tokens) + num_same = sum((pc & gc).values()) + if num_same == 0: + f1 = 0.0 + else: + precision = num_same / len(p_tokens) + recall = num_same / len(g_tokens) + f1 = 2 * precision * recall / (precision + recall) + best = max(best, f1) + return best + + +_CHOICES = [chr(ord('a')+i) for i in range(26)] + + +def parse_choice_letter(text: str) -> Optional[str]: + """Parse A-F option letter from free text (English/Chinese patterns). + Tries: (C)、(C)、answer is C、answer: C、Option C、Choice C、答案是C、选择C、故选C、开头字母、'对/错' 映射到 A/B。 + """ + t = (text or '').strip().lower() + # (c) or (c) fullwidth + m = re.search(r"[\((]([a-z])[\))]", t) + if m: + return m.group(1) + # answer is c / answer: c + m = re.search(r"\banswer\s*(?:is|:)\s*([a-z])\b", t) + if m: + return m.group(1) + # option c / choice c + m = re.search(r"\b(?:option|choice)\s*([a-z])\b", t) + if m: + return m.group(1) + # Chinese: 答案是C / 答案:C / 正确答案是C / 答案为C / 最终答案是C + m = re.search(r"(?:最终)?答案\s*(?:是|为|:|:)\s*([a-z])", t) + if m: + return m.group(1) + m = re.search(r"正确答案\s*(?:是|:|:)\s*([a-z])", t) + if m: + return m.group(1) + # Chinese: 选项C / 选择C / 选择的是C / 选择为C / C选项 + m = re.search(r"(?:选项|选择)(?:的?是|为)?\s*([a-z])", t) + if m: + return m.group(1) + m = re.search(r"([a-z])\s*选项", t) + if m: + return m.group(1) + # Chinese short: 故选C / 因此选C / 所以选C / 选C + m = re.search(r"(?:故|因此|所以)?\s*选[择项]?\s*([a-z])", t) + if m: + return m.group(1) + # leading letter + m = re.search(r"^\s*([a-z])\b", t) + if m: + return m.group(1) + # map boolean words to A/B + if re.search(r"\btrue\b", t) or ('对' in t) or ('正确' in t): + return 'a' + if re.search(r"\bfalse\b", t) or ('错' in t) or ('错误' in t): + return 'b' + return None + + +def accuracy(preds: List[str], refs: List[List[str]]) -> float: + n = len(preds) + correct = 0 + for p, r in zip(preds, refs): + if not r: + continue + if exact_match(p, r): + correct += 1 + return 100.0 * correct / max(1, n) + + +def token_f1(preds: List[str], refs: List[List[str]]) -> float: + if not preds: + return 0.0 + total = 0.0 + for p, r in zip(preds, refs): + total += _f1_score(p, r) + return 100.0 * total / len(preds) + + +def choice_accuracy(preds: List[str], refs: List[str]) -> float: + n = len(preds) + correct = 0 + for p, r in zip(preds, refs): + pl = parse_choice_letter(p) + if pl is None: + continue + if pl == r.lower().strip(): + correct += 1 + return 100.0 * correct / max(1, n) + + +# Optional: BERTScore if available +try: + from bert_score import score as bert_score +except Exception: # pragma: no cover + bert_score = None + + +def bertscore_f1( + preds: List[str], + refs: List[List[str]], + model_type: str = 'roberta-large', + rescale_with_baseline: bool = True, + batch_size: int = 32, + num_layers: Optional[int] = None, + lang: Optional[str] = None, +) -> Optional[float]: + """Compute BERTScore-F1. + + - model_type can be a HF model id or a local directory (e.g., '/z_data/pretrained/syxin/roberta-large'). + - rescale_with_baseline=False avoids baseline lookup if running fully offline with a non-standard model path. + """ + import os + strict = (os.environ.get('BERTSCORE_STRICT', '').strip() == '1') + if bert_score is None: + if strict: + raise ImportError('bert-score package not available in current Python environment') + return None + # Simplified: pick first reference per sample for batch scoring + ref_first = [(r[0] if r else "") for r in refs] + kwargs = dict( + model_type=model_type, + verbose=False, + rescale_with_baseline=rescale_with_baseline, + batch_size=batch_size, + ) + # Pass language only when baseline rescaling is requested + if rescale_with_baseline and lang: + kwargs['lang'] = lang + # If using a local directory as model_type, bert-score expects num_layers. + if num_layers is None: + try: + import os, json + if os.path.isdir(model_type): + cfg = os.path.join(model_type, 'config.json') + if os.path.exists(cfg): + with open(cfg, 'r', encoding='utf-8') as f: + cfgj = json.load(f) + nl = cfgj.get('num_hidden_layers') + if isinstance(nl, int): + num_layers = nl + except Exception: + pass + if num_layers is not None: + kwargs['num_layers'] = int(num_layers) + try: + P, R, F1 = bert_score( + preds, + ref_first, + **kwargs, + ) + except AssertionError as e: + # Common offline error: "Need to specify Language when rescaling with baseline". + # Auto-disable baseline rescaling and retry for robustness. + msg = str(e) + if 'rescaling with baseline' in msg or 'Language' in msg: + kwargs['rescale_with_baseline'] = False + P, R, F1 = bert_score( + preds, + ref_first, + **kwargs, + ) + else: + raise + return 100.0 * float(F1.mean()) diff --git a/ICL/LV/code/experiment.md b/ICL/LV/code/experiment.md new file mode 100644 index 0000000000000000000000000000000000000000..9dfb15382e7a55a59d96b644025aa5e5207332e9 --- /dev/null +++ b/ICL/LV/code/experiment.md @@ -0,0 +1,274 @@ +# 实验手册(Unified Multi‑Model) + +本手册用于在统一代码库 E:\xiaobin\code 下复现实验(零样本 / 随机 few‑shot / 文本/视觉/多模态检索 few‑shot / 模态顺序消融 + 单项指标)。代码已对接多模型适配器,输入统一为 OpenAI 扁平片段(image→text;示例用 [REQUEST]/[RESPONSE],查询的 [RESPONSE] 留空)。 + +## 目录结构与约定 +- core/ + - prompting/openai_segments.py 扁平序列构造与图片落盘(支持 Windows 路径、BASE64、data:URL) + - datasets/m3it_reader.py M3IT 统一读取与 base64 缓存(自动探测 test/val/validation 等) + - metrics/metrics.py Token‑F1、BERTScore‑F1、解析 A/B/C… 选项字母 + - eval/ + - zero_shot_vqa.py / random_k_shot_vqa.py + - eval_textual_retriever_vqa.py / eval_visual_retriever_vqa.py / eval_multimodal_retriever_vqa.py + - order_eval_core.py(内部共享逻辑,仅被下列脚本调用) + - _modal_order.py(内部,构造三种示例内部顺序) + - eval_order_caption_bertscore.py / eval_order_caption_cider.py + - eval_order_classification_accuracy.py / eval_order_classification_f1.py + - eval_order_reasoning_accuracy.py / eval_order_reasoning_ras.py + - eval_order_vqa_bertscore.py / eval_order_vqa_tokenf1.py +- adapters/ + - idefics2_adapter.py / qwen_vl_adapter.py / qwen3vl_adapter.py / gemma3_adapter.py + - _runners/idefics2_infer.py / qwen_vl_infer.py / qwen3_vl_infer.py / gemma3_infer.py + +约定: +- 适配器统一接口:`create(model_path)` 和 `generate_from_segments(segs, temperature, top_p, max_new_tokens)`。 +- 指令由两部分组成: + 1) 内置的任务引导语(paper instruction) + 2) 数据集 `data///instructions.json`(若存在) + 两者将合并,且可选 `--instruction-image` 参与示例前的片段序列。 +- 片段序列(示例 + 查询): + - {"type":"text","text":"instruction..."}(可选) + - {"type":"image_url","image_url":""}(可选,作为引导图) + - 重复 k 次: + - {"type":"image_url","image_url":""} + - {"type":"text","text":"[REQUEST]\nQ_i\n[RESPONSE]\nA_i"} + - 查询: + - {"type":"image_url","image_url":""} + - {"type":"text","text":"[REQUEST]\nQ?\n[RESPONSE]"} + +## 环境依赖 +- Python 3.8+(推荐 3.10+) +- 必选:`torch`、`transformers`、`Pillow` +- 可选:`bert-score`(BERTScore‑F1 与 RAS 备用)、`pycocoevalcap`(CIDEr) +- 可选:roscoe(若使用 `--ras-backend roscoe`) + +安装示例(联网): +``` +pip install torch transformers pillow +pip install bert-score +pip install git+https://github.com/salaniz/pycocoevalcap # 可选 +``` + +## 数据与模型路径示例 +- 数据根:`E:\datasets\M3IT`(目录下应有 `data/`) +- 生成模型:`E:\models\idefics2-8b`、`E:\models\Qwen-VL`、`E:\models\Qwen3-VL-8B-Instruct`、`E:\models\gemma-3-4b-it` 等 +- 检索模型: + - 文本:`E:\models\roberta-large` + - 视觉:`E:\models\CLIP-ViT-L-14` + - 多模态:`E:\models\bridgetower-large-itm-mlm-itc` + +## 快速开始 +以下命令均在目录 `E:\xiaobin\code` 下执行,按需替换路径;`--adapter` 选择 `idefics2 | qwen-vl | qwen3-vl | gemma3`。 + +### 1) 零样本 VQA +``` +python3 -m core.eval.zero_shot_vqa \ + --adapter idefics2 \ + --model-path E:\models\idefics2-8b \ + --dataset-root E:\datasets\M3IT \ + --split test --total-samples 500 \ + --use-paper-instruction \ + --instruction-image "C:\\Users\\you\\instruction.png" \ + --temperature 0.2 --top-p 1.0 --max-new-tokens 32 \ + --dump-first 2 +``` + +### 2) 随机 few‑shot VQA(k 可调) +``` +python3 -m core.eval.random_k_shot_vqa \ + --adapter qwen-vl \ + --model-path E:\models\Qwen-VL \ + --dataset-root E:\datasets\M3IT \ + --split val --k-shots 3 --total-samples 500 \ + --use-paper-instruction \ + --temperature 0.6 --top-p 1.0 --max-new-tokens 32 \ + --dump-first 2 +``` + +### 3) 检索器 few‑shot VQA +- 文本检索(RoBERTa): +``` +python3 -m core.eval.eval_textual_retriever_vqa \ + --adapter idefics2 \ + --model-path E:\models\idefics2-8b \ + --dataset-root E:\datasets\M3IT \ + --retriever-model-path E:\models\roberta-large \ + --k-shots 3 --total-samples 500 --split val \ + --use-paper-instruction --temperature 0.6 --top-p 1.0 --max-new-tokens 32 +``` +- 视觉检索(CLIP): +``` +python3 -m core.eval.eval_visual_retriever_vqa \ + --adapter qwen-vl \ + --model-path E:\models\Qwen-VL \ + --dataset-root E:\datasets\M3IT \ + --retriever-model-path E:\models\CLIP-ViT-L-14 \ + --k-shots 3 --total-samples 500 --split val \ + --no-instruction --temperature 0.6 --top-p 1.0 --max-new-tokens 32 +``` +- 多模态检索(BridgeTower): +``` +python3 -m core.eval.eval_multimodal_retriever_vqa \ + --adapter gemma3 \ + --model-path E:\models\gemma-3-4b-it \ + --dataset-root E:\datasets\M3IT \ + --retriever-model-path E:\models\bridgetower-large-itm-mlm-itc \ + --k-shots 3 --total-samples 500 --split val \ + --no-instruction --temperature 0.6 --top-p 1.0 --max-new-tokens 32 +``` + +输出文件: +- runs/m3it_textual_retriever_vqa/vqa_textual_retriever_3shot.json +- runs/m3it_visual_retriever_vqa/vqa_visual_retriever_3shot.json +- runs/m3it_multimodal_retriever_vqa/vqa_multimodal_retriever_3shot.json + +### 4) 模态顺序消融(示例内部顺序) +顺序含义: +- image-text:每条示例先图片,再统一文本块 `[REQUEST]` 问题 + `[RESPONSE]` 答案 +- text-image:每条示例先统一文本块,再图片 +- text-image-text:每条示例 `[REQUEST]` → 图片 → `[RESPONSE]` + +所有脚本均固定查询为“图片→[REQUEST] 问题→[RESPONSE](空占位)”。 + +- VQA Token‑F1: +``` +python3 -m core.eval.eval_order_vqa_tokenf1 \ + --adapter idefics2 \ + --model-path E:\models\idefics2-8b \ + --dataset-root E:\datasets\M3IT \ + --retriever-model-path E:\models\bridgetower-large-itm-mlm-itc \ + --orders image-text,text-image,text-image-text \ + --k-shots 3 --total-samples 500 --split val \ + --temperature 0.6 --top-p 1.0 --max-new-tokens 32 +``` +- VQA BERTScore‑F1: +``` +python3 -m core.eval.eval_order_vqa_bertscore \ + --adapter qwen3-vl \ + --model-path E:\models\Qwen3-VL-8B-Instruct \ + --dataset-root E:\datasets\M3IT \ + --retriever-model-path E:\models\bridgetower-large-itm-mlm-itc \ + --orders image-text,text-image,text-image-text \ + --k-shots 3 --total-samples 500 --split val \ + --bertscore-model roberta-large --bertscore-batch-size 32 --bertscore-num-layers -1 +``` +- Captioning(BERTScore/CIDEr): +``` +python3 -m core.eval.eval_order_caption_bertscore --adapter idefics2 --model-path ... --dataset-root ... --retriever-model-path ... --orders image-text,text-image,text-image-text --k-shots 3 --split val --max-new-tokens 128 +python3 -m core.eval.eval_order_caption_cider --adapter idefics2 --model-path ... --dataset-root ... --retriever-model-path ... --orders image-text,text-image,text-image-text --k-shots 3 --split val --max-new-tokens 128 +``` +- Classification(Accuracy / Macro‑F1): +``` +python3 -m core.eval.eval_order_classification_accuracy --adapter qwen-vl --model-path ... --dataset-root ... --retriever-model-path ... --orders image-text,text-image,text-image-text --k-shots 3 --split val --max-new-tokens 128 +python3 -m core.eval.eval_order_classification_f1 --adapter qwen-vl --model-path ... --dataset-root ... --retriever-model-path ... --orders image-text,text-image,text-image-text --k-shots 3 --split val --max-new-tokens 128 +``` +- Reasoning(Accuracy / RAS): +``` +python3 -m core.eval.eval_order_reasoning_accuracy --adapter gemma3 --model-path ... --dataset-root ... --retriever-model-path ... --orders image-text,text-image,text-image-text --k-shots 3 --split val --max-new-tokens 128 +python3 -m core.eval.eval_order_reasoning_ras --adapter gemma3 --model-path ... --dataset-root ... --retriever-model-path ... --orders image-text,text-image,text-image-text --k-shots 3 --split val --max-new-tokens 128 --ras-backend auto +``` + +结果输出与缓存: +- 指标结果:runs/order_metrics///result.json;汇总:runs/order_metrics//summary.json +- 预测缓存:runs/order_metrics//_cache/__.jsonl,含每条样本的 `pred/answers/ref_text/meta{task,image_path,text,inputs,gold_choice,order}` + +## 常见问题与说明 +- 任务自动检测:加 `--auto-detect` 时,脚本仅对本地存在的子任务运行。 +- 图像/BASE64:`core/prompting/openai_segments.py` 会把 Windows 路径转 WSL、BASE64/data:URL 落到运行目录下 `_image_cache/`,推理使用该缓存文件。 +- 示例选择去重:避开与查询同 `id/img_id`、同 base64 指纹、同像素签名(合理阈值)、同文本;不足 k 时随机回填。 +- 生成安全:适配器统一 `generate_from_segments`;不做静默 text-only 回退,避免丢失示例图片造成偏差。 +- BERTScore:离线模型目录作为 `--bertscore-model` 时,若未能自动推断层数,可用 `--bertscore-num-layers` 指定(如 roberta-large=24)。 +- ROSCOE:`--ras-backend roscoe` 时需本地可导入的模块与函数(可用 `--roscoe-module/--roscoe-func/--roscoe-model-path` 指定),否则自动降级到 bert-score/token‑F1。 +- 随机性:`--seed` 控制检索与抽样。 + +## 片段序列(核对) +- instruction(可选文本+可选图片) +- k 条示例:每条按所选顺序(image-text / text-image / text-image-text)组织;文本块内采用 `[REQUEST]` 问题 与 `[RESPONSE]` 答案 +- 查询:固定为“图片 → `[REQUEST]` 问题 → 空 `[RESPONSE]` 占位” + +如需新增任务或适配新的基座模型,仅需: +1) 在 adapters/ 新增 {xxx}_adapter.py,实现 `create()` 与 `generate_from_segments()`; +2) 使用本手册命令,替换 `--adapter xxx` 与 `--model-path` 即可。 + +## 批量评测脚本总览(.sh) +以下脚本用于批量评测 8 个指标,默认仅评测单一模态顺序 `image-text`(可用 `--orders` 覆盖为 `image-text,text-image,text-image-text`)。所有脚本都需要传 `--model-path`,其余参数有默认值。 + +- 公共特性 + - 指标:vqa_tokenf1, vqa_bertscore, captioning_bertscore, captioning_cider, classification_accuracy, classification_f1, reasoning_accuracy, reasoning_ras + - 输出:`runs/shot_sweep_allmetrics_{adapter}/shot{K}/{metric}/` 或 `runs/order_{adapter}/`(multi_gpu 系列) + - 日志:`runs/logs/*.log` + - 0‑shot 与 1..7‑shot 完全一致,传 `--k-shots 0` 即可(无单独特殊脚本) + +### A. Shot Sweep(k=0..7) +- `run_all_metrics_shot_sweep_0_7.sh` + - 用途:基础 0..7 Shot Sweep 跑法;默认 GPU i 跑 k=i,可通过 `GPU_K_MAP="0:1,1:0,..."` 改映射 + - GPU:0..7 并行,每块 GPU 依次串行 8 指标 + - 结果:末尾打印每个 k 的 8 指标与均值汇总 + +- `run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3.sh` + - 用途:在 GPU0–3 上两波跑完 0..7(第一波 k=0..3,第二波 k=4..7) + - GPU:0..3;每波 4 个 k 并行 + +- `run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7.sh` + - 用途:在 GPU4–7 上两波跑完 0..7(第一波 k=0..3,第二波 k=4..7) + - GPU:4..7;每波 4 个 k 并行 + +### B. 型号 Wrapper(自动设定 `--adapter/--model-path/--output-base`) +- Qwen‑VL: + - `run_all_metrics_shot_sweep_0_7_qwenvl.sh`(GPU0–3 两波,转发到 A-0..3 脚本) + - `run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_qwenvl.sh`(同上,显式命名) + - `run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_qwenvl.sh`(GPU4–7 两波) + +- Gemma3: + - `run_all_metrics_shot_sweep_0_7_gemma3.sh`(GPU0–3 两波) + - `run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_gemma3.sh`(同上,显式命名) + - `run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_gemma3.sh`(GPU4–7 两波) + +- Qwen3‑VL: + - `run_all_metrics_shot_sweep_0_7_qwen3vl.sh`(基础 0..7 脚本 + 轻微 GPU 映射,默认 0↔2 置换) + - `run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_qwen3vl.sh`(GPU0–3 两波) + - `run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_qwen3vl.sh`(GPU4–7 两波) + +- Idefics2: + - `run_all_metrics_shot_sweep_0_7_idefics2.sh`(基础 0..7 脚本 + 轻微 GPU 映射,默认 0↔1 置换) + - `run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_idefics2.sh`(GPU0–3 两波) + - `run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_idefics2.sh`(GPU4–7 两波) + +用法示例(Qwen‑VL / GPU0–3): +``` +bash run_all_metrics_shot_sweep_0_7_qwenvl.sh \ + --model-path E:\\models\\Qwen-VL \ + --dataset-root E:\\datasets\\M3IT \ + --total-samples 4000 --split val --max-new-tokens 128 +``` + +### C. Multi‑GPU 单一 k‑shots(8 指标并行) +以下脚本在 8 块 GPU 上同时跑 8 个指标(一个 `K_SHOTS`,默认 3),默认 `--orders image-text,text-image,text-image-text`,并在结束后写 `overall.json` 及汇总。 + +- `run_all_multi_gpu_0_7_qwenvl.sh` + - 适配器:qwen‑vl;输出:`runs/order_qwen-vl/` + - 可调:`K_SHOTS`、`TOTAL_SAMPLES`、以及各任务配额 `TOT_VQA/TOT_CAP/TOT_CLS/TOT_RS` + +- `run_all_multi_gpu_0_7_qwen3vl.sh` + - 适配器:qwen3‑vl;输出:`runs/order_qwen3-vl/` + +- `run_all_multi_gpu_0_7_idefics2.sh` + - 适配器:idefics2;输出:`runs/order_idefics2/` + +- `run_all_multi_gpu_0_7_gemma3_4b.sh` + - 适配器:gemma3(4B);输出:`runs/order_gemma3_4b/` + +- `run_all_multi_gpu_0_7_gemma3_12b.sh` + - 适配器:gemma3(12B);输出:`runs/order_gemma3_12b/` + +用法示例(Qwen3‑VL / 单一 k=3): +``` +bash run_all_multi_gpu_0_7_qwen3vl.sh \ + --model-path E:\\models\\Qwen3-VL-8B-Instruct \ + --dataset-root E:\\datasets\\M3IT \ + --k-shots 3 --orders image-text,text-image,text-image-text \ + --total-samples 2000 --max-new-tokens 128 +``` + +提示:以上所有脚本默认只跑 `image-text`(multi_gpu 系列例外,默认三顺序)。如需改为多顺序,请显式传 `--orders`;如需更细粒度的随机/检索 few‑shot 与单项指标,请直接使用 `core/eval/*.py` 脚本。 + diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7.sh new file mode 100644 index 0000000000000000000000000000000000000000..4334af6bc9db7107c53acff789eb52024ae3adf8 --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7.sh @@ -0,0 +1,186 @@ +#!/usr/bin/env bash +# Run 8 metrics (modal-order evaluation) for k-shots = 0..7 in parallel on GPUs 0..7. +# Metrics: vqa_tokenf1, vqa_bertscore, captioning_bertscore, captioning_cider, +# classification_accuracy, classification_f1, reasoning_accuracy, reasoning_ras. +# - Strict ROSCOE backend (no fallback) for reasoning_ras. +# - total-samples defaults to 4000. +# - Prints a final summary per k-shot with all 8 scores converted to 0..100 and their average. + +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" +# Always use the current Python interpreter (preserves active conda env) +PYTHON_BIN="$(command -v python)" + +# Defaults (override by CLI or env) +ADAPTER="${ADAPTER:-qwen3-vl}" +MODEL_PATH="${MODEL_PATH:-}" # must be provided or via --model-path +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +# Default to a single fixed modal order to match the k-shot evaluator +# in QWEN3VL-code (eval_random_k_shot_vqa.py): image first, then text. +# You can still override with --orders if you need to sweep multiple orders. +ORDERS="${ORDERS:-image-text}" +TOTAL_SAMPLES="${TOTAL_SAMPLES:-4000}" +SPLIT="${SPLIT:-val}" +TEMP="${TEMP:-0.6}" +TOPP="${TOPP:-1.0}" +MAX_NEW="${MAX_NEW:-128}" +RAS_BACKEND="${RAS_BACKEND:-roscoe}" +BERTSCORE_MODEL="${BERTSCORE_MODEL:-/z_data/pretrained/syxin/roberta-large}" +BERTSCORE_LANG="${BERTSCORE_LANG:-en}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_${ADAPTER}}" +REUSE="${REUSE:---reuse-cache}" + +# ROSCOE config (auto-detect repo-local path; strict by default) +export ROSCOE_MODEL_PATH="${ROSCOE_MODEL_PATH:-/z_data/pretrained/syxin/roscoe-512-roberta-base}" +if [[ -z "${ROSCOE_PY_PATH:-}" ]]; then + if [[ -d "$ROOT_DIR/../roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/../roscoe"; fi + if [[ -z "${ROSCOE_PY_PATH:-}" && -d "$ROOT_DIR/roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/roscoe"; fi +fi +ROSCOE_PATH_ARG=""; if [[ -n "${ROSCOE_PY_PATH:-}" && -d "$ROSCOE_PY_PATH" ]]; then ROSCOE_PATH_ARG="--roscoe-path \"$ROSCOE_PY_PATH\""; fi +RAS_STRICT="${RAS_STRICT:-1}"; STRICT_FLAG=""; if [[ "$RAS_STRICT" == "1" ]]; then STRICT_FLAG="--ras-strict"; fi + +# CLI overrides +while [[ $# -gt 0 ]]; do + case "$1" in + --adapter) ADAPTER="$2"; shift 2;; + --model-path) MODEL_PATH="$2"; shift 2;; + --dataset-root) DATASET_ROOT="$2"; shift 2;; + --retriever-model) RETRIEVER_MODEL="$2"; shift 2;; + --orders) ORDERS="$2"; shift 2;; + --total-samples) TOTAL_SAMPLES="$2"; shift 2;; + --split) SPLIT="$2"; shift 2;; + --temp) TEMP="$2"; shift 2;; + --top-p) TOPP="$2"; shift 2;; + --max-new) MAX_NEW="$2"; shift 2;; + --bertscore-model) BERTSCORE_MODEL="$2"; shift 2;; + --bertscore-lang) BERTSCORE_LANG="$2"; shift 2;; + --output-base) OUTPUT_BASE="$2"; shift 2;; + --ras-backend) RAS_BACKEND="$2"; shift 2;; + *) echo "Unknown arg: $1" >&2; exit 2;; + esac +done + +if [[ -z "$MODEL_PATH" ]]; then echo "[ERR] MODEL_PATH empty. Pass --model-path" >&2; exit 2; fi +mkdir -p "$OUTPUT_BASE" runs/logs + +echo "[INFO] Adapter=$ADAPTER | Model=$MODEL_PATH | Dataset=$DATASET_ROOT | Orders=$ORDERS | total-samples=$TOTAL_SAMPLES | split=$SPLIT" + +build_cmds_for_k() { + local kshot="$1"; local outdir="$2"; + echo "[build] k-shot=$kshot -> $outdir" >&2 + cat <&2 echo "[GPU$gpu] k-shot=$kshot -> $outdir" + { + # Avoid process substitution to keep jobs as children of this shell + while IFS= read -r cmd; do + echo "[GPU$gpu][k=$kshot] $cmd" + CUDA_VISIBLE_DEVICES="$gpu" eval "$cmd" + done <<< "$(build_cmds_for_k "$kshot" "$outdir")" + } > "runs/logs/shot_allmetrics_g${gpu}.log" 2>&1 +} + +# Launch jobs +pids=() +if [[ -n "${GPU_K_MAP:-}" ]]; then + # Expected format: "0:1,1:0,2:2,3:3,4:4,5:5,6:6,7:7" (gpu:k pairs, comma-separated) + IFS=',' read -ra __pairs <<< "$GPU_K_MAP" + for kv in "${__pairs[@]}"; do + gpu="${kv%%:*}"; kshot_pair="${kv##*:}" + if [[ -n "$gpu" && -n "$kshot_pair" ]]; then + run_one_gpu "$gpu" "$kshot_pair" "$OUTPUT_BASE/shot${kshot_pair}" & + pids+=("$!") + sleep 1 + fi + done +else + # Default: GPU i runs k-shot = i (each job runs all 8 metrics sequentially) + for i in 0 1 2 3 4 5 6 7; do + run_one_gpu "$i" "$i" "$OUTPUT_BASE/shot${i}" & + pids+=("$!") + sleep 1 + done +fi + +fail=0; idx=0 +for pid in "${pids[@]}"; do + if ! wait "$pid"; then echo "[JOB $idx] failed (pid=$pid)" >&2; fail=$((fail+1)); else echo "[JOB $idx] done (pid=$pid)"; fi + idx=$((idx+1)) +done +if [[ $fail -gt 0 ]]; then echo "[WARN] Some jobs failed: $fail" >&2; fi + +# Summarize: read each metric's summary.json, scale to 0..100 if needed, then average per shot +echo "\n[SUMMARY] All 8 metrics per k-shot (0..7), 0..100 scale with AVG" +python - "$OUTPUT_BASE" <<'PY' +import json, os, sys +base = sys.argv[1] + +metrics = [ + ("vqa_tokenf1", 1.0), + ("vqa_bertscore", 1.0), + ("captioning_bertscore", 1.0), + # captioning_cider is already on a 0..100 scale in eval_order_caption_cider.py + # (pycocoevalcap returns ~0..10; we multiply by 100 there). Avoid double scaling here. + ("captioning_cider", 1.0), + ("classification_accuracy", 1.0), + ("classification_f1", 1.0), + ("reasoning_accuracy", 1.0), + # reasoning_ras is scaled to percentage in eval_order_reasoning_ras.py when needed (<=1.05) + # so we should not multiply by 100 again here. + ("reasoning_ras", 1.0), +] + +def readj(p): + try: + with open(p,'r',encoding='utf-8') as f: + return json.load(f) + except Exception: + return None + +rows = [] +for k in range(8): + shot_dir = os.path.join(base, f'shot{k}') + row = {"k": k, "metrics": {}, "avg": None} + vals = [] + for m, scale in metrics: + summ = readj(os.path.join(shot_dir, m, 'summary.json')) + if isinstance(summ, dict) and summ: + # average over orders + scores = [v for v in summ.values() if isinstance(v, (int, float))] + mv = (sum(scores)/len(scores)) if scores else None + else: + mv = None + if mv is not None: + mv = float(mv) * float(scale) + vals.append(mv) + row["metrics"][m] = mv + row["avg"] = (sum(vals)/len(vals)) if vals else None + rows.append(row) + +for r in rows: + k = r["k"] + parts = [] + for m, _ in [(m,s) for m,s in metrics]: + v = r["metrics"].get(m) + parts.append(f"{m}={('NA' if v is None else f'{v:.2f}')}") + avg = 'NA' if r['avg'] is None else f"{r['avg']:.2f}" + print(f" k={k} " + ", ".join(parts) + f", AVG={avg}") +PY + +echo "\n[INFO] Outputs under $OUTPUT_BASE/shot{0..7}/{metric}/" diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_gemma3.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_gemma3.sh new file mode 100644 index 0000000000000000000000000000000000000000..445f73b266ceef41ac3dbe81b56701d0f53a471c --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_gemma3.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +ADAPTER="gemma3" +# Default to 4B; override MODEL_PATH to target 12B or others as needed +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/gemma-3-4b-it}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_gemma3}" + +# Run on GPUs 0..3 with two waves: shots 0..3 then 4..7 +exec bash "$ROOT_DIR/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3.sh" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --dataset-root "$DATASET_ROOT" \ + --retriever-model "$RETRIEVER_MODEL" \ + --output-base "$OUTPUT_BASE" \ + "$@" diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_idefics2.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_idefics2.sh new file mode 100644 index 0000000000000000000000000000000000000000..ce040022fd5ee79fd6b6ecaa285aefa6a2fd6dce --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_idefics2.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +ADAPTER="idefics2" +# Use BF16 by default to reduce VRAM; override with IDEFICS2_DTYPE=fp16/fp32 if needed. +export IDEFICS2_DTYPE="${IDEFICS2_DTYPE:-bf16}" +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/idefics2-8b}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_idefics2}" + +# Swap GPU0 and GPU1: GPU0->k1, GPU1->k0 (others identity) +export GPU_K_MAP="0:1,1:0,2:2,3:3,4:4,5:5,6:6,7:7" +exec bash "$ROOT_DIR/run_all_metrics_shot_sweep_0_7.sh" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --dataset-root "$DATASET_ROOT" \ + --retriever-model "$RETRIEVER_MODEL" \ + --output-base "$OUTPUT_BASE" \ + "$@" diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_qwenvl.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_qwenvl.sh new file mode 100644 index 0000000000000000000000000000000000000000..1e440c6555a9ccb9e692711284a291e38ee661cf --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_qwenvl.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +ADAPTER="qwen-vl" +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/Qwen-VL}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_qwen-vl}" + +# Run on GPUs 0..3 with two waves: shots 0..3 then 4..7 +exec bash "$ROOT_DIR/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3.sh" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --dataset-root "$DATASET_ROOT" \ + --retriever-model "$RETRIEVER_MODEL" \ + --output-base "$OUTPUT_BASE" \ + "$@" diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3.sh new file mode 100644 index 0000000000000000000000000000000000000000..3c766b749a64b7752facd9901f43a657973306b5 --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3.sh @@ -0,0 +1,185 @@ +#!/usr/bin/env bash +# Run 8 metrics (modal-order evaluation) for k-shots = 0..7 using GPUs 0..3. +# - Wave 1: shots 0..3 on GPUs 0,1,2,3 respectively (in parallel) +# - Wave 2: shots 4..7 on GPUs 0,1,2,3 respectively (after wave 1 completes) +# Metrics: vqa_tokenf1, vqa_bertscore, captioning_bertscore, captioning_cider, +# classification_accuracy, classification_f1, reasoning_accuracy, reasoning_ras. +# - Strict ROSCOE backend (no fallback) for reasoning_ras. +# - total-samples defaults to 4000. +# - Prints a final summary per k-shot with all 8 scores converted to 0..100 and their average. + +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" +# Always use the current Python interpreter (preserves active conda env) +PYTHON_BIN="$(command -v python)" + +# Defaults (override by CLI or env) +ADAPTER="${ADAPTER:-qwen3-vl}" +MODEL_PATH="${MODEL_PATH:-}" # must be provided or via --model-path +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +# Default to a single fixed modal order to align with eval_random_k_shot_vqa.py +# (image first, then text). Override via --orders to sweep multiple orders if needed. +ORDERS="${ORDERS:-image-text}" +TOTAL_SAMPLES="${TOTAL_SAMPLES:-4000}" +SPLIT="${SPLIT:-val}" +TEMP="${TEMP:-0.6}" +TOPP="${TOPP:-1.0}" +MAX_NEW="${MAX_NEW:-128}" +RAS_BACKEND="${RAS_BACKEND:-roscoe}" +BERTSCORE_MODEL="${BERTSCORE_MODEL:-/z_data/pretrained/syxin/roberta-large}" +BERTSCORE_LANG="${BERTSCORE_LANG:-en}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_${ADAPTER}}" +REUSE="${REUSE:---reuse-cache}" + +# ROSCOE config (auto-detect repo-local path; strict by default) +export ROSCOE_MODEL_PATH="${ROSCOE_MODEL_PATH:-/z_data/pretrained/syxin/roscoe-512-roberta-base}" +if [[ -z "${ROSCOE_PY_PATH:-}" ]]; then + if [[ -d "$ROOT_DIR/../roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/../roscoe"; fi + if [[ -z "${ROSCOE_PY_PATH:-}" && -d "$ROOT_DIR/roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/roscoe"; fi +fi +ROSCOE_PATH_ARG=""; if [[ -n "${ROSCOE_PY_PATH:-}" && -d "$ROSCOE_PY_PATH" ]]; then ROSCOE_PATH_ARG="--roscoe-path \"$ROSCOE_PY_PATH\""; fi +RAS_STRICT="${RAS_STRICT:-1}"; STRICT_FLAG=""; if [[ "$RAS_STRICT" == "1" ]]; then STRICT_FLAG="--ras-strict"; fi + +# CLI overrides +while [[ $# -gt 0 ]]; do + case "$1" in + --adapter) ADAPTER="$2"; shift 2;; + --model-path) MODEL_PATH="$2"; shift 2;; + --dataset-root) DATASET_ROOT="$2"; shift 2;; + --retriever-model) RETRIEVER_MODEL="$2"; shift 2;; + --orders) ORDERS="$2"; shift 2;; + --total-samples) TOTAL_SAMPLES="$2"; shift 2;; + --split) SPLIT="$2"; shift 2;; + --temp) TEMP="$2"; shift 2;; + --top-p) TOPP="$2"; shift 2;; + --max-new) MAX_NEW="$2"; shift 2;; + --bertscore-model) BERTSCORE_MODEL="$2"; shift 2;; + --bertscore-lang) BERTSCORE_LANG="$2"; shift 2;; + --output-base) OUTPUT_BASE="$2"; shift 2;; + --ras-backend) RAS_BACKEND="$2"; shift 2;; + *) echo "Unknown arg: $1" >&2; exit 2;; + esac +done + +if [[ -z "$MODEL_PATH" ]]; then echo "[ERR] MODEL_PATH empty. Pass --model-path" >&2; exit 2; fi +mkdir -p "$OUTPUT_BASE" runs/logs + +echo "[INFO] Adapter=$ADAPTER | Model=$MODEL_PATH | Dataset=$DATASET_ROOT | Orders=$ORDERS | total-samples=$TOTAL_SAMPLES | split=$SPLIT" + +# Build the per-shot command list (identical to the 0..7 base runner) +build_cmds_for_k() { + local kshot="$1"; local outdir="$2"; + echo "[build] k-shot=$kshot -> $outdir" >&2 + cat <&2 echo "[GPU$gpu] k-shot=$kshot -> $outdir" + { + # Each command runs sequentially on the assigned GPU + while IFS= read -r cmd; do + echo "[GPU$gpu][k=$kshot] $cmd" + CUDA_VISIBLE_DEVICES="$gpu" eval "$cmd" + done <<< "$(build_cmds_for_k "$kshot" "$outdir")" + } > "runs/logs/shot_allmetrics_g${gpu}_k${kshot}.log" 2>&1 +} + +# Wave launcher for a list of (gpu,k) pairs +launch_wave() { + local -a pairs=("$@") + local pids=() + local desc="" + for pair in "${pairs[@]}"; do + local gpu="${pair%%:*}" + local k="${pair##*:}" + desc+=" (GPU${gpu}->k${k})" + run_one_gpu "$gpu" "$k" "$OUTPUT_BASE/shot${k}" & + pids+=("$!") + sleep 1 + done + echo "[INFO] Launched wave:${desc}" + local fail=0; local idx=0 + for pid in "${pids[@]}"; do + if ! wait "$pid"; then echo "[WAVE] job $idx failed (pid=$pid)" >&2; fail=$((fail+1)); else echo "[WAVE] job $idx done (pid=$pid)"; fi + idx=$((idx+1)) + done + if [[ $fail -gt 0 ]]; then echo "[WARN] Some jobs failed in this wave: $fail" >&2; fi +} + +# Two waves on GPUs 0..3 +launch_wave 0:0 1:1 2:2 3:3 +launch_wave 0:4 1:5 2:6 3:7 + +# Summarize: read each metric's summary.json, scale to 0..100 if needed, then average per shot +echo "\n[SUMMARY] All 8 metrics per k-shot (0..7), 0..100 scale with AVG" +python - "$OUTPUT_BASE" <<'PY' +import json, os, sys + +base = sys.argv[1] + +metrics = [ + ("vqa_tokenf1", 1.0), + ("vqa_bertscore", 1.0), + ("captioning_bertscore", 1.0), + # Already scaled to 0..100 in eval_order_caption_cider.py + ("captioning_cider", 1.0), + ("classification_accuracy", 1.0), + ("classification_f1", 1.0), + ("reasoning_accuracy", 1.0), + # Already scaled to 0..100 in eval_order_reasoning_ras.py when needed + ("reasoning_ras", 1.0), +] + +def readj(p): + try: + with open(p,'r',encoding='utf-8') as f: + return json.load(f) + except Exception: + return None + +rows = [] +for k in range(8): + shot_dir = os.path.join(base, f'shot{k}') + row = {"k": k, "metrics": {}, "avg": None} + vals = [] + for m, scale in metrics: + summ = readj(os.path.join(shot_dir, m, 'summary.json')) + if isinstance(summ, dict) and summ: + # average over orders when multiple orders are evaluated + scores = [v for v in summ.values() if isinstance(v, (int, float))] + mv = (sum(scores)/len(scores)) if scores else None + else: + mv = None + if mv is not None: + mv = float(mv) * float(scale) + vals.append(mv) + row["metrics"][m] = mv + row["avg"] = (sum(vals)/len(vals)) if vals else None + rows.append(row) + +for r in rows: + k = r["k"] + parts = [] + for m, _ in [(m,s) for m,s in metrics]: + v = r["metrics"].get(m) + parts.append(f"{m}={('NA' if v is None else f'{v:.2f}')}") + avg = 'NA' if r['avg'] is None else f"{r['avg']:.2f}" + print(f" k={k} " + ", ".join(parts) + f", AVG={avg}") +PY + +echo "\n[INFO] Outputs under $OUTPUT_BASE/shot{0..7}/{metric}/" + diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_gemma3.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_gemma3.sh new file mode 100644 index 0000000000000000000000000000000000000000..3028d3ad4da2938289c722e0fb7fa9f71f551834 --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_gemma3.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +ADAPTER="gemma3" +# Default to 4B; override MODEL_PATH to target 12B or others as needed +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/gemma-3-4b-it}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_gemma3}" + +exec bash "$ROOT_DIR/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3.sh" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --dataset-root "$DATASET_ROOT" \ + --retriever-model "$RETRIEVER_MODEL" \ + --output-base "$OUTPUT_BASE" \ + "$@" + diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_idefics2.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_idefics2.sh new file mode 100644 index 0000000000000000000000000000000000000000..66e8e9f7f0aa49e7329f9616fdaf62fb7c8656e4 --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_idefics2.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +ADAPTER="idefics2" +# Use BF16 by default to reduce VRAM; override with IDEFICS2_DTYPE=fp16/fp32 if needed. +export IDEFICS2_DTYPE="${IDEFICS2_DTYPE:-bf16}" +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/idefics2-8b}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_idefics2}" + +exec bash "$ROOT_DIR/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3.sh" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --dataset-root "$DATASET_ROOT" \ + --retriever-model "$RETRIEVER_MODEL" \ + --output-base "$OUTPUT_BASE" \ + "$@" diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_qwen3vl.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_qwen3vl.sh new file mode 100644 index 0000000000000000000000000000000000000000..4f91ea72db3e529fb95c66306cefcf4877d87666 --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_qwen3vl.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +ADAPTER="qwen3-vl" +MODEL_PATH="${MODEL_PATH:-/workspace/Qwen3-VL-8B-Instruct}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_qwen3-vl}" + +exec bash "$ROOT_DIR/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3.sh" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --dataset-root "$DATASET_ROOT" \ + --retriever-model "$RETRIEVER_MODEL" \ + --output-base "$OUTPUT_BASE" \ + "$@" + diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_qwenvl.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_qwenvl.sh new file mode 100644 index 0000000000000000000000000000000000000000..aeb0cfb25915054526615d88db30ed41b8eba701 --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3_qwenvl.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +ADAPTER="qwen-vl" +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/Qwen-VL}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_qwen-vl}" + +exec bash "$ROOT_DIR/run_all_metrics_shot_sweep_0_7_two_waves_gpu0_3.sh" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --dataset-root "$DATASET_ROOT" \ + --retriever-model "$RETRIEVER_MODEL" \ + --output-base "$OUTPUT_BASE" \ + "$@" + diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_gemma3.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_gemma3.sh new file mode 100644 index 0000000000000000000000000000000000000000..4d915ee34f2e2565a2942abca5d7a6abd282be17 --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_gemma3.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +ADAPTER="gemma3" +# Default to 4B; override MODEL_PATH to target 12B or others as needed +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/gemma-3-4b-it}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_gemma3}" + +exec bash "$ROOT_DIR/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7.sh" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --dataset-root "$DATASET_ROOT" \ + --retriever-model "$RETRIEVER_MODEL" \ + --output-base "$OUTPUT_BASE" \ + "$@" + diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_gemma3_12b.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_gemma3_12b.sh new file mode 100644 index 0000000000000000000000000000000000000000..472358d69dc4e2802d11a1886e3b67d1ddd18ede --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_gemma3_12b.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +ADAPTER="gemma3" +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/gemma-3-12b-it}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_gemma3_12b}" + +exec bash "$ROOT_DIR/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7.sh" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --dataset-root "$DATASET_ROOT" \ + --retriever-model "$RETRIEVER_MODEL" \ + --output-base "$OUTPUT_BASE" \ + "$@" + diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_idefics2.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_idefics2.sh new file mode 100644 index 0000000000000000000000000000000000000000..06623b9d94952926cf65f9d4a67b5687dd8716fc --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_idefics2.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +ADAPTER="idefics2" +# Use BF16 by default to reduce VRAM; override with IDEFICS2_DTYPE=fp16/fp32 if needed. +export IDEFICS2_DTYPE="${IDEFICS2_DTYPE:-bf16}" +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/idefics2-8b}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_idefics2}" + +exec bash "$ROOT_DIR/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7.sh" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --dataset-root "$DATASET_ROOT" \ + --retriever-model "$RETRIEVER_MODEL" \ + --output-base "$OUTPUT_BASE" \ + "$@" diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_qwen3vl.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_qwen3vl.sh new file mode 100644 index 0000000000000000000000000000000000000000..e57f8ca72fcb4e334566ba7eea4d2eed9909acc4 --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_qwen3vl.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +ADAPTER="qwen3-vl" +MODEL_PATH="${MODEL_PATH:-/workspace/Qwen3-VL-8B-Instruct}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_qwen3-vl}" + +exec bash "$ROOT_DIR/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7.sh" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --dataset-root "$DATASET_ROOT" \ + --retriever-model "$RETRIEVER_MODEL" \ + --output-base "$OUTPUT_BASE" \ + "$@" + diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_qwenvl.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_qwenvl.sh new file mode 100644 index 0000000000000000000000000000000000000000..b4ade717cb18d05b85cd342de656b7e578a58709 --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7_qwenvl.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +ADAPTER="qwen-vl" +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/Qwen-VL}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_qwen-vl}" + +exec bash "$ROOT_DIR/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7.sh" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --dataset-root "$DATASET_ROOT" \ + --retriever-model "$RETRIEVER_MODEL" \ + --output-base "$OUTPUT_BASE" \ + "$@" + diff --git a/ICL/LV/code/run_all_metrics_single_gpu_k3_n2000_3orders.sh b/ICL/LV/code/run_all_metrics_single_gpu_k3_n2000_3orders.sh new file mode 100644 index 0000000000000000000000000000000000000000..68ec3ebac1c93c6db15eb3fe3883940dc40dcbd4 --- /dev/null +++ b/ICL/LV/code/run_all_metrics_single_gpu_k3_n2000_3orders.sh @@ -0,0 +1,172 @@ +#!/usr/bin/env bash +# Single-GPU modal-order evaluation (8 metrics) for a single model. +# Fixed setting: +# - k-shots = 3 +# - total-samples = 2000 (per metric script; distributed across tasks internally) +# - orders = image-text,text-image,text-image-text +# +# Usage: +# bash run_all_metrics_single_gpu_k3_n2000_3orders.sh --adapter idefics2 --model-path /path/to/model --gpu 4 +# +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +# Always use the current Python interpreter (preserves active conda env) +PYTHON_BIN="$(command -v python)" + +# Required +ADAPTER="" +MODEL_PATH="" + +# Fixed experiment setting +K_SHOTS=3 +TOTAL_SAMPLES=2000 +ORDERS="image-text,text-image,text-image-text" + +# Defaults (override via CLI flags where allowed) +GPU="${GPU:-}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +BERTSCORE_MODEL="${BERTSCORE_MODEL:-/z_data/pretrained/syxin/xlm-roberta-large}" +BERTSCORE_LANG="${BERTSCORE_LANG:-}" +SPLIT="${SPLIT:-val}" +TEMP="${TEMP:-0.6}" +TOPP="${TOPP:-1.0}" +MAX_NEW="${MAX_NEW:-128}" +RAS_BACKEND="${RAS_BACKEND:-roscoe}" +RAS_STRICT="${RAS_STRICT:-1}" +OUTPUT_BASE="${OUTPUT_BASE:-}" +LOG_DIR="${LOG_DIR:-runs/logs}" +REUSE="${REUSE:---reuse-cache}" + +# ROSCOE config (auto-detect repo-local path; strict by default) +export ROSCOE_MODEL_PATH="${ROSCOE_MODEL_PATH:-/z_data/pretrained/syxin/roscoe-512-roberta-base}" +if [[ -z "${ROSCOE_PY_PATH:-}" ]]; then + if [[ -d "$ROOT_DIR/roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/roscoe"; fi + if [[ -z "${ROSCOE_PY_PATH:-}" && -d "$ROOT_DIR/../roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/../roscoe"; fi +fi +ROSCOE_PATH_ARG="" +if [[ -n "${ROSCOE_PY_PATH:-}" && -d "$ROSCOE_PY_PATH" ]]; then + ROSCOE_PATH_ARG="--roscoe-path \"$ROSCOE_PY_PATH\"" +fi +STRICT_FLAG=""; if [[ "$RAS_STRICT" == "1" ]]; then STRICT_FLAG="--ras-strict"; fi + +usage() { + cat <<'EOF' +Usage: + bash run_all_metrics_single_gpu_k3_n2000_3orders.sh --adapter --model-path [--gpu N] [--output-base DIR] + +Fixed setting (hard-coded): + k-shots=3 + total-samples=2000 + orders=image-text,text-image,text-image-text +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --adapter) ADAPTER="$2"; shift 2;; + --model-path) MODEL_PATH="$2"; shift 2;; + --gpu) GPU="$2"; shift 2;; + --dataset-root) DATASET_ROOT="$2"; shift 2;; + --retriever-model) RETRIEVER_MODEL="$2"; shift 2;; + --split) SPLIT="$2"; shift 2;; + --temp) TEMP="$2"; shift 2;; + --top-p) TOPP="$2"; shift 2;; + --max-new) MAX_NEW="$2"; shift 2;; + --bertscore-model) BERTSCORE_MODEL="$2"; shift 2;; + --bertscore-lang) BERTSCORE_LANG="$2"; shift 2;; + --ras-backend) RAS_BACKEND="$2"; shift 2;; + --output-base) OUTPUT_BASE="$2"; shift 2;; + -h|--help) usage; exit 0;; + --orders|--k-shots|--total-samples) + echo "[ERR] $1 is fixed in this script (k=3, total=2000, orders=3). Edit the script if you really need overrides." >&2 + exit 2 + ;; + *) echo "Unknown arg: $1" >&2; usage; exit 2;; + esac +done + +if [[ -z "$ADAPTER" ]]; then echo "[ERR] Missing --adapter" >&2; usage; exit 2; fi +if [[ -z "$MODEL_PATH" ]]; then echo "[ERR] Missing --model-path" >&2; usage; exit 2; fi + +if [[ -z "${CUDA_VISIBLE_DEVICES:-}" ]]; then + if [[ -n "${GPU:-}" ]]; then + export CUDA_VISIBLE_DEVICES="$GPU" + fi +fi + +if [[ -z "$OUTPUT_BASE" ]]; then + gpu_tag="${CUDA_VISIBLE_DEVICES:-na}" + OUTPUT_BASE="runs/order_${ADAPTER}_k3_n2000_3orders_gpu${gpu_tag}" +fi + +mkdir -p "$OUTPUT_BASE" "$LOG_DIR" runs/summaries + +echo "[INFO] Adapter=$ADAPTER | Model=$MODEL_PATH | Dataset=$DATASET_ROOT | Orders=$ORDERS | k=$K_SHOTS | total=$TOTAL_SAMPLES | split=$SPLIT | GPU=${CUDA_VISIBLE_DEVICES:-unset}" + +run() { + local name="$1"; shift + local log="$1"; shift + echo "[INFO] $(date +'%F %T') GPU=${CUDA_VISIBLE_DEVICES:-unset} $name" + eval "$*" >"$log" 2>&1 +} + +run "VQA Token-F1" "$LOG_DIR/${ADAPTER}_gpu${CUDA_VISIBLE_DEVICES:-x}_vqa_tokenf1.log" \ + "$PYTHON_BIN -m core.eval.eval_order_vqa_tokenf1 --adapter '$ADAPTER' --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOTAL_SAMPLES --split '$SPLIT' --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE' $REUSE" + +run "VQA BERTScore" "$LOG_DIR/${ADAPTER}_gpu${CUDA_VISIBLE_DEVICES:-x}_vqa_bertscore.log" \ + "$PYTHON_BIN -m core.eval.eval_order_vqa_bertscore --adapter '$ADAPTER' --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOTAL_SAMPLES --split '$SPLIT' --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --bertscore-model '$BERTSCORE_MODEL' --no-bertscore-baseline --bertscore-lang '$BERTSCORE_LANG' --strict-bertscore --output-dir '$OUTPUT_BASE' $REUSE" + +run "Caption BERTScore" "$LOG_DIR/${ADAPTER}_gpu${CUDA_VISIBLE_DEVICES:-x}_caption_bertscore.log" \ + "$PYTHON_BIN -m core.eval.eval_order_caption_bertscore --adapter '$ADAPTER' --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOTAL_SAMPLES --split '$SPLIT' --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --bertscore-model '$BERTSCORE_MODEL' --no-bertscore-baseline --bertscore-lang '$BERTSCORE_LANG' --strict-bertscore --output-dir '$OUTPUT_BASE' $REUSE" + +run "Caption CIDEr" "$LOG_DIR/${ADAPTER}_gpu${CUDA_VISIBLE_DEVICES:-x}_caption_cider.log" \ + "$PYTHON_BIN -m core.eval.eval_order_caption_cider --adapter '$ADAPTER' --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOTAL_SAMPLES --split '$SPLIT' --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE' $REUSE" + +run "Classification Accuracy" "$LOG_DIR/${ADAPTER}_gpu${CUDA_VISIBLE_DEVICES:-x}_cls_accuracy.log" \ + "$PYTHON_BIN -m core.eval.eval_order_classification_accuracy --adapter '$ADAPTER' --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOTAL_SAMPLES --split '$SPLIT' --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE' $REUSE" + +run "Classification F1" "$LOG_DIR/${ADAPTER}_gpu${CUDA_VISIBLE_DEVICES:-x}_cls_f1.log" \ + "$PYTHON_BIN -m core.eval.eval_order_classification_f1 --adapter '$ADAPTER' --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOTAL_SAMPLES --split '$SPLIT' --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE' $REUSE" + +run "Reasoning Accuracy" "$LOG_DIR/${ADAPTER}_gpu${CUDA_VISIBLE_DEVICES:-x}_reasoning_accuracy.log" \ + "$PYTHON_BIN -m core.eval.eval_order_reasoning_accuracy --adapter '$ADAPTER' --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOTAL_SAMPLES --split '$SPLIT' --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE' $REUSE" + +run "Reasoning RAS" "$LOG_DIR/${ADAPTER}_gpu${CUDA_VISIBLE_DEVICES:-x}_reasoning_ras.log" \ + "$PYTHON_BIN -m core.eval.eval_order_reasoning_ras --adapter '$ADAPTER' --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOTAL_SAMPLES --split '$SPLIT' --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --ras-backend $RAS_BACKEND $STRICT_FLAG $ROSCOE_PATH_ARG --roscoe-module core.metrics.roscoe_shim --roscoe-func evaluate --bertscore-model '$BERTSCORE_MODEL' --no-bertscore-baseline --bertscore-lang '$BERTSCORE_LANG' --no-bertscore-fallback --output-dir '$OUTPUT_BASE' $REUSE" + +echo "\n[INFO] All metrics finished. Summaries:" +show() { local name="$1"; local path="$2"; echo "\n== $name =="; if [ -f "$path" ]; then cat "$path"; else echo "(summary not found: $path)"; fi; } +show "VQA Token-F1" "$OUTPUT_BASE/vqa_tokenf1/summary.json" +show "VQA BERTScore" "$OUTPUT_BASE/vqa_bertscore/summary.json" +show "Caption BERTScore" "$OUTPUT_BASE/captioning_bertscore/summary.json" +show "Caption CIDEr" "$OUTPUT_BASE/captioning_cider/summary.json" +show "Classification Acc" "$OUTPUT_BASE/classification_accuracy/summary.json" +show "Classification F1" "$OUTPUT_BASE/classification_f1/summary.json" +show "Reasoning Accuracy" "$OUTPUT_BASE/reasoning_accuracy/summary.json" +show "Reasoning RAS" "$OUTPUT_BASE/reasoning_ras/summary.json" +echo "\n[INFO] Logs under $LOG_DIR" + +echo "\n[INFO] Writing per-order summaries and overall.json under $OUTPUT_BASE ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/summarize_by_order.py" \ + --output-base "$OUTPUT_BASE" \ + --orders "$ORDERS" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --k-shots "$K_SHOTS" \ + --split "$SPLIT" \ + --ras-mul 100 \ + --ras-auto-scale || true + +echo "[INFO] Updating aggregated top/bottom (RAS x100 with auto-scale) ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/collect_all_scores.py" \ + --output-dir runs/summaries \ + --topk 5 \ + --ras-mul 100 \ + --ras-auto-scale || true + +echo "[INFO] overall: $OUTPUT_BASE/overall.json" diff --git a/ICL/LV/code/run_all_multi_gpu_0_7_gemma3_12b.sh b/ICL/LV/code/run_all_multi_gpu_0_7_gemma3_12b.sh new file mode 100644 index 0000000000000000000000000000000000000000..6c4eecdd624c187310a3575a25d42f44328848f6 --- /dev/null +++ b/ICL/LV/code/run_all_multi_gpu_0_7_gemma3_12b.sh @@ -0,0 +1,140 @@ +#!/usr/bin/env bash +# Run 8 modal-order evaluation tasks on GPUs 0–7 in parallel for a single model (gemma3). +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" +PYTHON_BIN="${PYTHON_BIN:-python}" +export ROSCOE_MODEL_PATH="${ROSCOE_MODEL_PATH:-/z_data/pretrained/syxin/roscoe-512-roberta-base}" +# Prefer repo-local roscoe path; shim can also auto-detect, so this is optional +if [[ -z "${ROSCOE_PY_PATH:-}" ]]; then + if [[ -d "$ROOT_DIR/roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/roscoe"; fi + if [[ -z "${ROSCOE_PY_PATH:-}" && -d "$ROOT_DIR/../roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/../roscoe"; fi +fi +ROSCOE_PATH_ARG=""; if [[ -n "${ROSCOE_PY_PATH:-}" && -d "$ROSCOE_PY_PATH" ]]; then ROSCOE_PATH_ARG="--roscoe-path \"$ROSCOE_PY_PATH\""; fi + +# Allow toggling strict mode via env (0/1); default off to allow graceful fallback +RAS_STRICT="${RAS_STRICT:-1}" +STRICT_FLAG=""; if [[ "$RAS_STRICT" == "1" ]]; then STRICT_FLAG="--ras-strict"; fi + +ADAPTER="gemma3" +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/gemma-3-12b-it}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +BERTSCORE_MODEL="${BERTSCORE_MODEL:-/z_data/pretrained/syxin/xlm-roberta-large}" +BERTSCORE_LANG="${BERTSCORE_LANG:-}" +ORDERS="${ORDERS:-image-text,text-image,text-image-text}" +TOTAL_SAMPLES="${TOTAL_SAMPLES:-2000}" +K_SHOTS="${K_SHOTS:-3}" +SPLIT="${SPLIT:-val}" +TEMP="${TEMP:-0.6}" +TOPP="${TOPP:-1.0}" +MAX_NEW="${MAX_NEW:-128}" +RAS_BACKEND="${RAS_BACKEND:-roscoe}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/order_gemma3_12b}" +TOT_CAP="${TOT_CAP:-2000}" +TOT_VQA="${TOT_VQA:-2000}" +TOT_CLS="${TOT_CLS:-2000}" +TOT_RS="${TOT_RS:-2000}" +REUSE="--reuse-cache" + +# Allow CLI overrides +while [[ $# -gt 0 ]]; do + case "$1" in + --model-path) MODEL_PATH="$2"; shift 2;; + --dataset-root) DATASET_ROOT="$2"; shift 2;; + --retriever-model) RETRIEVER_MODEL="$2"; shift 2;; + --orders) ORDERS="$2"; shift 2;; + --total-samples) TOTAL_SAMPLES="$2"; shift 2;; + --k-shots) K_SHOTS="$2"; shift 2;; + --split) SPLIT="$2"; shift 2;; + --temp) TEMP="$2"; shift 2;; + --top-p) TOPP="$2"; shift 2;; + --max-new) MAX_NEW="$2"; shift 2;; + --bertscore-model) BERTSCORE_MODEL="$2"; shift 2;; + --bertscore-lang) BERTSCORE_LANG="$2"; shift 2;; + --ras-backend) RAS_BACKEND="$2"; shift 2;; + --output-base) OUTPUT_BASE="$2"; shift 2;; + *) echo "Unknown arg: $1" >&2; exit 2;; + esac +done + +if [[ -z "$MODEL_PATH" ]]; then echo "[ERR] MODEL_PATH empty. Pass --model-path" >&2; exit 2; fi +mkdir -p "$OUTPUT_BASE" logs + +build_cmd_for_gpu(){ + local gpu="$1"; + case "$gpu" in + 0) echo "$PYTHON_BIN -m core.eval.eval_order_vqa_tokenf1 --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_VQA --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE' $REUSE" ;; + 1) echo "$PYTHON_BIN -m core.eval.eval_order_vqa_bertscore --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_VQA --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --bertscore-model '$BERTSCORE_MODEL' --no-bertscore-baseline --bertscore-lang '$BERTSCORE_LANG' --strict-bertscore --output-dir '$OUTPUT_BASE' $REUSE" ;; + 2) echo "$PYTHON_BIN -m core.eval.eval_order_caption_bertscore --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_CAP --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --bertscore-model '$BERTSCORE_MODEL' --no-bertscore-baseline --bertscore-lang '$BERTSCORE_LANG' --strict-bertscore --output-dir '$OUTPUT_BASE' $REUSE" ;; + 3) echo "$PYTHON_BIN -m core.eval.eval_order_caption_cider --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_CAP --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE' $REUSE" ;; + 4) echo "$PYTHON_BIN -m core.eval.eval_order_classification_accuracy --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_CLS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE' $REUSE" ;; + 5) echo "$PYTHON_BIN -m core.eval.eval_order_classification_f1 --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_CLS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE' $REUSE" ;; + 6) echo "$PYTHON_BIN -m core.eval.eval_order_reasoning_accuracy --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_RS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE' $REUSE" ;; + 7) echo "$PYTHON_BIN -m core.eval.eval_order_reasoning_ras --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_RS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --ras-backend $RAS_BACKEND $STRICT_FLAG $ROSCOE_PATH_ARG --roscoe-module core.metrics.roscoe_shim --roscoe-func evaluate --bertscore-model '$BERTSCORE_MODEL' --no-bertscore-baseline --bertscore-lang '$BERTSCORE_LANG' --no-bertscore-fallback --output-dir '$OUTPUT_BASE' $REUSE" ;; + *) echo "Invalid GPU index $gpu" >&2; return 1;; + esac +} + +pids=() +for gpu in 0 1 2 3 4 5 6 7; do + cmd=$(build_cmd_for_gpu "$gpu") + echo "[GPU${gpu}] $ADAPTER -> $MODEL_PATH"; echo "$cmd" + CUDA_VISIBLE_DEVICES=$gpu nohup bash -lc "$cmd" > "logs/${ADAPTER}_g${gpu}.log" 2>&1 & + pids+=("$!") + sleep 1 +done + +fail=0; i=0 +for pid in "${pids[@]}"; do + if ! wait "$pid"; then echo "[JOB $i] failed (pid=$pid)"; fail=$((fail+1)); else echo "[JOB $i] done (pid=$pid)"; fi + i=$((i+1)) +done + +if [[ $fail -gt 0 ]]; then echo "Some jobs failed: $fail" >&2; exit 1; fi + +echo "All 8 tasks finished. Logs in ./logs, outputs under $OUTPUT_BASE." + +# Write an overall.json for this run and update global top/bottom summary +echo "\n[INFO] Writing per-order summaries and overall.json under $OUTPUT_BASE ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/summarize_by_order.py" \ + --output-base "$OUTPUT_BASE" \ + --orders "$ORDERS" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --k-shots "$K_SHOTS" \ + --split "$SPLIT" \ + --ras-mul 100 \ + --ras-auto-scale || true + +echo "[INFO] Updating aggregated top/bottom (RAS x100 with auto-scale) ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/collect_all_scores.py" \ + --output-dir runs/summaries \ + --topk 5 \ + --ras-mul 100 \ + --ras-auto-scale || true + +echo "[INFO] overall: $OUTPUT_BASE/overall.json" +echo "[INFO] aggregated: runs/summaries/top_bottom.json (and runs/summaries/all_scores.json)" + +# Dump per-sample scores for this run (including RAS x100 with auto-scale) +echo "[INFO] Dumping per-sample scores under $OUTPUT_BASE/per_sample ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/dump_per_sample_scores.py" \ + --output-base "$OUTPUT_BASE" \ + --bertscore-model "$BERTSCORE_MODEL" \ + --bertscore-lang "$BERTSCORE_LANG" \ + --bertscore-num-layers 24 \ + --bertscore-batch-size 16 \ + --strict-bertscore \ + --roscoe-model-path "$ROSCOE_MODEL_PATH" \ + --ras-scale-100 + +# Select extremes by TOTAL (sum over 8 metrics; BERTScore zeros ignored) +echo "[INFO] Selecting TOTAL extremes (by per-sample aggregated total) ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/select_total_extremes.py" \ + --output-base "$OUTPUT_BASE" \ + --k 5 \ + --ignore-bert-zero \ + --skip-bert-fallback || true +echo "[INFO] total-extremes: $OUTPUT_BASE/per_sample/extremes_total.json" diff --git a/ICL/LV/code/run_all_multi_gpu_0_7_gemma3_4b.sh b/ICL/LV/code/run_all_multi_gpu_0_7_gemma3_4b.sh new file mode 100644 index 0000000000000000000000000000000000000000..c76e64eced71af7ff08e3ec7b9bec3c7c6e8e05d --- /dev/null +++ b/ICL/LV/code/run_all_multi_gpu_0_7_gemma3_4b.sh @@ -0,0 +1,197 @@ +#!/usr/bin/env bash +# Run 8 modal-order evaluation tasks on GPUs 0–7 in parallel for a single model (gemma3). +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" +PYTHON_BIN="${PYTHON_BIN:-python}" +export ROSCOE_MODEL_PATH="${ROSCOE_MODEL_PATH:-/z_data/pretrained/syxin/roscoe-512-roberta-base}" +# Prefer repo-local roscoe path; shim can also auto-detect, so this is optional +if [[ -z "${ROSCOE_PY_PATH:-}" ]]; then + if [[ -d "$ROOT_DIR/roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/roscoe"; fi + if [[ -z "${ROSCOE_PY_PATH:-}" && -d "$ROOT_DIR/../roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/../roscoe"; fi +fi +ROSCOE_PATH_ARG=""; if [[ -n "${ROSCOE_PY_PATH:-}" && -d "$ROSCOE_PY_PATH" ]]; then ROSCOE_PATH_ARG="--roscoe-path \"$ROSCOE_PY_PATH\""; fi + +# Allow toggling strict mode via env (0/1); default off to allow graceful fallback +RAS_STRICT="${RAS_STRICT:-1}" +STRICT_FLAG=""; if [[ "$RAS_STRICT" == "1" ]]; then STRICT_FLAG="--ras-strict"; fi + +ADAPTER="gemma3" +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/gemma-3-4b-it}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +BERTSCORE_MODEL="${BERTSCORE_MODEL:-/z_data/pretrained/syxin/xlm-roberta-large}" +BERTSCORE_LANG="${BERTSCORE_LANG:-}" +ORDERS="${ORDERS:-image-text,text-image,text-image-text}" +TOTAL_SAMPLES="${TOTAL_SAMPLES:-2000}" +K_SHOTS="${K_SHOTS:-3}" +SPLIT="${SPLIT:-val}" +TEMP="${TEMP:-0.6}" +TOPP="${TOPP:-1.0}" +MAX_NEW="${MAX_NEW:-128}" +RAS_BACKEND="${RAS_BACKEND:-roscoe}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/order_gemma3_4b}" +TOT_CAP="${TOT_CAP:-2000}" +TOT_VQA="${TOT_VQA:-2000}" +TOT_CLS="${TOT_CLS:-2000}" +TOT_RS="${TOT_RS:-2000}" +REUSE="--reuse-cache" +stamp() { date +"%F %T"; } +start_msg() { local gpu="$1"; local name="$2"; local total="$3"; echo "[$(stamp)] [$gpu] Start $name | orders=$ORDERS total=$total k=$K_SHOTS split=$SPLIT"; } +end_msg() { local gpu="$1"; local name="$2"; echo "[$(stamp)] [$gpu] Done $name"; } + +# Allow CLI overrides +while [[ $# -gt 0 ]]; do + case "$1" in + --model-path) MODEL_PATH="$2"; shift 2;; + --dataset-root) DATASET_ROOT="$2"; shift 2;; + --retriever-model) RETRIEVER_MODEL="$2"; shift 2;; + --orders) ORDERS="$2"; shift 2;; + --total-samples) TOTAL_SAMPLES="$2"; shift 2;; + --k-shots) K_SHOTS="$2"; shift 2;; + --split) SPLIT="$2"; shift 2;; + --temp) TEMP="$2"; shift 2;; + --top-p) TOPP="$2"; shift 2;; + --max-new) MAX_NEW="$2"; shift 2;; + --bertscore-model) BERTSCORE_MODEL="$2"; shift 2;; + --bertscore-lang) BERTSCORE_LANG="$2"; shift 2;; + --ras-backend) RAS_BACKEND="$2"; shift 2;; + --output-base) OUTPUT_BASE="$2"; shift 2;; + *) echo "Unknown arg: $1" >&2; exit 2;; + esac +done + +if [[ -z "$MODEL_PATH" ]]; then echo "[ERR] MODEL_PATH empty. Pass --model-path" >&2; exit 2; fi +mkdir -p "$OUTPUT_BASE" logs + +build_cmd_for_gpu(){ + local gpu="$1"; + case "$gpu" in + 0) echo "$PYTHON_BIN -m core.eval.eval_order_vqa_tokenf1 --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_VQA --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE'" ;; + 1) echo "$PYTHON_BIN -m core.eval.eval_order_vqa_bertscore --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_VQA --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --bertscore-model '$BERTSCORE_MODEL' --no-bertscore-baseline --bertscore-lang '$BERTSCORE_LANG' --strict-bertscore --output-dir '$OUTPUT_BASE'" ;; + 2) echo "$PYTHON_BIN -m core.eval.eval_order_caption_bertscore --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_CAP --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --bertscore-model '$BERTSCORE_MODEL' --no-bertscore-baseline --bertscore-lang '$BERTSCORE_LANG' --strict-bertscore --output-dir '$OUTPUT_BASE'" ;; + 3) echo "$PYTHON_BIN -m core.eval.eval_order_caption_cider --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_CAP --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE'" ;; + 4) echo "$PYTHON_BIN -m core.eval.eval_order_classification_accuracy --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_CLS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE'" ;; + 5) echo "$PYTHON_BIN -m core.eval.eval_order_classification_f1 --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_CLS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE'" ;; + 6) echo "$PYTHON_BIN -m core.eval.eval_order_reasoning_accuracy --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_RS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir '$OUTPUT_BASE'" ;; + 7) echo "$PYTHON_BIN -m core.eval.eval_order_reasoning_ras --adapter $ADAPTER --model-path '$MODEL_PATH' --dataset-root '$DATASET_ROOT' --retriever-model-path '$RETRIEVER_MODEL' --orders '$ORDERS' --k-shots $K_SHOTS --total-samples $TOT_RS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --ras-backend $RAS_BACKEND --ras-strict --roscoe-path \"$ROSCOE_PY_PATH\" --roscoe-module core.metrics.roscoe_shim --roscoe-func evaluate --bertscore-model '$BERTSCORE_MODEL' --no-bertscore-baseline --bertscore-lang '$BERTSCORE_LANG' --output-dir '$OUTPUT_BASE'" ;; + *) echo "Invalid GPU index $gpu" >&2; return 1;; + esac +} + +mkdir -p runs/logs +( + export CUDA_VISIBLE_DEVICES=0 + start_msg GPU0 "VQA Token-F1" $TOT_VQA + $PYTHON_BIN -m core.eval.eval_order_vqa_tokenf1 --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_VQA --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/vqa_tokenf1_g0.log 2>&1 + end_msg GPU0 "VQA Token-F1" +) & PID0=$! +( + export CUDA_VISIBLE_DEVICES=1 + start_msg GPU1 "VQA BERTScore" $TOT_VQA + $PYTHON_BIN -m core.eval.eval_order_vqa_bertscore --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_VQA --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --bertscore-model "$BERTSCORE_MODEL" --bertscore-lang "$BERTSCORE_LANG" --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/vqa_bertscore_g1.log 2>&1 + end_msg GPU1 "VQA BERTScore" +) & PID1=$! +( + export CUDA_VISIBLE_DEVICES=2 + start_msg GPU2 "Caption BERTScore" $TOT_CAP + $PYTHON_BIN -m core.eval.eval_order_caption_bertscore --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_CAP --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --bertscore-model "$BERTSCORE_MODEL" --bertscore-lang "$BERTSCORE_LANG" --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/caption_bertscore_g2.log 2>&1 + end_msg GPU2 "Caption BERTScore" +) & PID2=$! +( + export CUDA_VISIBLE_DEVICES=3 + start_msg GPU3 "Caption CIDEr" $TOT_CAP + $PYTHON_BIN -m core.eval.eval_order_caption_cider --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_CAP --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/caption_cider_g3.log 2>&1 + end_msg GPU3 "Caption CIDEr" +) & PID3=$! +( + export CUDA_VISIBLE_DEVICES=4 + start_msg GPU4 "Classification Accuracy" $TOT_CLS + $PYTHON_BIN -m core.eval.eval_order_classification_accuracy --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_CLS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/cls_accuracy_g4.log 2>&1 + end_msg GPU4 "Classification Accuracy" +) & PID4=$! +( + export CUDA_VISIBLE_DEVICES=5 + start_msg GPU5 "Classification F1" $TOT_CLS + $PYTHON_BIN -m core.eval.eval_order_classification_f1 --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_CLS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/cls_f1_g5.log 2>&1 + end_msg GPU5 "Classification F1" +) & PID5=$! +( + export CUDA_VISIBLE_DEVICES=6 + start_msg GPU6 "Reasoning Accuracy" $TOT_RS + $PYTHON_BIN -m core.eval.eval_order_reasoning_accuracy --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_RS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/reasoning_accuracy_g6.log 2>&1 + end_msg GPU6 "Reasoning Accuracy" +) & PID6=$! +( + export CUDA_VISIBLE_DEVICES=7 + start_msg GPU7 "Reasoning RAS" $TOT_RS + $PYTHON_BIN -m core.eval.eval_order_reasoning_ras \ + --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" \ + --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" \ + --k-shots $K_SHOTS --total-samples $TOT_RS --split $SPLIT \ + --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW \ + --ras-backend $RAS_BACKEND $STRICT_FLAG \ + $ROSCOE_PATH_ARG --roscoe-module core.metrics.roscoe_shim --roscoe-func evaluate \ + --bertscore-model "$BERTSCORE_MODEL" --no-bertscore-baseline --bertscore-lang "$BERTSCORE_LANG" --no-bertscore-fallback \ + --output-dir "$OUTPUT_BASE" $REUSE \ + > runs/logs/reasoning_ras_g7.log 2>&1 + end_msg GPU7 "Reasoning RAS" +) & PID7=$! + +wait "$PID0" "$PID1" "$PID2" "$PID3" "$PID4" "$PID5" "$PID6" "$PID7" + +echo "\n[INFO] All jobs finished. Summaries:" +show() { local name="$1"; local path="$2"; echo "\n== $name =="; if [ -f "$path" ]; then cat "$path"; else echo "(summary not found: $path)"; fi; } +show "VQA Token-F1" "$OUTPUT_BASE/vqa_tokenf1/summary.json" +show "VQA BERTScore" "$OUTPUT_BASE/vqa_bertscore/summary.json" +show "Caption BERTScore" "$OUTPUT_BASE/captioning_bertscore/summary.json" +show "Caption CIDEr" "$OUTPUT_BASE/captioning_cider/summary.json" +show "Classification Acc" "$OUTPUT_BASE/classification_accuracy/summary.json" +show "Classification F1" "$OUTPUT_BASE/classification_f1/summary.json" +show "Reasoning Accuracy" "$OUTPUT_BASE/reasoning_accuracy/summary.json" +show "Reasoning RAS" "$OUTPUT_BASE/reasoning_ras/summary.json" +echo "\n[INFO] Logs under runs/logs/*.log" + +# Write an overall.json for this run and update global top/bottom summary +echo "\n[INFO] Writing per-order summaries and overall.json under $OUTPUT_BASE ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/summarize_by_order.py" \ + --output-base "$OUTPUT_BASE" \ + --orders "$ORDERS" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --k-shots "$K_SHOTS" \ + --split "$SPLIT" \ + --ras-mul 100 \ + --ras-auto-scale || true + +echo "[INFO] Updating aggregated top/bottom (RAS x100 with auto-scale) ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/collect_all_scores.py" \ + --output-dir runs/summaries \ + --topk 5 \ + --ras-mul 100 \ + --ras-auto-scale || true + +echo "[INFO] overall: $OUTPUT_BASE/overall.json" +echo "[INFO] aggregated: runs/summaries/top_bottom.json (and runs/summaries/all_scores.json)" + +# Dump per-sample scores for this run (including RAS x100 with auto-scale) +echo "[INFO] Dumping per-sample scores under $OUTPUT_BASE/per_sample ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/dump_per_sample_scores.py" \ + --output-base "$OUTPUT_BASE" \ + --bertscore-model "$BERTSCORE_MODEL" \ + --bertscore-lang "$BERTSCORE_LANG" \ + --bertscore-num-layers 24 \ + --bertscore-batch-size 16 \ + --strict-bertscore \ + --roscoe-model-path "$ROSCOE_MODEL_PATH" \ + --ras-scale-100 + +# Select extremes by TOTAL (sum over 8 metrics; BERTScore zeros ignored) +echo "[INFO] Selecting TOTAL extremes (by per-sample aggregated total) ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/select_total_extremes.py" \ + --output-base "$OUTPUT_BASE" \ + --k 5 \ + --ignore-bert-zero \ + --skip-bert-fallback || true +echo "[INFO] total-extremes: $OUTPUT_BASE/per_sample/extremes_total.json" diff --git a/ICL/LV/code/run_all_multi_gpu_0_7_idefics2.sh b/ICL/LV/code/run_all_multi_gpu_0_7_idefics2.sh new file mode 100644 index 0000000000000000000000000000000000000000..1f864d5b3d808c31fccea9d80ccae423b6c609f1 --- /dev/null +++ b/ICL/LV/code/run_all_multi_gpu_0_7_idefics2.sh @@ -0,0 +1,243 @@ +#!/usr/bin/env bash +# Run 8 modal-order evaluation tasks on GPUs 0–7 in parallel for a single model (idefics2). +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" +PYTHON_BIN="${PYTHON_BIN:-python}" +export ROSCOE_MODEL_PATH="${ROSCOE_MODEL_PATH:-/z_data/pretrained/syxin/roscoe-512-roberta-base}" +# Prefer repo-local roscoe path; shim can also auto-detect, so this is optional +if [[ -z "${ROSCOE_PY_PATH:-}" ]]; then + if [[ -d "$ROOT_DIR/roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/roscoe"; fi + if [[ -z "${ROSCOE_PY_PATH:-}" && -d "$ROOT_DIR/../roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/../roscoe"; fi +fi +ROSCOE_PATH_ARG=""; if [[ -n "${ROSCOE_PY_PATH:-}" && -d "$ROSCOE_PY_PATH" ]]; then ROSCOE_PATH_ARG="--roscoe-path \"$ROSCOE_PY_PATH\""; fi + +# Allow toggling strict mode via env (0/1); default off to allow graceful fallback +RAS_STRICT="${RAS_STRICT:-1}" +STRICT_FLAG=""; if [[ "$RAS_STRICT" == "1" ]]; then STRICT_FLAG="--ras-strict"; fi + +ADAPTER="idefics2" +# Use BF16 by default to reduce VRAM; override with IDEFICS2_DTYPE=fp16/fp32 if needed. +export IDEFICS2_DTYPE="${IDEFICS2_DTYPE:-bf16}" +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/idefics2-8b/}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +BERTSCORE_MODEL="${BERTSCORE_MODEL:-/z_data/pretrained/syxin/xlm-roberta-large}" +BERTSCORE_LANG="${BERTSCORE_LANG:-}" +ORDERS="${ORDERS:-image-text,text-image,text-image-text}" +TOTAL_SAMPLES="${TOTAL_SAMPLES:-2000}" +K_SHOTS="${K_SHOTS:-3}" +SPLIT="${SPLIT:-val}" +TEMP="${TEMP:-0.6}" +TOPP="${TOPP:-1.0}" +MAX_NEW="${MAX_NEW:-128}" +RAS_BACKEND="${RAS_BACKEND:-roscoe}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/order_idefics2}" +TOT_CAP="${TOT_CAP:-2000}" +TOT_VQA="${TOT_VQA:-2000}" +TOT_CLS="${TOT_CLS:-2000}" +TOT_RS="${TOT_RS:-2000}" +REUSE="--reuse-cache" + +# helper messages (mimic idefics2 style) +stamp() { date +"%F %T"; } +start_msg() { local gpu="$1"; local name="$2"; local total="$3"; echo "[$(stamp)] [$gpu] Start $name | orders=$ORDERS total=$total k=$K_SHOTS split=$SPLIT"; } +end_msg() { local gpu="$1"; local name="$2"; echo "[$(stamp)] [$gpu] Done $name"; } + +# Allow CLI overrides +while [[ $# -gt 0 ]]; do + case "$1" in + --model-path) MODEL_PATH="$2"; shift 2;; + --dataset-root) DATASET_ROOT="$2"; shift 2;; + --retriever-model) RETRIEVER_MODEL="$2"; shift 2;; + --orders) ORDERS="$2"; shift 2;; + --total-samples) TOTAL_SAMPLES="$2"; shift 2;; + --k-shots) K_SHOTS="$2"; shift 2;; + --split) SPLIT="$2"; shift 2;; + --temp) TEMP="$2"; shift 2;; + --top-p) TOPP="$2"; shift 2;; + --max-new) MAX_NEW="$2"; shift 2;; + --bertscore-model) BERTSCORE_MODEL="$2"; shift 2;; + --bertscore-lang) BERTSCORE_LANG="$2"; shift 2;; + --ras-backend) RAS_BACKEND="$2"; shift 2;; + --output-base) OUTPUT_BASE="$2"; shift 2;; + *) echo "Unknown arg: $1" >&2; exit 2;; + esac +done + +if [[ -z "$MODEL_PATH" ]]; then echo "[ERR] MODEL_PATH empty. Pass --model-path" >&2; exit 2; fi +mkdir -p "$OUTPUT_BASE" logs + +mkdir -p runs/logs + +( + export CUDA_VISIBLE_DEVICES=0 + start_msg GPU0 "VQA Token-F1" $TOT_VQA + $PYTHON_BIN -m core.eval.eval_order_vqa_tokenf1 \ + --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" \ + --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" \ + --k-shots $K_SHOTS --total-samples $TOT_VQA --split $SPLIT \ + --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW \ + --output-dir "$OUTPUT_BASE" $REUSE \ + > runs/logs/vqa_tokenf1_g0.log 2>&1 + end_msg GPU0 "VQA Token-F1" +) & PID0=$! + +( + export CUDA_VISIBLE_DEVICES=1 + start_msg GPU1 "VQA BERTScore" $TOT_VQA + $PYTHON_BIN -m core.eval.eval_order_vqa_bertscore \ + --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" \ + --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" \ + --k-shots $K_SHOTS --total-samples $TOT_VQA --split $SPLIT \ + --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW \ + --bertscore-model "$BERTSCORE_MODEL" --no-bertscore-baseline --bertscore-lang "$BERTSCORE_LANG" --strict-bertscore \ + --output-dir "$OUTPUT_BASE" $REUSE \ + > runs/logs/vqa_bertscore_g1.log 2>&1 + end_msg GPU1 "VQA BERTScore" +) & PID1=$! + +( + export CUDA_VISIBLE_DEVICES=2 + start_msg GPU2 "Caption BERTScore" $TOT_CAP + $PYTHON_BIN -m core.eval.eval_order_caption_bertscore \ + --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" \ + --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" \ + --k-shots $K_SHOTS --total-samples $TOT_CAP --split $SPLIT \ + --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW \ + --bertscore-model "$BERTSCORE_MODEL" --no-bertscore-baseline --bertscore-lang "$BERTSCORE_LANG" --strict-bertscore \ + --output-dir "$OUTPUT_BASE" $REUSE \ + > runs/logs/caption_bertscore_g2.log 2>&1 + end_msg GPU2 "Caption BERTScore" +) & PID2=$! + +( + export CUDA_VISIBLE_DEVICES=3 + start_msg GPU3 "Caption CIDEr" $TOT_CAP + $PYTHON_BIN -m core.eval.eval_order_caption_cider \ + --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" \ + --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" \ + --k-shots $K_SHOTS --total-samples $TOT_CAP --split $SPLIT \ + --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW \ + --output-dir "$OUTPUT_BASE" $REUSE \ + > runs/logs/caption_cider_g3.log 2>&1 + end_msg GPU3 "Caption CIDEr" +) & PID3=$! + +( + export CUDA_VISIBLE_DEVICES=4 + start_msg GPU4 "Classification Accuracy" $TOT_CLS + $PYTHON_BIN -m core.eval.eval_order_classification_accuracy \ + --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" \ + --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" \ + --k-shots $K_SHOTS --total-samples $TOT_CLS --split $SPLIT \ + --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW \ + --output-dir "$OUTPUT_BASE" $REUSE \ + > runs/logs/cls_accuracy_g4.log 2>&1 + end_msg GPU4 "Classification Accuracy" +) & PID4=$! + +( + export CUDA_VISIBLE_DEVICES=5 + start_msg GPU5 "Classification F1" $TOT_CLS + $PYTHON_BIN -m core.eval.eval_order_classification_f1 \ + --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" \ + --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" \ + --k-shots $K_SHOTS --total-samples $TOT_CLS --split $SPLIT \ + --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW \ + --output-dir "$OUTPUT_BASE" $REUSE \ + > runs/logs/cls_f1_g5.log 2>&1 + end_msg GPU5 "Classification F1" +) & PID5=$! + +( + export CUDA_VISIBLE_DEVICES=6 + start_msg GPU6 "Reasoning Accuracy" $TOT_RS + $PYTHON_BIN -m core.eval.eval_order_reasoning_accuracy \ + --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" \ + --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" \ + --k-shots $K_SHOTS --total-samples $TOT_RS --split $SPLIT \ + --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW \ + --output-dir "$OUTPUT_BASE" $REUSE \ + > runs/logs/reasoning_accuracy_g6.log 2>&1 + end_msg GPU6 "Reasoning Accuracy" +) & PID6=$! + +( + export CUDA_VISIBLE_DEVICES=7 + start_msg GPU7 "Reasoning RAS" $TOT_RS + $PYTHON_BIN -m core.eval.eval_order_reasoning_ras \ + --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" \ + --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" \ + --k-shots $K_SHOTS --total-samples $TOT_RS --split $SPLIT \ + --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW \ + --ras-backend $RAS_BACKEND $STRICT_FLAG $ROSCOE_PATH_ARG --roscoe-module core.metrics.roscoe_shim --roscoe-func evaluate --bertscore-model "$BERTSCORE_MODEL" --no-bertscore-baseline --bertscore-lang "$BERTSCORE_LANG" --no-bertscore-fallback \ + --output-dir "$OUTPUT_BASE" $REUSE \ + > runs/logs/reasoning_ras_g7.log 2>&1 + end_msg GPU7 "Reasoning RAS" +) & PID7=$! + +wait "$PID0" "$PID1" "$PID2" "$PID3" "$PID4" "$PID5" "$PID6" "$PID7" + +echo "\n[INFO] All jobs finished. Summaries:" + +show() { + local name="$1"; local path="$2"; + echo "\n== $name =="; + if [ -f "$path" ]; then cat "$path"; else echo "(summary not found: $path)"; fi +} + +show "VQA Token-F1" "$OUTPUT_BASE/vqa_tokenf1/summary.json" +show "VQA BERTScore" "$OUTPUT_BASE/vqa_bertscore/summary.json" +show "Caption BERTScore" "$OUTPUT_BASE/captioning_bertscore/summary.json" +show "Caption CIDEr" "$OUTPUT_BASE/captioning_cider/summary.json" +show "Classification Acc" "$OUTPUT_BASE/classification_accuracy/summary.json" +show "Classification F1" "$OUTPUT_BASE/classification_f1/summary.json" +show "Reasoning Accuracy" "$OUTPUT_BASE/reasoning_accuracy/summary.json" +show "Reasoning RAS" "$OUTPUT_BASE/reasoning_ras/summary.json" + +echo "\n[INFO] Logs under runs/logs/*.log" + +# Write an overall.json for this run and update global top/bottom summary +echo "\n[INFO] Writing per-order summaries and overall.json under $OUTPUT_BASE ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/summarize_by_order.py" \ + --output-base "$OUTPUT_BASE" \ + --orders "$ORDERS" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --k-shots "$K_SHOTS" \ + --split "$SPLIT" \ + --ras-mul 100 \ + --ras-auto-scale || true + +echo "[INFO] Updating aggregated top/bottom (RAS x100 with auto-scale) ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/collect_all_scores.py" \ + --output-dir runs/summaries \ + --topk 5 \ + --ras-mul 100 \ + --ras-auto-scale || true + +echo "[INFO] overall: $OUTPUT_BASE/overall.json" +echo "[INFO] aggregated: runs/summaries/top_bottom.json (and runs/summaries/all_scores.json)" + +# Dump per-sample scores for this run (including RAS x100 with auto-scale) +echo "[INFO] Dumping per-sample scores under $OUTPUT_BASE/per_sample ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/dump_per_sample_scores.py" \ + --output-base "$OUTPUT_BASE" \ + --bertscore-model "$BERTSCORE_MODEL" \ + --bertscore-lang "$BERTSCORE_LANG" \ + --bertscore-num-layers 24 \ + --bertscore-batch-size 16 \ + --strict-bertscore \ + --roscoe-model-path "$ROSCOE_MODEL_PATH" \ + --ras-scale-100 + +# Select extremes by TOTAL (sum over 8 metrics; BERTScore zeros ignored) +echo "[INFO] Selecting TOTAL extremes (by per-sample aggregated total) ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/select_total_extremes.py" \ + --output-base "$OUTPUT_BASE" \ + --k 5 \ + --ignore-bert-zero \ + --skip-bert-fallback || true +echo "[INFO] total-extremes: $OUTPUT_BASE/per_sample/extremes_total.json" diff --git a/ICL/LV/code/run_all_multi_gpu_0_7_qwen3vl.sh b/ICL/LV/code/run_all_multi_gpu_0_7_qwen3vl.sh new file mode 100644 index 0000000000000000000000000000000000000000..4c47bcf608cf140d0791e91f66ec72ec9724f02d --- /dev/null +++ b/ICL/LV/code/run_all_multi_gpu_0_7_qwen3vl.sh @@ -0,0 +1,173 @@ +#!/usr/bin/env bash +# Run 8 modal-order evaluation tasks on GPUs 0–7 in parallel for a single model (qwen3-vl). +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" +PYTHON_BIN="${PYTHON_BIN:-python}" +export ROSCOE_MODEL_PATH="${ROSCOE_MODEL_PATH:-/z_data/pretrained/syxin/roscoe-512-roberta-base}" +# Prefer repo-local roscoe path; shim can also auto-detect, so this is optional +if [[ -z "${ROSCOE_PY_PATH:-}" ]]; then + if [[ -d "$ROOT_DIR/roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/roscoe"; fi + if [[ -z "${ROSCOE_PY_PATH:-}" && -d "$ROOT_DIR/../roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/../roscoe"; fi +fi +ROSCOE_PATH_ARG=""; if [[ -n "${ROSCOE_PY_PATH:-}" && -d "$ROSCOE_PY_PATH" ]]; then ROSCOE_PATH_ARG="--roscoe-path \"$ROSCOE_PY_PATH\""; fi + +# Allow toggling strict mode via env (0/1); default off to allow graceful fallback +RAS_STRICT="${RAS_STRICT:-1}" +STRICT_FLAG=""; if [[ "$RAS_STRICT" == "1" ]]; then STRICT_FLAG="--ras-strict"; fi + +ADAPTER="qwen3-vl" +MODEL_PATH="${MODEL_PATH:-/workspace/Qwen3-VL-8B-Instruct}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +BERTSCORE_MODEL="${BERTSCORE_MODEL:-/z_data/pretrained/syxin/xlm-roberta-large}" +BERTSCORE_LANG="${BERTSCORE_LANG:-}" +ORDERS="${ORDERS:-image-text,text-image,text-image-text}" +TOTAL_SAMPLES="${TOTAL_SAMPLES:-2000}" +K_SHOTS="${K_SHOTS:-3}" +SPLIT="${SPLIT:-val}" +TEMP="${TEMP:-0.6}" +TOPP="${TOPP:-1.0}" +MAX_NEW="${MAX_NEW:-128}" +RAS_BACKEND="${RAS_BACKEND:-roscoe}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/order_qwen3-vl}" +TOT_CAP="${TOT_CAP:-2000}" +TOT_VQA="${TOT_VQA:-2000}" +TOT_CLS="${TOT_CLS:-2000}" +TOT_RS="${TOT_RS:-2000}" +REUSE="--reuse-cache" +stamp() { date +"%F %T"; } +start_msg() { local gpu="$1"; local name="$2"; local total="$3"; echo "[$(stamp)] [$gpu] Start $name | orders=$ORDERS total=$total k=$K_SHOTS split=$SPLIT"; } +end_msg() { local gpu="$1"; local name="$2"; echo "[$(stamp)] [$gpu] Done $name"; } + +# Allow CLI overrides +while [[ $# -gt 0 ]]; do + case "$1" in + --model-path) MODEL_PATH="$2"; shift 2;; + --dataset-root) DATASET_ROOT="$2"; shift 2;; + --retriever-model) RETRIEVER_MODEL="$2"; shift 2;; + --orders) ORDERS="$2"; shift 2;; + --total-samples) TOTAL_SAMPLES="$2"; shift 2;; + --k-shots) K_SHOTS="$2"; shift 2;; + --split) SPLIT="$2"; shift 2;; + --temp) TEMP="$2"; shift 2;; + --top-p) TOPP="$2"; shift 2;; + --max-new) MAX_NEW="$2"; shift 2;; + --bertscore-model) BERTSCORE_MODEL="$2"; shift 2;; + --bertscore-lang) BERTSCORE_LANG="$2"; shift 2;; + --ras-backend) RAS_BACKEND="$2"; shift 2;; + --output-base) OUTPUT_BASE="$2"; shift 2;; + *) echo "Unknown arg: $1" >&2; exit 2;; + esac +done + +if [[ -z "$MODEL_PATH" ]]; then echo "[ERR] MODEL_PATH empty. Pass --model-path" >&2; exit 2; fi +mkdir -p "$OUTPUT_BASE" logs + +mkdir -p runs/logs +( + export CUDA_VISIBLE_DEVICES=0 + start_msg GPU0 "VQA Token-F1" $TOT_VQA + $PYTHON_BIN -m core.eval.eval_order_vqa_tokenf1 --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_VQA --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/vqa_tokenf1_g0.log 2>&1 + end_msg GPU0 "VQA Token-F1" +) & PID0=$! +( + export CUDA_VISIBLE_DEVICES=1 + start_msg GPU1 "VQA BERTScore" $TOT_VQA + $PYTHON_BIN -m core.eval.eval_order_vqa_bertscore --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_VQA --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --bertscore-model "$BERTSCORE_MODEL" --no-bertscore-baseline --bertscore-lang "$BERTSCORE_LANG" --strict-bertscore --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/vqa_bertscore_g1.log 2>&1 + end_msg GPU1 "VQA BERTScore" +) & PID1=$! +( + export CUDA_VISIBLE_DEVICES=2 + start_msg GPU2 "Caption BERTScore" $TOT_CAP + $PYTHON_BIN -m core.eval.eval_order_caption_bertscore --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_CAP --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --bertscore-model "$BERTSCORE_MODEL" --no-bertscore-baseline --bertscore-lang "$BERTSCORE_LANG" --strict-bertscore --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/caption_bertscore_g2.log 2>&1 + end_msg GPU2 "Caption BERTScore" +) & PID2=$! +( + export CUDA_VISIBLE_DEVICES=3 + start_msg GPU3 "Caption CIDEr" $TOT_CAP + $PYTHON_BIN -m core.eval.eval_order_caption_cider --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_CAP --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/caption_cider_g3.log 2>&1 + end_msg GPU3 "Caption CIDEr" +) & PID3=$! +( + export CUDA_VISIBLE_DEVICES=4 + start_msg GPU4 "Classification Accuracy" $TOT_CLS + $PYTHON_BIN -m core.eval.eval_order_classification_accuracy --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_CLS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/cls_accuracy_g4.log 2>&1 + end_msg GPU4 "Classification Accuracy" +) & PID4=$! +( + export CUDA_VISIBLE_DEVICES=5 + start_msg GPU5 "Classification F1" $TOT_CLS + $PYTHON_BIN -m core.eval.eval_order_classification_f1 --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_CLS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/cls_f1_g5.log 2>&1 + end_msg GPU5 "Classification F1" +) & PID5=$! +( + export CUDA_VISIBLE_DEVICES=6 + start_msg GPU6 "Reasoning Accuracy" $TOT_RS + $PYTHON_BIN -m core.eval.eval_order_reasoning_accuracy --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_RS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/reasoning_accuracy_g6.log 2>&1 + end_msg GPU6 "Reasoning Accuracy" +) & PID6=$! +( + export CUDA_VISIBLE_DEVICES=7 + start_msg GPU7 "Reasoning RAS" $TOT_RS + $PYTHON_BIN -m core.eval.eval_order_reasoning_ras --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_RS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --ras-backend $RAS_BACKEND $STRICT_FLAG $ROSCOE_PATH_ARG --roscoe-module core.metrics.roscoe_shim --roscoe-func evaluate --bertscore-model "$BERTSCORE_MODEL" --no-bertscore-baseline --bertscore-lang "$BERTSCORE_LANG" --no-bertscore-fallback --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/reasoning_ras_g7.log 2>&1 + end_msg GPU7 "Reasoning RAS" +) & PID7=$! + +wait "$PID0" "$PID1" "$PID2" "$PID3" "$PID4" "$PID5" "$PID6" "$PID7" + +echo "\n[INFO] All jobs finished. Summaries:" +show() { local name="$1"; local path="$2"; echo "\n== $name =="; if [ -f "$path" ]; then cat "$path"; else echo "(summary not found: $path)"; fi; } +show "VQA Token-F1" "$OUTPUT_BASE/vqa_tokenf1/summary.json" +show "VQA BERTScore" "$OUTPUT_BASE/vqa_bertscore/summary.json" +show "Caption BERTScore" "$OUTPUT_BASE/captioning_bertscore/summary.json" +show "Caption CIDEr" "$OUTPUT_BASE/captioning_cider/summary.json" +show "Classification Acc" "$OUTPUT_BASE/classification_accuracy/summary.json" +show "Classification F1" "$OUTPUT_BASE/classification_f1/summary.json" +show "Reasoning Accuracy" "$OUTPUT_BASE/reasoning_accuracy/summary.json" +show "Reasoning RAS" "$OUTPUT_BASE/reasoning_ras/summary.json" +echo "\n[INFO] Logs under runs/logs/*.log" + +# Write an overall.json for this run and update global top/bottom summary +echo "\n[INFO] Writing per-order summaries and overall.json under $OUTPUT_BASE ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/summarize_by_order.py" \ + --output-base "$OUTPUT_BASE" \ + --orders "$ORDERS" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --k-shots "$K_SHOTS" \ + --split "$SPLIT" \ + --ras-mul 100 \ + --ras-auto-scale || true + +echo "[INFO] Updating aggregated top/bottom (RAS x100 with auto-scale) ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/collect_all_scores.py" \ + --output-dir runs/summaries \ + --topk 5 \ + --ras-mul 100 \ + --ras-auto-scale || true + +echo "[INFO] overall: $OUTPUT_BASE/overall.json" +echo "[INFO] aggregated: runs/summaries/top_bottom.json (and runs/summaries/all_scores.json)" + +# Dump per-sample scores for this run (including RAS x100 with auto-scale) +echo "[INFO] Dumping per-sample scores under $OUTPUT_BASE/per_sample ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/dump_per_sample_scores.py" \ + --output-base "$OUTPUT_BASE" \ + --bertscore-model "$BERTSCORE_MODEL" \ + --bertscore-lang "$BERTSCORE_LANG" \ + --bertscore-num-layers 24 \ + --bertscore-batch-size 16 \ + --strict-bertscore \ + --roscoe-model-path "$ROSCOE_MODEL_PATH" \ + --ras-scale-100 + +# Select extremes by TOTAL (sum over 8 metrics; BERTScore zeros ignored) +echo "[INFO] Selecting TOTAL extremes (by per-sample aggregated total) ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/select_total_extremes.py" \ + --output-base "$OUTPUT_BASE" \ + --k 5 \ + --ignore-bert-zero \ + --skip-bert-fallback || true +echo "[INFO] total-extremes: $OUTPUT_BASE/per_sample/extremes_total.json" diff --git a/ICL/LV/code/run_all_multi_gpu_0_7_qwenvl.sh b/ICL/LV/code/run_all_multi_gpu_0_7_qwenvl.sh new file mode 100644 index 0000000000000000000000000000000000000000..11599d902ee3b382eeb44005e443ab4c36923875 --- /dev/null +++ b/ICL/LV/code/run_all_multi_gpu_0_7_qwenvl.sh @@ -0,0 +1,174 @@ +#!/usr/bin/env bash +# Run 8 modal-order evaluation tasks on GPUs 0–7 in parallel for a single model (qwen-vl). +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" +PYTHON_BIN="${PYTHON_BIN:-python}" +export ROSCOE_MODEL_PATH="${ROSCOE_MODEL_PATH:-/z_data/pretrained/syxin/roscoe-512-roberta-base}" +# Prefer repo-local roscoe path; shim can also auto-detect, so this is optional +if [[ -z "${ROSCOE_PY_PATH:-}" ]]; then + if [[ -d "$ROOT_DIR/roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/roscoe"; fi + if [[ -z "${ROSCOE_PY_PATH:-}" && -d "$ROOT_DIR/../roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/../roscoe"; fi +fi +ROSCOE_PATH_ARG=""; if [[ -n "${ROSCOE_PY_PATH:-}" && -d "$ROSCOE_PY_PATH" ]]; then ROSCOE_PATH_ARG="--roscoe-path \"$ROSCOE_PY_PATH\""; fi + +# Allow toggling strict mode via env (0/1); default off to allow graceful fallback +RAS_STRICT="${RAS_STRICT:-1}" +STRICT_FLAG=""; if [[ "$RAS_STRICT" == "1" ]]; then STRICT_FLAG="--ras-strict"; fi + +ADAPTER="qwen-vl" +MODEL_PATH="${MODEL_PATH:-/z_data/pretrained/syxin/Qwen-VL}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +BERTSCORE_MODEL="${BERTSCORE_MODEL:-/z_data/pretrained/syxin/xlm-roberta-large}" +# Mixed CN/EN answers; we keep baseline disabled, so lang is unused. Leave empty to avoid confusion. +BERTSCORE_LANG="${BERTSCORE_LANG:-}" +ORDERS="${ORDERS:-image-text,text-image,text-image-text}" +TOTAL_SAMPLES="${TOTAL_SAMPLES:-2000}" +K_SHOTS="${K_SHOTS:-3}" +SPLIT="${SPLIT:-val}" +TEMP="${TEMP:-0.6}" +TOPP="${TOPP:-1.0}" +MAX_NEW="${MAX_NEW:-128}" +RAS_BACKEND="${RAS_BACKEND:-roscoe}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/order_qwen-vl}" +TOT_CAP="${TOT_CAP:-2000}" +TOT_VQA="${TOT_VQA:-2000}" +TOT_CLS="${TOT_CLS:-2000}" +TOT_RS="${TOT_RS:-2000}" +REUSE="--reuse-cache" +stamp() { date +"%F %T"; } +start_msg() { local gpu="$1"; local name="$2"; local total="$3"; echo "[$(stamp)] [$gpu] Start $name | orders=$ORDERS total=$total k=$K_SHOTS split=$SPLIT"; } +end_msg() { local gpu="$1"; local name="$2"; echo "[$(stamp)] [$gpu] Done $name"; } + +# Allow CLI overrides +while [[ $# -gt 0 ]]; do + case "$1" in + --model-path) MODEL_PATH="$2"; shift 2;; + --dataset-root) DATASET_ROOT="$2"; shift 2;; + --retriever-model) RETRIEVER_MODEL="$2"; shift 2;; + --orders) ORDERS="$2"; shift 2;; + --total-samples) TOTAL_SAMPLES="$2"; shift 2;; + --k-shots) K_SHOTS="$2"; shift 2;; + --split) SPLIT="$2"; shift 2;; + --temp) TEMP="$2"; shift 2;; + --top-p) TOPP="$2"; shift 2;; + --max-new) MAX_NEW="$2"; shift 2;; + --bertscore-model) BERTSCORE_MODEL="$2"; shift 2;; + --bertscore-lang) BERTSCORE_LANG="$2"; shift 2;; + --ras-backend) RAS_BACKEND="$2"; shift 2;; + --output-base) OUTPUT_BASE="$2"; shift 2;; + *) echo "Unknown arg: $1" >&2; exit 2;; + esac +done + +if [[ -z "$MODEL_PATH" ]]; then echo "[ERR] MODEL_PATH empty. Pass --model-path" >&2; exit 2; fi +mkdir -p "$OUTPUT_BASE" logs + +mkdir -p runs/logs +( + export CUDA_VISIBLE_DEVICES=0 + start_msg GPU0 "VQA Token-F1" $TOT_VQA + $PYTHON_BIN -m core.eval.eval_order_vqa_tokenf1 --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_VQA --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/vqa_tokenf1_g0.log 2>&1 + end_msg GPU0 "VQA Token-F1" +) & PID0=$! +( + export CUDA_VISIBLE_DEVICES=1 + start_msg GPU1 "VQA BERTScore" $TOT_VQA + $PYTHON_BIN -m core.eval.eval_order_vqa_bertscore --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_VQA --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --bertscore-model "$BERTSCORE_MODEL" --no-bertscore-baseline --bertscore-lang "$BERTSCORE_LANG" --strict-bertscore --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/vqa_bertscore_g1.log 2>&1 + end_msg GPU1 "VQA BERTScore" +) & PID1=$! +( + export CUDA_VISIBLE_DEVICES=2 + start_msg GPU2 "Caption BERTScore" $TOT_CAP + $PYTHON_BIN -m core.eval.eval_order_caption_bertscore --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_CAP --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --bertscore-model "$BERTSCORE_MODEL" --no-bertscore-baseline --bertscore-lang "$BERTSCORE_LANG" --strict-bertscore --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/caption_bertscore_g2.log 2>&1 + end_msg GPU2 "Caption BERTScore" +) & PID2=$! +( + export CUDA_VISIBLE_DEVICES=3 + start_msg GPU3 "Caption CIDEr" $TOT_CAP + $PYTHON_BIN -m core.eval.eval_order_caption_cider --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_CAP --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/caption_cider_g3.log 2>&1 + end_msg GPU3 "Caption CIDEr" +) & PID3=$! +( + export CUDA_VISIBLE_DEVICES=4 + start_msg GPU4 "Classification Accuracy" $TOT_CLS + $PYTHON_BIN -m core.eval.eval_order_classification_accuracy --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_CLS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/cls_accuracy_g4.log 2>&1 + end_msg GPU4 "Classification Accuracy" +) & PID4=$! +( + export CUDA_VISIBLE_DEVICES=5 + start_msg GPU5 "Classification F1" $TOT_CLS + $PYTHON_BIN -m core.eval.eval_order_classification_f1 --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_CLS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/cls_f1_g5.log 2>&1 + end_msg GPU5 "Classification F1" +) & PID5=$! +( + export CUDA_VISIBLE_DEVICES=6 + start_msg GPU6 "Reasoning Accuracy" $TOT_RS + $PYTHON_BIN -m core.eval.eval_order_reasoning_accuracy --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_RS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/reasoning_accuracy_g6.log 2>&1 + end_msg GPU6 "Reasoning Accuracy" +) & PID6=$! +( + export CUDA_VISIBLE_DEVICES=7 + start_msg GPU7 "Reasoning RAS" $TOT_RS + $PYTHON_BIN -m core.eval.eval_order_reasoning_ras --adapter "$ADAPTER" --model-path "$MODEL_PATH" --dataset-root "$DATASET_ROOT" --retriever-model-path "$RETRIEVER_MODEL" --orders "$ORDERS" --k-shots $K_SHOTS --total-samples $TOT_RS --split $SPLIT --temperature $TEMP --top-p $TOPP --max-new-tokens $MAX_NEW --ras-backend $RAS_BACKEND $STRICT_FLAG $ROSCOE_PATH_ARG --roscoe-module core.metrics.roscoe_shim --roscoe-func evaluate --bertscore-model "$BERTSCORE_MODEL" --no-bertscore-baseline --bertscore-lang "$BERTSCORE_LANG" --no-bertscore-fallback --output-dir "$OUTPUT_BASE" $REUSE > runs/logs/reasoning_ras_g7.log 2>&1 + end_msg GPU7 "Reasoning RAS" +) & PID7=$! + +wait "$PID0" "$PID1" "$PID2" "$PID3" "$PID4" "$PID5" "$PID6" "$PID7" + +echo "\n[INFO] All jobs finished. Summaries:" +show() { local name="$1"; local path="$2"; echo "\n== $name =="; if [ -f "$path" ]; then cat "$path"; else echo "(summary not found: $path)"; fi; } +show "VQA Token-F1" "$OUTPUT_BASE/vqa_tokenf1/summary.json" +show "VQA BERTScore" "$OUTPUT_BASE/vqa_bertscore/summary.json" +show "Caption BERTScore" "$OUTPUT_BASE/captioning_bertscore/summary.json" +show "Caption CIDEr" "$OUTPUT_BASE/captioning_cider/summary.json" +show "Classification Acc" "$OUTPUT_BASE/classification_accuracy/summary.json" +show "Classification F1" "$OUTPUT_BASE/classification_f1/summary.json" +show "Reasoning Accuracy" "$OUTPUT_BASE/reasoning_accuracy/summary.json" +show "Reasoning RAS" "$OUTPUT_BASE/reasoning_ras/summary.json" +echo "\n[INFO] Logs under runs/logs/*.log" + +# Write an overall.json for this run and update global top/bottom summary +echo "\n[INFO] Writing per-order summaries and overall.json under $OUTPUT_BASE ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/summarize_by_order.py" \ + --output-base "$OUTPUT_BASE" \ + --orders "$ORDERS" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --k-shots "$K_SHOTS" \ + --split "$SPLIT" \ + --ras-mul 100 \ + --ras-auto-scale || true + +echo "[INFO] Updating aggregated top/bottom (RAS x100 with auto-scale) ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/collect_all_scores.py" \ + --output-dir runs/summaries \ + --topk 5 \ + --ras-mul 100 \ + --ras-auto-scale || true + +echo "[INFO] overall: $OUTPUT_BASE/overall.json" +echo "[INFO] aggregated: runs/summaries/top_bottom.json (and runs/summaries/all_scores.json)" + +# Dump per-sample scores for this run (including RAS x100 with auto-scale) +echo "[INFO] Dumping per-sample scores under $OUTPUT_BASE/per_sample ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/dump_per_sample_scores.py" \ + --output-base "$OUTPUT_BASE" \ + --bertscore-model "$BERTSCORE_MODEL" \ + --bertscore-lang "$BERTSCORE_LANG" \ + --bertscore-num-layers 24 \ + --bertscore-batch-size 16 \ + --strict-bertscore \ + --roscoe-model-path "$ROSCOE_MODEL_PATH" \ + --ras-scale-100 + +# Select extremes by TOTAL (sum over 8 metrics; BERTScore zeros ignored) +echo "[INFO] Selecting TOTAL extremes (by per-sample aggregated total) ..." +$PYTHON_BIN "$ROOT_DIR/core/eval/select_total_extremes.py" \ + --output-base "$OUTPUT_BASE" \ + --k 5 \ + --ignore-bert-zero \ + --skip-bert-fallback || true +echo "[INFO] total-extremes: $OUTPUT_BASE/per_sample/extremes_total.json" diff --git a/ICL/RL/trl_source/.github/ISSUE_TEMPLATE/bug-report.yml b/ICL/RL/trl_source/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 0000000000000000000000000000000000000000..fbf352cffccef11f4690e7b41346504f16ce8778 --- /dev/null +++ b/ICL/RL/trl_source/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,67 @@ +name: "\U0001F41B Bug Report" +description: Submit a bug report to help us improve TRL +labels: [ "bug" ] +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this bug report! 🤗 + + 🚩 If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug) + + - type: textarea + id: reproduction + validations: + required: true + attributes: + label: Reproduction + description: | + Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. + If you have code snippets, error messages, stack traces please provide them here as well. + Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting + Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code. + + value: | + ```python + from trl import ... + + ``` + + outputs: + + ``` + Traceback (most recent call last): + File "example.py", line 42, in + ... + ``` + + - type: textarea + id: system-info + attributes: + label: System Info + description: | + Please provide information about your system: platform, Python version, PyTorch version, Transformers version, devices, TRL version, ... + You can get this information by running `trl env` in your terminal. + + placeholder: Copy-paste the output of `trl env` + validations: + required: true + + - type: checkboxes + id: terms + attributes: + label: Checklist + description: | + Before submitting, please confirm that you've completed each of the following. + If an item doesn't apply to your issue, check it anyway to show you've reviewed it. + options: + - label: "I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))" + required: true + - label: "I have included my system information" + required: true + - label: "Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))" + required: true + - label: "Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))" + required: true + - label: "Any traceback provided is complete" + required: true diff --git a/ICL/RL/trl_source/.github/ISSUE_TEMPLATE/feature-request.yml b/ICL/RL/trl_source/.github/ISSUE_TEMPLATE/feature-request.yml new file mode 100644 index 0000000000000000000000000000000000000000..0a593186c098ae3824ef994374686092f97ccb4a --- /dev/null +++ b/ICL/RL/trl_source/.github/ISSUE_TEMPLATE/feature-request.yml @@ -0,0 +1,31 @@ +name: "\U0001F680 Feature request" +description: Submit a proposal/request for a new TRL feature +labels: [ "Feature request" ] +body: + - type: textarea + id: feature-request + validations: + required: true + attributes: + label: Feature request + description: | + A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist. + + - type: textarea + id: motivation + validations: + required: true + attributes: + label: Motivation + description: | + Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too. + + + - type: textarea + id: contribution + validations: + required: true + attributes: + label: Your contribution + description: | + Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) diff --git a/ICL/RL/trl_source/.github/ISSUE_TEMPLATE/new-trainer-addition.yml b/ICL/RL/trl_source/.github/ISSUE_TEMPLATE/new-trainer-addition.yml new file mode 100644 index 0000000000000000000000000000000000000000..ea0b5afb10ae6d7519d07ee510faf617f369048c --- /dev/null +++ b/ICL/RL/trl_source/.github/ISSUE_TEMPLATE/new-trainer-addition.yml @@ -0,0 +1,32 @@ +name: "\U0001F31F New trainer addition" +description: Submit a proposal/request to implement a new trainer for a post-training method +labels: [ "New trainer" ] + +body: + - type: textarea + id: description-request + validations: + required: true + attributes: + label: Method description + description: | + Put any and all important information relative to the method + + - type: checkboxes + id: information-tasks + attributes: + label: Open source status + description: | + Please note that if the method implementation isn't available or model weights with training datasets aren't available, we are less likely to implement it in `trl`. + options: + - label: "The method implementation is available" + - label: "The model weights are available" + - label: "The training datasets are available" + + - type: textarea + id: additional-info + attributes: + label: Provide useful links for the implementation + description: | + Please provide information regarding the implementation, the weights, and the authors. + Please mention the authors by @gh-username if you're aware of their usernames. diff --git a/ICL/RL/trl_source/.github/codeql/custom-queries.qls b/ICL/RL/trl_source/.github/codeql/custom-queries.qls new file mode 100644 index 0000000000000000000000000000000000000000..81deab4a871ed3b8114eeec45a4e2edbf9204b70 --- /dev/null +++ b/ICL/RL/trl_source/.github/codeql/custom-queries.qls @@ -0,0 +1,19 @@ +import codeql + +from WorkflowString interpolation, Workflow workflow +where + interpolation.getStringValue().matches("${{ github.event.issue.title }}") or + interpolation.getStringValue().matches("${{ github.event.issue.body }}") or + interpolation.getStringValue().matches("${{ github.event.pull_request.title }}") or + interpolation.getStringValue().matches("${{ github.event.pull_request.body }}") or + interpolation.getStringValue().matches("${{ github.event.review.body }}") or + interpolation.getStringValue().matches("${{ github.event.comment.body }}") or + interpolation.getStringValue().matches("${{ github.event.inputs.* }}") or + interpolation.getStringValue().matches("${{ github.event.head_commit.message }}") + interpolation.getStringValue().matches("${{ github.event.* }}") and + ( + step.getKey() = "run" or // Injection in run + step.getKey() = "env" or // Injection via env + step.getKey() = "with" // Injection via with + ) +select workflow, "🚨 Do not use directly as input of action" diff --git a/ICL/RL/trl_source/.github/workflows/build_documentation.yml b/ICL/RL/trl_source/.github/workflows/build_documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..b4c6a9a94bc967e4dfb0c7885617dceeeee9bbab --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/build_documentation.yml @@ -0,0 +1,21 @@ +name: Build documentation + +on: + push: + branches: + - main + - doc-builder* + - v*-release + +env: + TRL_EXPERIMENTAL_SILENCE: 1 + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main + with: + commit_sha: ${{ github.sha }} + package: trl + version_tag_suffix: "" + secrets: + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} diff --git a/ICL/RL/trl_source/.github/workflows/build_pr_documentation.yml b/ICL/RL/trl_source/.github/workflows/build_pr_documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..65e51056ba855391465fdc8c027513c7014c3e6d --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/build_pr_documentation.yml @@ -0,0 +1,21 @@ +name: Build PR Documentation + +on: + pull_request: + +env: + TRL_EXPERIMENTAL_SILENCE: 1 + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + if: github.event.pull_request.draft == false + uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main + with: + commit_sha: ${{ github.event.pull_request.head.sha }} + pr_number: ${{ github.event.number }} + package: trl + version_tag_suffix: "" diff --git a/ICL/RL/trl_source/.github/workflows/clear_cache.yml b/ICL/RL/trl_source/.github/workflows/clear_cache.yml new file mode 100644 index 0000000000000000000000000000000000000000..d93e7661806e7a5181e22771b995540fe43fab6b --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/clear_cache.yml @@ -0,0 +1,33 @@ +name: "Cleanup Cache" + +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + +jobs: + cleanup: + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v6 + + - name: Cleanup + run: | + gh extension install actions/gh-actions-cache + + REPO=${{ github.repository }} + + echo "Fetching list of cache key" + cacheKeysForPR=$(gh actions-cache list -R $REPO | cut -f 1 ) + + ## Setting this to not fail the workflow while deleting cache keys. + set +e + echo "Deleting caches..." + for cacheKey in $cacheKeysForPR + do + gh actions-cache delete $cacheKey -R $REPO --confirm + done + echo "Done" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/ICL/RL/trl_source/.github/workflows/codeQL.yml b/ICL/RL/trl_source/.github/workflows/codeQL.yml new file mode 100644 index 0000000000000000000000000000000000000000..f197e0c3e2dbfabd6ffcca05e1854e4f87e640a8 --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/codeQL.yml @@ -0,0 +1,26 @@ +name: "CodeQL Analysis - Workflows" + +on: + workflow_dispatch: + +jobs: + analyze: + name: "Analyze GitHub Workflows" + runs-on: ubuntu-latest + permissions: + security-events: write + actions: read + contents: read + + steps: + - name: "Checkout repository" + uses: actions/checkout@v6 + + - name: "Initialize CodeQL" + uses: github/codeql-action/init@v2 + with: + languages: "yaml" + queries: +security-and-quality, ./.github/codeql/custom-queries.qls + + - name: "Perform CodeQL Analysis" + uses: github/codeql-action/analyze@v2 diff --git a/ICL/RL/trl_source/.github/workflows/docker-build.yml b/ICL/RL/trl_source/.github/workflows/docker-build.yml new file mode 100644 index 0000000000000000000000000000000000000000..d172155016da2d0593f3fff50e61407da949b487 --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/docker-build.yml @@ -0,0 +1,86 @@ +name: Build TRL Docker image + +on: + push: + branches: + - main + workflow_dispatch: + +concurrency: + group: docker-image-builds + cancel-in-progress: false + +jobs: + trl: + name: "Build and push TRL Docker image" + runs-on: + group: aws-general-8-plus + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Get TRL version from PyPI + run: | + VERSION=$(curl -s https://pypi.org/pypi/trl/json | jq -r .info.version) + echo "VERSION=$VERSION" >> $GITHUB_ENV + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to DockerHub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_PASSWORD }} + + - name: Build and Push + uses: docker/build-push-action@v6 + with: + context: docker/trl + push: true + tags: | + huggingface/trl:${{ env.VERSION }} + huggingface/trl + + - name: Post to Slack + if: always() + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }} + title: 🤗 Results of the TRL Dev Docker Image build + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + + trl-dev: + name: "Build and push TRL Dev Docker image" + runs-on: + group: aws-general-8-plus + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to DockerHub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_PASSWORD }} + + - name: Build and Push + uses: docker/build-push-action@v6 + with: + context: docker/trl-dev + push: true + tags: | + huggingface/trl:dev + + - name: Post to Slack + if: always() + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }} + title: 🤗 Results of the TRL Dev Docker Image build + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} diff --git a/ICL/RL/trl_source/.github/workflows/issue_auto_labeller.yml b/ICL/RL/trl_source/.github/workflows/issue_auto_labeller.yml new file mode 100644 index 0000000000000000000000000000000000000000..8f04275ab84308c393e6b0bc665d4568836b0d88 --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/issue_auto_labeller.yml @@ -0,0 +1,15 @@ +name: "Hugging Face Issue Labeler" +on: + issues: + types: opened + +jobs: + triage: + runs-on: ubuntu-latest + permissions: + issues: write + steps: + - uses: actions/checkout@v6 + - uses: August-murr/auto-labeler@0.0.1 + with: + hf-api-key: ${{ secrets.CI_HF_API_TOKEN }} diff --git a/ICL/RL/trl_source/.github/workflows/pr_style_bot.yml b/ICL/RL/trl_source/.github/workflows/pr_style_bot.yml new file mode 100644 index 0000000000000000000000000000000000000000..b032f25857488a6f88ec53e6393be20312083e26 --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/pr_style_bot.yml @@ -0,0 +1,127 @@ +name: PR Style Bot + +on: + workflow_dispatch: + + +permissions: + contents: write + pull-requests: write + +jobs: + run-style-bot: + if: > + contains(github.event.comment.body, '@bot /style') && + github.event.issue.pull_request != null + runs-on: ubuntu-latest + + steps: + - name: Extract PR details + id: pr_info + uses: actions/github-script@v8 + with: + script: | + const prNumber = context.payload.issue.number; + const { data: pr } = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: prNumber + }); + + // We capture both the branch ref and the "full_name" of the head repo + // so that we can check out the correct repository & branch (including forks). + core.setOutput("prNumber", prNumber); + core.setOutput("headRef", pr.head.ref); + core.setOutput("headRepoFullName", pr.head.repo.full_name); + + - name: Check out PR branch + uses: actions/checkout@v6 + env: + HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }} + HEADREF: ${{ steps.pr_info.outputs.headRef }} + with: + # Instead of checking out the base repo, use the contributor's repo name + repository: ${{ env.HEADREPOFULLNAME }} + ref: ${{ env.HEADREF }} + # You may need fetch-depth: 0 for being able to push + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Debug + env: + HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }} + HEADREF: ${{ steps.pr_info.outputs.headRef }} + PRNUMBER: ${{ steps.pr_info.outputs.prNumber }} + run: | + echo "PR number: ${{ env.PRNUMBER }}" + echo "Head Ref: ${{ env.HEADREF }}" + echo "Head Repo Full Name: ${{ env.HEADREPOFULLNAME }}" + + - name: Set up Python + uses: actions/setup-python@v6 + + - name: Install dependencies + run: | + pip install ruff pre-commit + + - name: Download Makefile from main branch + run: | + curl -o main_Makefile https://raw.githubusercontent.com/huggingface/trl/main/Makefile + + - name: Compare Makefiles + run: | + if ! diff -q main_Makefile Makefile; then + echo "Error: The Makefile has changed. Please ensure it matches the main branch." + exit 1 + fi + echo "No changes in Makefile. Proceeding..." + rm -rf main_Makefile + + - name: Run make style and make quality + run: | + make precommit || true + + - name: Commit and push changes + id: commit_and_push + env: + HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }} + HEADREF: ${{ steps.pr_info.outputs.headRef }} + PRNUMBER: ${{ steps.pr_info.outputs.prNumber }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + echo "HEADREPOFULLNAME: ${{ env.HEADREPOFULLNAME }}, HEADREF: ${{ env.HEADREF }}" + # Configure git with the Actions bot user + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + # Make sure your 'origin' remote is set to the contributor's fork + git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${{ env.HEADREPOFULLNAME }}.git" + + # If there are changes after running style/quality, commit them + if [ -n "$(git status --porcelain)" ]; then + git add . + git commit -m "Apply style fixes" + # Push to the original contributor's forked branch + git push origin HEAD:${{ env.HEADREF }} + echo "changes_pushed=true" >> $GITHUB_OUTPUT + else + echo "No changes to commit." + echo "changes_pushed=false" >> $GITHUB_OUTPUT + fi + + - name: Comment on PR with workflow run link + if: steps.commit_and_push.outputs.changes_pushed == 'true' + uses: actions/github-script@v8 + with: + script: | + const prNumber = parseInt(process.env.prNumber, 10); + const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}` + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: prNumber, + body: `Style fixes have been applied. [View the workflow run here](${runUrl}).` + }); + env: + prNumber: ${{ steps.pr_info.outputs.prNumber }} diff --git a/ICL/RL/trl_source/.github/workflows/publish.yml b/ICL/RL/trl_source/.github/workflows/publish.yml new file mode 100644 index 0000000000000000000000000000000000000000..4913c8b06ffb6f765c77fffc37a762f47381fbc0 --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/publish.yml @@ -0,0 +1,43 @@ +name: Publish to PyPI + +on: + push: + branches: + - main + - v*-release + paths: + - "VERSION" + +jobs: + publish: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + + - name: Read version + id: get_version + run: echo "version=$(cat VERSION)" >> $GITHUB_OUTPUT + + - name: Debug - Show version.txt content + run: echo "Version is ${{ steps.get_version.outputs.version }}" + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.x" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build twine + + - name: Build package + run: python -m build + + - name: Publish to PyPI + if: ${{ !contains(steps.get_version.outputs.version, 'dev') }} + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} + run: | + python -m twine upload dist/* diff --git a/ICL/RL/trl_source/.github/workflows/slow-tests.yml b/ICL/RL/trl_source/.github/workflows/slow-tests.yml new file mode 100644 index 0000000000000000000000000000000000000000..c7062d787d3a0785bc040dc1cd8f7e893e31d69b --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/slow-tests.yml @@ -0,0 +1,112 @@ +name: Slow tests (on push) + +on: + push: + branches: [main] + paths: + # Run only when python files are modified + - "trl/**.py" + - "examples/**.py" +env: + RUN_SLOW: "yes" + IS_GITHUB_CI: "1" + SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + TRL_EXPERIMENTAL_SILENCE: 1 + +jobs: + run_all_tests_single_gpu: + runs-on: + group: aws-g4dn-2xlarge + env: + CUDA_VISIBLE_DEVICES: "0" + TEST_TYPE: "single_gpu" + container: + image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel + options: --gpus all --shm-size "16gb" + defaults: + run: + shell: bash + steps: + - name: Git checkout + uses: actions/checkout@v6 + + - name: Install system dependencies + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + uv pip install pytest-reportlog + + - name: Run slow SFT tests on single GPU + if: always() + run: | + source .venv/bin/activate + make slow_tests + + - name: Generate Report + if: always() + run: | + source .venv/bin/activate + uv pip install slack_sdk tabulate + python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY + + run_all_tests_multi_gpu: + runs-on: + group: aws-g4dn-2xlarge + env: + CUDA_VISIBLE_DEVICES: "0,1" + TEST_TYPE: "multi_gpu" + container: + image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel + options: --gpus all --shm-size "16gb" + defaults: + run: + shell: bash + steps: + - name: Git checkout + uses: actions/checkout@v6 + + - name: Install system dependencies + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + uv pip install pytest-reportlog + + - name: Run slow SFT tests on Multi GPU + if: always() + run: | + source .venv/bin/activate + make slow_tests + + - name: Generate Reports + if: always() + run: | + source .venv/bin/activate + uv pip install slack_sdk tabulate + python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY + rm *.txt \ No newline at end of file diff --git a/ICL/RL/trl_source/.github/workflows/tests.yml b/ICL/RL/trl_source/.github/workflows/tests.yml new file mode 100644 index 0000000000000000000000000000000000000000..e169856997adc6a6f9d40b8287cb6cf761fc4cc2 --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/tests.yml @@ -0,0 +1,312 @@ +name: Tests + +on: + push: + branches: + - main + - ci-* + pull_request: + paths: + # Run only when relevant files are modified + - ".github/**.yml" + - "examples/**.py" + - "scripts/**.py" + - "tests/**.py" + - "trl/**.py" + - "pyproject.toml" + # Exclude if only experimental code/tests + - "!trl/experimental/**" + - "!tests/experimental/**" + +env: + TQDM_DISABLE: 1 + CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }} + PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" + +jobs: + check_code_quality: + name: Check code quality + runs-on: ubuntu-latest + if: github.event.pull_request.draft == false + steps: + - uses: actions/checkout@v6 + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: 3.12 + - uses: pre-commit/action@v3.0.1 + with: + extra_args: --all-files + + tests: + name: Tests + strategy: + matrix: + python-version: ['3.10', '3.11', '3.12', '3.13'] + fail-fast: false + runs-on: + group: aws-g4dn-2xlarge + container: + image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + if: github.event.pull_request.draft == false + steps: + - name: Git checkout + uses: actions/checkout@v6 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + + - name: Test with pytest + run: | + source .venv/bin/activate + make test + + - name: Post to Slack + if: github.ref == 'refs/heads/main' && always() # Check if the branch is main + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: Results with Python ${{ matrix.python-version }} and latest dependencies + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + + tests_dev: + name: Tests with dev dependencies + runs-on: + group: aws-g4dn-2xlarge + container: + image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + if: github.event.pull_request.draft == false + steps: + - name: Git checkout + uses: actions/checkout@v6 + + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: '3.12' + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + uv pip install -U git+https://github.com/huggingface/accelerate.git + uv pip install -U git+https://github.com/huggingface/datasets.git + uv pip install -U git+https://github.com/huggingface/transformers.git + uv pip install -U git+https://github.com/huggingface/peft.git + + - name: Test with pytest + run: | + source .venv/bin/activate + make test + + - name: Post to Slack + if: github.ref == 'refs/heads/main' && always() # Check if the branch is main + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: Results with Python 3.12 and dev dependencies + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + + tests_wo_optional_deps: + name: Tests without optional dependencies + runs-on: + group: aws-g4dn-2xlarge + container: + image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + if: github.event.pull_request.draft == false + steps: + - name: Git checkout + uses: actions/checkout@v6 + + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: '3.12' + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[test]" + + - name: Test with pytest + run: | + source .venv/bin/activate + make test + + - name: Post to Slack + if: github.ref == 'refs/heads/main' && always() # Check if the branch is main + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: Results with Python 3.12 without optional dependencies + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + + tests_min_versions: + name: Tests with minimum versions + runs-on: + group: aws-g4dn-2xlarge + container: + image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + if: github.event.pull_request.draft == false + steps: + - name: Git checkout + uses: actions/checkout@v6 + + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: '3.12' + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + uv pip install accelerate==1.4.0 + uv pip install datasets==3.0.0 + uv pip install transformers==4.56.2 + + - name: Test with pytest + run: | + source .venv/bin/activate + make test + + - name: Post to Slack + if: github.ref == 'refs/heads/main' && always() # Check if the branch is main + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: Results with Python 3.12 and minimum dependencies versions + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + + distributed_smoke: + name: Distributed smoke tests + runs-on: + group: aws-g5-12xlarge-cache + container: + image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + if: github.event.pull_request.draft == false + env: + CUDA_VISIBLE_DEVICES: "0,1" + steps: + - name: Git checkout + uses: actions/checkout@v6 + + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: '3.12' + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + + - name: Run distributed smoke tests + run: | + source .venv/bin/activate + pytest -v tests/distributed/test_distributed.py + + - name: Post to Slack + if: github.ref == 'refs/heads/main' && always() # Check if the branch is main + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: Results of distributed smoke tests + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} diff --git a/ICL/RL/trl_source/.github/workflows/tests_latest.yml b/ICL/RL/trl_source/.github/workflows/tests_latest.yml new file mode 100644 index 0000000000000000000000000000000000000000..e576c330130cae9c5203a2d0a87f0c052fefbb33 --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/tests_latest.yml @@ -0,0 +1,67 @@ +name: Tests latest TRL release with dev dependencies + +on: + schedule: + - cron: '0 0 * * *' # Runs daily at midnight UTC + + workflow_dispatch: + +env: + TQDM_DISABLE: 1 + CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }} + TRL_EXPERIMENTAL_SILENCE: 1 + +jobs: + tests: + name: Tests latest TRL release with dev dependencies + runs-on: + group: aws-g4dn-2xlarge + container: + image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + steps: + - name: Git checkout + uses: actions/checkout@v6 + with: { ref: v0.28-release } + + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: '3.12' + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + uv pip install -U git+https://github.com/huggingface/accelerate.git + uv pip install -U git+https://github.com/huggingface/datasets.git + uv pip install -U git+https://github.com/huggingface/transformers.git + + - name: Test with pytest + run: | + source .venv/bin/activate + make test + + - name: Post to Slack + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: Results of latest TRL with Python 3.12 and dev dependencies + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} diff --git a/ICL/RL/trl_source/.github/workflows/trufflehog.yml b/ICL/RL/trl_source/.github/workflows/trufflehog.yml new file mode 100644 index 0000000000000000000000000000000000000000..6f797daf259e8ce039af9c28efeba749a60d7bcb --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/trufflehog.yml @@ -0,0 +1,18 @@ +on: + push: + +name: Secret Leaks + +jobs: + trufflehog: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v6 + with: + fetch-depth: 0 + - name: Secret Scanning + uses: trufflesecurity/trufflehog@v3.93.1 + with: + # exclude buggy postgres detector that is causing false positives and not relevant to our codebase + extra_args: --results=verified,unknown --exclude-detectors=postgres diff --git a/ICL/RL/trl_source/.github/workflows/upload_pr_documentation.yml b/ICL/RL/trl_source/.github/workflows/upload_pr_documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..2ad2ba0e8de52699f60c2da7792dab742dd6f200 --- /dev/null +++ b/ICL/RL/trl_source/.github/workflows/upload_pr_documentation.yml @@ -0,0 +1,16 @@ +name: Upload PR Documentation + +on: + workflow_run: + workflows: ["Build PR Documentation"] + types: + - completed + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main + with: + package_name: trl + secrets: + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} + comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} \ No newline at end of file diff --git a/ICL/RL/trl_source/docker/trl-dev/Dockerfile b/ICL/RL/trl_source/docker/trl-dev/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..9a756a8821d0c225dbdd012c76c9c6a540397df1 --- /dev/null +++ b/ICL/RL/trl_source/docker/trl-dev/Dockerfile @@ -0,0 +1,5 @@ +FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel +RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* +RUN pip install --upgrade pip uv +RUN uv pip install --system --no-cache "git+https://github.com/huggingface/trl.git#egg=trl[liger,peft,vlm]" +RUN uv pip install --system kernels liger_kernel peft trackio \ No newline at end of file diff --git a/ICL/RL/trl_source/docker/trl/Dockerfile b/ICL/RL/trl_source/docker/trl/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..8b6e2842a3859ab8a45e1cf9983a39d6f160988b --- /dev/null +++ b/ICL/RL/trl_source/docker/trl/Dockerfile @@ -0,0 +1,4 @@ +FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel +RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* +RUN pip install --upgrade pip uv +RUN uv pip install --system trl[liger,peft,vlm] kernels trackio \ No newline at end of file diff --git a/ICL/RL/trl_source/docs/source/_toctree.yml b/ICL/RL/trl_source/docs/source/_toctree.yml new file mode 100644 index 0000000000000000000000000000000000000000..d726c45f284c41a71669fccc7fdbaf0579dd295f --- /dev/null +++ b/ICL/RL/trl_source/docs/source/_toctree.yml @@ -0,0 +1,136 @@ +- sections: + - local: index + title: TRL + - local: installation + title: Installation + - local: quickstart + title: Quickstart + title: Getting started +- sections: + - local: dataset_formats + title: Dataset Formats + - local: paper_index + title: Paper Index + title: Conceptual Guides +- sections: # Sorted alphabetically + - local: dpo_trainer + title: DPO + - local: grpo_trainer + title: GRPO + - local: reward_trainer + title: Reward + - local: rloo_trainer + title: RLOO + - local: sft_trainer + title: SFT + title: Trainers +- sections: + - local: clis + title: Command Line Interface (CLI) + - local: jobs_training + title: Training using Jobs + - local: customization + title: Customizing the Training + - local: reducing_memory_usage + title: Reducing Memory Usage + - local: speeding_up_training + title: Speeding Up Training + - local: distributing_training + title: Distributing Training + - local: use_model + title: Using Trained Models + title: How-to guides +- sections: + - local: deepspeed_integration + title: DeepSpeed + - local: kernels_hub + title: Kernels Hub + - local: liger_kernel_integration + title: Liger Kernel + - local: peft_integration + title: PEFT + - local: ptt_integration + title: Post Training Toolkit + - local: rapidfire_integration + title: RapidFire AI + - local: trackio_integration + title: Trackio + - local: unsloth_integration + title: Unsloth + - local: vllm_integration + title: vLLM + title: Integrations +- sections: + - local: example_overview + title: Example Overview + - local: community_tutorials + title: Community Tutorials + - local: lora_without_regret + title: LoRA Without Regret + title: Examples +- sections: + - sections: + - local: chat_template_utils + title: Chat Template Utilities + - local: data_utils + title: Data Utilities + - local: model_utils + title: Model Utilities + - local: script_utils + title: Script Utilities + title: Utilities + - local: callbacks + title: Callbacks + - local: rewards + title: Reward Functions + - local: others + title: Others + title: API +- sections: + - local: experimental_overview + title: Experimental Overview + - local: openenv + title: OpenEnv Integration + - local: bema_for_reference_model # Sorted alphabetically + title: BEMA for Reference Model + - local: bco_trainer + title: BCO + - local: cpo_trainer + title: CPO + - local: gfpo + title: GFPO + - local: gkd_trainer + title: GKD + - local: gold_trainer + title: GOLD + - local: grpo_with_replay_buffer + title: GRPO With Replay Buffer + - local: gspo_token + title: GSPO-token + - local: judges + title: Judges + - local: kto_trainer + title: KTO + - local: merge_model_callback + title: MergeModelCallback + - local: minillm_trainer + title: MiniLLM + - local: nash_md_trainer + title: Nash-MD + - local: nemo_gym + title: NeMo Gym + - local: online_dpo_trainer + title: Online DPO + - local: orpo_trainer + title: ORPO + - local: papo_trainer + title: PAPO + - local: ppo_trainer + title: PPO + - local: prm_trainer + title: PRM + - local: winrate_callback + title: WinRateCallback + - local: xpo_trainer + title: XPO + title: Experimental \ No newline at end of file diff --git a/ICL/RL/trl_source/docs/source/bco_trainer.md b/ICL/RL/trl_source/docs/source/bco_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..5a6b5abb7e44664c629c2878f2f2383ed4c97136 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/bco_trainer.md @@ -0,0 +1,105 @@ +# BCO Trainer + +[![model badge](https://img.shields.io/badge/All_models-BCO-blue)](https://huggingface.co/models?other=bco,trl) + +TRL supports the Binary Classifier Optimization (BCO). +The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. +For a full example have a look at [`examples/scripts/bco.py`]. + +## Expected dataset type + +The [`experimental.bco.BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference). +The [`experimental.bco.BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +## Expected model format + +The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. + +## Using the `BCOTrainer` + +For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response. + +The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder). + +```python +from trl.experimental.bco import BCOConfig, BCOTrainer + +training_args = BCOConfig( + beta=0.1, +) + +bco_trainer = BCOTrainer( + model, + model_ref, + args=training_args, + train_dataset=train_dataset, + processing_class=tokenizer, +) +``` + +After this one can then call: + +```python +bco_trainer.train() +``` + +## Underlying Distribution matching (UDM) + +In practical scenarios, the thumbs-up and thumbs-down datasets are likely to have divergent underlying distributions of prompts. +Consider an LLM deployed for user feedback: if the model excels in writing tasks but underperforms in coding, the thumbs-up dataset will be dominated by writing-related prompts, while the thumbs-down dataset will contain mostly coding-related prompts. +If the prompts in your desired and undesired datasets differ a lot, it is useful to enable UDM. + +Choose an embedding model and tokenizer: + +```python +embedding_model = AutoModel.from_pretrained(your_model_id) +embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id) + +# customize this function depending on your embedding model +def embed_prompt(input_ids, attention_mask, model): + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + return outputs.last_hidden_state.mean(dim=1) + +embedding_model = Accelerator().prepare_model(self.embedding_model) +embedding_func = partial(embed_prompt, model=embedding_model) +``` + +Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function: + +```python +training_args = BCOConfig( + beta=0.1, + prompt_sample_size=512, +) + +bco_trainer = BCOTrainer( + model, + model_ref, + args=training_args, + train_dataset=train_dataset, + processing_class=tokenizer, + embedding_func=embedding_func, + embedding_tokenizer=self.embedding_tokenizer, +) + +bco_trainer.train() +``` + +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001). + +## BCOTrainer + +[[autodoc]] experimental.bco.BCOTrainer + - train + - save_model + - push_to_hub + +## BCOConfig + +[[autodoc]] experimental.bco.BCOConfig diff --git a/ICL/RL/trl_source/docs/source/bema_for_reference_model.md b/ICL/RL/trl_source/docs/source/bema_for_reference_model.md new file mode 100644 index 0000000000000000000000000000000000000000..832acfc932c06efc4f28798fc1fa2558c52311a8 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/bema_for_reference_model.md @@ -0,0 +1,42 @@ +# BEMA for Reference Model + +This feature implements the BEMA algorithm to update the reference model during DPO training. + +## Usage + +```python +from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + + +pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") +ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + +bema_callback = BEMACallback(update_ref_model=True) + +model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") +tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") +tokenizer.pad_token = tokenizer.eos_token + +trainer = DPOTrainer( + model=model, + ref_model=ref_model, + train_dataset=pref_dataset, + processing_class=tokenizer, + callbacks=[bema_callback], +) + +trainer.train() +``` + +## DPOTrainer + +[[autodoc]] experimental.bema_for_ref_model.DPOTrainer + - train + - save_model + - push_to_hub + +## BEMACallback + +[[autodoc]] experimental.bema_for_ref_model.BEMACallback diff --git a/ICL/RL/trl_source/docs/source/callbacks.md b/ICL/RL/trl_source/docs/source/callbacks.md new file mode 100644 index 0000000000000000000000000000000000000000..a4bc54a77cd954e1810d22ca99d0a4433abd4339 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/callbacks.md @@ -0,0 +1,21 @@ +# Callbacks + +## SyncRefModelCallback + +[[autodoc]] SyncRefModelCallback + +## RichProgressCallback + +[[autodoc]] RichProgressCallback + +## LogCompletionsCallback + +[[autodoc]] LogCompletionsCallback + +## BEMACallback + +[[autodoc]] BEMACallback + +## WeaveCallback + +[[autodoc]] WeaveCallback diff --git a/ICL/RL/trl_source/docs/source/chat_template_utils.md b/ICL/RL/trl_source/docs/source/chat_template_utils.md new file mode 100644 index 0000000000000000000000000000000000000000..bc46ed7250d7901b14c66f14098b7ca5f25c32af --- /dev/null +++ b/ICL/RL/trl_source/docs/source/chat_template_utils.md @@ -0,0 +1,21 @@ +# Chat template utilities + +## clone_chat_template + +[[autodoc]] clone_chat_template + +## add_response_schema + +[[autodoc]] chat_template_utils.add_response_schema + +## is_chat_template_prefix_preserving + +[[autodoc]] chat_template_utils.is_chat_template_prefix_preserving + +## get_training_chat_template + +[[autodoc]] chat_template_utils.get_training_chat_template + +## parse_response + +[[autodoc]] chat_template_utils.parse_response diff --git a/ICL/RL/trl_source/docs/source/clis.md b/ICL/RL/trl_source/docs/source/clis.md new file mode 100644 index 0000000000000000000000000000000000000000..54c8c1055f99522af7de6bcc989599487a26e29b --- /dev/null +++ b/ICL/RL/trl_source/docs/source/clis.md @@ -0,0 +1,703 @@ +# Command Line Interfaces (CLIs) + +TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly. + +## Commands + +Currently supported commands are: + +### Training Commands + +- `trl dpo`: fine-tune a LLM with DPO +- `trl grpo`: fine-tune a LLM with GRPO +- `trl kto`: fine-tune a LLM with KTO +- `trl reward`: train a Reward Model +- `trl rloo`: fine-tune a LLM with RLOO +- `trl sft`: fine-tune a LLM with SFT + +### Other Commands + +- `trl env`: get the system information +- `trl vllm-serve`: serve a model with vLLM + +## Fine-Tuning with the TRL CLI + +### Basic Usage + +You can launch training directly from the CLI by specifying required arguments like the model and dataset: + + + + +```bash +trl sft \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name stanfordnlp/imdb +``` + + + + +```bash +trl dpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name anthropic/hh-rlhf +``` + + + + +```bash +trl reward \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/ultrafeedback_binarized +``` + + + + +```bash +trl grpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name HuggingFaceH4/Polaris-Dataset-53K \ + --reward_funcs accuracy_reward +``` + + + + +```bash +trl rloo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name HuggingFaceH4/Polaris-Dataset-53K \ + --reward_funcs accuracy_reward +``` + + + + +```bash +trl kto \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/kto-mix-14k +``` + + + + +### Using Configuration Files + +To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file: + + + + +```yaml +# sft_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: stanfordnlp/imdb +``` + +Launch with: + +```bash +trl sft --config sft_config.yaml +``` + + + + +```yaml +# dpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: anthropic/hh-rlhf +``` + +Launch with: + +```bash +trl dpo --config dpo_config.yaml +``` + + + + +```yaml +# reward_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/ultrafeedback_binarized +``` + +Launch with: + +```bash +trl reward --config reward_config.yaml +``` + + + + +```yaml +# grpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: HuggingFaceH4/Polaris-Dataset-53K +reward_funcs: + - accuracy_reward +``` + +Launch with: + +```bash +trl grpo --config grpo_config.yaml +``` + + + + +```yaml +# rloo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: HuggingFaceH4/Polaris-Dataset-53K +reward_funcs: + - accuracy_reward +``` + +Launch with: + +```bash +trl rloo --config rloo_config.yaml +``` + + + + +```yaml +# kto_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/kto-mix-14k +``` + +Launch with: + +```bash +trl kto --config kto_config.yaml +``` + + + + +### Scaling Up with Accelerate + +TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI. + +You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch). + + + + +```bash +trl sft \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name stanfordnlp/imdb \ + --num_processes 4 +``` + +or, with a config file: + +```yaml +# sft_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: stanfordnlp/imdb +num_processes: 4 +``` + +Launch with: + +```bash +trl sft --config sft_config.yaml +``` + + + + +```bash +trl dpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name anthropic/hh-rlhf \ + --num_processes 4 +``` + +or, with a config file: + +```yaml +# dpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: anthropic/hh-rlhf +num_processes: 4 +``` + +Launch with: + +```bash +trl dpo --config dpo_config.yaml +``` + + + + +```bash +trl reward \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --num_processes 4 +``` + +or, with a config file: + +```yaml +# reward_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/ultrafeedback_binarized +num_processes: 4 +``` + +Launch with: + +```bash +trl reward --config reward_config.yaml +``` + + + + +```bash +trl grpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name HuggingFaceH4/Polaris-Dataset-53K \ + --reward_funcs accuracy_reward \ + --num_processes 4 +``` + +or, with a config file: + +```yaml +# grpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: HuggingFaceH4/Polaris-Dataset-53K +reward_funcs: + - accuracy_reward +num_processes: 4 +``` + +Launch with: + +```bash +trl grpo --config grpo_config.yaml +``` + + + + +```bash +trl rloo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name HuggingFaceH4/Polaris-Dataset-53K \ + --reward_funcs accuracy_reward \ + --num_processes 4 +``` + +or, with a config file: + +```yaml +# rloo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: HuggingFaceH4/Polaris-Dataset-53K +reward_funcs: + - accuracy_reward +num_processes: 4 +``` + +Launch with: + +```bash +trl rloo --config rloo_config.yaml +``` + + + + +```bash +trl kto \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/kto-mix-14k \ + --num_processes 4 +``` + +or, with a config file: + +```yaml +# kto_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/kto-mix-14k +num_processes: 4 +``` + +Launch with: + +```bash +trl kto --config kto_config.yaml +``` + + + + +### Using `--accelerate_config` for Accelerate Configuration + +The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either: + +- the name of a predefined config profile (built into TRL), or +- a path to a custom Accelerate YAML config file. + +#### Predefined Config Profiles + +TRL provides several ready-to-use Accelerate configs to simplify common training setups: + +| Name | Description | +| --- | --- | +| `fsdp1` | Fully Sharded Data Parallel Stage 1 | +| `fsdp2` | Fully Sharded Data Parallel Stage 2 | +| `zero1` | DeepSpeed ZeRO Stage 1 | +| `zero2` | DeepSpeed ZeRO Stage 2 | +| `zero3` | DeepSpeed ZeRO Stage 3 | +| `multi_gpu` | Multi-GPU training | +| `single_gpu` | Single-GPU training | + +To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`. + +#### Example Usage + + + + +```bash +trl sft \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name stanfordnlp/imdb \ + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml +``` + +or, with a config file: + +```yaml +# sft_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: stanfordnlp/imdb +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml +``` + +Launch with: + +```bash +trl sft --config sft_config.yaml +``` + + + + +```bash +trl dpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name anthropic/hh-rlhf \ + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml +``` + +or, with a config file: + +```yaml +# dpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: anthropic/hh-rlhf +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml +``` + +Launch with: + +```bash +trl dpo --config dpo_config.yaml +``` + + + + +```bash +trl reward \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml +``` + +or, with a config file: + +```yaml +# reward_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/ultrafeedback_binarized +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml +``` + +Launch with: + +```bash +trl reward --config reward_config.yaml +``` + + + + +```bash +trl grpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name HuggingFaceH4/Polaris-Dataset-53K \ + --reward_funcs accuracy_reward \ + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml +``` + +or, with a config file: + +```yaml +# grpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: HuggingFaceH4/Polaris-Dataset-53K +reward_funcs: + - accuracy_reward +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml +``` + +Launch with: + +```bash +trl grpo --config grpo_config.yaml +``` + + + + +```bash +trl rloo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name HuggingFaceH4/Polaris-Dataset-53K \ + --reward_funcs accuracy_reward \ + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml +``` + +or, with a config file: + +```yaml +# rloo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: HuggingFaceH4/Polaris-Dataset-53K +reward_funcs: + - accuracy_reward +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml +``` + +Launch with: + +```bash +trl rloo --config rloo_config.yaml +``` + + + + +```bash +trl kto \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/kto-mix-14k \ + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml +``` + +or, with a config file: + +```yaml +# kto_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: trl-lib/kto-mix-14k +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml +``` + +Launch with: + +```bash +trl kto --config kto_config.yaml +``` + + + + +### Using dataset mixtures + +You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data. + + + + +```yaml +# sft_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +datasets: + - path: stanfordnlp/imdb + - path: roneneldan/TinyStories +``` + +Launch with: + +```bash +trl sft --config sft_config.yaml +``` + + + + +```yaml +# dpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +datasets: + - path: BAAI/Infinity-Preference + - path: argilla/Capybara-Preferences +``` + +Launch with: + +```bash +trl dpo --config dpo_config.yaml +``` + + + + +```yaml +# reward_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +datasets: + - path: trl-lib/tldr-preference + - path: trl-lib/lm-human-preferences-sentiment +``` + +Launch with: + +```bash +trl reward --config reward_config.yaml +``` + + + + +```yaml +# grpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +datasets: + - path: HuggingFaceH4/Polaris-Dataset-53K + - path: trl-lib/DeepMath-103K +reward_funcs: + - accuracy_reward +``` + +Launch with: + +```bash +trl grpo --config grpo_config.yaml +``` + + + + +```yaml +# rloo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +datasets: + - path: HuggingFaceH4/Polaris-Dataset-53K + - path: trl-lib/DeepMath-103K +reward_funcs: + - accuracy_reward +``` + +Launch with: + +```bash +trl rloo --config rloo_config.yaml +``` + + + + +```yaml +# kto_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +datasets: + - path: trl-lib/kto-mix-14k + - path: argilla/ultrafeedback-binarized-preferences-cleaned +``` + +Launch with: + +```bash +trl kto --config kto_config.yaml +``` + + + + +To see all the available keywords for defining dataset mixtures, refer to the [`scripts.utils.DatasetConfig`] and [`DatasetMixtureConfig`] classes. + +## Getting the System Information + +You can get the system information by running the following command: + +```bash +trl env +``` + +This will print out the system information, including the GPU information, the CUDA version, the PyTorch version, the transformers version, the TRL version, and any optional dependencies that are installed. + +```txt +Copy-paste the following information when reporting an issue: + +- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31 +- Python version: 3.11.9 +- PyTorch version: 2.4.1 +- accelerator(s): NVIDIA H100 80GB HBM3 +- Transformers version: 4.45.0.dev0 +- Accelerate version: 0.34.2 +- Accelerate config: + - compute_environment: LOCAL_MACHINE + - distributed_type: DEEPSPEED + - mixed_precision: no + - use_cpu: False + - debug: False + - num_processes: 4 + - machine_rank: 0 + - num_machines: 1 + - rdzv_backend: static + - same_network: True + - main_training_function: main + - enable_cpu_affinity: False + - deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2} + - downcast_bf16: no + - tpu_use_cluster: False + - tpu_use_sudo: False + - tpu_env: [] +- Datasets version: 3.0.0 +- HF Hub version: 0.24.7 +- TRL version: 0.12.0.dev0+acb4d70 +- bitsandbytes version: 0.41.1 +- DeepSpeed version: 0.15.1 +- Diffusers version: 0.30.3 +- Liger-Kernel version: 0.3.0 +- LLM-Blender version: 0.0.2 +- OpenAI version: 1.46.0 +- PEFT version: 0.12.0 +- vLLM version: not installed +``` + +This information is required when reporting an issue. diff --git a/ICL/RL/trl_source/docs/source/community_tutorials.md b/ICL/RL/trl_source/docs/source/community_tutorials.md new file mode 100644 index 0000000000000000000000000000000000000000..81eda22088afb593407f1b06e747c47b2404e85b --- /dev/null +++ b/ICL/RL/trl_source/docs/source/community_tutorials.md @@ -0,0 +1,66 @@ +# Community Tutorials + +Community tutorials are made by active members of the Hugging Face community who want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities. + +## Language Models + +### Tutorials + +| Task | Class | Description | Author | Tutorial | Colab | +| --- | --- | --- | --- | --- | --- | +| Reinforcement Learning | [`GRPOTrainer`] | Efficient Online Training with GRPO and vLLM in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/grpo_vllm_online_training) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/grpo_vllm_online_training.ipynb) | +| Reinforcement Learning | [`GRPOTrainer`] | Post training an LLM for reasoning with GRPO in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb) | +| Reinforcement Learning | [`GRPOTrainer`] | Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/mini-deepseek-r1) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb) | +| Reinforcement Learning | [`GRPOTrainer`] | RL on LLaMA 3.1-8B with GRPO and Unsloth optimizations | [Andrea Manzoni](https://huggingface.co/AManzoni) | [Link](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) | +| Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) | +| Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) | +| Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) | +| Preference Optimization | [`experimental.orpo.ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) | +| Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) | +| Step-Level Reasoning | [`GRPOTrainer`] | Supervised Reinforcement Learning (SRL) for step-by-step reasoning with vLLM | [Deepak Swaminathan](https://huggingface.co/s23deepak) | [Link](https://github.com/s23deepak/Supervised-Reinforcement-Learning) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/s23deepak/Supervised-Reinforcement-Learning/blob/main/notebooks/srl_grpo_tutorial.ipynb) | + +### Videos + +| Task | Title | Author | Video | +| --- | --- | --- | --- | +| Instruction tuning | Fine-tuning open AI models using Hugging Face TRL | [Wietse Venema](https://huggingface.co/wietsevenema) | [](https://youtu.be/cnGyyM0vOes) | +| Instruction tuning | How to fine-tune a smol-LM with Hugging Face, TRL, and the smoltalk Dataset | [Mayurji](https://huggingface.co/iammayur) | [](https://youtu.be/jKdXv3BiLu0) | + + +
+⚠️ Deprecated features notice for "How to fine-tune a smol-LM with Hugging Face, TRL, and the smoltalk Dataset" (click to expand) + +> [!WARNING] +> The tutorial uses two deprecated features: +> +> - `SFTTrainer(..., tokenizer=tokenizer)`: Use `SFTTrainer(..., processing_class=tokenizer)` instead, or simply omit it (it will be inferred from the model). +> - `setup_chat_format(model, tokenizer)`: Use `SFTConfig(..., chat_template_path="Qwen/Qwen3-0.6B")`, where `chat_template_path` specifies the model whose chat template you want to copy. + +
+ +## Vision Language Models + +### Tutorials + +| Task | Class | Description | Author | Tutorial | Colab | +| --- | --- | --- | --- | --- | --- | +| Visual QA | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for visual question answering on ChartQA dataset | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_trl.ipynb) | +| Visual QA | [`SFTTrainer`] | Fine-tuning SmolVLM with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_smol_vlm_sft_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_smol_vlm_sft_trl.ipynb) | +| SEO Description | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for generating SEO-friendly descriptions from images | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-multimodal-llms-with-trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-multimodal-llms-with-trl.ipynb) | +| Visual QA | [`DPOTrainer`] | PaliGemma 🤝 Direct Preference Optimization | [Merve Noyan](https://huggingface.co/merve) | [Link](https://github.com/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | +| Visual QA | [`DPOTrainer`] | Fine-tuning SmolVLM using direct preference optimization (DPO) with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_dpo_smolvlm_instruct.ipynb) | +| Object Detection Grounding | [`SFTTrainer`] | Fine tuning a VLM for Object Detection Grounding using TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_object_detection_grounding) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_object_detection_grounding.ipynb) | +| Visual QA | [`DPOTrainer`] | Fine-Tuning a Vision Language Model with TRL using MPO | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_mpo) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_mpo.ipynb) | +| Reinforcement Learning | [`GRPOTrainer`] | Post training a VLM for reasoning with GRPO using TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_grpo_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_grpo_trl.ipynb) | + +## Speech Language Models + +### Tutorials + +| Task | Class | Description | Author | Tutorial | +| --- | --- | --- | --- | --- | +| Text-to-Speech | [`GRPOTrainer`] | Post training a Speech Language Model with GRPO using TRL | [Steven Zheng](https://huggingface.co/Steveeeeeeen) | [Link](https://huggingface.co/blog/Steveeeeeeen/llasa-grpo) | + +## Contributing + +If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community. diff --git a/ICL/RL/trl_source/docs/source/cpo_trainer.md b/ICL/RL/trl_source/docs/source/cpo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..e1ff2a198a4212272cd5d3ce2b5c5c57edaaea34 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/cpo_trainer.md @@ -0,0 +1,126 @@ +# CPO Trainer + +[![model badge](https://img.shields.io/badge/All_models-CPO-blue)](https://huggingface.co/models?other=cpo,trl) + +## Overview + +Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high level, CPO trains models to avoid generating adequate, but not perfect, translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat. + +CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective. + +## Quick start + +This example demonstrates how to train a model using the CPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here: + + + +Below is the script to train the model: + +```python +# train_cpo.py +from datasets import load_dataset +from trl.experimental.cpo import CPOConfig, CPOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO") +trainer = CPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_cpo.py +``` + +## Expected dataset type + +CPO requires a [preference dataset](dataset_formats#preference). The [`experimental.cpo.CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +## Example script + +We provide an example script to train a model using the CPO method. The script is available in [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) + +To test the CPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command: + +```bash +accelerate launch examples/scripts/cpo.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --num_train_epochs 1 \ + --output_dir Qwen2-0.5B-CPO +``` + +## Logged metrics + +While training and evaluating, we record the following reward metrics: + +* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta +* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta +* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards +* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards +* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses + +## CPO variants + +### Simple Preference Optimization (SimPO) + +[Simple Preference Optimization](https://huggingface.co/papers/2405.14734) (SimPO) by [Yu Meng](https://huggingface.co/yumeng5), [Mengzhou Xia](https://huggingface.co/mengzhouxia), and [Danqi Chen](https://huggingface.co/cdq10131) proposes a simpler and more effective preference optimization algorithm than DPO without using a reference model. The key designs in SimPO are (1) using length-normalized log likelihood as the implicit reward, and (2) incorporating a target reward margin in the Bradley-Terry ranking objective. The official code can be found at [princeton-nlp/SimPO](https://github.com/princeton-nlp/SimPO). + +The abstract from the paper is the following: + +> Direct Preference Optimization (DPO) is a widely used offline preference optimization algorithm that reparameterizes reward functions in reinforcement learning from human feedback (RLHF) to enhance simplicity and training stability. In this work, we propose SimPO, a simpler yet more effective approach. The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as the implicit reward. This reward formulation better aligns with model generation and eliminates the need for a reference model, making it more compute and memory efficient. Additionally, we introduce a target reward margin to the Bradley-Terry objective to encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance. We compare SimPO to DPO and its latest variants across various state-of-the-art training setups, including both base and instruction-tuned models like Mistral and Llama3. We evaluated on extensive instruction-following benchmarks, including AlpacaEval 2, MT-Bench, and the recent challenging Arena-Hard benchmark. Our results demonstrate that SimPO consistently and significantly outperforms existing approaches without substantially increasing response length. Specifically, SimPO outperforms DPO by up to 6.4 points on AlpacaEval 2 and by up to 7.5 points on Arena-Hard. Our top-performing model, built on Llama3-8B-Instruct, achieves a remarkable 44.7 length-controlled win rate on AlpacaEval 2 -- surpassing Claude 3 Opus on the leaderboard, and a 33.8 win rate on Arena-Hard -- making it the strongest 8B open-source model. + +The SimPO loss is integrated in the [`experimental.cpo.CPOTrainer`], as it's an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, just turn on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`experimental.cpo.CPOConfig`] and set the `simpo_gamma` to a recommended value. + +### CPO-SimPO + +We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`experimental.cpo.CPOConfig`]. + +### AlphaPO + +The [AlphaPO -- Reward shape matters for LLM alignment](https://huggingface.co/papers/2501.03884) (AlphaPO) method by Aman Gupta, Shao Tang, Qingquan Song, Sirou Zhu, [Jiwoo Hong](https://huggingface.co/JW17), Ankan Saha, Viral Gupta, Noah Lee, Eunki Kim, Jason Zhu, Natesh Pillai, and S. Sathiya Keerthi is also implemented in the [`experimental.cpo.CPOTrainer`]. AlphaPO is an alternative method that applies a transformation to the reward function shape in the context of SimPO loss. The abstract from the paper is the following: + +> Reinforcement Learning with Human Feedback (RLHF) and its variants have made huge strides toward the effective alignment of large language models (LLMs) to follow instructions and reflect human values. More recently, Direct Alignment Algorithms (DAAs) have emerged in which the reward modeling stage of RLHF is skipped by characterizing the reward directly as a function of the policy being learned. Some popular examples of DAAs include Direct Preference Optimization (DPO) and Simple Preference Optimization (SimPO). These methods often suffer from likelihood displacement, a phenomenon by which the probabilities of preferred responses are often reduced undesirably. In this paper, we argue that, for DAAs the reward (function) shape matters. We introduce AlphaPO, a new DAA method that leverages an α-parameter to help change the shape of the reward function beyond the standard log reward. AlphaPO helps maintain fine-grained control over likelihood displacement and overoptimization. Compared to SimPO, one of the best performing DAAs, AlphaPO leads to about 7% to 10% relative improvement in alignment performance for the instruct versions of Mistral-7B and Llama3-8B while achieving 15% to 50% relative improvement over DPO on the same models. The analysis and results presented highlight the importance of the reward shape and how one can systematically change it to affect training dynamics, as well as improve alignment performance. + +To use this loss as described in the paper, we can set the `loss_type="alphapo"` which automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values in the [`experimental.cpo.CPOConfig`]. Alternatively, you can manually set `loss_type="simpo"`, `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values. Other variants of this method are also possible, such as setting `loss_type="ipo"` and `alpha` to any non-zero value. + +## Loss functions + +The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`experimental.cpo.CPOConfig`]. The following loss functions are supported: + +| `loss_type=` | Description | +| --- | --- | +| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model, and in fact, the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. | +| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. | +| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair, and thus the smaller the `beta`, the larger this gap is. As per the paper, the loss is averaged over log-likelihoods of the completion (unlike DPO, which is summed only). | +| `"simpo"` | The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`experimental.cpo.CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, simply set `loss_type="simpo"` and `cpo_alpha=0.0` in the [`experimental.cpo.CPOConfig`] and `simpo_gamma` to a recommended value. | +| `"alphapo"` | The [AlphaPO](https://huggingface.co/papers/2501.03884) method is also implemented in the [`experimental.cpo.CPOTrainer`]. This is syntactic sugar that automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. AlphaPO applies a transformation to the reward function shape in the context of SimPO loss when the `alpha` parameter is non-zero. | + +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g., [`~transformers.MixtralConfig`]). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config. + +## CPOTrainer + +[[autodoc]] experimental.cpo.CPOTrainer + - train + - save_model + - push_to_hub + +## CPOConfig + +[[autodoc]] experimental.cpo.CPOConfig diff --git a/ICL/RL/trl_source/docs/source/customization.md b/ICL/RL/trl_source/docs/source/customization.md new file mode 100644 index 0000000000000000000000000000000000000000..19ba1088fd11f2ebd8e5c4801dc4117297b22d6f --- /dev/null +++ b/ICL/RL/trl_source/docs/source/customization.md @@ -0,0 +1,142 @@ +# Training customization + +TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are examples on how you can apply and test different techniques. + +> [!NOTE] +> Although these examples use the [`DPOTrainer`], these customization methods apply to most (if not all) trainers in TRL. + +## Use different optimizers and schedulers + +By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows: + +```python +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from torch import optim +from trl import DPOConfig, DPOTrainer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") + +optimizer = optim.SGD(model.parameters(), lr=training_args.learning_rate) + +trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, + optimizers=(optimizer, None), +) +trainer.train() +``` + +### Add a learning rate scheduler + +You can also add learning rate schedulers by passing both optimizer and scheduler: + +```python +from torch import optim + +optimizer = optim.AdamW(model.parameters(), lr=training_args.learning_rate) +lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +trainer = DPOTrainer(..., optimizers=(optimizer, lr_scheduler)) +``` + +## Memory efficient fine-tuning by sharing layers + +Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train. + +```python +from trl import create_reference_model + +ref_model = create_reference_model(model, num_shared_layers=6) + +trainer = DPOTrainer(..., ref_model=ref_model) +``` + +## Pass 8-bit reference models + +Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning. + +Read more about 8-bit model loading in `transformers` [Load in 8bit or 4bit](https://huggingface.co/docs/transformers/en/peft). + +```python +from transformers import AutoModelForCausalLM, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) +ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config=quantization_config) + +trainer = DPOTrainer(..., ref_model=ref_model) +``` + +## Add custom callbacks + +You can customize the training loop by adding callbacks for logging, monitoring, or early stopping. Callbacks allow you to execute custom code at specific points during training. + +```python +from transformers import TrainerCallback + + +class CustomLoggingCallback(TrainerCallback): + def on_log(self, args, state, control, logs=None, **kwargs): + if logs is not None: + print(f"Step {state.global_step}: {logs}") + + +trainer = DPOTrainer(..., callbacks=[CustomLoggingCallback()]) +``` + +## Add custom evaluation metrics + +You can define custom evaluation metrics to track during training. This is useful for monitoring model performance on specific tasks. + +```python +def compute_metrics(eval_preds): + logits, labels = eval_preds + # Add your metric computation here + return {"custom_metric": 0.0} + + +training_args = DPOConfig(..., eval_strategy="steps", eval_steps=100) + +trainer = DPOTrainer(..., eval_dataset=eval_dataset, compute_metrics=compute_metrics) +``` + +## Use mixed precision training + +Mixed precision training can significantly speed up training and reduce memory usage. You can enable it by setting `bf16=True` or `fp16=True` in the training config. + +```python +# Use bfloat16 precision (recommended for modern GPUs) +training_args = DPOConfig(..., bf16=True) +``` + +Note: Use `bf16=True` for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True` for older GPUs. + +## Use gradient accumulation + +When training with limited GPU memory, gradient accumulation allows you to simulate larger batch sizes by accumulating gradients over multiple steps before updating weights. + +```python +# Simulate a batch size of 32 with per_device_train_batch_size=4 and gradient_accumulation_steps=8 +training_args = DPOConfig( + ..., + per_device_train_batch_size=4, + gradient_accumulation_steps=8, +) +``` + +## Use a custom data collator + +You can provide a custom data collator to handle special data preprocessing or padding strategies. + +```python +from trl.trainer.dpo_trainer import DataCollatorForPreference + +data_collator = DataCollatorForPreference(pad_token_id=tokenizer.pad_token_id) + +trainer = DPOTrainer(..., data_collator=data_collator) +``` diff --git a/ICL/RL/trl_source/docs/source/data_utils.md b/ICL/RL/trl_source/docs/source/data_utils.md new file mode 100644 index 0000000000000000000000000000000000000000..b6a18efb343e53e67e8f3e6f99bc8d38e79d113d --- /dev/null +++ b/ICL/RL/trl_source/docs/source/data_utils.md @@ -0,0 +1,53 @@ +# Data Utilities + +## prepare_multimodal_messages + +[[autodoc]] prepare_multimodal_messages + +## prepare_multimodal_messages_vllm + +[[autodoc]] prepare_multimodal_messages_vllm + +## is_conversational + +[[autodoc]] is_conversational + +## is_conversational_from_value + +[[autodoc]] is_conversational_from_value + +## apply_chat_template + +[[autodoc]] apply_chat_template + +## maybe_apply_chat_template + +[[autodoc]] maybe_apply_chat_template + +## maybe_convert_to_chatml + +[[autodoc]] maybe_convert_to_chatml + +## extract_prompt + +[[autodoc]] extract_prompt + +## maybe_extract_prompt + +[[autodoc]] maybe_extract_prompt + +## unpair_preference_dataset + +[[autodoc]] unpair_preference_dataset + +## maybe_unpair_preference_dataset + +[[autodoc]] maybe_unpair_preference_dataset + +## pack_dataset + +[[autodoc]] pack_dataset + +## truncate_dataset + +[[autodoc]] truncate_dataset diff --git a/ICL/RL/trl_source/docs/source/dataset_formats.md b/ICL/RL/trl_source/docs/source/dataset_formats.md new file mode 100644 index 0000000000000000000000000000000000000000..1ecad2038c2efcaa34624be4edafee13a9178198 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/dataset_formats.md @@ -0,0 +1,993 @@ +# Dataset formats and types + +This guide provides an overview of the dataset formats and types supported by each trainer in TRL. + +## Overview of the dataset formats and types + +- The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*. +- The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Type \ FormatStandardConversational
Language modeling +
{"text": "The sky is blue."}
+
+
{"messages": [{"role": "user", "content": "What color is the sky?"},
+              {"role": "assistant", "content": "It is blue."}]}
+
Prompt-only +
{"prompt": "The sky is"}
+
+
{"prompt": [{"role": "user", "content": "What color is the sky?"}]}
+
Prompt-completion +
{"prompt": "The sky is",
+ "completion": " blue."}
+
+
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
+ "completion": [{"role": "assistant", "content": "It is blue."}]}
+
Preference +
{"prompt": "The sky is",
+ "chosen": " blue.",
+ "rejected": " green."}
+ or, with implicit prompt: +
{"chosen": "The sky is blue.",
+ "rejected": "The sky is green."}
+
+
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
+ "chosen": [{"role": "assistant", "content": "It is blue."}],
+ "rejected": [{"role": "assistant", "content": "It is green."}]}
+ or, with implicit prompt: +
{"chosen": [{"role": "user", "content": "What color is the sky?"},
+              {"role": "assistant", "content": "It is blue."}],
+ "rejected": [{"role": "user", "content": "What color is the sky?"},
+                {"role": "assistant", "content": "It is green."}]}
+
Unpaired preference +
{"prompt": "The sky is",
+ "completion": " blue.",
+ "label": True}
+
+
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
+ "completion": [{"role": "assistant", "content": "It is green."}],
+ "label": False}
+
Stepwise supervision +
{"prompt": "Which number is larger, 9.8 or 9.11?",
+ "completions": ["The fractional part of 9.8 is 0.8.",
+                 "The fractional part of 9.11 is 0.11.",
+                 "0.11 is greater than 0.8.",
+                 "Hence, 9.11 > 9.8."],
+ "labels": [True, True, False, False]}
+
+ +### Formats + +#### Standard + +The standard dataset format typically consists of plain text strings. The columns in the dataset vary depending on the task. This is the format expected by TRL trainers. Below are examples of standard dataset formats for different tasks: + +```python +# Language modeling +language_modeling_example = {"text": "The sky is blue."} +# Preference +preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."} +# Unpaired preference +unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True} +``` + +#### Conversational + +Conversational datasets are used for tasks involving dialogues or chat interactions between users and assistants. Unlike standard dataset formats, these contain sequences of messages where each message has a `role` (e.g., `"user"` or `"assistant"`) and `content` (the message text). + +```python +messages = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, + {"role": "user", "content": "I'd like to show off how chat templating works!"}, +] +``` + +Just like standard datasets, the columns in conversational datasets vary depending on the task. Below are examples of conversational dataset formats for different tasks: + +```python +# Prompt-completion +prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}]} +# Preference +preference_example = { + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "chosen": [{"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "assistant", "content": "It is green."}], +} +``` + +#### Tool Calling + +Some chat templates support *tool calling*, which allows the model to interact with external functions—referred to as **tools**—during generation. This extends the conversational capabilities of the model by enabling it to output a `"tool_calls"` field instead of a standard `"content"` message whenever it decides to invoke a tool. + +After the assistant initiates a tool call, the tool executes and returns its output. The assistant can then process this output and continue the conversation accordingly. + +Here’s a simple example of a tool-calling interaction: + +```python +messages = [ + {"role": "user", "content": "Turn on the living room lights."}, + {"role": "assistant", "tool_calls": [ + {"type": "function", "function": { + "name": "control_light", + "arguments": {"room": "living room", "state": "on"} + }}] + }, + {"role": "tool", "name": "control_light", "content": "The lights in the living room are now on."}, + {"role": "assistant", "content": "Done!"} +] +``` + +When preparing datasets for Supervised Fine-Tuning (SFT) with tool calling, it is important that your dataset includes an additional column named `tools`. This column contains the list of available tools for the model, which is usually used by the chat template to construct the system prompt. + +The tools must be specified in a codified JSON schema format. You can automatically generate this schema from Python function signatures using the [`~transformers.utils.get_json_schema`] utility: + +```python +from transformers.utils import get_json_schema + +def control_light(room: str, state: str) -> str: + """ + Controls the lights in a room. + + Args: + room: The name of the room. + state: The desired state of the light ("on" or "off"). + + Returns: + str: A message indicating the new state of the lights. + """ + return f"The lights in {room} are now {state}." + +# Generate JSON schema +json_schema = get_json_schema(control_light) +``` + +The generated schema would look like: + +```python +{ + "type": "function", + "function": { + "name": "control_light", + "description": "Controls the lights in a room.", + "parameters": { + "type": "object", + "properties": { + "room": {"type": "string", "description": "The name of the room."}, + "state": {"type": "string", "description": 'The desired state of the light ("on" or "off").'}, + }, + "required": ["room", "state"], + }, + "return": {"type": "string", "description": "str: A message indicating the new state of the lights."}, + }, +} +``` + +A complete dataset entry for SFT might look like: + +```python +{"messages": messages, "tools": [json_schema]} +``` + +For more detailed information on tool calling, refer to the [Tool Calling section in the `transformers` documentation](https://huggingface.co/docs/transformers/chat_extras#tools-and-rag) and the blog post [Tool Use, Unified](https://huggingface.co/blog/unified-tool-use). + +### Harmony + +The [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) was introduced with the [OpenAI GPT OSS models](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4). It extends the conversational format by adding richer structure for reasoning, function calls, and metadata about the model’s behavior. Key features include: + +- **Developer role** – Provides high level instructions (similar to a system prompt) and lists available tools. +- **Channels** – Separate types of assistant output into distinct streams: + + - `analysis` – for internal reasoning, from the key `"thinking"` + - `final` – for the user-facing answer, from the key `"content"` + - `commentary` – for tool calls or meta notes + +- **Reasoning effort** – Signals how much thinking the model should show (e.g., `"low"`, `"medium"`, `"high"`). +- **Model identity** – Explicitly defines the assistant’s persona. + +```python +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b") + +messages = [ + {"role": "developer", "content": "Use a friendly tone."}, + {"role": "user", "content": "What is the meaning of life?"}, + {"role": "assistant", "thinking": "Deep reflection...", "content": "The final answer is..."}, +] + +print( + tokenizer.apply_chat_template( + messages, + tokenize=False, + reasoning_effort="low", + model_identity="You are HuggingGPT, a large language model trained by Hugging Face.", + ) +) +``` + +This produces: + +```txt +<|start|>system<|message|>You are HuggingGPT, a large language model trained by Hugging Face. +Knowledge cutoff: 2024-06 +Current date: 2025-08-03 + +Reasoning: low + +# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions + +Use a friendly tone.<|end|><|start|>user<|message|>What is the meaning of life?<|end|><|start|>assistant<|channel|>analysis<|message|>Deep reflection...<|end|><|start|>assistant<|channel|>final<|message|>The final answer is...<|return|> +``` + +For full details on message structure, supported fields, and advanced usage, see the [Harmony documentation](https://cookbook.openai.com/articles/openai-harmony). + +### Types + +#### Language modeling + +A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text. + +```python +# Standard format +language_modeling_example = {"text": "The sky is blue."} +# Conversational format +language_modeling_example = {"messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."} +]} +``` + +#### Prompt-only + +In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating completion based on this prompt, where the model learns to continue or complete the given input. + +```python +# Standard format +prompt_only_example = {"prompt": "The sky is"} +# Conversational format +prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]} +``` + +For examples of prompt-only datasets, refer to the [Prompt-only datasets collection](https://huggingface.co/collections/trl-lib/prompt-only-datasets-677ea25245d20252cea00368). + +> [!TIP] +> While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type: +> +> ```python +> from transformers import AutoTokenizer +> from trl import apply_chat_template +> +> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct") +> +> # Example for prompt-only type +> prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]} +> apply_chat_template(prompt_only_example, tokenizer) +> # Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'} +> +> # Example for language modeling type +> lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]} +> apply_chat_template(lm_example, tokenizer) +> # Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'} +> ``` +> +> - The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistant’s turn and expecting the model to generate a completion. +> - In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content. + +#### Prompt-completion + +A prompt-completion dataset includes a `"prompt"` and a `"completion"`. + +```python +# Standard format +prompt_completion_example = {"prompt": "The sky is", "completion": " blue."} +# Conversational format +prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}]} +``` + +For examples of prompt-completion datasets, refer to the [Prompt-completion datasets collection](https://huggingface.co/collections/trl-lib/prompt-completion-datasets-677ea2bb20bbb6bdccada216). + +#### Preference + +A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response. +Some datasets may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible. + +```python +# Standard format +## Explicit prompt (recommended) +preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."} +# Implicit prompt +preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."} + +# Conversational format +## Explicit prompt (recommended) +preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}], + "chosen": [{"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "assistant", "content": "It is green."}]} +## Implicit prompt +preference_example = {"chosen": [{"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is green."}]} +``` + +For examples of preference datasets, refer to the [Preference datasets collection](https://huggingface.co/collections/trl-lib/preference-datasets-677e99b581018fcad9abd82c). + +Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets. + +#### Unpaired preference + +An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not. + +```python +# Standard format +unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True} +# Conversational format +unpaired_preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + "label": True} +``` + +For examples of unpaired preference datasets, refer to the [Unpaired preference datasets collection](https://huggingface.co/collections/trl-lib/unpaired-preference-datasets-677ea22bf5f528c125b0bcdf). + +#### Stepwise supervision + +A stepwise (or process) supervision dataset is similar to an [unpaired preference](#unpaired-preference) dataset but includes multiple steps of completions, each with its own label. This structure is useful for tasks that need detailed, step-by-step labeling, such as reasoning tasks. By evaluating each step separately and providing targeted labels, this approach helps identify precisely where the reasoning is correct and where errors occur, allowing for targeted feedback on each part of the reasoning process. + +```python +stepwise_example = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.", "Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."], + "labels": [True, False] +} +``` + +For examples of stepwise supervision datasets, refer to the [Stepwise supervision datasets collection](https://huggingface.co/collections/trl-lib/stepwise-supervision-datasets-677ea27fd4c5941beed7a96e). + +## Which dataset type to use? + +Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer. + +| Trainer | Expected dataset type | +| --- | --- | +| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`GRPOTrainer`] | [Prompt-only](#prompt-only) | +| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | +| [`RLOOTrainer`] | [Prompt-only](#prompt-only) | +| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) | +| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | +| [`experimental.cpo.CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) | +| [`experimental.kto.KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | +| [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) | +| [`experimental.online_dpo.OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | +| [`experimental.orpo.ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`experimental.ppo.PPOTrainer`] | Tokenized language modeling | +| [`experimental.prm.PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) | +| [`experimental.xpo.XPOTrainer`] | [Prompt-only](#prompt-only) | + +## Using any dataset with TRL: preprocessing and conversion + +Many datasets come in formats tailored to specific tasks, which might not be directly compatible with TRL. To use such datasets with TRL, you may need to preprocess and convert them into the required format. + +To make this easier, we provide a set of [example scripts](https://github.com/huggingface/trl/tree/main/examples/datasets) that cover common dataset conversions. + +### Example: UltraFeedback dataset + +Let’s take the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback) as an example. Here's a preview of the dataset: + + + +As shown above, the dataset format does not match the expected structure. It’s not in a conversational format, the column names differ, and the results pertain to different models (e.g., Bard, GPT-4) and aspects (e.g., "helpfulness", "honesty"). + +By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference type, and push it to the Hub: + +```sh +python examples/datasets/ultrafeedback.py --push_to_hub --repo_id trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness +``` + +Once converted, the dataset will look like this: + + + +Now, you can use this dataset with TRL! + +By adapting the provided scripts or creating your own, you can convert any dataset into a format compatible with TRL. + +## Utilities for converting dataset types + +This section provides example code to help you convert between different dataset types. While some conversions can be performed after applying the chat template (i.e., in the standard format), we recommend performing the conversion before applying the chat template to ensure it works consistently. + +For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification. + +| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision | +| --- | --- | --- | --- | --- | --- | --- | --- | +| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A | +| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A | +| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A | +| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A | +| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A | +| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A | +| Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A | + +### From prompt-completion to language modeling dataset + +To convert a prompt-completion dataset into a language modeling dataset, concatenate the prompt and the completion. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is"], + "completion": [" blue.", " in the sky."], +}) + +def concat_prompt_completion(example): + return {"text": example["prompt"] + example["completion"]} + +dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"]) +``` + +```python +>>> dataset[0] +{'text': 'The sky is blue.'} +``` + +### From prompt-completion to prompt-only dataset + +To convert a prompt-completion dataset into a prompt-only dataset, remove the completion. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is"], + "completion": [" blue.", " in the sky."], +}) + +dataset = dataset.remove_columns("completion") +``` + +```python +>>> dataset[0] +{'prompt': 'The sky is'} +``` + +### From preference with implicit prompt to language modeling dataset + +To convert a preference with implicit prompt dataset into a language modeling dataset, remove the rejected, and rename the column `"chosen"` to `"text"`. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "chosen": ["The sky is blue.", "The sun is in the sky."], + "rejected": ["The sky is green.", "The sun is in the sea."], +}) + +dataset = dataset.rename_column("chosen", "text").remove_columns("rejected") +``` + +```python +>>> dataset[0] +{'text': 'The sky is blue.'} +``` + +### From preference with implicit prompt to prompt-completion dataset + +To convert a preference dataset with implicit prompt into a prompt-completion dataset, extract the prompt with [`extract_prompt`], remove the rejected, and rename the column `"chosen"` to `"completion"`. + +```python +from datasets import Dataset +from trl import extract_prompt + +dataset = Dataset.from_dict({ + "chosen": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}], + ], + "rejected": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}], + ], +}) +dataset = dataset.map(extract_prompt).remove_columns("rejected").rename_column("chosen", "completion") +``` + +```python +>>> dataset[0] +{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], 'completion': [{'role': 'assistant', 'content': 'It is blue.'}]} +``` + +### From preference with implicit prompt to prompt-only dataset + +To convert a preference dataset with implicit prompt into a prompt-only dataset, extract the prompt with [`extract_prompt`], and remove the rejected and the chosen. + +```python +from datasets import Dataset +from trl import extract_prompt + +dataset = Dataset.from_dict({ + "chosen": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}], + ], + "rejected": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}], + ], +}) +dataset = dataset.map(extract_prompt).remove_columns(["chosen", "rejected"]) +``` + +```python +>>> dataset[0] +{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}]} +``` + +### From implicit to explicit prompt preference dataset + +To convert a preference dataset with implicit prompt into a preference dataset with explicit prompt, extract the prompt with [`extract_prompt`]. + +```python +from datasets import Dataset +from trl import extract_prompt + +dataset = Dataset.from_dict({ + "chosen": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}], + ], + "rejected": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}], + ], +}) + +dataset = dataset.map(extract_prompt) +``` + +```python +>>> dataset[0] +{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], + 'chosen': [{'role': 'assistant', 'content': 'It is blue.'}], + 'rejected': [{'role': 'assistant', 'content': 'It is green.'}]} +``` + +### From preference with implicit prompt to unpaired preference dataset + +To convert a preference dataset with implicit prompt into an unpaired preference dataset, extract the prompt with [`extract_prompt`], and unpair the dataset with [`unpair_preference_dataset`]. + +```python +from datasets import Dataset +from trl import extract_prompt, unpair_preference_dataset + +dataset = Dataset.from_dict({ + "chosen": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}], + ], + "rejected": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}], + ], +}) + +dataset = dataset.map(extract_prompt) +dataset = unpair_preference_dataset(dataset) +``` + +```python +>>> dataset[0] +{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], + 'completion': [{'role': 'assistant', 'content': 'It is blue.'}], + 'label': True} +``` + +> [!WARNING] +> Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad. +> Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad. +> This can be ensured by checking absolute rating of each completion, e.g. from a reward model. + +### From preference to language modeling dataset + +To convert a preference dataset into a language modeling dataset, remove the rejected, concatenate the prompt and the chosen into the `"text"` column. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is"], + "chosen": [" blue.", " in the sky."], + "rejected": [" green.", " in the sea."], +}) + +def concat_prompt_chosen(example): + return {"text": example["prompt"] + example["chosen"]} + +dataset = dataset.map(concat_prompt_chosen, remove_columns=["prompt", "chosen", "rejected"]) +``` + +```python +>>> dataset[0] +{'text': 'The sky is blue.'} +``` + +### From preference to prompt-completion dataset + +To convert a preference dataset into a prompt-completion dataset, remove the rejected, and rename the column `"chosen"` to `"completion"`. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is"], + "chosen": [" blue.", " in the sky."], + "rejected": [" green.", " in the sea."], +}) + +dataset = dataset.remove_columns("rejected").rename_column("chosen", "completion") +``` + +```python +>>> dataset[0] +{'prompt': 'The sky is', 'completion': ' blue.'} +``` + +### From preference to prompt-only dataset + +To convert a preference dataset into a prompt-only dataset, remove the rejected and the chosen. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is"], + "chosen": [" blue.", " in the sky."], + "rejected": [" green.", " in the sea."], +}) + +dataset = dataset.remove_columns(["chosen", "rejected"]) +``` + +```python +>>> dataset[0] +{'prompt': 'The sky is'} +``` + +### From explicit to implicit prompt preference dataset + +To convert a preference dataset with explicit prompt into a preference dataset with implicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": [ + [{"role": "user", "content": "What color is the sky?"}], + [{"role": "user", "content": "Where is the sun?"}], + ], + "chosen": [ + [{"role": "assistant", "content": "It is blue."}], + [{"role": "assistant", "content": "In the sky."}], + ], + "rejected": [ + [{"role": "assistant", "content": "It is green."}], + [{"role": "assistant", "content": "In the sea."}], + ], +}) + +def concat_prompt_to_completions(example): + return {"chosen": example["prompt"] + example["chosen"], "rejected": example["prompt"] + example["rejected"]} + +dataset = dataset.map(concat_prompt_to_completions, remove_columns="prompt") +``` + +```python +>>> dataset[0] +{'chosen': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is blue.'}], + 'rejected': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is green.'}]} +``` + +### From preference to unpaired preference dataset + +To convert dataset into an unpaired preference dataset, unpair the dataset with [`unpair_preference_dataset`]. + +```python +from datasets import Dataset +from trl import unpair_preference_dataset + +dataset = Dataset.from_dict({ + "prompt": [ + [{"role": "user", "content": "What color is the sky?"}], + [{"role": "user", "content": "Where is the sun?"}], + ], + "chosen": [ + [{"role": "assistant", "content": "It is blue."}], + [{"role": "assistant", "content": "In the sky."}], + ], + "rejected": [ + [{"role": "assistant", "content": "It is green."}], + [{"role": "assistant", "content": "In the sea."}], + ], +}) + +dataset = unpair_preference_dataset(dataset) +``` + +```python +>>> dataset[0] +{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], + 'completion': [{'role': 'assistant', 'content': 'It is blue.'}], + 'label': True} +``` + +> [!WARNING] +> Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad. +> Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad. +> This can be ensured by checking absolute rating of each completion, e.g. from a reward model. + +### From unpaired preference to language modeling dataset + +To convert an unpaired preference dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column, and remove the prompt, completion and label columns. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"], + "completion": [" blue.", " in the sky.", " green.", " in the sea."], + "label": [True, True, False, False], +}) + +def concatenate_prompt_completion(example): + return {"text": example["prompt"] + example["completion"]} + +dataset = dataset.filter(lambda x: x["label"]).map(concatenate_prompt_completion).remove_columns(["prompt", "completion", "label"]) +``` + +```python +>>> dataset[0] +{'text': 'The sky is blue.'} +``` + +### From unpaired preference to prompt-completion dataset + +To convert an unpaired preference dataset into a prompt-completion dataset, filter for good labels, then remove the label columns. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"], + "completion": [" blue.", " in the sky.", " green.", " in the sea."], + "label": [True, True, False, False], +}) + +dataset = dataset.filter(lambda x: x["label"]).remove_columns(["label"]) +``` + +```python +>>> dataset[0] +{'prompt': 'The sky is', 'completion': ' blue.'} +``` + +### From unpaired preference to prompt-only dataset + +To convert an unpaired preference dataset into a prompt-only dataset, remove the completion and the label columns. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"], + "completion": [" blue.", " in the sky.", " green.", " in the sea."], + "label": [True, True, False, False], +}) + +dataset = dataset.remove_columns(["completion", "label"]) +``` + +```python +>>> dataset[0] +{'prompt': 'The sky is'} +``` + +### From stepwise supervision to language modeling dataset + +To convert a stepwise supervision dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["Blue light", "Water"], + "completions": [[" scatters more in the atmosphere,", " so the sky is green."], + [" forms a less dense structure in ice,", " which causes it to expand when it freezes."]], + "labels": [[True, False], [True, True]], +}) + +def concatenate_prompt_completions(example): + completion = "".join(example["completions"]) + return {"text": example["prompt"] + completion} + +dataset = dataset.filter(lambda x: all(x["labels"])).map(concatenate_prompt_completions, remove_columns=["prompt", "completions", "labels"]) +``` + +```python +>>> dataset[0] +{'text': 'Blue light scatters more in the atmosphere, so the sky is green.'} +``` + +### From stepwise supervision to prompt-completion dataset + +To convert a stepwise supervision dataset into a prompt-completion dataset, join the good completions and remove the labels. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["Blue light", "Water"], + "completions": [[" scatters more in the atmosphere,", " so the sky is green."], + [" forms a less dense structure in ice,", " which causes it to expand when it freezes."]], + "labels": [[True, False], [True, True]], +}) + +def join_completions(example): + completion = "".join(example["completions"]) + return {"completion": completion} + +dataset = dataset.filter(lambda x: all(x["labels"])).map(join_completions, remove_columns=["completions", "labels"]) +``` + +```python +>>> dataset[0] +{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.'} +``` + +### From stepwise supervision to prompt-only dataset + +To convert a stepwise supervision dataset into a prompt-only dataset, remove the completions and the labels. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["Blue light", "Water"], + "completions": [[" scatters more in the atmosphere,", " so the sky is green."], + [" forms a less dense structure in ice,", " which causes it to expand when it freezes."]], + "labels": [[True, False], [True, True]], +}) + +dataset = dataset.remove_columns(["completions", "labels"]) +``` + +```python +>>> dataset[0] +{'prompt': 'Blue light'} +``` + +### From stepwise supervision to unpaired preference dataset + +To convert a stepwise supervision dataset into an unpaired preference dataset, join the completions and merge the labels. + +The method for merging the labels depends on the specific task. In this example, we use the logical AND operation. This means that if the step labels indicate the correctness of individual steps, the resulting label will reflect the correctness of the entire sequence. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["Blue light", "Water"], + "completions": [[" scatters more in the atmosphere,", " so the sky is green."], + [" forms a less dense structure in ice,", " which causes it to expand when it freezes."]], + "labels": [[True, False], [True, True]], +}) + +def merge_completions_and_labels(example): + return {"prompt": example["prompt"], "completion": "".join(example["completions"]), "label": all(example["labels"])} + +dataset = dataset.map(merge_completions_and_labels, remove_columns=["completions", "labels"]) +``` + +```python +>>> dataset[0] +{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.', 'label': False} +``` + +## Vision datasets + +Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently. + +A conversational vision dataset differs from a standard conversational dataset in two key ways: + +1. The dataset must contain the key `images` with the image data (as lists of PIL images) or `image` with a single PIL image. +2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`. + +Example: + +```python +# Textual dataset: +"content": "What color is the sky?" + +# Vision dataset: +"content": [ + {"type": "image"}, + {"type": "text", "text": "What color is the sky in the image?"} +] +``` + +An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset). Below is an embedded view of the dataset's training data, allowing you to explore it directly: + + + +> [!NOTE] +> Mixing text-only and vision-language data in the dataset is possible, but it requires `transformers` version 4.57.0 or later. Example: +> +> ```python +> dataset = Dataset.from_dict({ +> "prompt": [ +> [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky in the image?"}]}], +> [{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}], +> ], +> "completion": [ +> [{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}], +> [{"role": "assistant", "content": [{"type": "text", "text": "Paris."}]}], +> ], +> "images": [ +> [PIL.Image.open("path/to/sky_image1.png")], +> [], +> ], +> }) +> ``` diff --git a/ICL/RL/trl_source/docs/source/deepspeed_integration.md b/ICL/RL/trl_source/docs/source/deepspeed_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..a605787972e3578949f820364be8e3d951657637 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/deepspeed_integration.md @@ -0,0 +1,36 @@ +# DeepSpeed Integration + +> [!WARNING] +> Section under construction. Feel free to contribute! + +TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more. + +DeepSpeed integrates the [Zero Redundancy Optimizer (ZeRO)](https://huggingface.co/papers/1910.02054), which allows to scale the model size proportional to the number of devices with sustained high efficiency. + +![ZeRO Stages](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/zero_stages.png) + +## Installation + +To use DeepSpeed with TRL, install it using the following command: + +```bash +pip install deepspeed +``` + +## Running Training Scripts with DeepSpeed + +No modifications to your training script are required. Simply run it with the DeepSpeed configuration file: + +```bash +accelerate launch --config_file train.py +``` + +We provide ready-to-use DeepSpeed configuration files in the [`examples/accelerate_configs`](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) directory. For example, to run training with ZeRO Stage 2, use the following command: + +```bash +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml train.py +``` + +## Additional Resources + +Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin. diff --git a/ICL/RL/trl_source/docs/source/distributing_training.md b/ICL/RL/trl_source/docs/source/distributing_training.md new file mode 100644 index 0000000000000000000000000000000000000000..34bf7d165c76a0134ebafd2d7b8c092643fbbafd --- /dev/null +++ b/ICL/RL/trl_source/docs/source/distributing_training.md @@ -0,0 +1,445 @@ +# Distributing Training + +> [!WARNING] +> Section under construction. Feel free to contribute! + +## Multi-GPU Training with TRL + +The trainers in TRL use [🤗 Accelerate](https://github.com/huggingface/accelerate) to enable distributed training across multiple GPUs or nodes. To do so, first create an [🤗 Accelerate](https://github.com/huggingface/accelerate) config file by running + +```bash +accelerate config +``` + +and answering the questions according to your multi-GPU / multi-node setup. You can then launch distributed training by running: + +```bash +accelerate launch train.py +``` + +We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.: + +```shell +accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml train.py +``` + +This automatically distributes the workload across all available GPUs. + +Under the hood, [🤗 Accelerate](https://github.com/huggingface/accelerate) creates one model per GPU. Each process: + +- Processes its own batch of data +- Computes the loss and gradients for that batch +- Shares gradient updates across all GPUs + +![multi gpu](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/multi_gpu.png) + +The effective batch size is calculated as: + +$$ +\text{Batch Size} = \text{per\_device\_train\_batch\_size} \times \text{num\_devices} \times \text{gradient\_accumulation\_steps} +$$ + +To maintain a consistent batch size when scaling to multiple GPUs, make sure to update `per_device_train_batch_size` and `gradient_accumulation_steps` accordingly. + +Example, these configurations are equivalent, and should yield the same results: + +| Number of GPUs | Per device batch size | Gradient accumulation steps | Comments | +| --- | --- | --- | --- | +| 1 | 32 | 1 | Possibly high memory usage, but faster training | +| 1 | 4 | 8 | Lower memory usage, slower training | +| 8 | 4 | 1 | Multi-GPU to get the best of both worlds | + +> [!TIP] +> Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration) guide for more details. + +## Sequence Parallelism for Long Context Training + +Sequence Parallelism (also called Context Parallelism) is a parallelization technique that enables training with longer sequences by splitting the sequence dimension across multiple GPUs. Each GPU processes a portion of the sequence, allowing you to train with sequences longer than what would fit on a single GPU's memory. + +> [!NOTE] +> **Terminology clarification:** This section describes parallelism techniques for splitting sequences to enable longer context training: +> - **Context Parallelism (CP)**: Splits sequences across GPUs (implemented as Ring Attention with FSDP2) +> - **Sequence Parallelism (SP)**: Another form of sequence splitting (implemented as ALST/Ulysses with DeepSpeed) +> +> Both CP and SP are different from traditional Sequence Parallelism used with Tensor Parallelism (TP+SP) to reduce activation memory. With the techniques here, parallelism dimensions multiply: `TP=2` and `CP=2` would require 4 GPUs (2×2), whereas traditional `TP+SP=2` only needs 2 GPUs as they share the same ranks. +> +> In Accelerate's `ParallelismConfig`: +> - Use `cp_size` with `cp_backend="torch"` for Ring Attention (FSDP2) +> - Use `sp_size` with `sp_backend="deepspeed"` for ALST/Ulysses (DeepSpeed) + +Sequence parallelism is particularly useful when: + +- You want to train with very long sequences (>32k tokens) +- Single GPU memory is insufficient for your desired sequence length +- You need to maintain sequence coherence across the full context + +### Available Implementations + +TRL supports two sequence parallelism implementations, each with different characteristics: + +1. **Ring Attention (FSDP2)** - Uses ring-based communication for memory-efficient processing of extremely long sequences +2. **ALST/Ulysses (DeepSpeed)** - Uses attention head parallelism for faster training with high-bandwidth interconnects + +> [!IMPORTANT] +> **Sequence Length Terminology:** When using Context Parallelism, the sequence is split across GPUs, introducing two concepts: +> - **Global sequence length**: The full sequence length before splitting across GPUs +> - **Micro sequence length**: The sequence length per GPU after splitting +> +> In TRL, `max_seq_length` (or `max_length`) refers to the **global sequence length**. The framework automatically handles splitting into micro sequences: +> - **Ring Attention (FSDP2)**: Uses `cp_size` to split sequences. With `max_seq_length=8192` and `cp_size=4`, each GPU processes 2048 tokens. +> - **ALST/Ulysses (DeepSpeed)**: Uses `sp_size` (with `sp_backend="deepspeed"`) to split sequences. With `max_seq_length=8192` and `sp_size=2`, each GPU processes 4096 tokens. +> +> The Trainer automatically accounts for context parallelism when calculating batch sizes and training metrics. + +### Choosing Between Ring Attention and Ulysses + +The comparison table below highlights the key differences between the two approaches: + +| Feature | Ring Attention (FSDP2) | ALST/Ulysses (DeepSpeed) | +|---------|----------|-------------------------| +| **Method** | Ring Self-Attention | Attention Head Parallelism | +| **Backend** | PyTorch FSDP2 | DeepSpeed ZeRO | +| **Attention** | SDPA only | Flash Attention 2 or SDPA | +| **Minimum Accelerate** | 1.11.0+ | 1.12.0+ | +| **Minimum DeepSpeed** | N/A | 0.18.1+ | +| **Sequence Divisibility** | `cp_size * 2` | `sp_size` | +| **Zero Stage** | N/A | ZeRO Stage 1/2/3 | + +**Ring Attention is better when:** +- You need to handle extremely long sequences (1M+ tokens) +- The model has limited attention heads (Ring Attention is not constrained by head count) +- You want flexibility in scaling to any sequence length +- Network topology is limited (Ring Attention works with simple P2P ring communication) + +**Ulysses is better when:** +- You have high-bandwidth, low-latency interconnects (NVLink, InfiniBand) +- The model has many attention heads that can be split across GPUs +- You want lower communication volume +- You want faster training speed for moderate sequence lengths (up to ~500k tokens) + +**Key Trade-offs:** +- **Communication Volume:** Ulysses has lower communication volume, making it more efficient with good interconnects. Ring Attention has higher communication volume but is more flexible with different network topologies. +- **Attention Head Constraints:** Ulysses is limited by the number of attention heads (requires `num_heads >= sp_size`). Ring Attention scales with sequence length regardless of model architecture. +- **Network Sensitivity:** Ulysses all-to-all communication is sensitive to network latency. Ring Attention uses P2P ring communication which is more tolerant of varying network conditions. + +For a detailed comparison, see the [Ulysses and Ring Attention blog post](https://huggingface.co/blog/exploding-gradients/ulysses-ring-attention). + +### Ring Attention Implementation (FSDP2) + +Ring Attention uses a ring-like communication pattern where each GPU processes a portion of the sequence and passes information to the next GPU in the ring. + +#### Requirements and Limitations + +1. **Accelerate 1.11.0 or higher** is required for Ring Attention / Context Parallelism support +2. **FSDP2 (PyTorch FSDP v2)** is required as the distributed training backend +3. **SDPA attention** - Flash Attention is currently not supported +4. **Sequence length divisibility** - sequences must be divisible by `cp_size * 2`. This is automatically handled using the `pad_to_multiple_of` parameter in the data collator. + +#### Configuration + +##### Accelerate Configuration + +Use one of the provided accelerate config files (e.g. [`context_parallel_2gpu.yaml`](https://github.com/huggingface/trl/blob/main/examples/accelerate_configs/context_parallel_2gpu.yaml) for 2 GPUs): + +```yaml +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: true # Enable activation checkpointing for memory efficiency + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_version: 2 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 # Number of GPUs +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false +parallelism_config: + parallelism_config_dp_replicate_size: 1 + parallelism_config_dp_shard_size: 1 + parallelism_config_tp_size: 1 + parallelism_config_cp_size: 2 # Context parallel size +``` + +##### Training Configuration + +```python +from trl import SFTConfig + +training_args = SFTConfig( + # required + pad_to_multiple_of=4, # ensures divisibility by cp_size * 2 + # to get the most out of CP + max_length=16384, # long sequence length + packing=True, # use packing to reduce padding + use_liger_kernel=True, # compatible with CP + gradient_checkpointing=False, # The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg can't be set to True simultaneously + per_device_train_batch_size=1, + ... +) +``` + +Then, launch your training script with the appropriate accelerate config file: + +```bash +accelerate launch --config_file context_parallel_2gpu.yaml train.py +``` + +#### Best Practices + +1. **Use the `pad_to_multiple_of` parameter** - This is now the recommended way to ensure sequence length divisibility: + - For `cp_size=2`: use `pad_to_multiple_of=4` (since `cp_size * 2 = 4`) + - For `cp_size=4`: use `pad_to_multiple_of=8` (since `cp_size * 2 = 8`) + - The data collator automatically pads sequences to the required multiple, ensuring compatibility with CP + +2. **Use packing with padding** - The default BFD (Best Fit Decreasing) strategy works perfectly: + - Preserves sequence boundaries and maintains training quality + - Works seamlessly with both `padding_free=True` and standard padding modes + +3. **Combine with other memory optimizations** like Liger kernels, bfloat16, and gradient checkpointing + +4. **Start with smaller context parallel sizes** (2-4 GPUs) before scaling up + +5. **Monitor memory usage** across all GPUs to ensure balanced workload + +#### Benchmarking Ring Attention + +We benchmarked Ring Attention to highlight its potential improvements in training efficiency. +Our experiments were conducted using **1, 2, 4, and 8 H100 GPUs**, though the results can be extended to larger clusters with more nodes and GPUs. + +For the setup, we fine-tuned an **8B model** ([Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B)) using the provided accelerate configuration +([`context_parallel_2gpu.yaml`](https://github.com/huggingface/trl/blob/main/examples/accelerate_configs/context_parallel_2gpu.yaml)). +We adjusted `num_processes` and `parallelism_config_cp_size` based on the number of GPUs for each run. +Training was performed with the [sft.py](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) example script, combined with the parameters described above. + +The results below summarize the **maximum trainable sequence length** and **iterations per second** for different numbers of GPUs. A value marked as `OOM` indicates that the configuration ran out of memory and could not be trained. + +These results show that **Context Parallelism (CP) scales effectively with more GPUs**, enabling training on much longer sequences. With **8 GPUs**, context lengths of over **300k tokens** become feasible, unlocking training with extremely long contexts while maintaining reasonable throughput. + +
+ CP Max content length + CP seconds/iteration +
+ +> [!TIP] +> Accelerate also supports **N-Dimensional Parallelism (ND-parallelism)**, which enables you to combine different parallelization strategies to efficiently distribute model training across multiple GPUs. +> +> You can learn more and explore configuration examples in the [Accelerate ND-parallelism guide](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#nd-parallelism). + +### ALST/Ulysses Implementation (DeepSpeed) + +ALST (Arctic Long Sequence Training) / Ulysses uses attention head parallelism to split long sequences across GPUs, working with DeepSpeed's ZeRO optimizer. + +> [!NOTE] +> **Technical Note on Parallelism Configuration:** +> - **DeepSpeed ALST/Ulysses** uses `sp_size` with `sp_backend="deepspeed"` in both YAML and Python API +> - **Ring Attention (FSDP2)** uses `cp_size` with `cp_backend="torch"` +> +> The Trainer automatically accounts for both CP and SP when calculating effective batch sizes and training metrics. + +#### Requirements and Limitations + +1. **DeepSpeed 0.18.1 or higher** is required +2. **Accelerate 1.12.0 or higher** is required for ALST/Ulysses sequence parallelism support +3. **Attention implementation** - Flash Attention 2 recommended (clean output), SDPA works as fallback +4. **Sequence length divisibility** - sequences must be divisible by `sp_size`. Use `pad_to_multiple_of` in your training config. +5. **Parallelism configuration** - You must ensure `dp_replicate_size × dp_shard_size × sp_size = num_processes` + +#### Configuration + +##### Accelerate Configuration + +Use the provided accelerate config file ([`alst_ulysses_4gpu.yaml`](https://github.com/huggingface/trl/blob/main/examples/accelerate_configs/alst_ulysses_4gpu.yaml)): + +```yaml +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + zero_stage: 3 + seq_parallel_communication_data_type: bf16 +distributed_type: DEEPSPEED +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 # Number of GPUs +parallelism_config: + parallelism_config_dp_replicate_size: 1 + parallelism_config_dp_shard_size: 2 # Enables 2D parallelism with SP + parallelism_config_tp_size: 1 + parallelism_config_sp_size: 2 # Sequence parallel size + parallelism_config_sp_backend: deepspeed + parallelism_config_sp_seq_length_is_variable: true + parallelism_config_sp_attn_implementation: flash_attention_2 +``` + +##### Training Configuration + +```python +from trl import SFTConfig + +training_args = SFTConfig( + # required + pad_to_multiple_of=2, # Must equal sp_size + # to get the most out of SP + max_seq_length=4096, + packing=True, + attn_implementation="flash_attention_2", + per_device_train_batch_size=1, + ... +) +``` + +Then, launch your training script with the appropriate accelerate config file: + +```bash +accelerate launch --config_file examples/accelerate_configs/alst_ulysses_4gpu.yaml train.py +``` + +#### 2D Parallelism + +The 4 GPU configuration above automatically enables 2D parallelism by combining Data Parallelism (DP) with Sequence Parallelism (SP). With `sp_size=2` and `dp_shard_size=2`, the 4 GPUs are organized as: +- 2 sequence parallel groups (processing the same data split across sequences) +- 2 data parallel groups (processing different data) + +To adjust the parallelism for different GPU counts, modify the YAML config: + +| GPUs | sp_size | dp_shard_size | Use Case | YAML Changes | +|------|---------|---------------|----------|--------------| +| 4 | 2 | 2 | Balanced - longer sequences + more data | `num_processes: 4`, `sp_size: 2`, `dp_shard_size: 2` | +| 4 | 4 | 1 | Pure SP for maximum sequence length | `num_processes: 4`, `sp_size: 4`, `dp_shard_size: 1` | +| 8 | 2 | 4 | Large-scale training | `num_processes: 8`, `sp_size: 2`, `dp_shard_size: 4` | + +#### Best Practices + +1. **Use `pad_to_multiple_of`** to ensure sequences are divisible by `sp_size` +2. **Use Flash Attention 2** for clean output (SDPA works but shows packing warnings) +3. **Start with `sp_size=2`** before scaling to larger values +4. **Use DeepSpeed ZeRO Stage 3** for large models +5. **Combine with memory optimizations** like Liger kernels and gradient checkpointing +6. **Validate parallelism config**: Ensure `dp_replicate_size × dp_shard_size × sp_size = num_processes` + +#### Complete Example + +Here's how to run ALST/Ulysses training using the built-in [`sft.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) script with 4 GPUs: + +```bash +accelerate launch --config_file examples/accelerate_configs/alst_ulysses_4gpu.yaml \ + trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --learning_rate 2e-4 \ + --max_steps 100 \ + --max_seq_length 4096 \ + --packing \ + --packing_strategy wrapped \ + --torch_dtype bfloat16 \ + --attn_implementation flash_attention_2 \ + --output_dir output-alst-4gpu \ + --logging_steps 10 \ + --report_to trackio +``` + +This command automatically: +- Configures 2D parallelism (SP=2, DP=2) across 4 GPUs +- Uses Flash Attention 2 for clean training +- Enables packing with automatic padding to ensure sequence divisibility +- Leverages DeepSpeed ZeRO Stage 3 for memory efficiency + +### Further Reading + +#### General Resources +- [Hugging Face Blog: Understanding Ulysses and Ring Attention](https://huggingface.co/blog/exploding-gradients/ulysses-ring-attention) - Detailed comparison of Ring Attention vs Ulysses approaches +- [Accelerate: Context Parallelism Guide](https://huggingface.co/docs/accelerate/concept_guides/context_parallelism) +- [Hugging Face Blog: Enabling Long-Context Training with Sequence Parallelism in Axolotl](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) + +#### Ring Attention (FSDP2) +- [Ultrascale Playbook - Context Parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism) +- [Accelerate Example: 128k Sequence Length](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#context-parallelism-128k-sequence-length) +- [Accelerate ND-parallelism Guide](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#nd-parallelism) + +#### ALST/Ulysses (DeepSpeed) +- [DeepSpeed Sequence Parallelism Documentation](https://www.deepspeed.ai/tutorials/ds-sequence/) +- [Snowflake Engineering Blog: Arctic Long Sequence Training (ALST)](https://www.snowflake.com/en/engineering-blog/arctic-long-sequence-training-multi-million-token-ai/) + +## Multi-Node Training + +When a single machine doesn't have enough GPUs, TRL can scale training across multiple machines (nodes) using [🤗 Accelerate](https://huggingface.co/docs/accelerate/basic_tutorials/launch#multi-node-training). + +### Accelerate Configuration +Create an `accelerate` config file (e.g., `multi_node.yaml`) for multi-node training. Key fields: + +```yaml +compute_environment: LOCAL_MACHINE +distributed_type: MULTI_GPU +num_machines: 2 +machine_rank: 0 # 0 for main node, 1 for second node +main_process_ip: 10.0.0.1 # IP of rank 0 node +main_process_port: 29500 +num_processes: 16 # total processes across nodes +mixed_precision: bf16 +use_cpu: false +same_network: true +``` + +Adjust `num_processes` to match the total number of GPUs across all nodes. + +> [!NOTE] +> Replace `10.0.0.1` with the actual IP address of the rank 0 (main) node. + +### Launching + +#### Option 1: Manual Launch (Non-HPC) + +Run the following on each node manually: +```bash +# Node 0 (main node) +accelerate launch --config_file multi_node.yaml --machine_rank 0 train.py + +# Node 1 +accelerate launch --config_file multi_node.yaml --machine_rank 1 train.py +``` +#### Option 2: SLURM Launch (HPC Clusters) + +For clusters using SLURM job scheduler, create a job script (e.g., `slurm_job.sh`): +```bash +#!/bin/bash +#SBATCH --nodes=2 +#SBATCH --gpus-per-node=8 +#SBATCH --job-name=trl_multi + +srun accelerate launch --config_file multi_node.yaml train.py +``` + +Then submit the job: +```bash +sbatch slurm_job.sh +``` + +SLURM automatically distributes the training across all requested nodes and GPUs, and `srun` configures the necessary environment variables for multi-node communication. + +**Key SLURM directives:** +- `--nodes=2`: Request 2 compute nodes +- `--gpus-per-node=8`: Allocate 8 GPUs per node (16 total) +- `--job-name`: Label for tracking in the job queue + +You can combine multi-node with DeepSpeed by setting `distributed_type: DEEPSPEED` and adding a `deepspeed_config` block. See the [DeepSpeed integration guide](https://huggingface.co/docs/trl/en/deepspeed_integration). + +### Further Reading + +- [Accelerate: Launching Scripts](https://huggingface.co/docs/accelerate/basic_tutorials/launch) +- [Accelerate: Example Zoo](https://huggingface.co/docs/accelerate/usage_guides/training_zoo) +- [SLURM Workload Manager Documentation](https://slurm.schedmd.com/) - For cluster job scheduling + + + diff --git a/ICL/RL/trl_source/docs/source/dpo_trainer.md b/ICL/RL/trl_source/docs/source/dpo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..2d618c7a96b2bf71f7f4888e32e6640acee81d9b --- /dev/null +++ b/ICL/RL/trl_source/docs/source/dpo_trainer.md @@ -0,0 +1,304 @@ +# DPO Trainer + +[![model badge](https://img.shields.io/badge/All_models-DPO-blue)](https://huggingface.co/models?other=dpo,trl) [![model badge](https://img.shields.io/badge/smol_course-Chapter_2-yellow)](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment) + +## Overview + +TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://huggingface.co/papers/2305.18290) by [Rafael Rafailov](https://huggingface.co/rmrafailov), Archit Sharma, Eric Mitchell, [Stefano Ermon](https://huggingface.co/ermonste), [Christopher D. Manning](https://huggingface.co/manning), [Chelsea Finn](https://huggingface.co/cbfinn). + +The abstract from the paper is the following: + +> While large-scale unsupervised language models (LMs) learn broad world knowledge and some reasoning skills, achieving precise control of their behavior is difficult due to the completely unsupervised nature of their training. Existing methods for gaining such steerability collect human labels of the relative quality of model generations and fine-tune the unsupervised LM to align with these preferences, often with reinforcement learning from human feedback (RLHF). However, RLHF is a complex and often unstable procedure, first fitting a reward model that reflects the human preferences, and then fine-tuning the large unsupervised LM using reinforcement learning to maximize this estimated reward without drifting too far from the original model. In this paper we introduce a new parameterization of the reward model in RLHF that enables extraction of the corresponding optimal policy in closed form, allowing us to solve the standard RLHF problem with only a simple classification loss. The resulting algorithm, which we call Direct Preference Optimization (DPO), is stable, performant, and computationally lightweight, eliminating the need for sampling from the LM during fine-tuning or performing significant hyperparameter tuning. Our experiments show that DPO can fine-tune LMs to align with human preferences as well as or better than existing methods. Notably, fine-tuning with DPO exceeds PPO-based RLHF in ability to control sentiment of generations, and matches or improves response quality in summarization and single-turn dialogue while being substantially simpler to implement and train. + +The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm. + +Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppo_trainer): + +1. **Data collection**: Gather a [preference dataset](dataset_formats#preference) with positive and negative selected pairs of generation, given a prompt. +2. **Optimization**: Maximize the log-likelihood of the DPO loss directly. + +This process is illustrated in the sketch below (from [Figure 1 of the DPO paper](https://huggingface.co/papers/2305.18290)): + +![Figure 1 DPO](https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d) + +Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290). + +## Quick start + +This example demonstrates how to train a model using the DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here: + + + +Below is the script to train the model: + +```python +# train_dpo.py +from datasets import load_dataset +from trl import DPOConfig, DPOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO") +trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_dpo.py +``` + +Distributed across 8 GPUs, the training takes approximately 3 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/dpo-qwen2-reward-margin.png) + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models). + +
$ transformers chat trl-lib/Qwen2-0.5B-DPO
+<shirin_yamani>:
+What is Huggingface?
+
+<trl-lib/Qwen2-0.5B-DPO>:
+Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets for the development of machine learning models and applications. It provides a repository of over 300, 000 pre-trained models in  Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets for the development of machine learning models and applications. It provides a repository of over 300, 000  pre-trained models in a variety of languages, enabling users to explore and utilize the latest techniques and technologies in the field of machine learning.
+
+ +## Expected dataset type + +DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section. + +### Special considerations for vision-language models + +The [`DPOTrainer`] supports fine-tuning vision-language models (VLMs). For these models, a vision dataset is required. To learn more about the specific format for vision datasets, refer to the [Vision dataset format](dataset_formats#vision-datasets) section. + +Additionally, unlike standard text-based models where a `tokenizer` is used, for VLMs, you should replace the `tokenizer` with a `processor`. + +```diff +- model = AutoModelForCausalLM.from_pretrained(model_id) ++ model = AutoModelForImageTextToText.from_pretrained(model_id) + +- tokenizer = AutoTokenizer.from_pretrained(model_id) ++ processor = AutoProcessor.from_pretrained(model_id) + + trainer = DPOTrainer( + model, + args=training_args, + train_dataset=train_dataset, +- processing_class=tokenizer, ++ processing_class=processor, +) +``` + +For a complete example of fine-tuning a vision-language model, refer to the script in [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py). + +## Example script + +We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py) + +To test the DPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command: + +```bash +accelerate launch trl/scripts/dpo.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --num_train_epochs 1 \ + --output_dir Qwen2-0.5B-DPO +``` + +## Logged metrics + +While training and evaluating, we record the following reward metrics: + +- `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta +- `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta +- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards +- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards + +## Loss functions + +The DPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`DPOConfig`]. The following loss functions are supported: + +| `loss_type=` | Description | +| --- | --- | +| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. | +| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. | +| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). | +| `"exo_pair"` | The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large. | +| `"nca_pair"` | The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. | +| `"robust"` | The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) | +| `"bco_pair"` | The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. For unpaired data, we recommend the dedicated [`experimental.bco.BCOTrainer`]. | +| `"sppo_hard"` | The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. | +| `"aot"` or `loss_type="aot_unpaired"` | The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_unpaired"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_unpaired"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. | +| `"apo_zero"` or `loss_type="apo_down"` | The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. | +| `"discopop"` | The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). | +| `"sft"` | SFT (Supervised Fine-Tuning) loss is the negative log likelihood loss, used to train the model to generate preferred responses. | + +### Multi-loss combinations + +The DPO trainer supports combining multiple loss functions with different weights, enabling more sophisticated optimization strategies. This is particularly useful for implementing algorithms like MPO (Mixed Preference Optimization). MPO is a training approach that combines multiple optimization objectives, as described in the paper [Enhancing the Reasoning Ability of Multimodal Large Language Models via Mixed Preference Optimization](https://huggingface.co/papers/2411.10442). + +To combine multiple losses, specify the loss types and corresponding weights as lists: + +```python +# MPO: Combines DPO (sigmoid) for preference and BCO (bco_pair) for quality +training_args = DPOConfig( + loss_type=["sigmoid", "bco_pair", "sft"], # Loss types to combine + loss_weights=[0.8, 0.2, 1.0] # Corresponding weights, as used in the MPO paper +) +``` + +If `loss_weights` is not provided, all loss types will have equal weights (1.0 by default). + +### Label smoothing + +The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0). + +### Syncing the reference model + +The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`]. + +### RPO loss + +The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, include `"sft"` in the `loss_type` list in the [`DPOConfig`] and set its weight in `loss_weights`. + +> [!WARNING] +> The old implementation of RPO loss in TRL used the `rpo_alpha` parameter. This parameter is deprecated and will be removed in 0.29.0; instead. + +### WPO loss + +The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`]. + +### LD-DPO loss + +The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient \\( \alpha \\). To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`. + +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config. + +### Rapid Experimentation for DPO + +RapidFire AI is an open-source experimentation engine that sits on top of TRL and lets you launch multiple DPO configurations at once, even on a single GPU. Instead of trying configurations sequentially, RapidFire lets you **see all their learning curves earlier, stop underperforming runs, and clone promising ones with new settings in flight** without restarting. For more information, see [RapidFire AI Integration](rapidfire_integration). + +## Accelerate DPO fine-tuning using `unsloth` + +You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below: + +| GPU | Model | Dataset | 🤗 | 🤗 + FlashAttention 2 | 🦥 Unsloth | 🦥 VRAM saved | +| --- | --- | --- | --- | --- | --- | --- | +| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% | +| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% | + +First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows: + +```diff + from datasets import load_dataset + from trl import DPOConfig, DPOTrainer +- from transformers import AutoModelForCausalLM, AutoTokenizer ++ from unsloth import FastLanguageModel + +- model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") ++ model, tokenizer = FastLanguageModel.from_pretrained("Qwen/Qwen2-0.5B-Instruct") ++ model = FastLanguageModel.get_peft_model(model) + train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +- training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO") ++ training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", bf16=True) + trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) + trainer.train() + +``` + +The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth). + +## Reference model considerations with PEFT + +You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA. + +1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient. +2. Merge the adapter into the base model, create another adapter on top, then leave the `ref_model` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below. +3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls. + +### Downsides to merging QLoRA before DPO (approach 2) + +As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py). + +However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand). + +### Using option 3 - load the adapter twice + +To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in [`DPOTrainer`]. + +For example: + +```python +# Load the base model. +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", +) +model = AutoModelForCausalLM.from_pretrained( + "mistralai/mixtral-8x7b-v0.1", + load_in_4bit=True, + quantization_config=bnb_config, + attn_implementation="kernels-community/flash-attn2", + dtype=torch.bfloat16, + device_map="auto", +) + +# Load the adapter. +model = PeftModel.from_pretrained( + model, + "/path/to/peft", + is_trainable=True, + adapter_name="train", +) +# Load the adapter a second time, with a different name, which will be our reference model. +model.load_adapter("/path/to/peft", adapter_name="reference") + +# Initialize the trainer, without a ref_model param. +training_args = DPOConfig( + model_adapter_name="train", + ref_adapter_name="reference", +) +dpo_trainer = DPOTrainer( + model, + args=training_args, + ... +) +``` + +## DPOTrainer + +[[autodoc]] DPOTrainer + - train + - save_model + - push_to_hub + +## DPOConfig + +[[autodoc]] DPOConfig + +## DataCollatorForPreference + +[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference + diff --git a/ICL/RL/trl_source/docs/source/example_overview.md b/ICL/RL/trl_source/docs/source/example_overview.md new file mode 100644 index 0000000000000000000000000000000000000000..b78db67020e6968902fa05a4a27413d8c8aae724 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/example_overview.md @@ -0,0 +1,101 @@ +# Examples + +This directory contains a collection of examples that demonstrate how to use the TRL library for various applications. We provide both **scripts** for advanced use cases and **notebooks** for an easy start and interactive experimentation. + +The notebooks are self-contained and can run on **free Colab**, while the scripts can run on **single GPU, multi-GPU, or DeepSpeed** setups. + +**Getting Started** + +Install TRL and additional dependencies as follows: + +```bash +pip install --upgrade trl[quantization] +``` + +Check for additional optional dependencies [here](https://github.com/huggingface/trl/blob/main/pyproject.toml). + +For scripts, you will also need an 🤗 Accelerate config (recommended for multi-gpu settings): + +```bash +accelerate config # will prompt you to define the training configuration +``` + +This allows you to run scripts with `accelerate launch` in single or multi-GPU settings. + +## Notebooks + +These notebooks are easier to run and are designed for quick experimentation with TRL. The list of notebooks can be found in the [`trl/examples/notebooks/`](https://github.com/huggingface/trl/tree/main/examples/notebooks/) directory. + + +| Notebook | Description | Open in Colab | +|----------|-------------|---------------| +| [`grpo_trl_lora_qlora.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_trl_lora_qlora.ipynb) | GRPO using QLoRA on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_trl_lora_qlora.ipynb) | +| [`grpo_functiongemma_browsergym_openenv.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb) | GRPO on FunctionGemma in the BrowserGym environment | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb) | +| [`grpo_agent.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_agent.ipynb) | GRPO for agent training | Not available due to OOM with Colab GPUs | +| [`grpo_rnj_1_instruct.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_rnj_1_instruct.ipynb) | GRPO rnj-1-instruct with QLoRA using TRL on Colab to add reasoning capabilities | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_rnj_1_instruct.ipynb) | +| [`sft_ministral3_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_ministral3_vl.ipynb) | Supervised Fine-Tuning (SFT) Ministral 3 with QLoRA using TRL on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_ministral3_vl.ipynb) | +| [`grpo_ministral3_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_ministral3_vl.ipynb) | GRPO Ministral 3 with QLoRA using TRL on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_ministral3_vl.ipynb) | +| [`openenv_sudoku_grpo.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/openenv_sudoku_grpo.ipynb) | GRPO to play Sudoku on an OpenEnv environment | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_sudoku_grpo.ipynb) | +| [`openenv_wordle_grpo.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/openenv_wordle_grpo.ipynb) | GRPO to play Worldle on an OpenEnv environment | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_wordle_grpo.ipynb) | +| [`sft_trl_lora_qlora.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_trl_lora_qlora.ipynb) | Supervised Fine-Tuning (SFT) using QLoRA on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb) | +| [`sft_qwen_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_qwen_vl.ipynb) | Supervised Fine-Tuning (SFT) Qwen3-VL with QLoRA using TRL on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_qwen_vl.ipynb) | +| [`grpo_qwen3_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_qwen3_vl.ipynb) | GRPO Qwen3-VL with QLoRA using TRL on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_qwen3_vl.ipynb) | + +## Scripts + +Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) and [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directories. They show how to use different trainers such as `SFTTrainer`, `PPOTrainer`, `DPOTrainer`, `GRPOTrainer`, and more. + +| File | Description | +| --- | --- | +| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`experimental.kto.KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty, and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. | +| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`experimental.cpo.CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | +| [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a model. | +| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. | +| [`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py) | This script shows how to use [`experimental.judges.HfPairwiseJudge`] or [`experimental.judges.OpenAIPairwiseJudge`] to judge model generations. | +| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`experimental.gkd.GKDTrainer`] to fine-tune a model. | +| [`trl/scripts/grpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a model. | +| [`trl/scripts/grpo_agent.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/grpo_agent.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a model to enable agentic usage. | +| [`examples/scripts/grpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. | +| [`examples/scripts/gspo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune model for reasoning using the [AI-MO/NuminaMath-TIR](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. | +| [`examples/scripts/gspo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo_vlm.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. | +| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`experimental.kto.KTOTrainer`] to fine-tune a model. | +| [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. | +| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`experimental.nash_md.NashMDTrainer`] to fine-tune a model. | +| [`examples/scripts/nemo_gym/train_multi_environment.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_environment.py) | This script shows how to use the [`GRPOTrainer`] to train language models in NVIDIA NeMo-Gym environments. Supports multi-turn and tool calling environments, and multi-environment training. See the [NeMo-Gym Integration](nemo_gym) guide for setup and usage. | +| [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a model. | +| [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a a Vision Language Model. | +| [`examples/scripts/openenv/browsergym.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's BrowserGym environment and vLLM for VLMs | +| [`examples/scripts/openenv/browsergym_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym_llm.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's BrowserGym environment and vLLM for LLMs | +| [`examples/scripts/openenv/catch.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/catch.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Catch environment (OpenSpiel) and vLLM | +| [`examples/scripts/openenv/echo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Echo environment and vLLM. | +| [`examples/scripts/openenv/sudoku.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/sudoku.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Sudoku environment and vLLM. | +| [`examples/scripts/openenv/wordle.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/wordle.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Wordle environment and vLLM. | +| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`experimental.orpo.ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | +| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. | +| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. | +| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`experimental.prm.PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). | +| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train an Outcome Reward Model (ORM) on your own dataset. | +| [`examples/scripts/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to solve math questions. | +| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model. | +| [`examples/scripts/sft_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model. | +| [`examples/scripts/sft_video_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_video_llm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Video Language Model. | +| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models, so users may see unexpected behaviour in other model architectures. | +| [`examples/scripts/sft_vlm_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model on vision to text tasks. | +| [`examples/scripts/sft_vlm_smol_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a SmolVLM model. | +| [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`experimental.xpo.XPOTrainer`] to fine-tune a model. | + +## Distributed Training (for scripts) + +You can run scripts on multiple GPUs with 🤗 Accelerate: + +```shell +accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script +``` + +For DeepSpeed ZeRO-{1,2,3}: + +```shell +accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script +``` + +Adjust `NUM_GPUS` and `--all_arguments_of_the_script` as needed. diff --git a/ICL/RL/trl_source/docs/source/experimental_overview.md b/ICL/RL/trl_source/docs/source/experimental_overview.md new file mode 100644 index 0000000000000000000000000000000000000000..af0bf3e3cf23e14f1042629ad54c875a57175344 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/experimental_overview.md @@ -0,0 +1,31 @@ +# Experimental + +This directory contains a minimal, clearly separated space for fast iteration on new ideas. + +> [!WARNING] +> **Stability contract:** Anything under `trl.experimental` may change or be removed in *any* release (including patch versions) without prior deprecation. Do not rely on these APIs for production workloads. + +## Promotion Path (Simple) + +1. **Prototype outside the main repo:** Start development in your own fork or a separate repository to iterate quickly. +2. **Experimental inclusion:** Once it’s ready for early users, move the idea into `trl.experimental.`. +3. **Improve:** Add tests, a short doc/example, and demonstrate the usage. +4. **Promote:** Once the API proves stable and there is clear interest or adoption from the community, move it into `trl.` (stable module). + +## FAQ + +**Why not just use branches?** +Because branches are not shipped to users; experimental code inside the package lets early adopters try things and give feedback. + +**Can these APIs change or vanish without warning?** +Yes. Anything inside `trl.experimental` can change or disappear in *any* release. + +**Should I use this in production?** +Only if you are fine with updating your code quickly when things change. + +**Will maintainers promptly fix issues in `trl.experimental`?** +Not necessarily. The experimental module is a playground for new ideas, and maintainers may not prioritize bug fixes or feature requests there. Issues may remain unresolved until (or unless) the feature graduates to the stable API. + +**How to silence the runtime notice?** + +Use: `export TRL_EXPERIMENTAL_SILENCE=1`. diff --git a/ICL/RL/trl_source/docs/source/gfpo.md b/ICL/RL/trl_source/docs/source/gfpo.md new file mode 100644 index 0000000000000000000000000000000000000000..fac77c9d1f092d6ef17af7262951d5d29799446b --- /dev/null +++ b/ICL/RL/trl_source/docs/source/gfpo.md @@ -0,0 +1,50 @@ +# GFPO + +This feature implements the GFPO algorithm to enforce concise reasoning in the model's output generation, as proposed in the paper [Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning](https://huggingface.co/papers/2508.09726). + +## Usage + +To activate GFPO in [`GFPOTrainer`]: + +- set `num_remains_in_group` in [`GFPOConfig`] +- define a group filter function and set it to `group_filter_func` in [`GFPOTrainer`]. `group_filter_func` will score the `num_generations` completions and The GFPOTrainer filters groups according to their scores to get top `num_remains_in_group` completions as a new group. Model will be trained on the filtered group. + +```python +# train_gfpo.py +from trl.experimental.gfpo import GFPOConfig, GFPOTrainer + +# dummy group filter to scores the completions based on its indice in group +class GroupFilter: + def __call__(self, group_completions, group_rewards, **kwargs): + group_scores = [] + for completions, rewards in zip(group_completions, group_rewards): + scores = [float(i) for i in range(len(completions))] + group_scores.append(scores) + return group_scores + +training_args = GFPOConfig( + output_dir="Qwen3-0.6B-GFPO", + per_device_train_batch_size=4, + num_remains_in_group=2, + bf16=True, +) +trainer = GFPOTrainer( + model="Qwen/Qwen3-0.6B", + reward_funcs=..., + train_dataset=..., + args=training_args, + group_filter_func=GroupFilter(), +) +trainer.train() +``` + +## GFPOTrainer + +[[autodoc]] experimental.gfpo.GFPOTrainer + - train + - save_model + - push_to_hub + +## GFPOConfig + +[[autodoc]] experimental.gfpo.GFPOConfig diff --git a/ICL/RL/trl_source/docs/source/gkd_trainer.md b/ICL/RL/trl_source/docs/source/gkd_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..b703a1712b900cb3ba1003bf3863261b3526523b --- /dev/null +++ b/ICL/RL/trl_source/docs/source/gkd_trainer.md @@ -0,0 +1,99 @@ +# Generalized Knowledge Distillation Trainer + +[![model badge](https://img.shields.io/badge/All_models-GKD-blue)](https://huggingface.co/models?other=gkd,trl) + +## Overview + +Generalized Knowledge Distillation (GKD) was proposed in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem. + +The abstract from the paper is the following: + +> Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning. + +The key aspects of GKD are: + +1. It addresses the train-inference distribution mismatch in auto-regressive sequence models by training the student model on its self-generated output sequences. +2. GKD allows flexibility in choosing different divergence measures between student and teacher models via the generalized Jensen-Shannon Divergence (JSD), which can be useful when the student lacks the capacity to fully mimic the teacher. + +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Lewis Tunstall](https://huggingface.co/lewtun). + +## Usage tips + +The [`experimental.gkd.GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`experimental.gkd.GKDConfig`] namely: + +* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch. +* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher. +* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two. + +The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method. + +> [!WARNING] +> Make sure that `attn_implementation="kernels-community/flash-attn2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture. + +The basic API is as follows: + +```python +from datasets import Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl.experimental.gkd import GKDConfig, GKDTrainer + +NUM_DUMMY_SAMPLES = 100 + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +# The model to optimise +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +# The teacher model to calculate the KL divergence against +teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct") + +train_dataset = Dataset.from_dict( + { + "messages": [ + [ + {"role": "user", "content": "Hi, how are you?"}, + {"role": "assistant", "content": "I'm great thanks"}, + ] + ] + * NUM_DUMMY_SAMPLES + } +) +eval_dataset = Dataset.from_dict( + { + "messages": [ + [ + {"role": "user", "content": "What colour is the sky?"}, + {"role": "assistant", "content": "The sky is blue"}, + ] + ] + * NUM_DUMMY_SAMPLES + } +) + +training_args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1) +trainer = GKDTrainer( + model=model, + teacher_model=teacher_model, + args=training_args, + processing_class=tokenizer, + train_dataset=train_dataset, + eval_dataset=eval_dataset, +) +trainer.train() +``` + +### Expected dataset type + +The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys: + +* `role`: either `system`, `assistant` or `user` +* `content`: the message content + +## GKDTrainer + +[[autodoc]] experimental.gkd.GKDTrainer + - train + - save_model + - push_to_hub + +## GKDConfig + +[[autodoc]] experimental.gkd.GKDConfig diff --git a/ICL/RL/trl_source/docs/source/gold_trainer.md b/ICL/RL/trl_source/docs/source/gold_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..dbe7e7b01e64611bec34840ad6528c87c6d66db8 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/gold_trainer.md @@ -0,0 +1,165 @@ +# General Online Logit Distillation (GOLD) Trainer + +[![All_models-GOLD-blue](https://img.shields.io/badge/All_models-GOLD-blue)](https://huggingface.co/models?other=sft,gold) + +## Overview + +General Online Logit Distillation (GOLD) is an extension of Universal Logit Distillation (ULD) that supports +student/teacher pairs with different tokenizers. It aligns the textual spans produced by both tokenizers and merges the +associated logits so no completion tokens are dropped. This enables cross-tokenizer knowledge distillation, including +mixed model families (for example, LLaMA students with Qwen teachers). + +Key capabilities: + +1. **Cross-tokenizer alignment** – GOLD incrementally decodes the student and teacher tokens, groups passages with the same visible text, and merges probabilities inside each group. This guarantees loss terms are computed over the full completion even when token boundaries differ. +2. **Hybrid ULD loss** – when `uld_use_hybrid_loss` is enabled, GOLD compares exact vocabulary matches directly and falls back to the original sorted-probability ULD loss for unmatched tokens. This improves stability for students whose vocabularies only partially overlap with the teacher. +3. **Seamless integration with GKD** – GOLD inherits the on-policy vs. off-policy scheduling from the [`experimental.gkd.GKDTrainer`], so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run. + +> [!NOTE] +> GOLD is currently part of the `trl.experimental` namespace. APIs may change without notice while the feature is iterated on. + +## Usage tips + +The [`GOLDTrainer`] subclasses [`SFTTrainer`] and accepts the same datasets as other TRL trainers (lists of ChatML style +messages). Important configuration flags on [`GOLDConfig`] include: + +* `use_uld_loss` – toggles Universal Logit Distillation. Set this to `True` for cross-tokenizer setups. +* `teacher_tokenizer_name_or_path` – required when `use_uld_loss=True`; GOLD uses the teacher tokenizer to align tokens. +* `uld_use_hybrid_loss`, `uld_hybrid_matched_weight`, `uld_hybrid_unmatched_weight` – enables and weights the hybrid + matched/unmatched loss. +* `beta`, `lmbda`, `seq_kd` – inherited from [`experimental.gkd.GKDConfig`], controlling the generalized JSD interpolation and on-policy + sampling ratio. + +A minimal end-to-end example: + +```python +from datasets import load_dataset +from trl.experimental.gold import GOLDConfig, GOLDTrainer + +train_dataset = load_dataset( + "HuggingFaceTB/OpenR1-Math-220k-default-verified", + "all", + split="train[:1024]", +) + +trainer = GOLDTrainer( + model="meta-llama/Llama-3.2-1B-Instruct", + teacher_model="Qwen/Qwen2.5-0.5B-Instruct", + args=GOLDConfig(output_dir="gold-model", use_uld_loss=True, teacher_tokenizer_name_or_path="Qwen/Qwen2.5-0.5B-Instruct"), + train_dataset=train_dataset, +) +trainer.train() +``` + +For quick-start workflows you can rely on string identifiers as shown above—the trainer will load the model and tokenizer for you. Explicitly instantiating `AutoModelForCausalLM`, `AutoTokenizer`, or populating `GOLDConfig` is recommended only for advanced use cases where you need fine-grained control over initialization. + +A more explicit setup might look like this when you need to customise model loading, tokenizer settings, or training arguments: + +```python +from datasets import load_dataset +from trl import GOLDConfig, GOLDTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +student_name = "meta-llama/Llama-3.2-1B-Instruct" +teacher_name = "Qwen/Qwen2.5-0.5B-Instruct" + +tokenizer = AutoTokenizer.from_pretrained(student_name) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +model = AutoModelForCausalLM.from_pretrained(student_name) +teacher_model = AutoModelForCausalLM.from_pretrained(teacher_name) + +train_dataset = load_dataset( + "HuggingFaceTB/Countdown-Task-GOLD", + "verified_Qwen2.5-0.5B-Instruct", + split="train", +) + +training_args = GOLDConfig( + output_dir="gold-model", + per_device_train_batch_size=1, + teacher_model=teacher_name, + teacher_tokenizer_name_or_path=teacher_name, + use_uld_loss=True, + uld_use_hybrid_loss=True, +) + +trainer = GOLDTrainer( + model=model, + teacher_model=teacher_model, + args=training_args, + processing_class=tokenizer, + train_dataset=train_dataset, +) +trainer.train() +``` + +### Expected dataset type + +GOLD requires a [conversational](dataset_formats#conversational) [language modeling](dataset_formats#language_modeling) dataset, e.g.: + +```python +{"messages": [{"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}]} +``` + +`GOLDTrainer` keeps the raw messages so the ChatML collator can construct prompts and completions with the correct +boundaries. + +## How Token Merging Works + +When student and teacher use different tokenizers, the same text may be split differently: + +- **Student**: `"Hugging Face"` → 1 token +- **Teacher**: `"Hugging"`, `" Face"` → 2 tokens + +GOLD aligns these sequences and merges the teacher's multi-token probabilities into a single distribution that can be compared with the student's single-token distribution. + +### Probability Merging + +For a teacher sequence of tokens `[token₀, token₁, ..., tokenₖ]` that maps to a single student token, GOLD computes: + +``` +P_merged(y) = P(y | context) × P(token₁ | token₀, context) × ... × P(tokenₖ | ..., context) +``` + +where: +- `P(y | context)` is the marginal probability distribution over all vocabulary tokens at the first position +- `P(tokenᵢ | ..., context)` are **scalar** conditional probabilities of the actual tokens that were generated + +**Key insight**: Only the conditional probabilities of the **actual continuation tokens** are extracted as scalars. The full marginal distribution at the first position is then scaled by multiplying these scalar probabilities. + +This ensures: +1. **Correct joint probability** for the actual generated sequence (by the chain rule) +2. **Reasonable approximation** for counterfactual tokens (scaled by the same continuation likelihood) +3. **Unnormalized distributions** that preserve the correct relative probabilities for ULD loss computation + +### Example + +Given: +``` +P(x₀): ["HF": 0.6, "is": 0.3, "cool": 0.1] +P(x₁ | "HF"): ["HF": 0.05, "is": 0.9, "cool": 0.05] +``` + +If tokens 0 and 1 are merged, and the actual sequence was `["HF", "is"]`: +``` +P_merged("HF") = 0.6 × 0.9 = 0.54 ✓ (correct joint probability) +P_merged("is") = 0.3 × 0.9 = 0.27 +P_merged("cool") = 0.1 × 0.9 = 0.09 +``` + +The merged distribution is unnormalized (sums to 0.81), but this is intentional and correct for ULD loss computation, which uses sorting and L1 distance. + +## GOLDTrainer + +[[autodoc]] experimental.gold.GOLDTrainer + - train + - generate_on_policy_outputs + - save_model + - push_to_hub + +## GOLDConfig + +[[autodoc]] experimental.gold.GOLDConfig diff --git a/ICL/RL/trl_source/docs/source/grpo_trainer.md b/ICL/RL/trl_source/docs/source/grpo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..8216b34f2ef81ab04bde49e2bf1fcaf9feab948a --- /dev/null +++ b/ICL/RL/trl_source/docs/source/grpo_trainer.md @@ -0,0 +1,730 @@ +# GRPO Trainer + +[![model badge](https://img.shields.io/badge/All_models-GRPO-blue)](https://huggingface.co/models?other=grpo,trl) + +## Overview + +TRL supports the GRPO Trainer for training language models, as described in the paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300) by [Zhihong Shao](https://huggingface.co/syhia), [Peiyi Wang](https://huggingface.co/peiyiwang89), [Qihao Zhu](https://huggingface.co/zqh11), Runxin Xu, [Junxiao Song](https://huggingface.co/haha-point), Mingchuan Zhang, Y. K. Li, Y. Wu, [Daya Guo](https://huggingface.co/guoday). + +The abstract from the paper is the following: + +> Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO. + +This post-training method was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec). + +## Quick start + +This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [DeepMath-103K dataset](https://huggingface.co/datasets/trl-lib/DeepMath-103K). You can view the data in the dataset here: + + + +Below is the script to train the model. + +```python +# train_grpo.py +from datasets import load_dataset +from trl import GRPOTrainer +from trl.rewards import accuracy_reward + +dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + +trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=accuracy_reward, + train_dataset=dataset, +) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_grpo.py +``` + +Distributed across 8 GPUs, the training takes approximately 1 day. + +![GRPO curves](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_curves.png) + +## Looking deeper into the GRPO method + +GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**. + +![GRPO visual](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_visual.png) + +### Generating completions + +At each training step, we sample a batch of prompts and generate a set of \\( G \\) completions for each prompt (denoted as \\( o_i \\)). + +### Computing the advantage + +For each of the \\( G \\) sequences, we compute the reward using a reward model or reward function. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows: + +$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$ + +This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**. + +> [!TIP] +> It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`]. +> Note that turning off std-based scaling also removes variance normalization, so update magnitudes depend directly on the raw reward scale and batch composition. + +> [!TIP] +> As shown in [Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)](https://huggingface.co/papers/2508.08221), calculating the mean at the local (group) level and the standard deviation at the global (batch) level enables more robust reward shaping. You can use this scaling strategy by setting `scale_rewards="batch"` in [`GRPOConfig`]. + +### Estimating the KL divergence + +KL divergence is estimated using the approximator introduced by [Schulman et al. (2020)](http://joschu.net/blog/kl-approx.html). The approximator is defined as follows: + +$$\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i, [!TIP] +> Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that this introduces a response-level length bias. More details in [loss types](#loss-types). + +> [!TIP] +> Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we use \\( \beta = 0.0 \\) by default, meaning that the KL divergence term is not used. This choice is motivated by several recent studies (e.g., [Open-Reasoner-Zero: An Open Source Approach to Scaling Up Reinforcement Learning on the Base Model](https://huggingface.co/papers/2503.24290)) which have shown that the KL divergence term is not essential for training with GRPO. As a result, it has become common practice to exclude it (e.g. [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783), [DAPO: An Open-Source LLM Reinforcement Learning System at Scale](https://huggingface.co/papers/2503.14476)). If you wish to include the KL divergence term, you can set `beta` in [`GRPOConfig`] to a non-zero value. + +In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**: + +$$ +\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right], +$$ + +where \\(\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) \\) ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between \\( 1 - \epsilon \\) and \\( 1 + \epsilon \\). +When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifies to the original objective. + +#### Loss Types + +Several formulations of the objective have been proposed in the literature. Initially, the objective of GRPO was defined as follows: + +$$ +\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} l_{i,t}, +$$ + +where + +$$ +l_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]. +$$ + +The [DAPO paper](https://huggingface.co/papers/2503.14476) highlights the limitations of the GRPO algorithm’s sample-level loss in long-CoT scenarios, where longer responses are under-penalized, leading to poorer quality outputs. The proposed solution is a token-level normalization, which better handles longer sequences by assigning more balanced rewards to individual tokens, regardless of response length: + +$$ +\mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t}, +$$ + +To use this formulation, set `loss_type="dapo"` in [`GRPOConfig`]. + +Furthermore, it was demonstrated in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that the initial GRPO formulation introduces a response length bias. They show that while the DAPO formulation reduces this bias, it does not eliminate it completely. To fully remove this bias, they propose dividing by a constant instead of the sequence length, resulting in the following formulation: + +$$ +\mathcal{L}_{\text{Dr. GRPO}}(\theta) = - \frac{1}{LG} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t}, +$$ + +This constant is recommended to be the maximum completion length. To use this formulation, set `loss_type="dr_grpo"` in the [`GRPOConfig`]. + +Alternatively, in the [SAPO paper](https://huggingface.co/papers/2511.20347), the Qwen team proposes replacing the "hard" clipping mechanism of GRPO with a smooth, temperature-controlled soft gating mechanism. While GRPO zeroes out gradients when the policy deviates too far from the reference, SAPO uses a soft trust region that smoothly decays the gradient weight. This allows the model to retain useful learning signals from "near-on-policy" tokens while suppressing noise from extreme deviations. + +The loss function is defined as: + +$$ +\mathcal{L}_{\text{SAPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} f_{i,t} \left( \frac{\pi_\theta(o_{i,t} | q, o_{i, 0 \\ +\tau_{\text{neg}}, & \text{otherwise} +\end{cases} +$$ + +They recommend using asymmetric temperatures, \\( \tau_{\text{neg}} > \tau_{\text{pos}} \\) (defaults are \\( \tau_{\text{pos}}=1.0, \tau_{\text{neg}}=1.05 \\) ). This ensures that the model is penalized more strictly for "bad" actions to prevent instability, while being more permissive with "good" actions. + +To use this formulation, set `loss_type="sapo"` in the [`GRPOConfig`]. + +## Logged metrics + +While training and evaluating, we record the following reward metrics: + +- `num_tokens`: The total number of tokens processed so far, including both prompts and completions. When using tools, only non-tool tokens are counted. +- `step_time`: The average time (in seconds) taken per training step (including generation). +- `completions/mean_length`: The average length of generated completions. When using tools, only non-tool tokens are counted. +- `completions/min_length`: The minimum length of generated completions. When using tools, only non-tool tokens are counted. +- `completions/max_length`: The maximum length of generated completions. When using tools, only non-tool tokens are counted. +- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS. When using tools, only non-tool tokens are counted. +- `completions/min_terminated_length`: The minimum length of generated completions that terminate with EOS. When using tools, only non-tool tokens are counted. +- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS. When using tools, only non-tool tokens are counted. +- `completions/clipped_ratio`: The ratio of truncated (clipped) completions. +- `reward/{reward_func_name}/mean`: The average reward from a specific reward function. +- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function. +- `reward`: The overall average reward after summing rewards across functions (unweighted). +- `reward_std`: The standard deviation of summed rewards across functions (unweighted), computed over the full batch. +- `frac_reward_zero_std`: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect). +- `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.) +- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero. +- `clip_ratio/region_mean`: The ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities where the GRPO objective is clipped to stay within the trust region: \\( \text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \quad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \\). A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change. +- `clip_ratio/low_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\). +- `clip_ratio/low_min`: The minimum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\). +- `clip_ratio/high_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\). +- `clip_ratio/high_max`: The maximum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\). + +## Customization + +### Speed up training with vLLM-powered generation + +Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with + +```shell +pip install trl[vllm] +``` + +We support two ways of using vLLM during training: **server mode** and **colocate mode**. + +> [!TIP] +> By default, Truncated Importance Sampling is activated for vLLM generation to address the generation-training mismatch that occurs when using different frameworks. This can be turned off by setting `vllm_importance_sampling_correction=False`. For more information, see [Truncated Importance Sampling](paper_index#truncated-importance-sampling) + +#### 🔌 Option 1: Server mode + +In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference. + +1. **Start the vLLM server**: + + ```bash + trl vllm-serve --model + ``` + +2. **Enable server mode in your training script**: + + ```python + from trl import GRPOConfig + + training_args = GRPOConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted + ) + ``` + +> [!WARNING] +> Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable. + +#### 🧩 Option 2: Colocate mode + +In this mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + +> [!TIP] +> Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`GRPOConfig`] to avoid underutilization or out-of-memory errors. +> +> We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-memory) to help estimate the recommended GPU memory utilization based on your model configuration and experiment settings. Simply use it as follows to get `vllm_gpu_memory_utilization` recommendation: +> +> +> +> If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability. +> +> If you still find you are getting out-of-memory errors set `vllm_enable_sleep_mode` to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see [Reducing Memory Usage with vLLM Sleep Mode](reducing_memory_usage#vllm-sleep-mode). + +> [!TIP] +> By default, GRPO uses `MASTER_ADDR=localhost` and `MASTER_PORT=12345` for vLLM, but you can override these values by setting the environment variables accordingly. + +For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods). + + +#### Dealing with the Training-Inference Mismatch +While vLLM greatly accelerates inference, it also decouples the inference engine from the training engine. In theory these engines are mathematically identical, in practice however they can produce different outputs due to precision effects and hardware specific optimizations. This divergence reflects the different optimization objectives of the two systems. This divergence reflects the distinct optimization goals of the two systems. Inference engines aim to maximize sampling throughput, typically measured in tokens per second, while maintaining acceptable sampling fidelity. Training frameworks instead focus on numerical stability and precision for gradient computation, often using higher precision formats like FP32 for master weights and optimizer states. These differing priorities and constraints introduce an inevitable, albeit subtle, mismatch between training and inference. + +This mismatch leads to a biased gradient update which has been observed to destabilize training ([[1]](https://fengyao.notion.site/off-policy-rl)[[2]](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda)[[3]](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/#true-on-policy-rl)[[4]](https://huggingface.co/papers/2510.26788)[[5]](https://huggingface.co/papers/2510.18855)). For simplicity, consider the REINFORCE policy gradient: + +$$ +\nabla_\theta \mathcal{J}(x,\theta) += \mathbb{E}_{y \sim \pi^\text{train}(\cdot \mid x,\theta)} +\left[ \nabla_\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right] +$$ + +Here \\( x \\) denotes prompts sampled from some data distribution, and \\( \pi^\text{train} \\) is the policy implemented by the training engine. With vLLM in the loop we obtain a separate inference policy \\( \pi^\text{inference} \\), so the effective policy gradient becomes + +$$ +\nabla_\theta \mathcal{J}_{\text{biased}}(x,\theta) += \mathbb{E}_{y \sim \pi^\text{inference}(\cdot \mid x,\theta)} +\left[ \nabla_\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right]. +$$ + +This turns an otherwise on policy RL problem into an off policy one. + +The standard way to correct for this distribution shift is **importance sampling (IS)**. We provide two IS variants: [Truncated Importance Sampling (TIS)](paper_index#truncated-importance-sampling) and [Masked Importance Sampling (MIS)](paper_index#masked-importance-sampling). Both variants can be applied either at the token level or at the sequence level.Let \\( \rho \\) denote the importance weight, for example \\( \rho_t \\) per token or \\( \rho_{\text{seq}} \\) per sequence. Under TIS, ratios larger than `vllm_importance_sampling_cap` are clipped, + +$$ +\rho \leftarrow \min(\rho, C). +$$ + +Under MIS, ratios larger than `vllm_importance_sampling_cap` are set to zero, so those samples do not contribute to the gradient. In other words, large ratio samples are downweighted under TIS and discarded under MIS. The configuration flag `vllm_importance_sampling_mode` chooses both the IS variant (masking or truncation) and the granularity (token level or sequence level). + +Importance sampling is the principled algorithmic response to the training–inference mismatch. However, there are also more direct approaches that attempt to reduce the mismatch between the two engines themselves. Most of these are engineering solutions. For example, [MiniMax M1 uses an FP32 language model head](https://huggingface.co/papers/2506.13585) in the inference engine. Thinking Machines has explored [deterministic inference kernels](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/), although this comes with a significant efficiency cost. vLLM has shown [bitwise consistent policies](https://blog.vllm.ai/2025/11/10/bitwise-consistent-train-inference.html) by building on the batch invariant deterministic kernels from Thinking Machines, but as of November 2025 there remains a substantial throughput penalty relative to standard vLLM inference. + +### GRPO at scale: train a 70B+ Model on multiple nodes + +When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include: + +- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such models. For more details, see [DeepSpeed Integration](deepspeed_integration). +- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training). +- **vLLM**: See the previous section on how to use vLLM to speed up generation. + +Below is an example SLURM script to train a 70B model with GRPO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation. + +```sh +#!/bin/bash +#SBATCH --nodes=5 +#SBATCH --gres=gpu:8 + +# Get the list of allocated nodes +NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) + +# Assign the first 4 nodes for training and the 5th node for vLLM +TRAIN_NODES="${NODELIST[@]:0:4}" # Nodes 0, 1, 2, 3 for training +VLLM_NODE="${NODELIST[4]}" # Node 4 for vLLM + +# Run training on the first 4 nodes (Group 1) +srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \ + --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ + --num_processes 32 \ + --num_machines 4 \ + --main_process_ip ${NODELIST[0]} \ + --machine_rank $SLURM_PROCID \ + --rdzv_backend c10d \ + train_grpo.py \ + --server_ip $VLLM_NODE & + +# Run vLLM server on the 5th node (Group 2) +srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 & + +wait +``` + +```python +import argparse + +from datasets import load_dataset +from trl import GRPOTrainer, GRPOConfig +from trl.rewards import accuracy_reward + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP") + args = parser.parse_args() + + dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + + training_args = GRPOConfig( + per_device_train_batch_size=4, + use_vllm=True, + vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."), # from ip-X-X-X-X to X.X.X.X + ) + + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-72B", + args=training_args, + reward_funcs=accuracy_reward, + train_dataset=dataset + ) + trainer.train() + +if __name__=="__main__": + main() +``` + +### Using a custom reward function + +The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements: + +Reward functions can be either synchronous Python callables or asynchronous `async def` coroutines. When you provide multiple asynchronous reward functions, they are awaited concurrently (run in parallel via `asyncio.gather`) so their latency overlaps. + +1. **Input arguments**: + - The function must accept the following as keyword arguments: + - `prompts` (contains the prompts), + - `completions` (contains the generated completions), + - `completion_ids` (contains the tokenized completions), + - `trainer_state` ([`~transformers.TrainerState`]): The current state of the trainer. This can be used to implement dynamic reward functions, such as curriculum learning, where the reward is adjusted based on the training progress. + - All column names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument. + + The easiest way to comply with this requirement is to use `**kwargs` in the function signature. + - Depending on the dataset format, the input will vary: + - For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings. + - For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries. + +2. **Return value**: The function must return a list of floats. Each float represents the reward corresponding to a single completion. + +#### Example 1: Reward longer completions + +Below is an example of a reward function for a standard format that rewards longer completions: + +```python +def reward_func(completion_ids, **kwargs): + """Reward function that assigns higher scores to longer completions (in terms of token count).""" + return [float(len(ids)) for ids in completion_ids] +``` + +You can test it as follows: + +```python +>>> prompts = ["The sky is", "The sun is"] # not used in the reward function, but the trainer will pass it +>>> completions = [" blue.", " in the sky."] # not used in the reward function, but the trainer will pass it +>>> completion_ids = [[6303, 13], [304, 279, 12884, 13]] +>>> reward_func(prompts=prompts, completions=completions, completion_ids=completion_ids) +[2.0, 4.0] +``` + +#### Example 1.1: Reward longer completions (based on the number of characters) + +Same as the previous example, but this time the reward function is based on the number of characters instead of tokens. + +```python +def reward_func(completions, **kwargs): + """Reward function that assigns higher scores to longer completions (in terms of character count).""" + return [float(len(completion)) for completion in completions] +``` + +You can test it as follows: + +```python +>>> prompts = ["The sky is", "The sun is"] +>>> completions = [" blue.", " in the sky."] +>>> completion_ids = [[6303, 13], [304, 279, 12884, 13]] # not used in the reward function, but the trainer will pass it +>>> reward_func(prompts=prompts, completions=completions, completion_ids=completion_ids) +[6.0, 12.0] +``` + +#### Example 2: Reward completions with a specific format + +Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). +It is designed for a conversational format, where prompts and completions consist of structured messages. + +```python +import re + +def format_reward_func(completions, **kwargs): + """Reward function that checks if the completion has a specific format.""" + pattern = r"^.*?.*?$" + completion_contents = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, content) for content in completion_contents] + return [1.0 if match else 0.0 for match in matches] +``` + +You can test this function as follows: + +```python +>>> prompts = [ +... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}], +... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}], +... ] +>>> completions = [ +... [{"role": "assistant", "content": "The sum of 1 and 2 is 3, which we multiply by 4 to get 12.(1 + 2) * 4 = 12"}], +... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}], +... ] +>>> format_reward_func(prompts=prompts, completions=completions) +[1.0, 0.0] +``` + +#### Example 3: Reward completions based on a reference + +Below is an example of a reward function that checks if the completion is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). +This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`. + +```python +import re + +def reward_func(completions, ground_truth, **kwargs): + # Regular expression to capture content inside \boxed{} + matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions] + contents = [match.group(1) if match else "" for match in matches] + # Reward 1 if the content is the same as the ground truth, 0 otherwise + return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)] +``` + +You can test this function as follows: + +```python +>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."] +>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."] +>>> ground_truth = ["2", "5"] +>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth) +[1.0, 0.0] +``` + +#### Example 4: Multi-task reward functions + +Below is an example of using multiple reward functions in the [`GRPOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works. + +```python +from datasets import Dataset +from trl import GRPOTrainer + +# Define a dataset that contains both math and coding problems +dataset = Dataset.from_list( + [ + {"prompt": "What is 2+2?", "task": "math"}, + {"prompt": "Write a function that returns the sum of two numbers.", "task": "code"}, + {"prompt": "What is 3*4?", "task": "math"}, + {"prompt": "Write a function that returns the product of two numbers.", "task": "code"}, + ] +) + +# Math-specific reward function +def math_reward_func(prompts, completions, task, **kwargs): + rewards = [] + for prompt, completion, t in zip(prompts, completions, task): + if t == "math": + # Calculate math-specific reward + correct = check_math_solution(prompt, completion) + reward = 1.0 if correct else -1.0 + rewards.append(reward) + else: + # Return None for non-math tasks + rewards.append(None) + return rewards + +# Coding-specific reward function +def coding_reward_func(prompts, completions, task, **kwargs): + rewards = [] + for prompt, completion, t in zip(prompts, completions, task): + if t == "coding": + # Calculate coding-specific reward + works = test_code_solution(prompt, completion) + reward = 1.0 if works else -1.0 + rewards.append(reward) + else: + # Return None for non-coding tasks + rewards.append(None) + return rewards + +# Use both task-specific reward functions +trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=[math_reward_func, coding_reward_func], + train_dataset=dataset, +) + +trainer.train() +``` + +In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None`, and the [`GRPOTrainer`] will continue with the valid functions and tasks. This allows the [`GRPOTrainer`] to handle multiple reward functions with different applicability. + +Note that the [`GRPOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function. + +#### Example 5: Asynchronous reward functions + +Custom reward functions can also be defined as `async def` coroutines. This is useful if your reward depends on slow I/O (for example, calling a remote service). When you pass multiple async reward functions, [`GRPOTrainer`] executes them concurrently so their latency overlaps. + +Below is a minimal example of an async reward function that simulates an I/O-bound operation: + +```python +import asyncio + +async def async_reward_func(prompts, completions, **kwargs): + # Simulate an I/O-bound call (e.g., HTTP request, database lookup) + await asyncio.sleep(0.01) + # Simple toy reward: 1.0 if the completion is non-empty, else 0.0 + return [1.0 if completion else 0.0 for completion in completions] +``` + +#### Passing the reward function to the trainer + +To use your custom reward function, pass it to the [`GRPOTrainer`] as follows: + +```python +from trl import GRPOTrainer + +trainer = GRPOTrainer( + reward_funcs=reward_func, + ..., +) +``` + +You can pass several reward functions as a list; this list may include both synchronous and asynchronous functions: + +```python +from trl import GRPOTrainer + +trainer = GRPOTrainer( + reward_funcs=[reward_func, async_reward_func1, async_reward_func2], + ..., +) +``` + +and the reward will be computed as the sum of the rewards from each function, or the weighted sum if `reward_weights` is provided in the config. + +Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details. + +### Rapid Experimentation for GRPO + +RapidFire AI is an open-source experimentation engine that sits on top of TRL and lets you launch multiple GRPO configurations at once, even on a single GPU. Instead of trying configurations sequentially, RapidFire lets you **see all their learning curves earlier, stop underperforming runs, and clone promising ones with new settings in flight** without restarting. For more information, see [RapidFire AI Integration](rapidfire_integration). + +## Agent Training + +GRPO supports **agent training** through the `tools` argument in [`GRPOTrainer`]. +This parameter expects a list of Python functions (sync or async) that define the tools available to the agent: + +```python +from trl import GRPOTrainer + +trainer = GRPOTrainer( + tools=[tool1, tool2], + ..., +) +``` + +Each tool must be a standard Python function with **type-hinted arguments and return types**, along with a **Google-style docstring** describing its purpose, arguments, and return value. +For more details, see the [Passing tools guide](https://huggingface.co/docs/transformers/en/chat_extras#passing-tools). + +Example: + +```python +from trl import GRPOTrainer + +def multiply(a: int, b: int) -> int: + """ + Multiplies two integers. + + Args: + a: The first integer. + b: The second integer. + + Returns: + The product of the two integers. + """ + return a * b + +async def async_add(a: int, b: int) -> int: + """ + Asynchronously adds two integers. + + Args: + a: The first integer. + b: The second integer. + + Returns: + The sum of the two integers. + """ + return a + b + +trainer = GRPOTrainer( + tools=[multiply, async_add], + ..., +) +``` + +### Supported Models + +Tested with: + +- **Qwen3** — e.g., `Qwen/Qwen3-0.6B` + +> [!TIP] +> Compatibility with all LLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes. + +### Quick Start + +Use [grpo\_agent.py](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_agent.py) to fine-tune a LLM for agentic workflows. + +```bash +accelerate launch \ + --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/grpo_agent.py \ + --model_name_or_path Qwen/Qwen3-0.6B + ... +``` + +## Vision-Language Model (VLM) Training + +GRPO supports training Vision-Language Models (VLMs) on multimodal datasets containing both text and images. + +### Supported Models + +Tested with: + +- **Gemma3** — e.g., `google/gemma-3-4b-it` +- **LLaVA-NeXT** — e.g., `llava-hf/llava-v1.6-mistral-7b-hf` +- **Qwen2-VL** — e.g., `Qwen/Qwen2-VL-2B-Instruct` +- **Qwen2.5-VL** — e.g., `Qwen/Qwen2.5-VL-3B-Instruct` +- **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct` + +> [!TIP] +> Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes. + +### Quick Start + +Use [grpo\_vlm.py](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) to fine-tune a VLM. Example command for training on [`lmms-lab/multimodal-open-r1-8k-verified`](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified): + +```bash +accelerate launch \ + --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/grpo_vlm.py \ + --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \ + --output_dir grpo-Qwen2.5-VL-3B-Instruct \ + --learning_rate 1e-5 \ + --dtype bfloat16 \ + --max_completion_length 1024 \ + --use_vllm \ + --vllm_mode colocate \ + --use_peft \ + --lora_target_modules "q_proj", "v_proj" \ + --log_completions +``` + +### Configuration Tips + +- Use LoRA on vision-language projection layers +- Enable 4-bit quantization to reduce memory usage +- VLMs are memory-intensive — start with smaller batch sizes +- Most models are compatible with vLLM (`server` and `colocate` modes) + +### Dataset Format + +Each training sample should include: + +- `prompt`: Text formatted via the processor's chat template +- `image`/`images`: PIL Image or list of PIL Images + +The trainer automatically handles image-to-tensor conversion via the model’s image processor. + +## GRPOTrainer + +[[autodoc]] GRPOTrainer + - train + - save_model + - push_to_hub + +## GRPOConfig + +[[autodoc]] GRPOConfig diff --git a/ICL/RL/trl_source/docs/source/grpo_with_replay_buffer.md b/ICL/RL/trl_source/docs/source/grpo_with_replay_buffer.md new file mode 100644 index 0000000000000000000000000000000000000000..e68cce94458f7b1769630847b1a201f8fb99024a --- /dev/null +++ b/ICL/RL/trl_source/docs/source/grpo_with_replay_buffer.md @@ -0,0 +1,56 @@ +# GRPO With Replay Buffer + +This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that've been used to train a model in prior batches. + +## Usage + +```python +import torch +from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferConfig, GRPOWithReplayBufferTrainer +from datasets import load_dataset + +dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + +# Guarantee that some rewards have 0 std +def custom_reward_func(completions, **kwargs): + if torch.rand(1).item() < 0.25: + return [0] * len(completions) # simulate some None rewards + else: + return torch.rand(len(completions)).tolist() + +training_args = GRPOWithReplayBufferConfig( + output_dir="./tmp", + learning_rate=1e-4, + per_device_train_batch_size=4, + num_generations=4, + max_completion_length=8, + replay_buffer_size=8, + report_to="none", +) + +trainer = GRPOWithReplayBufferTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[custom_reward_func], + args=training_args, + train_dataset=dataset, +) + +previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + +trainer.train() +``` + +## GRPOWithReplayBufferTrainer + +[[autodoc]] experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer + - train + - save_model + - push_to_hub + +## GRPOWithReplayBufferConfig + +[[autodoc]] experimental.grpo_with_replay_buffer.GRPOWithReplayBufferConfig + +## ReplayBuffer + +[[autodoc]] experimental.grpo_with_replay_buffer.ReplayBuffer diff --git a/ICL/RL/trl_source/docs/source/gspo_token.md b/ICL/RL/trl_source/docs/source/gspo_token.md new file mode 100644 index 0000000000000000000000000000000000000000..394fb555f87f4e5e1aeb25ed963a0a3643e3dc47 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/gspo_token.md @@ -0,0 +1,25 @@ +# GSPO-token + +In the paper [Group Sequence Policy Optimization](https://huggingface.co/papers/2507.18071), the authors propose a token-level objective variant to GSPO, called GSPO-token. To use GSPO-token, you can use the `GRPOTrainer` class in `trl.experimental.gspo_token`. + +## Usage + +```python +from trl.experimental.gspo_token import GRPOTrainer +from trl import GRPOConfig + +training_args = GRPOConfig( + importance_sampling_level="sequence_token", + ... +) +``` + +> [!WARNING] +> To leverage GSPO-token, the user will need to provide the per-token advantage \\( \hat{A_{i,t}} \\) for each token \\( t \\) in the sequence \\( i \\) (i.e., make \\( \hat{A_{i,t}} \\) varies with \\( t \\)—which isn't the case here, \\( \hat{A_{i,t}}=\hat{A_{i}} \\)). Otherwise, GSPO-Token gradient is just equivalent to the original GSPO implementation. + +## GRPOTrainer + +[[autodoc]] experimental.gspo_token.GRPOTrainer + - train + - save_model + - push_to_hub diff --git a/ICL/RL/trl_source/docs/source/index.md b/ICL/RL/trl_source/docs/source/index.md new file mode 100644 index 0000000000000000000000000000000000000000..b1eaa0c289735232bbb0a6c60f04f17d1821d5fd --- /dev/null +++ b/ICL/RL/trl_source/docs/source/index.md @@ -0,0 +1,151 @@ +
+ +
+ +# TRL - Transformer Reinforcement Learning + +TRL is a full stack library where we provide a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more. +The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers). + +## 🎉 What's New + +**OpenEnv Integration:** TRL now supports **[OpenEnv](https://huggingface.co/blog/openenv)**, the open-source framework from Meta for defining, deploying, and interacting with environments in reinforcement learning and agentic workflows. + +Explore how to seamlessly integrate TRL with OpenEnv in our [dedicated documentation](openenv). + +## Taxonomy + +Below is the current list of TRL trainers, organized by method type (⚡️ = vLLM support; 🧪 = experimental). + +
+
+ +### Online methods + +- [`GRPOTrainer`](grpo_trainer) ⚡️ +- [`RLOOTrainer`](rloo_trainer) ⚡️ +- [`OnlineDPOTrainer`](online_dpo_trainer) 🧪 ⚡️ +- [`NashMDTrainer`](nash_md_trainer) 🧪 ⚡️ +- [`PPOTrainer`](ppo_trainer) 🧪 +- [`XPOTrainer`](xpo_trainer) 🧪 ⚡️ + +### Reward modeling + +- [`RewardTrainer`](reward_trainer) +- [`PRMTrainer`](prm_trainer) 🧪 + +
+
+ +### Offline methods + +- [`SFTTrainer`](sft_trainer) +- [`DPOTrainer`](dpo_trainer) +- [`BCOTrainer`](bco_trainer) 🧪 +- [`CPOTrainer`](cpo_trainer) 🧪 +- [`KTOTrainer`](kto_trainer) 🧪 +- [`ORPOTrainer`](orpo_trainer) 🧪 + +### Knowledge distillation + +- [`GKDTrainer`](gkd_trainer) 🧪 +- [`MiniLLMTrainer`](minillm_trainer) 🧪 + +
+
+ +You can also explore TRL-related models, datasets, and demos in the [TRL Hugging Face organization](https://huggingface.co/trl-lib). + +## Learn + +Learn post-training with TRL and other libraries in 🤗 [smol course](https://github.com/huggingface/smol-course). + +## Contents + +The documentation is organized into the following sections: + +- **Getting Started**: installation and quickstart guide. +- **Conceptual Guides**: dataset formats, training FAQ, and understanding logs. +- **How-to Guides**: reducing memory usage, speeding up training, distributing training, etc. +- **Integrations**: DeepSpeed, Liger Kernel, PEFT, etc. +- **Examples**: example overview, community tutorials, etc. +- **API**: trainers, utils, etc. + +## Blog posts + + + +## Talks + + diff --git a/ICL/RL/trl_source/docs/source/installation.md b/ICL/RL/trl_source/docs/source/installation.md new file mode 100644 index 0000000000000000000000000000000000000000..6a4a9117d5168d8869b4cc3647c4a1da18fe93f3 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/installation.md @@ -0,0 +1,42 @@ +# Installation + +You can install TRL either from PyPI or from source: + +## PyPI + +Install the library with pip or [uv](https://docs.astral.sh/uv/): + + + + +uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions. + +```bash +uv pip install trl +``` + + + + +```bash +pip install trl +``` + + + + +## Source + +You can also install the latest version from source. First clone the repo and then run the installation with `pip`: + +```bash +git clone https://github.com/huggingface/trl.git +cd trl/ +pip install -e . +``` + +If you want the development install you can replace the pip install with the following: + +```bash +pip install -e ".[dev]" +``` diff --git a/ICL/RL/trl_source/docs/source/jobs_training.md b/ICL/RL/trl_source/docs/source/jobs_training.md new file mode 100644 index 0000000000000000000000000000000000000000..31ead93b61f8a9dfd5a61b6f0a79c15196f8904b --- /dev/null +++ b/ICL/RL/trl_source/docs/source/jobs_training.md @@ -0,0 +1,274 @@ +# Training with Jobs + +[![model badge](https://img.shields.io/badge/All_models-HF_Jobs-blue)](https://huggingface.co/models?other=hf_jobs,trl) + +[Hugging Face Jobs](https://huggingface.co/docs/huggingface_hub/guides/jobs) lets you run training scripts on fully managed infrastructure—no need to manage GPUs or local environment setup. + +In this guide, you'll learn how to: + +* Use [TRL Jobs](https://github.com/huggingface/trl-jobs) to easily run pre-optimized TRL training +* Run any TRL training script with uv scripts + +For general details about Hugging Face Jobs (hardware selection, job monitoring, etc.), see the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/guides/jobs). + +## Requirements + +* A [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan +* Logged in to the Hugging Face Hub (`hf auth login`) + +## Using TRL Jobs + +[TRL Jobs](https://github.com/huggingface/trl-jobs) is a high-level wrapper around Hugging Face Jobs and TRL that streamlines training. It provides optimized default configurations so you can start quickly without manually tuning parameters. + +Example: + +```bash +pip install trl-jobs +trl-jobs sft --model_name Qwen/Qwen3-0.6B --dataset_name trl-lib/Capybara +``` + +TRL Jobs supports everything covered in this guide, with additional optimizations to simplify workflows. + +## Using uv Scripts + +For more control, you can run Hugging Face Jobs directly with your own scripts, using [uv scripts](https://docs.astral.sh/uv/guides/scripts/). + +Create a Python script (e.g., `train.py`) containing your training code: + +```python +from datasets import load_dataset +from trl import SFTTrainer + +dataset = load_dataset("trl-lib/Capybara", split="train") +trainer = SFTTrainer( + model="Qwen/Qwen2.5-0.5B", + train_dataset=dataset, +) +trainer.train() +trainer.push_to_hub("Qwen2.5-0.5B-SFT") +``` + +Launch the job using either the [`hf jobs` CLI](https://huggingface.co/docs/huggingface_hub/guides/cli#hf-jobs) or the Python API: + + + + +```bash +hf jobs uv run \ + --flavor a100-large \ + --with trl \ + --secrets HF_TOKEN \ + train.py +``` + + + + +```python +from huggingface_hub import run_uv_job + +run_uv_job( + "train.py", + dependencies=["trl"], + flavor="a100-large", + secrets={"HF_TOKEN": "hf_..."}, +) +``` + + + + +To run successfully, the script needs: + +* **TRL installed**: Use the `--with trl` flag or the `dependencies` argument. uv installs these dependencies automatically before running the script. +* **An authentication token**: Required to push the trained model (or perform other authenticated operations). Provide it with the `--secrets HF_TOKEN` flag or the `secrets` argument. + +> [!WARNING] +> When training with Jobs, be sure to: +> +> * **Set a sufficient timeout**. Jobs time out after 30 minutes by default. If your job exceeds the timeout, it will fail and all progress will be lost. See [Setting a custom timeout](https://huggingface.co/docs/huggingface_hub/guides/jobs#setting-a-custom-timeout). +> * **Push the model to the Hub**. The Jobs environment is ephemeral—files are deleted when the job ends. If you don’t push the model, it will be lost. + +You can also run a script directly from a URL: + + + + +```bash +hf jobs uv run \ + --flavor a100-large \ + --with trl \ + --secrets HF_TOKEN \ + "https://gist.githubusercontent.com/qgallouedec/eb6a7d20bd7d56f9c440c3c8c56d2307/raw/69fd78a179e19af115e4a54a1cdedd2a6c237f2f/train.py" +``` + + + + +```python +from huggingface_hub import run_uv_job + +run_uv_job( + "https://gist.githubusercontent.com/qgallouedec/eb6a7d20bd7d56f9c440c3c8c56d2307/raw/69fd78a179e19af115e4a54a1cdedd2a6c237f2f/train.py", + flavor="a100-large", + dependencies=["trl"], + secrets={"HF_TOKEN": "hf_..."}, +) +``` + + + + +To make a script self-contained, declare dependencies at the top: + +```python +# /// script +# dependencies = [ +# "trl", +# "peft", +# ] +# /// + +from datasets import load_dataset +from peft import LoraConfig +from trl import SFTTrainer + +dataset = load_dataset("trl-lib/Capybara", split="train") + +trainer = SFTTrainer( + model="Qwen/Qwen2.5-0.5B", + train_dataset=dataset, + peft_config=LoraConfig(), +) +trainer.train() +trainer.push_to_hub("Qwen2.5-0.5B-SFT") +``` + +You can then run the script without specifying dependencies: + + + + +```bash +hf jobs uv run \ + --flavor a100-large \ + --secrets HF_TOKEN \ + train.py +``` + + + + +```python +from huggingface_hub import run_uv_job + +run_uv_job( + "train.py", + flavor="a100-large", + secrets={"HF_TOKEN": "hf_..."}, +) +``` + + + + +TRL example scripts are fully uv-compatible, so you can run a complete training workflow directly on Jobs. You can customize training with standard script arguments plus hardware and secrets: + + + + +```bash +hf jobs uv run \ + --flavor a100-large \ + --secrets HF_TOKEN \ + https://raw.githubusercontent.com/huggingface/trl/refs/heads/main/examples/scripts/prm.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/prm800k \ + --output_dir Qwen2-0.5B-Reward \ + --push_to_hub +``` + + + + +```python +from huggingface_hub import run_uv_job +run_uv_job( + "https://raw.githubusercontent.com/huggingface/trl/refs/heads/main/examples/scripts/prm.py", + flavor="a100-large", + secrets={"HF_TOKEN": "hf_..."}, + script_args=[ + "--model_name_or_path", "Qwen/Qwen2-0.5B-Instruct", + "--dataset_name", "trl-lib/prm800k", + "--output_dir", "Qwen2-0.5B-Reward", + "--push_to_hub" + ] +) +``` + + + +See the full list of examples in [Maintained examples](example_overview#maintained-examples). + +### Docker Images + +An up-to-date Docker image with all TRL dependencies is available at [huggingface/trl](https://hub.docker.com/r/huggingface/trl) and can be used directly with Hugging Face Jobs: + + + + +```bash +hf jobs uv run \ + --flavor a100-large \ + --secrets HF_TOKEN \ + --image huggingface/trl \ + train.py +``` + + + + +```python +from huggingface_hub import run_uv_job + +run_uv_job( + "train.py", + flavor="a100-large", + secrets={"HF_TOKEN": "hf_..."}, + image="huggingface/trl", +) +``` + + + + +Jobs runs on a Docker image from Hugging Face Spaces or Docker Hub, so you can also specify any custom image: + + + + +```bash +hf jobs uv run \ + --flavor a100-large \ + --secrets HF_TOKEN \ + --image \ + --secrets HF_TOKEN \ + train.py +``` + + + + +```python +from huggingface_hub import run_uv_job + +run_uv_job( + "train.py", + flavor="a100-large", + secrets={"HF_TOKEN": "hf_..."}, + image="", +) +``` + + + diff --git a/ICL/RL/trl_source/docs/source/judges.md b/ICL/RL/trl_source/docs/source/judges.md new file mode 100644 index 0000000000000000000000000000000000000000..954bf10af1a9a697ffa14350cc4c6fff78194efe --- /dev/null +++ b/ICL/RL/trl_source/docs/source/judges.md @@ -0,0 +1,86 @@ +# Judges + +> [!WARNING] +> TRL Judges is an experimental API which is subject to change at any time. As of TRL v1.0, judges have been moved to the `trl.experimental.judges` module. + +TRL provides judges to easily compare two completions. + +Make sure to have installed the required dependencies by running: + +```bash +pip install trl[judges] +``` + +## Using the provided judges + +TRL provides several judges out of the box. For example, you can use the [`experimental.judges.HfPairwiseJudge`] to compare two completions using a pre-trained model from the Hugging Face model hub: + +```python +from trl.experimental.judges import HfPairwiseJudge + +judge = HfPairwiseJudge() +judge.judge( + prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"], + completions=[["Paris", "Lyon"], ["Saturn", "Jupiter"]], +) # Outputs: [0, 1] +``` + +## Define your own judge + +To define your own judge, we provide several base classes that you can subclass. For rank-based judges, you need to subclass [`experimental.judges.BaseRankJudge`] and implement the [`experimental.judges.BaseRankJudge.judge`] method. For pairwise judges, you need to subclass [`experimental.judges.BasePairJudge`] and implement the [`experimental.judges.BasePairJudge.judge`] method. If you want to define a judge that doesn't fit into these categories, you need to subclass [`experimental.judges.BaseJudge`] and implement the [`experimental.judges.BaseJudge.judge`] method. + +As an example, let's define a pairwise judge that prefers shorter completions: + +```python +from trl.experimental.judges import BasePairwiseJudge + +class PrefersShorterJudge(BasePairwiseJudge): + def judge(self, prompts, completions, shuffle_order=False): + return [0 if len(completion[0]) > len(completion[1]) else 1 for completion in completions] +``` + +You can then use this judge as follows: + +```python +judge = PrefersShorterJudge() +judge.judge( + prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"], + completions=[["Paris", "The capital of France is Paris."], ["Jupiter is the biggest planet in the solar system.", "Jupiter"]], +) # Outputs: [0, 1] +``` + +## Provided judges + +### PairRMJudge + +[[autodoc]] experimental.judges.PairRMJudge + +### HfPairwiseJudge + +[[autodoc]] experimental.judges.HfPairwiseJudge + +### OpenAIPairwiseJudge + +[[autodoc]] experimental.judges.OpenAIPairwiseJudge + +### AllTrueJudge + +[[autodoc]] experimental.judges.AllTrueJudge + +## Base classes + +### BaseJudge + +[[autodoc]] experimental.judges.BaseJudge + +### BaseBinaryJudge + +[[autodoc]] experimental.judges.BaseBinaryJudge + +### BaseRankJudge + +[[autodoc]] experimental.judges.BaseRankJudge + +### BasePairwiseJudge + +[[autodoc]] experimental.judges.BasePairwiseJudge diff --git a/ICL/RL/trl_source/docs/source/kernels_hub.md b/ICL/RL/trl_source/docs/source/kernels_hub.md new file mode 100644 index 0000000000000000000000000000000000000000..f3d7ee124ba58be7f9bfd503dded52d4dc90a2b5 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/kernels_hub.md @@ -0,0 +1,96 @@ +# Kernels Hub Integration and Usage + +kernel-builder logo + +The [`kernels`](https://huggingface.co/blog/hello-hf-kernels#get-started-and-next-steps) library allows optimized compute kernels to be loaded directly from the Hub. +You can find `kernels` in [dedicated orgs](https://huggingface.co/kernels-community) or by searching for the [`kernel` tag](https://huggingface.co/models?other=kernel) within the Hub. + +Kernels are **optimized code pieces** that help in model development, training, and inference. Here, we’ll focus on their **integration with TRL**, but check out the above resources to learn more about them. + +## Installation + +To use kernels with TRL, you'd need to install the library in your Python environment: + +```bash +pip install kernels +``` + +## Using Kernels from the Hub in TRL + +Kernels can directly replace attention implementations, removing the need to manually compile attention backends like Flash Attention and boosting training speed just by pulling the respective attention kernel from the Hub. + +You can specify a kernel when loading a model: + + +```python +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained( + "your-model-name", + attn_implementation="kernels-community/flash-attn2" # other options: kernels-community/vllm-flash-attn3, kernels-community/paged-attention +) +``` + +Or when running a TRL training script: + +```bash +python sft.py ... --attn_implementation kernels-community/flash-attn2 +``` + +Or using the TRL CLI: + +```bash +trl sft ... --attn_implementation kernels-community/flash-attn2 +``` + +> [!TIP] +> Now you can leverage faster attention backends with a pre-optimized kernel for your hardware configuration from the Hub, speeding up both development and training. + +## Comparing Attention Implementations + +We evaluated various attention implementations available in transformers, along with different kernel backends, using **TRL** and **SFT**. +The experiments were run on a single **H100 GPU** with **CUDA 12.9**, leveraging **Qwen3-8B** with a **batch size of 8**, **gradient accumulation of 1**, and **bfloat16** precision. +Keep in mind that the results shown here are specific to this setup and may vary with different training configurations. + +The following figure illustrates both **latency** (time per training step) and **peak allocated memory** for the different attention implementations and kernel backends. +Kernel-based implementations perform on par with custom-installed attention, and increasing the model’s `max_length` further enhances performance. Memory consumption is similar across all implementations, showing no significant differences. We get the same performance but with less friction, as described in [the following section](#flash-attention-vs-hub-kernels). + +
+ Latency and Memory Usage + Latency and Memory Usage +
+ +## Flash Attention vs. Hub Kernels + +Building Flash Attention from source can be time-consuming, often taking anywhere from several minutes to hours, depending on your hardware, CUDA/PyTorch configuration, and whether precompiled wheels are available. + +In contrast, **Hugging Face Kernels** provide a much faster and more reliable workflow. Developers don’t need to worry about complex setups—everything is handled automatically. In our benchmarks, kernels were ready to use in about **2.5 seconds**, with no compilation required. This allows you to start training almost instantly, significantly accelerating development. Simply specify the desired version, and `kernels` takes care of the rest. + +## Combining FlashAttention Kernels with Liger Kernels + +You can combine **FlashAttention kernels** with **Liger kernels** for additional TRL performance improvements. + +First, install the Liger kernel dependency: + +```bash +pip install liger-kernel +``` + +Then, combine both in your code: + +```python +from transformers import AutoModelForCausalLM +from trl import SFTConfig + +model = AutoModelForCausalLM.from_pretrained( + "your-model-name", + attn_implementation="kernels-community/flash-attn2" # choose the desired FlashAttention variant +) + +training_args = SFTConfig( + use_liger_kernel=True, + # ... other TRL training args +) +``` + +Learn more about the [Liger Kernel Integration](./liger_kernel_integration). diff --git a/ICL/RL/trl_source/docs/source/kto_trainer.md b/ICL/RL/trl_source/docs/source/kto_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..708448620b00a4bd61ea121cdd40b4f85f228f19 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/kto_trainer.md @@ -0,0 +1,144 @@ +# KTO Trainer + +[![model badge](https://img.shields.io/badge/All_models-KTO-blue)](https://huggingface.co/models?other=kto,trl) + +> [!WARNING] +> As of TRL v1.0, `KTOTrainer` and `KTOConfig` have been moved to the `trl.experimental.kto` module. +> KTO API is experimental and may change at any time. +> Promoting KTO back into the stable API is a high-priority task: KTO is slated for refactoring to align with the standard core trainer architecture. + +## Overview + +Kahneman-Tversky Optimization (KTO) was introduced in [KTO: Model Alignment as Prospect Theoretic Optimization](https://huggingface.co/papers/2402.01306) by [Kawin Ethayarajh](https://huggingface.co/kawine), [Winnie Xu](https://huggingface.co/xwinxu), [Niklas Muennighoff](https://huggingface.co/Muennighoff), Dan Jurafsky, [Douwe Kiela](https://huggingface.co/douwekiela). + +The abstract from the paper is the following: + +> Kahneman & Tversky's prospect theory tells us that humans perceive random variables in a biased but well-defined manner; for example, humans are famously loss-averse. We show that objectives for aligning LLMs with human feedback implicitly incorporate many of these biases -- the success of these objectives (e.g., DPO) over cross-entropy minimization can partly be ascribed to them being human-aware loss functions (HALOs). However, the utility functions these methods attribute to humans still differ from those in the prospect theory literature. Using a Kahneman-Tversky model of human utility, we propose a HALO that directly maximizes the utility of generations instead of maximizing the log-likelihood of preferences, as current methods do. We call this approach Kahneman-Tversky Optimization (KTO), and it matches or exceeds the performance of preference-based methods at scales from 1B to 30B. Crucially, KTO does not need preferences -- only a binary signal of whether an output is desirable or undesirable for a given input. This makes it far easier to use in the real world, where preference data is scarce and expensive. + +The official code can be found in [ContextualAI/HALOs](https://github.com/ContextualAI/HALOs). + +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Younes Belkada](https://huggingface.co/ybelkada), [Lewis Tunstall](https://huggingface.co/lewtun) and Pablo Vicente. + +## Quick start + +This example demonstrates how to train a model using the KTO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [KTO Mix 14k](https://huggingface.co/datasets/trl-lib/kto-mix-14k). You can view the data in the dataset here: + + + +Below is the script to train the model: + +```python +# train_kto.py +from datasets import load_dataset +from trl.experimental.kto import KTOConfig, KTOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +train_dataset = load_dataset("trl-lib/kto-mix-14k", split="train") + +training_args = KTOConfig(output_dir="Qwen2-0.5B-KTO") +trainer = KTOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_kto.py +``` + +Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time. + +![kto qwen2 reward margin](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kto-qwen2-reward-margin.png) + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models). + +
$ transformers chat trl-lib/Qwen2-0.5B-KTO
+<quentin_gallouedec>:
+What is the best programming language?
+
+<trl-lib/Qwen2-0.5B-KTO>:
+The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:
+
+Here are some other factors to consider when choosing a programming language for a project:
+
+ 1 JavaScript: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.
+ 2 Java: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.
+ 3 C++: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.
+ 4 Python: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.
+
+ +## Expected dataset format + +KTO requires an [unpaired preference dataset](dataset_formats#unpaired-preference). Alternatively, you can provide a *paired* preference dataset (also known simply as a *preference dataset*). In this case, the trainer will automatically convert it to an unpaired format by separating the chosen and rejected responses, assigning `label = True` to the chosen completions and `label = False` to the rejected ones. + +The [`experimental.kto.KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +In theory, the dataset should contain at least one chosen and one rejected completion. However, some users have successfully run KTO using *only* chosen or only rejected data. If using only rejected data, it is advisable to adopt a conservative learning rate. + +## Example script + +We provide an example script to train a model using the KTO method. The script is available in [`trl/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/kto.py) + +To test the KTO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/kto-mix-14k), run the following command: + +```bash +accelerate launch trl/scripts/kto.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/kto-mix-14k \ + --num_train_epochs 1 \ + --output_dir Qwen2-0.5B-KTO +``` + +## Usage tips + +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config. + +### Batch size recommendations + +Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor. + +### Learning rate recommendations + +Each choice of `beta` has a maximum learning rate it can tolerate before learning performance degrades. For the default setting of `beta = 0.1`, the learning rate should typically not exceed `1e-6` for most models. As `beta` decreases, the learning rate should also be reduced accordingly. In general, we strongly recommend keeping the learning rate between `5e-7` and `5e-6`. Even with small datasets, we advise against using a learning rate outside this range. Instead, opt for more epochs to achieve better results. + +### Imbalanced data + +The `desirable_weight` and `undesirable_weight` of the [`experimental.kto.KTOConfig`] refer to the weights placed on the losses for desirable/positive and undesirable/negative examples. +By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3. + +## Logged metrics + +While training and evaluating, we record the following reward metrics: + +- `rewards/chosen_sum`: the sum of log probabilities of the policy model for the chosen responses scaled by beta +- `rewards/rejected_sum`: the sum of log probabilities of the policy model for the rejected responses scaled by beta +- `logps/chosen_sum`: the sum of log probabilities of the chosen completions +- `logps/rejected_sum`: the sum of log probabilities of the rejected completions +- `logits/chosen_sum`: the sum of logits of the chosen completions +- `logits/rejected_sum`: the sum of logits of the rejected completions +- `count/chosen`: the count of chosen samples in a batch +- `count/rejected`: the count of rejected samples in a batch + +## KTOTrainer + +[[autodoc]] experimental.kto.KTOTrainer + - train + - save_model + - push_to_hub + +## KTOConfig + +[[autodoc]] experimental.kto.KTOConfig diff --git a/ICL/RL/trl_source/docs/source/liger_kernel_integration.md b/ICL/RL/trl_source/docs/source/liger_kernel_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..7a387c813fd3dd16b3daeca082e08e704eddaed4 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/liger_kernel_integration.md @@ -0,0 +1,78 @@ +# Liger Kernel Integration + +[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, with more to come. The kernel works out of the box with [FlashAttention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). + +With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance. + +| Speed Up | Memory Reduction | +| --- | --- | +| ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) | + +## Supported Trainers + +Liger Kernel is supported in the following TRL trainers: +- **SFT** (Supervised Fine-Tuning) +- **DPO** (Direct Preference Optimization) +- **GRPO** (Group Relative Policy Optimization) +- **KTO** (Kahneman-Tversky Optimization) +- **GKD** (Generalized Knowledge Distillation) + +## Usage + +1. First, install Liger Kernel: + + ```bash + pip install liger-kernel + ``` + +2. Once installed, set `use_liger_kernel=True` in your trainer config. No other changes are needed! + + + + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl import KTOConfig + +training_args = KTOConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl.experimental.gkd import GKDConfig + +training_args = GKDConfig(..., use_liger_kernel=True) +``` + + + + +To learn more about Liger-Kernel, visit their [official repository](https://github.com/linkedin/Liger-Kernel/). diff --git a/ICL/RL/trl_source/docs/source/lora_without_regret.md b/ICL/RL/trl_source/docs/source/lora_without_regret.md new file mode 100644 index 0000000000000000000000000000000000000000..2875d66228801826a1cc731f52f93f038eb8ed09 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/lora_without_regret.md @@ -0,0 +1,347 @@ +# LoRA Without Regret + +Recent research from the team at [Thinking Machines Lab](https://thinkingmachines.ai/blog/lora/) (Schulman et al., 2025) shows that **LoRA can match full fine-tuning performance** when configured correctly, while using only ~67% of the compute. These findings are exciting to TRL users because they're straightforward to implement and can improve model performance on smaller budgets. + +This guide provides simple instructions to reproduce the results of the blog post in TRL. + +> [!TIP] +> It is recommended to read the blog post before following this guide, or to consult both resources in parallel for best results. + +## Benefits of LoRA over full fine-tuning + +First of all, let's remind ourselves of the benefits of [LoRA over full fine-tuning](https://huggingface.co/docs/trl/en/peft_integration). + +LoRA adds adapter layers on top of the base model, which contains significantly fewer parameters than the base model itself. This design reduces GPU memory requirements and enables more efficient training. As described in the [blog](https://thinkingmachines.ai/blog/lora/), this approach was originally thought to involve a performance trade-off, although careful configuration can overcome this trade-off and match full fine-tuning performance. + +## Examples with TRL + +Let's implement and train LoRA adapters in TRL scripts based on the core findings of the blog post. Afterwards, we'll revisit each finding in light of the TRL results. + +### Supervised Fine-Tuning (SFT) + +The blog post performs SFT on a range of models and datasets from the Hub, which we can reproduce in TRL. + +| Model | Dataset | +| --- | --- | +| [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B) | [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) | +| [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B) | [open-thoughts/OpenThoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) | +| [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B) | [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) | +| [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B) | [open-thoughts/OpenThoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) | + + + + +We can integrate these findings with the TRL Python API like so: + +```python + +from datasets import load_dataset +from peft import LoraConfig +from trl import SFTTrainer, SFTConfig + +dataset = load_dataset("open-thoughts/OpenThoughts-114k", split="train") + +peft_config = LoraConfig(r=256, lora_alpha=16, target_modules="all-linear") + +training_args = SFTConfig( + learning_rate=2e-4, + per_device_train_batch_size=1, + gradient_accumulation_steps=4, + num_train_epochs=1, + report_to=["trackio"], +) + +trainer = SFTTrainer( + model="Qwen/Qwen2.5-3B-Instruct", + train_dataset=dataset, + peft_config=peft_config, + args=training_args, +) + +trainer.train() + +``` + + + + +```bash + +hf jobs uv run \ + --flavor a100-large \ + --timeout 8h \ + --secrets HF_TOKEN \ + "https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" \ + --model_name_or_path Qwen/Qwen2.5-3B-Instruct \ + --dataset_name open-thoughts/OpenThoughts-114k \ + --learning_rate 2.0e-5 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --use_peft \ + --lora_r 256 \ + --lora_alpha 16 \ + --lora_target_modules all-linear \ + --output_dir Qwen2.5-3B-OpenThoughts-LoRA \ + --report_to trackio \ + --push_to_hub + +``` + +To use Hugging Face Jobs, you will need to be logged in to the Hugging Face Hub (`hf auth login`) and have a [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan. Check out the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) for more details. + + + + +```bash + +uv run "https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" \ + --model_name_or_path Qwen/Qwen2.5-3B-Instruct \ + --dataset_name open-thoughts/OpenThoughts-114k \ + --learning_rate 2.0e-5 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --eval_strategy no \ + --use_peft \ + --lora_r 256 \ + --lora_alpha 16 \ + --lora_target_modules all-linear \ + --output_dir Qwen2.5-3B-OpenThoughts-LoRA \ + --report_to trackio \ + --push_to_hub + +``` + +To run the script locally, you will need to have `uv` installed. Check out the [uv documentation](https://docs.astral.sh/uv/) for more details. + + + + +Once training starts, you can monitor the progress in [Trackio](https://huggingface.co/trackio), which will log the URL. + +### Reinforcement Learning (GRPO) + +The blog post performs GRPO on a range of models and datasets from the Hub, and once again we can reproduce the results in TRL. + +| Model | Dataset | +| --- | --- | +| [Llama-3.1-8B-Base](https://huggingface.co/meta-llama/Llama-3.2-1B) | [GSM8k](https://huggingface.co/datasets/openai/gsm8k) | +| [Llama-3.1-8B-Base](https://huggingface.co/meta-llama/Llama-3.2-1B) | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) | +| [Qwen3-8b-base](https://huggingface.co/Qwen/Qwen3-8b-base) | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) | + +For reinforcement learning, the blog uses a math reasoning task that we can reproduce as a Python function. + + + + +We can implement these recommendations with the TRL Python API like so: + +```python + +from datasets import load_dataset +from peft import LoraConfig +from trl import GRPOConfig, GRPOTrainer +from trl.rewards import reasoning_accuracy_reward + +dataset = load_dataset("HuggingFaceH4/OpenR1-Math-220k-default-verified", split="train") + +peft_config = LoraConfig( + r=1, + lora_alpha=32, + target_modules="all-linear" +) + +training_args = GRPOConfig( + learning_rate=5e-5, + per_device_train_batch_size=1, + gradient_accumulation_steps=4, + num_train_epochs=1, + num_generations=8, + generation_batch_size=8, + report_to=["trackio"], +) + +trainer = GRPOTrainer( + model="Qwen/Qwen3-0.6B", + reward_funcs=reasoning_accuracy_reward, + args=training_args, + train_dataset=dataset, + peft_config=peft_config, +) + +trainer.train() + +``` + +> [!WARNING] +> This snippet skips the reward function which is defined above to keep the example concise. + + + + +```bash + +hf jobs uv run \ + --flavor a100-large \ + --timeout 4h \ + --secrets HF_TOKEN \ + --env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ + "https://huggingface.co/datasets/burtenshaw/lora-without-regrets/resolve/main/grpo.py" \ + --model_name_or_path Qwen/Qwen3-0.6B \ + --dataset_name HuggingFaceH4/OpenR1-Math-220k-default-verified \ + --output_dir grpo-full-qwen3-0.6b \ + --learning_rate 1.0e-6 \ + --lr_scheduler_type cosine \ + --warmup_steps 0.0 \ + --max_grad_norm 1.0 \ + --beta 0.0 \ + --max_completion_length 4096 \ + --num_generations 16 \ + --generation_batch_size 16 \ + --gradient_accumulation_steps 8 \ + --per_device_train_batch_size 1 \ + --num_train_epochs 1 \ + --lora_r 1 \ + --lora_alpha 32 \ + --lora_dropout 0.0 \ + --lora_target_modules all-linear \ + --vllm_mode colocate \ + --save_strategy steps \ + --save_steps 50 \ + --save_total_limit 1 \ + --logging_steps 1 \ + --max_steps 200 \ + --report_to trackio +``` + +To use Hugging Face Jobs, you will need to be logged in to the Hugging Face Hub (`hf auth login`) and have a [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan. Check out the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) for more details. + + + + +```bash +uv run "https://huggingface.co/datasets/burtenshaw/lora-without-regrets/resolve/main/grpo.py" \ + --model_name_or_path Qwen/Qwen3-0.6B \ + --dataset_name HuggingFaceH4/OpenR1-Math-220k-default-verified \ + --output_dir grpo-full-qwen3-0.6b \ + --learning_rate 1.0e-6 \ + --lr_scheduler_type cosine \ + --warmup_steps 0.0 \ + --max_grad_norm 1.0 \ + --beta 0.0 \ + --max_completion_length 4096 \ + --num_generations 16 \ + --generation_batch_size 16 \ + --gradient_accumulation_steps 8 \ + --per_device_train_batch_size 1 \ + --num_train_epochs 1 \ + --lora_r 1 \ + --lora_alpha 32 \ + --lora_dropout 0.0 \ + --lora_target_modules all-linear \ + --vllm_mode colocate \ + --save_strategy steps \ + --save_steps 50 \ + --save_total_limit 1 \ + --logging_steps 1 \ + --max_steps 200 \ + --report_to trackio +``` + +To run the script locally, you will need to have `uv` installed. Check out the [uv documentation](https://docs.astral.sh/uv/) for more details. + + + + +The reinforcement learning script with GRPO is implemented as a custom script in TRL, which uses the reward function shown above. You can review it at [`grpo.py`](https://huggingface.co/datasets/burtenshaw/lora-without-regrets/blob/main/grpo.py) - Reinforcement learning with LoRA best practices + +## Key findings in optimizing LoRA + +The authors recommend applying LoRA to all weight matrices rather than limiting it to attention layers, as increasing the rank does not compensate for this restriction. In TRL, this can be configured using `--lora_target_modules all-linear` to apply LoRA to all weight matrices. + +We were able to reproduce the results of the blog post using TRL and the SmolLM3 model. We trained the model for 500 steps on the [Math 220k dataset](https://huggingface.co/datasets/HuggingFaceH4/OpenR1-Math-220k-default-verified) with the reward function and configuration above. As you can see in the figure below, the LoRA model's average train reward curve matches the full fine-tuning curve. + +![train reward](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/5.png) + +And most importantly, the LoRA model uses significantly less memory than the full fine-tuning model, as we can see in the figure below. + +![memory usage](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/6.png) + +Here are the parameters we used to train the above models + +| Parameter | LoRA | Full FT | +| --- | --- | --- | +| `--model_name_or_path` | HuggingFaceTB/SmolLM3-3B | HuggingFaceTB/SmolLM3-3B | +| `--dataset_name` | HuggingFaceH4/OpenR1-Math-220k-default-verified | HuggingFaceH4/OpenR1-Math-220k-default-verified | +| `--learning_rate` | 1.0e-5 | 1.0e-6 | +| `--max_prompt_length` | 1024 | 1024 | +| `--max_completion_length` | 4096 | 4096 | +| `--lora_r` | 1 | - | +| `--lora_alpha` | 32 | - | +| `--lora_dropout` | 0.0 | - | +| `--lora_target_modules` | all-linear | - | + +Let's break down the key findings of the blog post and how we were able to reproduce them. + +### 1. *LoRA performs better when applied to all weight matrices* + +The authors recommend applying LoRA to all weight matrices rather than limiting it to attention layers, as increasing the rank does not compensate for this restriction. + +![all layers](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/1.png) + +Attention-only LoRA underperforms even when using a higher rank to match parameter count. In TRL, this can be configured using `--lora_target_modules all-linear` to apply LoRA to all weight matrices. In Python, we can do this like so: + +```python +from peft import LoraConfig + +peft_config = LoraConfig(target_modules="all-linear") +``` + +### 2. *The adapter needs sufficient capacity to learn from the dataset* + +The blog post recommends using a sufficient LoRA rank to learn from the dataset. The rank determines the number of trainable parameters in the LoRA adapter. Therefore, "For datasets that exceed LoRA capacity, LoRA underperforms FullFT". + +![learning rate](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/3.png) + +In the TRL script, we could use `--lora_r` to set the rank and adapt it based on the task and dataset we're training on. The blog post recommends the following ranks based on the task and dataset size: + +Reinforcement learning tasks typically require lower capacity, so smaller LoRA ranks can be used. This is because policy gradient algorithms extract roughly ~1 bit of information per episode, demanding minimal parameter capacity. + +The blog post defines the ideal dataset size for LoRA to match full fine-tuning as "Post-training scale". Which we can use to determine the recommended rank for SFT and RL LoRAs as: + +| Task Type | Dataset Size | Recommended Rank | +| --- | --- | --- | +| **SFT** | Post-training scale | 256 | +| **RL** | Any size | 1-32 | + +### 3. *"FullFT and high-rank LoRAs have similar learning curves"* + +Counterintuitively, the blog post recommends using a higher learning rate than for full fine-tuning. In the table above, we used 1.0e-5 for LoRA and 1.0e-6 for full fine-tuning. In the TRL script, we could use `--learning_rate` to set the learning rate. The \\( \frac{1}{r} \\) scaling in LoRA makes the optimal learning rate approximately rank-independent. + +![learning rate](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/2.png) + +### 4. *"In some scenarios, LoRA is less tolerant of large batch sizes than full fine-tuning."* + +The blog post recommends using an effective batch size < 32 because the authors found LoRA to be less tolerant of large batch sizes. This could not be mitigated by increasing the LoRA rank. In the TRL script, we could use `--per_device_train_batch_size` and `--gradient_accumulation_steps` to set the batch size. + +![learning rate](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/4.png) + +## Takeaways + +Using TRL, you can efficiently implement LoRA adapters to match full fine-tuning performance, applying the core insights (targeting all weight matrices, choosing the right rank, and managing batch size and learning rate) without the heavy compute cost of FullFT. + +## Citation + +```bibtex +@article{schulman2025lora, + title = {{LoRA Without Regret}}, + author = {John Schulman and Thinking Machines Lab}, + year = 2025, + journal = {Thinking Machines Lab: Connectionism}, + doi = {10.64434/tml.20250929}, + note = {https://thinkingmachines.ai/blog/lora/} +} +``` diff --git a/ICL/RL/trl_source/docs/source/merge_model_callback.md b/ICL/RL/trl_source/docs/source/merge_model_callback.md new file mode 100644 index 0000000000000000000000000000000000000000..fd7241e7d15f070698410505ccdadc041726e466 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/merge_model_callback.md @@ -0,0 +1,3 @@ +# MergeModelCallback + +[[autodoc]] experimental.merge_model_callback.MergeModelCallback diff --git a/ICL/RL/trl_source/docs/source/minillm_trainer.md b/ICL/RL/trl_source/docs/source/minillm_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..6db88955dc38ef11e1065f237142147e0342e4e3 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/minillm_trainer.md @@ -0,0 +1,67 @@ +# MiniLLM Trainer + +[![All_models-MiniLLM-blue](https://img.shields.io/badge/All_models-MiniLLM-blue)](https://huggingface.co/models?other=minillm,trl) + +## Overview + +TRL supports the MiniLLM Trainer for distilling large language models into smaller ones using reverse KLD for better precision, quality, and performance, as described in the paper [Knowledge Distillation of Large Language Models](https://huggingface.co/papers/2306.08543) by [Yuxian Gu](https://huggingface.co/t1101675), [Li Dong](https://huggingface.co/unilm), [Furu Wei](https://huggingface.co/thegenerality), and Minlie Huang. +The abstract from the paper is the following: + +> Knowledge Distillation (KD) is a promising technique for reducing the high computational demand of large language models (LLMs). However, previous KD methods are primarily applied to white-box classification models or training small models to imitate black-box model APIs like ChatGPT. How to effectively distill the knowledge from white-box generative LLMs is still under-explored, which becomes more and more important with the prosperity of LLMs. In this work, we propose MiniLLM that distills smaller language models from generative larger language models. We first replace the forward Kullback-Leibler divergence (KLD) objective in the standard KD approaches with reverse KLD, which is more suitable for KD on generative language models, to prevent the student model from overestimating the low-probability regions of the teacher distribution. Then, we derive an effective optimization approach to learn this objective. Extensive experiments in the instruction-following setting show that the MiniLLM models generate more precise responses with the higher overall quality, lower exposure bias, better calibration, and higher long-text generation performance. Our method is also scalable for different model families with 120M to 13B parameters. We will release our code and model checkpoints at https://aka.ms/MiniLLM. + +This post-training method was contributed by [Yuxian Gu](https://huggingface.co/t1101675). + +It is a generalized version of [Think Machine Lab's On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/), with the option to add distribution-level single-step distillation signals (like GKD when `beta=1`) and long-context reverse KLD signals. + +$$ +\begin{align} +L_{\text{MiniLLM}}&=\alpha_1\mathbb{E}_{x\sim \pi_{\theta}}\sum_{t'=t}^{|x|}\frac{\gamma^{t'-t}}{\sum_{t'}\gamma^{t'-t}}\left[\log \frac{\pi_{\theta}(x_{t'+1}|x_{1..t'})}{\pi_{\text{teacher}}(x_{t'+1}|x_{1..t'})}\right] \\ +&+ \alpha_2\mathbb{E}_{x\sim \pi_{\theta}} \text{KL}\left[\pi_\theta(\cdot|x_{1..t})||\pi_{\text{teacher}}(\cdot | x_{1..t})\right]. +\end{align} +$$ + +When \\( \alpha_1=1 \\), \\( \alpha_2=0 \\), \\( \gamma=0 \\), which corresponds to + +```python +from trl.experimental.minillm import MiniLLMConfig + +training_args = MiniLLMConfig( + rkl_advantage=True, + single_step_decomposition=False, + gamma=False +) +``` + +\\( L_{\text{MiniLLM}} \\) becomes the on-policy KD implemented in [Tinker](https://github.com/thinking-machines-lab/tinker-cookbook/blob/5d08be6d130596b7bedd02197861c41fa81ea436/tinker_cookbook/distillation/train_on_policy.py#L88): + +$$ +L_{\text{tinker}}=\mathbb{E}_{x\sim \pi_{\theta}}\left[\log \frac{\pi_{\theta}(x_{t'+1}|x_{1..t'})}{\pi_{\text{teacher}}(x_{t'+1}|x_{1..t'})}\right]. +$$ + +When \\( \alpha_1=0 \\), \\( \alpha_2=1 \\), which corresponds to + +```python +from trl.experimental.minillm import MiniLLMConfig + +training_args = MiniLLMConfig( + rkl_advantage=False, + single_step_decomposition=True +) +``` + +\\( L_{\text{MiniLLM}} \\) becomes the reverse KLD version of the GKD loss as in [GKD Trainer](./gkd.md): + +$$ +L_{\text{GKD-RKL}}=\mathbb{E}_{x\sim \pi_{\theta}} \text{KL}\left[\pi_\theta(\cdot|x_{1..t})||\pi_{\text{teacher}}(\cdot | x_{1..t})\right]. +$$ + +## MiniLLMTrainer + +[[autodoc]] experimental.minillm.MiniLLMTrainer + - train + - save_model + - push_to_hub + +## MiniLLMConfig + +[[autodoc]] experimental.minillm.MiniLLMConfig diff --git a/ICL/RL/trl_source/docs/source/model_utils.md b/ICL/RL/trl_source/docs/source/model_utils.md new file mode 100644 index 0000000000000000000000000000000000000000..6cfdc7b571b19e076cef3c7e86b0620a356a01e7 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/model_utils.md @@ -0,0 +1,13 @@ +# Model Utilities + +## get_act_offloading_ctx_manager + +[[autodoc]] models.get_act_offloading_ctx_manager + +## disable_gradient_checkpointing + +[[autodoc]] models.utils.disable_gradient_checkpointing + +## create_reference_model + +[[autodoc]] create_reference_model diff --git a/ICL/RL/trl_source/docs/source/nash_md_trainer.md b/ICL/RL/trl_source/docs/source/nash_md_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..02ffa7bb6539a003689333c04860166a43d77053 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/nash_md_trainer.md @@ -0,0 +1,159 @@ +# Nash-MD Trainer + +[![model badge](https://img.shields.io/badge/All_models-Nash--MD-blue)](https://huggingface.co/models?other=nash-md,trl) + +## Overview + +Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi. + +The abstract from the paper is the following: + +> Reinforcement learning from human feedback (RLHF) has emerged as the main paradigm for aligning large language models (LLMs) with human preferences. Typically, RLHF involves the initial step of learning a reward model from human feedback, often expressed as preferences between pairs of text generations produced by a pre-trained LLM. Subsequently, the LLM's policy is fine-tuned by optimizing it to maximize the reward model through a reinforcement learning algorithm. However, an inherent limitation of current reward models is their inability to fully represent the richness of human preferences and their dependency on the sampling distribution. In this study, we introduce an alternative pipeline for the fine-tuning of LLMs using pairwise human feedback. Our approach entails the initial learning of a preference model, which is conditioned on two inputs given a prompt, followed by the pursuit of a policy that consistently generates responses preferred over those generated by any competing policy, thus defining the Nash equilibrium of this preference model. We term this approach Nash learning from human feedback (NLHF). In the context of a tabular policy representation, we present a novel algorithmic solution, Nash-MD, founded on the principles of mirror descent. This algorithm produces a sequence of policies, with the last iteration converging to the regularized Nash equilibrium. Additionally, we explore parametric representations of policies and introduce gradient descent algorithms for deep-learning architectures. To demonstrate the effectiveness of our approach, we present experimental results involving the fine-tuning of a LLM for a text summarization task. We believe NLHF offers a compelling avenue for preference learning and policy optimization with the potential of advancing the field of aligning LLMs with human preferences. + +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec). + +## Quick start + +This example demonstrates how to train a model using the Nash-MD method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`experimental.judges.PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here: + + + +Below is the script to train the model: + +```python +# train_nash_md.py +from datasets import load_dataset +from trl.experimental.judges import PairRMJudge +from trl.experimental.nash_md import NashMDConfig, NashMDTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +judge = PairRMJudge() +train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train") + +training_args = NashMDConfig(output_dir="Qwen2-0.5B-NashMD") +trainer = NashMDTrainer( + model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset +) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_nash_md.py +``` + +Distributed across 8 GPUs, the training takes approximately 3 hours. + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models). + +
$ transformers chat trl-lib/Qwen2-0.5B-NashMD
+<quentin_gallouedec>:
+What is the best programming language?
+
+<trl-lib/Qwen2-0.5B-NashMD>:
+The best programming language depends on personal preference, the complexity of the project, and the specific requirements of the task. Some programming languages that are often recommended include Python, Java, and JavaScript, and there are many other languages to choose from depending on individual needs.
+
+ +## Expected dataset type + +Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`experimental.nash_md.NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +## Usage tips + +### Use a reward model + +Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model: + +```diff +- from trl.experimental.judges import PairRMJudge ++ from transformers import AutoModelForSequenceClassification + +- judge = PairRMJudge() ++ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1) + + trainer = NashMDTrainer( + ... +- judge=judge, ++ reward_funcs=reward_model, + ) +``` + +> [!WARNING] +> Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training. + +### Encourage EOS token generation + +We may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`experimental.nash_md.NashMDConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`experimental.nash_md.NashMDConfig`]: + +```python +training_args = NashMDConfig(..., max_new_tokens=128, missing_eos_penalty=1.0) +``` + +### Logging Completions + +To better understand your model’s behavior during training, you can log sample completions periodically using the [`LogCompletionsCallback`]. + +```python +trainer = NashMDTrainer(..., eval_dataset=eval_dataset) +completions_callback = LogCompletionsCallback(trainer, num_prompts=8) +trainer.add_callback(completions_callback) +``` + +This callback logs the model's generated completions directly to Weights & Biases. + +![Logged Completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/wandb_completions.png) + +## Example script + +We provide an example script to train a model using the Nash-MD method. The script is available in [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) + +To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command: + +```bash +python examples/scripts/nash_md.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --judge pair_rm \ + --dataset_name trl-lib/ultrafeedback-prompt \ + --learning_rate 5.0e-7 \ + --output_dir Qwen2.5-0.5B-NashMD-PairRM \ + --warmup_steps 0.1 \ + --push_to_hub +``` + +## Logged metrics + +While training and evaluating, we record the following reward metrics: + +* `loss/kl`: The mean KL divergence between the model and reference data. +* `objective/entropy`: The mean entropy of the model and reference data. +* `loss/score`: The mean reinforce score loss. +* `rewards/chosen`: The mean scores (according to the reward model) of the model completions. +* `rewards/rejected`: The mean scores (according to the reward model) of the mixture completions. +* `rewards/probabilities`: The mean probability (according to the reward model or judge) of the model completions chosen vs the mixture completion. +* `rewards/accuracies`: The accuracies of the Nash-MD's implicit reward model. +* `rewards/margins`: The mean reward margin (according to reward model) between the chosen and mixture completions. +* `logps/chosen`: The mean log probabilities of the chosen completions. +* `logps/rejected`: The mean log probabilities of the reference completions. +* `val/model_contain_eos_token`: The amount of times the model's output contains the eos token. +* `val/ref_contain_eos_token`: The amount of times the mixture's output contains the eos token. +* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`experimental.nash_md.NashMDConfig`]. +* `mixture_coef`: Logit mixture coefficient for the model and reference model. Typically fixed, but can be made dynamic by passing a list to [`experimental.nash_md.NashMDConfig`]. + +## NashMDTrainer + +[[autodoc]] experimental.nash_md.NashMDTrainer + - train + - save_model + - push_to_hub + +## NashMDConfig + +[[autodoc]] experimental.nash_md.NashMDConfig diff --git a/ICL/RL/trl_source/docs/source/nemo_gym.md b/ICL/RL/trl_source/docs/source/nemo_gym.md new file mode 100644 index 0000000000000000000000000000000000000000..295c779e8e204c0365b48ee6defd079f013c3962 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/nemo_gym.md @@ -0,0 +1,293 @@ +# NeMo Gym Integration + +NVIDIA NeMo Gym is a library for building RL environments for large language models. This integration enables training models in NeMo Gym environments using TRL's GRPOTrainer with vLLM server mode. + +The integration supports multi-step and multi-turn rollouts, multi-environment training, and any NeMo Gym environment (thoroughly tested: workplace assistant, reasoning gym, MCQA, and math with judge). + +## Why NeMo Gym + +- **Production-Ready Scale**: Tested for frontier model training with diverse environments running in parallel across math, coding, tool use, reasoning, and more. +- **Multi-Verifier Training**: Supports algorithmic verification, LLM-as-a-judge, and custom verification logic in a single training run. +- **Decoupled Architecture**: Build agents and environments independently from the training loop—no RL framework expertise required. +- **OpenAI-Compatible API**: All environments use the standardized OpenAI Responses API for seamless integration with vLLM, OpenAI models, and other endpoints. + +## Available Environments + +NeMo Gym provides training-ready environments across multiple domains, including but not limited to: + +| Environment | Domain | Description | +|-------------|--------|-------------| +| Workplace Assistant | Agent | Multi-step tool calling in common office scenarios (calendar, email, and more) | +| Math with Judge | Math | Math problems with algorithmic or judge-based verification | +| Code Gen | Coding | Competitive programming problems with code execution | +| MCQA | Knowledge | Multiple-choice question answering | +| Instruction Following | Instruction Following | IFEval/IFBench style tasks | +| Reasoning Gym | Multiple | Single-step procedurally generated verifiable tasks across domains | + +For a complete list of available training environments, refer to the [NeMo Gym repository](https://github.com/NVIDIA-NeMo/Gym#-available-resource-servers). + +## Before You Start + +Complete these one-time setup steps before running training. + +### Install TRL and NeMo Gym + +1. **Install TRL with vLLM extras** + + ```bash + cd trl/ + uv venv + source .venv/bin/activate + uv sync --extra vllm + ``` + +1. **Install NeMo Gym** + + ```bash + # deactivate trl venv + deactivate + git clone https://github.com/NVIDIA-NeMo/Gym.git + cd Gym + uv venv --python 3.12 + source .venv/bin/activate + uv sync + ``` + +### Prepare a Dataset + +Many NeMo Gym datasets used to train Nemotron models are available on Hugging Face. Use `ng_prepare_data` to download and prepare datasets. This command: + +- Downloads the dataset from Hugging Face +- Validates the data format +- Adds an `agent_ref` field to each example that tells NeMo Gym which agent server should handle that example + +> **Note**: `train_multi_environment.py` adds the `agent_ref` field when loading datasets, so this step is optional if datasets are created another way. + +1. **Set Hugging Face Token** + + Create `env.yaml` in `Gym/` with your HF token: + + ```yaml + hf_token: + ``` + +1. **Prepare Dataset** + + ```bash + # Enter Gym and activate the venv + cd Gym + source .venv/bin/activate + + # Set config paths + config_paths="responses_api_models/vllm_model/configs/vllm_model.yaml,\ + resources_servers/workplace_assistant/configs/workplace_assistant.yaml" + + # Download data and prep for training + ng_prepare_data "+config_paths=[${config_paths}]" \ + +output_dirpath=data/workplace_assistant \ + +mode=train_preparation \ + +should_download=true \ + +data_source=huggingface + ``` + + This creates `train.jsonl` and `validation.jsonl` files in `data/workplace_assistant/`. + +To create a new environment, refer to the [environment creation guide](https://docs.nvidia.com/nemo/gym/latest/contribute/environments/new-environment.html). We suggest running an existing one first! + +#### Dataset Format + +NeMo Gym datasets are stored as JSONL. Each line contains a task with input messages, tool definitions, metadata such as ground truth for verification, and an agent server reference. The following example shows the workplace dataset structure. Metadata fields can differ between datasets, as long as the corresponding resources server uses the fields appropriately. + +```json +{ + "responses_create_params": { + "input": [ + {"role": "system", "content": "..."}, + {"role": "user", "content": "Move any of jinsoo's tasks that are in review to completed"} + ], + "tools": [...], + "parallel_tool_calls": false, + "temperature": 1 + }, + "ground_truth": [ + {"name": "project_management_update_task", "arguments": "{...}"}, + ... + ], + "category": "workbench_project_management", + "environment_name": "workbench", + "agent_ref": { + "type": "responses_api_agents", + "name": "workplace_assistant_simple_agent" + } +} +``` + +## Interactive Training + +For development and testing on a single node. + +### Set Up + +1. **Update Environment Config** + + Update `env.yaml` in `Gym/` to include model information: + + ```yaml + policy_base_url: http://127.0.0.1:8000/v1 + policy_api_key: EMPTY + policy_model_name: Qwen/Qwen2.5-1.5B-Instruct + hf_token: ... + ``` + +2. **Update Training Config** + + Update `examples/scripts/nemo_gym/config.yaml` to point to the dataset generated above, and any other optional modifications. + +### Run Training + +The following steps run in 3 terminals. It can also be ran with processes in the background, or using tmux. + +1. **Start NeMo Gym Servers** (Terminal 1) + + ```bash + cd Gym/ + source .venv/bin/activate + + config_paths="resources_servers/workplace_assistant/configs/workplace_assistant.yaml,\ + responses_api_models/vllm_model/configs/vllm_model_for_training.yaml" + + ng_run "+config_paths=[${config_paths}]" + ``` + + This starts: + - **Agent server**: Orchestrates rollouts using resource servers and model servers + - **Resources server**: Supports environment logic such as state-management, tool implementations, and task verification + - **Model server**: Adapts vLLM server requests to support NeMo Gym agents and on-policy RL training while ensuring OpenAI API compatibility + - **Head server**: Manages servers used in training enabling their discovery + +1. **Start TRL vLLM Server on GPU 0** (Terminal 2) + + ```bash + cd trl/ + source .venv/bin/activate + CUDA_VISIBLE_DEVICES=0 trl vllm-serve \ + --model Qwen/Qwen2.5-1.5B-Instruct \ + --max-model-len 16384 \ + --host 0.0.0.0 \ + --port 8000 + ``` + +1. **Run Training on GPU 1** (Terminal 3) + + ```bash + source trl/.venv/bin/activate + cd trl/examples/scripts/nemo_gym + export WANDB_API_KEY=... + uv add omegaconf + + CUDA_VISIBLE_DEVICES=1 python train_multi_environment.py --config config.yaml + ``` + +## Multi-Node Training with Slurm + +An example five-node training script is provided in `submit.sh`. Nodes one through four run the training algorithm, while node five runs vLLM inference for NeMo Gym agent rollouts. + +1. **Configure the Script** + + Update `submit.sh` with your Slurm account, partition, paths to your project directory, and updated training configs. + +1. **Submit the Job** + + ```bash + sbatch submit.sh + ``` + +1. **Monitor Training** + + ```bash + tail -f logs//* + ``` + +> **Tip**: Set up wandb logging for detailed training metrics. For more details on TRL's vLLM integration, refer to the vLLM integration page. + +## Multi-Environment Training + +Train on multiple NeMo Gym environments simultaneously. This allows learning diverse capabilities, such as tool calling and math reasoning, in a single training run. + +1. **Prepare Individual Datasets** + + Prepare datasets for each environment. The workplace assistant dataset was prepared above. Now lets create a dataset for the mini sudoku environment implemented by the reasoning gym resources server in NeMo Gym: + + ```bash + cd Gym + source .venv/bin/activate + uv add reasoning-gym + cd resources_servers/reasoning_gym + python scripts/create_dataset.py \ + --task mini_sudoku \ + --size 2000 \ + --seed 42 \ + --output data/reasoning_gym/train_mini_sudoku.jsonl + + python scripts/create_dataset.py \ + --task mini_sudoku \ + --size 50 \ + --seed 24 \ + --output data/reasoning_gym/val_mini_sudoku.jsonl + ``` + +1. **Create Combined Dataset** + + Combine datasets into a single file with tasks from both environments: + + ```bash + cat data/workplace_assistant/train_workplace.jsonl data/reasoning_gym/train_mini_sudoku.jsonl | shuf > train_multi_env.jsonl + ``` + + > **Tip**: Ensure datasets are the same size before shuffling for an even blend of tasks. Repeat for the validation dataset. + +1. **Update Training Config** + + Update the config to point to the combined dataset: + + ```yaml + model_name: "Qwen/Qwen3-4B-Instruct-2507" + + dataset_path: "/path/to/data/train_multi_env.jsonl" + eval_dataset_path: "/path/to/data/val_multi_env.jsonl" + + task: "workplace-sudoku" # used in wandb run name + output_dir: "outputs/nemo_gym_multi_env" + + # ... rest of config same + ``` + +1. **Update ng_run** + + Whether training interactively or via Slurm, update the `ng_run` command to include config files from each resources server: + + ```bash + cd Gym + source .venv/bin/activate + + config_paths="responses_api_models/vllm_model/configs/vllm_model.yaml,\ + resources_servers/workplace_assistant/configs/workplace_assistant.yaml,\ + resources_servers/reasoning_gym/configs/reasoning_gym.yaml" + + ng_run "+config_paths=[${config_paths}]" + ``` + + This starts servers for both environments. The training script automatically routes each example to the correct agent server based on its `agent_ref` field. + +1. **Run Training** + + Update the Slurm submission script to use the new training config and both `ng_run` resources server configs, then submit the job as before. + + The training script reads `agent_ref` from each example's metadata, routes requests to the correct NeMo Gym agent server, and handles different agents and environments in the same batch. + +## Resources + +- [NeMo Gym GitHub](https://github.com/NVIDIA-NeMo/Gym) +- [NeMo Gym Documentation](https://docs.nvidia.com/nemo/gym/latest/) +- [Training Script](https://github.com/huggingface/trl/blob/main/examples/scripts/nemo_gym/train_multi_environment.py) +- [TRL GRPO Trainer](grpo_trainer) diff --git a/ICL/RL/trl_source/docs/source/online_dpo_trainer.md b/ICL/RL/trl_source/docs/source/online_dpo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..f442d32e4a7766a94b6c58dffdfc451e282c24ba --- /dev/null +++ b/ICL/RL/trl_source/docs/source/online_dpo_trainer.md @@ -0,0 +1,270 @@ +# Online DPO Trainer + +[![model badge](https://img.shields.io/badge/All_models-Online_DPO-blue)](https://huggingface.co/models?other=online-dpo,trl) + +## Overview + +Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel. + +The abstract from the paper is the following: + +> Direct alignment from preferences (DAP) methods, such as DPO, have recently emerged as efficient alternatives to reinforcement learning from human feedback (RLHF), that do not require a separate reward model. However, the preference datasets used in DAP methods are usually collected ahead of training and never updated, thus the feedback is purely offline. Moreover, responses in these datasets are often sampled from a language model distinct from the one being aligned, and since the model evolves over training, the alignment phase is inevitably off-policy. In this study, we posit that online feedback is key and improves DAP methods. Our method, online AI feedback (OAIF), uses an LLM as annotator: on each training iteration, we sample two responses from the current model and prompt the LLM annotator to choose which one is preferred, thus providing online feedback. Despite its simplicity, we demonstrate via human evaluation in several tasks that OAIF outperforms both offline DAP and RLHF methods. We further show that the feedback leveraged in OAIF is easily controllable, via instruction prompts to the LLM annotator. + +This post-training method was contributed by [Michael Noukhovitch](https://huggingface.co/mnoukhov), [Shengyi Costa Huang](https://huggingface.co/vwxyzjn), [Quentin Gallouédec](https://huggingface.co/qgallouedec), and [Edward Beeching](https://huggingface.co/edbeeching). + +## Quick start + +This example demonstrates how to train a model using the online DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`experimental.judges.PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here: + + + +Below is the script to train the model: + +```python +# train_online_dpo.py +from datasets import load_dataset +from trl.experimental.judges import PairRMJudge +from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +judge = PairRMJudge() +train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train") + +training_args = OnlineDPOConfig(output_dir="Qwen2-0.5B-OnlineDPO") +trainer = OnlineDPOTrainer( + model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset +) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_online_dpo.py +``` + +Distributed across 8 GPUs, the training takes approximately 1 hour. You can verify the training progress by checking the reward graph. An increasing trend in both the reward for rejected and chosen completions indicates that the model is improving and generating better responses over time. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/online-dpo-qwen2.png) + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models). + +
$ transformers chat trl-lib/Qwen2-0.5B-OnlineDPO
+<quentin_gallouedec>:
+What is the best programming language?
+
+<trl-lib/Qwen2-0.5B-OnlineDPO>:
+The best programming language depends on your specific needs and priorities. Some people prefer imperative programming languages (like Haskell or Lisp), while others prefer functional programming languages (like Scala or Python). It's important to consider your work style, programming environment, and project requirements when choosing a programming language.
+
+ +## Expected dataset type + +Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`experimental.online_dpo.OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +## Usage tips + +### Use a reward model + +Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model: + +```diff +- from trl.experimental.judges import PairRMJudge ++ from transformers import AutoModelForSequenceClassification + +- judge = PairRMJudge() ++ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1) ++ reward_tokenizer = AutoTokenizer.from_pretrained("trl-lib/Qwen2-0.5B-Reward") + + trainer = OnlineDPOTrainer( + ... +- judge=judge, ++ reward_funcs=reward_model, ++ reward_processing_class=reward_tokenizer, + ... + ) +``` + +### Encourage EOS token generation + +When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`experimental.online_dpo.OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`experimental.online_dpo.OnlineDPOConfig`]: + +```python +training_args = OnlineDPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0) +``` + +### Logging Completions + +To better understand your model’s behavior during training, you can log sample completions periodically using the [`LogCompletionsCallback`]. + +```python +trainer = OnlineDPOTrainer(..., eval_dataset=eval_dataset) +completions_callback = LogCompletionsCallback(trainer, num_prompts=8) +trainer.add_callback(completions_callback) +``` + +This callback logs the model's generated completions directly to Weights & Biases. + +![Logged Completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/wandb_completions.png) + +## Example script + +We provide an example script to train a model using the online DPO method. The script is available in [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py) + +To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command: + +```bash +python examples/scripts/dpo_online.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --judge pair_rm \ + --dataset_name trl-lib/ultrafeedback-prompt \ + --learning_rate 5.0e-7 \ + --output_dir Qwen2.5-0.5B-Online-DPO-PairRM \ + --warmup_steps 0.1 \ + --push_to_hub +``` + +## Logged metrics + +While training and evaluating, we record the following reward metrics. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/w4apmsi9) + +* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model. +* `objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model. +* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence. +* `objective/rlhf_reward`: The mean RLHF reward, which is `scores - non_score_reward`. The `rlhf_reward` is the ultimate objective of online DPO training. If training works as intended, this metric should keep going up. +* `objective/scores`: The mean scores returned by the reward model. +* `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions. +* `rewards/chosen`: The mean reward (according to online DPO's implicit reward model)of the chosen completions. +* `rewards/rejected`: The mean reward (according to online DPO's implicit reward model) of the rejected completions. +* `rewards/accuracies`: The accuracies of the online DPO's implicit reward model. +* `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions. +* `logps/chosen`: The mean log probabilities of the chosen completions. +* `logps/rejected`: The mean log probabilities of the rejected completions. +* `val/contain_eos_token`: The fraction of completions which contain an EOS token. +* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`experimental.online_dpo.OnlineDPOConfig`]. + +## Benchmark experiments + +To validate the online DPO implementation works, we ran experiments with the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). + +```shell +# 1B Online DPO experiment +accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \ + examples/scripts/dpo_online.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-1b-deduped-tldr-online-dpo \ + --beta 0.1 \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 3 \ + --max_new_tokens 53 \ + --warmup_steps 0.1 \ + --missing_eos_penalty 1.0 \ + --save_steps 0.1 \ + --push_to_hub + +# 2.8B Online DPO experiment +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/dpo_online.py \ + --model_name_or_path trl-lib/pythia-2.8b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-2.8b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-2.8b-deduped-tldr-online-dpo \ + --beta 0.1 \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 3 \ + --max_new_tokens 53 \ + --warmup_steps 0.1 \ + --missing_eos_penalty 1.0 \ + --save_steps 0.1 \ + --push_to_hub + +# 6.9B Online DPO experiment +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/dpo_online.py \ + --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-6.9b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-6.9b-deduped-tldr-online-dpo \ + --beta 0.1 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --num_train_epochs 3 \ + --max_new_tokens 53 \ + --warmup_steps 0.1 \ + --missing_eos_penalty 1.0 \ + --save_steps 0.1 \ + --push_to_hub +``` + +Checkpoints and experiment tracking are available at: + +* [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07) +* [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0) + +To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR. +For more information on how to use judges, see [Judges](judges). + +```bash +$ python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 33.00% +python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 41.50% +python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 62.60% +python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 74.20% +``` + +We can then plot the RLHF scaling chart. + +```python +import matplotlib.pyplot as plt + +results = { + "SFT": {1.0e9: 0.21, 2.8e9: 0.27, 6.9e9: 0.316}, + "online-dpo": {1.0e9: 0.542, 2.8e9: 0.746, 6.9e9: 0.796}, + "offline-dpo": {1.0e9: 0.422, 2.8e9: 0.517, 6.9e9: 0.701}, +} + + +plt.plot(results["SFT"].keys(), results["SFT"].values(), label="SFT", marker="o") +plt.plot(results["online-dpo"].keys(), results["online-dpo"].values(), label="Online-dpo with RM judge", marker="o") +plt.plot(results["offline-dpo"].keys(), results["offline-dpo"].values(), label="Offline-dpo", marker="o") +plt.axhline(y=0.5, color="black", linestyle="-.", label="Human reference summary") +plt.xscale("log") +plt.xlabel("Model size") +plt.ylabel("Win rate against reference summaries\n(according to GPT-4-0613)") +plt.title("DPO scaling by model size") +plt.legend() +plt.xlim(5e8, 1.2e10) +plt.xticks([1e9, 3e9, 1e10], ["1B", "3B", "10B"]) +plt.grid(True, which="both", ls="--", c="0.7") +plt.tight_layout() +plt.show() +``` + +The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended. + +## OnlineDPOTrainer + +[[autodoc]] experimental.online_dpo.OnlineDPOTrainer + - train + - save_model + - push_to_hub + +## OnlineDPOConfig + +[[autodoc]] experimental.online_dpo.OnlineDPOConfig diff --git a/ICL/RL/trl_source/docs/source/openenv.md b/ICL/RL/trl_source/docs/source/openenv.md new file mode 100644 index 0000000000000000000000000000000000000000..988a64e3b9d668896ccb4c71e1fac6fed7074c71 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/openenv.md @@ -0,0 +1,603 @@ +# OpenEnv Integration for Training LLMs with Environments + +[OpenEnv](https://github.com/meta-pytorch/OpenEnv) is an open-source framework from Meta's PyTorch team for defining, deploying, and interacting with environments in reinforcement learning (RL) and agentic workflows. It offers [Gymnasium-style APIs](https://gymnasium.farama.org) (e.g., `reset()` and `step()`) to interface with environments in a standard manner, and supports running these environments as backend servers (for example, via HTTP or containerised execution). You can find a collection of ready-to-use OpenEnv environments on the [Hugging Face Hub](https://huggingface.co/collections/openenv/environment-hub). + +In this guide, we’ll focus on **how to integrate OpenEnv with TRL**, but feel free to explore the links above to dive deeper into OpenEnv itself. + +> [!NOTE] +> You can explore ready-to-use example [scripts](example_overview#scripts) and [notebooks](example_overview#notebooks) in the Examples Overview. + +> [!NOTE] +> Explore the [OpenEnv docs](https://meta-pytorch.org/OpenEnv/) for more details. + +## Installation + +To use OpenEnv with TRL, install the environment package. You have two options: + +**Option A - Install from HF Space (recommended):** + +```bash +pip install git+https://huggingface.co/spaces/openenv/echo_env +``` + +> [!TIP] +> You can also install the core package from PyPI with `pip install "openenv-core[core]>=0.2.1"`, but note that environment-specific dependencies may need to be installed separately. + +**Option B - Clone OpenEnv repo (for development):** + +```bash +git clone https://github.com/meta-pytorch/OpenEnv.git +cd OpenEnv/envs/echo_env +pip install -e . +``` + +## Using `rollout_func` with OpenEnv environments + +TRL's [`GRPOTrainer`] supports _custom rollout logic_ through the `rollout_func` argument. This lets you override the trainer's default text-generation loop and directly interact with OpenEnv environments — for instance, to compute environment-driven rewards instead of relying solely on model-based signals. + +### Rollout Function Signature + +A rollout function must have the following signature: + +```python +def rollout_func( + prompts: list[str], + trainer: GRPOTrainer, +) -> dict[str, list]: + """ + Custom rollout function for generation and reward computation. + + Args: + prompts: List of prompts routed to the current process + trainer: Active GRPOTrainer (gives access to tokenizer, config and helper utilities) + + Returns: + Dictionary containing: + - prompt_ids: List of token IDs for each prompt + - completion_ids: List of token IDs for each completion + - logprobs: List of log probabilities for each token + - Any additional fields are forwarded to reward functions as kwargs + """ + pass +``` + +> [!NOTE] +> Any extra fields in the returned dictionary (beyond the required three) are automatically forwarded to your reward functions. This makes it easy to propagate signals such as environment rewards or auxiliary metrics from the rollout step. + +### Integration pattern + +The typical pattern when combining OpenEnv with TRL looks like this: + +1. Start or connect to an OpenEnv environment (e.g., a Dockerized env or HTTP endpoint). +2. Generate completions from your model — either via `trl.experimental.openenv.generate_rollout_completions` when using colocated vLLM, or by hitting your inference server when using vLLM in server mode. +3. Step through the environment using each completion to compute rewards or metrics. +4. Add environment results (e.g., `env_reward`) to the rollout result dict. +5. Access those rewards inside your reward function via `**kwargs`. + +By using OpenEnv in this loop, you can: + +* Train with realistic or interactive feedback (not just static reward functions). +* Plug in custom simulators, web APIs, or evaluators as environments. +* Pass structured reward signals back into RL training seamlessly. + +### vLLM Modes + +TRL supports two vLLM execution modes for generation: + +- **`colocate` mode** (default): vLLM runs in the same process as training. Requires 1 GPU. Use `trl.experimental.openenv.generate_rollout_completions` for generation. +- **`server` mode**: vLLM runs as a separate server process. Requires at least 2 GPUs (one for vLLM server, one for training), but is highly scalable: + - You can allocate multiple GPUs to the vLLM server for tensor parallelism (faster inference) + - You can run multiple training processes that share the same vLLM server + - You can use different GPU types for inference vs training (e.g., A100 for vLLM, H100 for training) + - The vLLM server can serve multiple experiments simultaneously + - Use `trl.experimental.openenv.generate_rollout_completions` which will communicate with the server via `vllm_server_url` + +Configure the mode via `GRPOConfig`: + +```python +# Colocate mode (1 GPU) +args = GRPOConfig( + use_vllm=True, + vllm_mode="colocate", + # ... other args +) + +# Server mode (2+ GPUs, scalable) +args = GRPOConfig( + use_vllm=True, + vllm_mode="server", + vllm_server_base_url="http://localhost:8000", + # ... other args +) + +# Example: Start vLLM server with multiple GPUs for tensor parallelism +# CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen3-1.7B --tensor-parallel-size 4 +``` + +## Running the Environments + +You can run OpenEnv environments in three different ways: + +- We can load the environment from the Hugging Face Hub and execute it as a Docker container. +- We can connect to a hosted environment running on the Hugging Face Hub. +- We can launch the environment directly using Uvicorn in Python. + + + + + +**Load from Hugging Face Hub** *(recommended)* + +We can use the [`from_hub`](https://meta-pytorch.org/OpenEnv/core/#core.http_env_client.HTTPEnvClient.from_hub) method to load the environment from the hub. This method will automatically start a Docker container for the environment on your local machine. [`openenv/echo-env`](https://huggingface.co/spaces/openenv/echo_env) is the repo_id of the space on the hub. + +```python +env = EchoEnv.from_hub("openenv/echo-env") +``` + +If you want to launch the environment manually, you can use the following command to pull and run the Docker container: + +```bash +docker run -d -p 8001:8000 --platform linux/amd64 registry.hf.space/openenv-echo-env:latest +``` + +And then you can connect to the environment using the following code: + +```python +env = EchoEnv(base_url="http://0.0.0.0:8001") +``` + +Here, we map the ports from 8001 to 8000 to make space for a vLLM server, but you will need to manage the ports for your local machine. + +> [!NOTE] +> You can find the Docker container for any space on the hub. +> +> * Open the space page on the hub. +> * Click the **⋮ (three dots)** menu. +> * Select **“Run locally.”** +> * Copy and execute the provided command in your terminal. +> +> ![open_env_launch_docker](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/open_env_launch_docker.png) + +> [!NOTE] +> You can also use the **Docker option** with `from_docker_image` by providing the image name.. +> For more details, refer to the official [OpenEnv documentation](https://meta-pytorch.org/OpenEnv/core/). + + + + +**Connect to a remote Hugging Face Space** + +You can connect to a hosted environment running on the Hugging Face Hub by passing the URL of the space to the `base_url` parameter of the environment class. + +```python +env = EchoEnv(base_url="https://openenv-echo-env.hf.space") +``` + +> [!NOTE] +> You can find the connection URL of any space on the hub. +> +> * Open the space page on the hub. +> * Click the **⋮ (three dots)** menu. +> * Select **“Embed this Space.”** +> * Copy the connection URL. + +> [!WARNING] +> **Currently**, it is recommended to **duplicate the Space to your own account** to avoid potential concurrency issues. + + + + + +**Local Python process** + +You can start the server manually as a local Python process. For more details about the available environments, refer to the [OpenEnv catalog](https://meta-pytorch.org/OpenEnv/environments/). + +```bash +hf download openenv/echo_env --repo-type=space --local-dir=echo_env +python -m uvicorn echo_env.src.envs.echo_env.server.app:app --host 0.0.0.0 --port 8001 +``` + +And then you can connect to the environment using the following code: + +```python +env = EchoEnv(base_url="http://0.0.0.0:8001") +``` + + + + + +## Environments Catalog + +Environment development is active and evolving. +The best way to explore the **current catalog of maintained environments** is by visiting the official OpenEnv [catalog](https://huggingface.co/collections/openenv/environment-hub). + +Custom environments are also supported. To learn how to create your own, check out the guide on [Building Your Own Environment with OpenEnv](https://meta-pytorch.org/OpenEnv/environment-builder/). + +Environments are tightly integrated with the Hub, allowing you to **push new environments directly** so the community can easily pull, reuse, and adapt them for their own use cases. + +## A simple example + +> [!NOTE] +> You can explore more ready-to-use example scripts in the [`examples/scripts/openenv/`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/) directory. + +The [echo.py](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py) script demonstrates a minimal, end-to-end integration between TRL and OpenEnv. In this example, the [Echo environment](https://meta-pytorch.org/OpenEnv/environments/echo/) rewards completions based on their text length, encouraging the model to generate longer outputs. This pattern can be extended to any custom environment that provides structured feedback or task-based rewards: + +```python +from echo_env import EchoEnv, EchoAction +from trl import GRPOConfig, GRPOTrainer +from trl.experimental.openenv import generate_rollout_completions + +# Create HTTP client for Echo Environment +client = EchoEnv.from_hub("openenv/echo-env") + +""" +Alternatively, you can start the environment manually with Docker and connect to it: + +# Step 1: Start the Echo environment +docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest + +# Step 2: Connect the client to the running container +client = EchoEnv(base_url="http://0.0.0.0:8001") +""" + +def rollout_func(prompts: list[str], trainer: GRPOTrainer): + # 1. Generate completions using TRL's helper (works for colocated vLLM) + outputs = generate_rollout_completions(trainer, prompts) + tokenizer = trainer.processing_class + completions_text = [ + tokenizer.decode(out["completion_ids"], skip_special_tokens=True) for out in outputs + ] + + # 2. Step through the environment to get rewards + client.reset() + env_rewards = [] + for msg in completions_text: + env_result = client.step(EchoAction(message=msg)) + env_rewards.append(env_result.reward) + + # 3. Add environment rewards as extra field + return { + "prompt_ids": [out["prompt_ids"] for out in outputs], + "completion_ids": [out["completion_ids"] for out in outputs], + "logprobs": [out["logprobs"] for out in outputs], + "env_reward": env_rewards, + } + +def reward_from_env(completions, **kwargs): + """Extract environment rewards passed via rollout_func kwargs.""" + env_rewards = kwargs.get("env_reward", []) + return [float(reward) for reward in env_rewards] if env_rewards else [0.0] * len(completions) + +dataset = Dataset.from_dict({"prompt": ["You are an AI that interacts with an *Echo* environment. Word to echo:"] * 64}) + +# Setup trainer with custom rollout +trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + reward_funcs=reward_from_env, + train_dataset=dataset, + rollout_func=rollout_func, # Use custom rollout + args=GRPOConfig( + use_vllm=True, + vllm_mode="colocate", # Use colocate mode (default) + num_train_epochs=1, + num_generations=8, + max_completion_length=2048, + per_device_train_batch_size=8, + gradient_accumulation_steps=4, + ), +) +trainer.train() +``` + +That's it! Now that you've seen the full example, let's unpack how the main pieces fit together. + +1. **Environment Client:** `EchoEnv` implements an HTTP interface to interact with the environment server. +2. **Custom rollout:** The `rollout_func` generates completions and steps through the environment to collect rewards. +3. **Extra fields:** The rollout adds `env_reward` to the result dictionary, which is automatically passed to reward functions. +4. **Reward function:** Extracts `env_reward` from `kwargs` to apply environment-computed rewards during training. + +> [!TIP] +> The trainer-aware rollout hook works in both vLLM server and colocate modes. Use `trl.experimental.openenv.generate_rollout_completions` so you reuse TRL's sampling configuration automatically. + +### Running the Example + +You can run the example in either colocate mode (1 GPU) or server mode (2 GPUs): + + + + + +**Colocate mode (1 GPU, recommended)** + +```bash +python examples/scripts/openenv/echo.py --env-mode space --env-host https://openenv-echo-env.hf.space --vllm-mode colocate +``` + +This runs vLLM in the same process as training, requiring only a single GPU. + + + + + +**Server mode (2+ GPUs, scalable)** + +```bash +# Terminal 1: Start vLLM inference server +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 + +# Terminal 2: Run GRPO training with OpenEnv +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/echo.py --env-mode space --env-host https://openenv-echo-env.hf.space --vllm-mode server --vllm-server-url http://localhost:8000 +``` + +This runs vLLM as a separate server process, useful when you want to: +- Share the inference server across multiple training jobs +- Use multiple GPUs for the vLLM server (via `--tensor-parallel-size`) +- Scale up training to many GPUs while sharing a single inference endpoint + + + + + +Alternatively, you can manually start the Echo environment in a Docker container before running the training: + +```bash +# Launch the Echo environment +docker run -d -p 8001:8001 registry.hf.space/openenv-echo-env:latest + +# Run training with docker-local mode +python examples/scripts/openenv/echo.py --env-mode docker-local --vllm-mode colocate +``` + +Below is the reward curve from training: + + + +## Advanced Example + +Let's level this up a bit by training a model to interact with a more complex environment. We'll use the game word guessing game [wordle](https://www.nytimes.com/games/wordle/index.html) from the [`TextArena`](https://meta-pytorch.org/OpenEnv/environments/textarena/) environment. + +> [!NOTE] +> You can explore the notebook version of this example [here](https://github.com/huggingface/trl/blob/main/examples/notebooks/openenv_wordle_grpo.ipynb). + +### The TextArena Environment + +[TextArena](https://huggingface.co/papers/2504.11442) is an open-source collection of competitive text-based games designed to evaluate reasoning skills in LLMs using textual games like Wordle, Snake, Tic-Tac-Toe, and more. Research has shown that such games improve model performance on reasoning tasks. + +![image of TextArena](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/text_arena_evals.png) + +We will use the `TextArena` environment to train a model to play Wordle. The environment is a simple text based response environment that allows the model to interact with the game by making guesses and receive feedback on them. + +### Wordle + +Wordle is a useful game to train a model on because it requires the model to reason about the word and the feedback provided by the environment. Also, it is a purely language based game that requires no external tools or knowledge. Furthermore, we found that models from 1 billion parameters and up are able to improve on wordle and only require 8 tokens to generate a guess, which makes the game a good benchmark to experiment with Reinforcement Learning environments without significant compute requirements. + +> [!NOTE] How does Wordle work? +> Wordle is a word guessing game where the player has to guess a 5-letter word. The player can make 6 guesses, and for each guess, the environment will provide feedback on the correctness of the guess. The player wins if they guess the word in 6 guesses or fewer. It challenges the model to generate words that are likely to be correct, and to learn from the feedback provided by the environment. +> +> For example, if the wordle environment returns the following feedback: +> +> ``` +> G U E S S +> X G Y X X +> ``` +> The model has guessed the word "GUESS" and the environment has provided feedback as the letters X, G, and Y. Referring to colors in the original game as blank, green, and yellow. From this feedback, the model should learn that the word "GUESS" is incorrect. The letter "E" is in the word, but in the wrong position. The letter "U" is correct and in the correct position. + +In the TextArena environment, a reward is only given when the model wins the game. The reward is 1.0 if the model wins, and 0.0 otherwise. This is not a very efficient reward signal for the model, so we have added a number of custom reward functions to the script to help the model learn to play the game. The extensible nature of `reward_funcs` and `rollout_func` allows you to add any custom reward function you want to the script. + +### Rollout Function + +The rollout function runs one full Wordle episode, prompting the model for a guess each turn and capturing both environment rewards and auxiliary signals such as letter coverage and repetition penalties. + +```python +def rollout_once( + trainer: GRPOTrainer, + env: TextArenaEnv, + tokenizer: AutoTokenizer, + dataset_prompt: str, + system_prompt: str, + max_turns: int, +) -> dict[str, list]: + result = env.reset() + observation = result.observation + + prompt_ids: list[int] = [] + completion_ids: list[int] = [] + logprobs: list[float] = [] + raw_rewards: list[float] = [] + green_scores: list[float] = [] + yellow_scores: list[float] = [] + repetition_scores: list[float] = [] + correct_scores: list[float] = [] + guess_counts: dict[str, int] = {} + + for _turn in range(max_turns): + # when the game is over the environment will return a done=True + if result.done: + break + + # set up the prompt for the model + base_prompt = observation.prompt or dataset_prompt + user_prompt = make_user_prompt(base_prompt, observation.messages) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + prompt_text = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=False, + enable_thinking=False, + ) + + # Generate completion using trainer (works for both colocate and server modes) + rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0] + prompt_ids.extend(rollout_outputs["prompt_ids"]) + completion_ids.extend(rollout_outputs["completion_ids"]) + logprobs.extend(rollout_outputs["logprobs"]) + completion_text = rollout_outputs.get("text") or tokenizer.decode( + rollout_outputs["completion_ids"], skip_special_tokens=True + ) + + # extract the guess from the completion + guess = extract_guess(completion_text) + + # step the environment with the guess + result = env.step(TextArenaAction(message=guess)) + raw_rewards.append(float(result.reward or 0.0)) + observation = result.observation + correct_score = float(result.reward or 0.0) + feedback = extract_wordle_feedback(observation) + + # Update guess counts + previous_occurrences = guess_counts.get(guess, 0) + repetition_score = scale_repetition_score(previous_occurrences, len(guess_counts)) + guess_counts[guess] = previous_occurrences + 1 + + # calculate custom reward signals from the feedback + if not feedback: + green_score = 0.0 + yellow_score = 0.0 + else: + green_count, yellow_count = extract_feedback_counts(feedback) + green_score = green_count / 5.0 + yellow_score = yellow_count / 5.0 + + repetition_scores.append(repetition_score) + green_scores.append(green_score) + yellow_scores.append(yellow_score) + correct_scores.append(correct_score) + + correct_reward_value = correct_scores[-1] if correct_scores else (raw_rewards[-1] if raw_rewards else 0.0) + + return { + "prompt_ids": prompt_ids, + "completion_ids": completion_ids, + "logprobs": logprobs, + "raw_rewards": raw_rewards, + "correct_reward": correct_reward_value, + "green_reward": green_scores[-1] if green_scores else 0.0, + "yellow_reward": yellow_scores[-1] if yellow_scores else 0.0, + "repetition_reward": repetition_scores[-1] if repetition_scores else 0.0, + } +``` + +The environment has a reward signal based on the completion of the game. We found that most models struggle to ever win the game, so we have added a number of custom reward functions to the script to help the model learn to play the game more iteratively. At first, the model will learn to cover new letters and avoid repeating guesses. As it improves, it will learn to win the game. + +### Reward Functions + +We log four reward streams that encourage the model to solve the puzzle, cover new letters, and avoid repeating guesses: + +- `reward_correct`: final win/loss signal from the environment. +- `reward_greens`: density of green letters in the last feedback. +- `reward_yellows`: density of yellow letters in the last feedback. +- `reward_repetition`: penalty for guessing the same token multiple times. + +```python +def reward_correct(completions: List[str], **kwargs: Optional[Dict]) -> List[float]: + rewards = kwargs.get("correct_reward") if kwargs else None + return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions) + + +def reward_greens(completions: List[str], **kwargs: Optional[Dict]) -> List[float]: + rewards = kwargs.get("green_reward") if kwargs else None + return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions) + + +def reward_yellows(completions: List[str], **kwargs: Optional[Dict]) -> List[float]: + rewards = kwargs.get("yellow_reward") if kwargs else None + return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions) + + +def reward_repetition(completions: List[str], **kwargs: Optional[Dict]) -> List[float]: + rewards = kwargs.get("repetition_reward") if kwargs else None + return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions) +``` + +### Training the Model + +The training script wires the custom rollout and rewards into `GRPOTrainer`. The CLI exposes the configuration used during development as defaults, so you can override endpoints or hyperparameters at launch time. + +```python +parser = argparse.ArgumentParser() +# ... add CLI arguments with sensible defaults ... +cli_args = parser.parse_args() + +trainer = GRPOTrainer( + model=cli_args.model_id, + processing_class=tokenizer, + reward_funcs=[ + reward_correct, + reward_greens, + reward_yellows, + reward_repetition, + ], + train_dataset=dataset, + args=grpo_config, + rollout_func=lambda prompts, trainer: rollout_func( + env=env, + tokenizer=tokenizer, + prompts=prompts, + trainer=trainer, + cli_args=cli_args, + system_prompt=system_prompt, + ), +) +trainer.train() +``` + +### Running the Advanced Example + +You can run the Wordle example in either colocate mode (1 GPU) or server mode (2 GPUs): + + + + + +**Colocate mode (1 GPU, recommended)** + +```bash +python examples/scripts/openenv/wordle.py --vllm-mode colocate +``` + +This runs vLLM in the same process as training, requiring only a single GPU. + + + + + +**Server mode (2+ GPUs, scalable)** + +```bash +# Terminal 1: Start vLLM inference server +CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen3-1.7B --host 0.0.0.0 --port 8000 + +# Terminal 2: Run GRPO training with OpenEnv +CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py --vllm-mode server --vllm-server-url http://localhost:8000 +``` + +This runs vLLM as a separate server process, useful when you want to: +- Share the inference server across multiple training jobs +- Use multiple GPUs for the vLLM server (via `--tensor-parallel-size`) +- Scale up training to many GPUs while sharing a single inference endpoint + + + + + +You can also manually start the TextArena environment in a Docker container before running the training: + +```bash +# Launch the TextArena environment +docker run -d -p 8001:8001 registry.hf.space/burtenshaw-textarena:latest +``` + +Then connect to it using `--env-mode docker-local--env-host localhost --env-port 8001`. + +### Results + +The resulting model improves its performance on the game, both by reducing the number of repetitions and by increasing the number of correct guesses. However, the Qwen3-1.7B model we trained is not able to consistently win the game. The following reward curve shows the coverage of the model's guesses and the coverage of correct Y and G letters. + + + +We experimented with larger models like `gpt-oss-20b` and found that the model was able to consistently win the game. However, this requires a lot of compute to train the model. Why not try this out yourself? diff --git a/ICL/RL/trl_source/docs/source/orpo_trainer.md b/ICL/RL/trl_source/docs/source/orpo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..555f0858316a3a2ff87b111e1100f3a69993e7f4 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/orpo_trainer.md @@ -0,0 +1,131 @@ +# ORPO Trainer + +[![model badge](https://img.shields.io/badge/All_models-ORPO-blue)](https://huggingface.co/models?other=orpo,trl) [![model badge](https://img.shields.io/badge/smol_course-Chapter_2-yellow)](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment) + +## Overview + +Odds Ratio Preference Optimization (ORPO) was introduced in [ORPO: Monolithic Preference Optimization without Reference Model](https://huggingface.co/papers/2403.07691) by [Jiwoo Hong](https://huggingface.co/JW17), [Noah Lee](https://huggingface.co/nlee-208), and [James Thorne](https://huggingface.co/j6mes). + +The abstract from the paper is the following: + +> While recent preference alignment algorithms for language models have demonstrated promising results, supervised fine-tuning (SFT) remains imperative for achieving successful convergence. In this paper, we study the crucial role of SFT within the context of preference alignment, emphasizing that a minor penalty for the disfavored generation style is sufficient for preference-aligned SFT. Building on this foundation, we introduce a straightforward and innovative reference model-free monolithic odds ratio preference optimization algorithm, ORPO, eliminating the necessity for an additional preference alignment phase. We demonstrate, both empirically and theoretically, that the odds ratio is a sensible choice for contrasting favored and disfavored styles during SFT across the diverse sizes from 125M to 7B. Specifically, fine-tuning Phi-2 (2.7B), Llama-2 (7B), and Mistral (7B) with ORPO on the UltraFeedback alone surpasses the performance of state-of-the-art language models with more than 7B and 13B parameters: achieving up to 12.20% on AlpacaEval_{2.0} (Figure 1), 66.19% on IFEval (instruction-level loose, Table 6), and 7.32 in MT-Bench (Figure 12). We release code and model checkpoints for Mistral-ORPO-alpha (7B) and Mistral-ORPO-beta (7B). + +It studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT. + +Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory. + +The official code can be found in [xfactlab/orpo](https://github.com/xfactlab/orpo). + +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Lewis Tunstall](https://huggingface.co/lewtun) and [Alvaro Bartolome](https://huggingface.co/alvarobartt). + +## Quick start + +This example demonstrates how to train a model using the ORPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here: + + + +Below is the script to train the model: + +```python +# train_orpo.py +from datasets import load_dataset +from trl.experimental.orpo import ORPOConfig, ORPOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +training_args = ORPOConfig(output_dir="Qwen2-0.5B-ORPO") +trainer = ORPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_orpo.py +``` + +Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time. + +![orpo qwen2 reward margin](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/orpo-qwen2-reward-margin.png) + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models). + +
$ transformers chat trl-lib/Qwen2-0.5B-ORPO
+<quentin_gallouedec>:
+What is the best programming language?
+
+<trl-lib/Qwen2-0.5B-ORPO>:
+It's challenging to determine the best programming language as no one language is perfect, as the complexity of a task and the type of project are significant factors. Some popular languages include Java, Python, JavaScript, and
+C++. If you have specific needs or requirements for a specific project, it's important to choose the language that best suits those needs.
+
+Here are some other factors to consider when choosing a programming language for a project:
+
+ • Language proficiency: A good programming language is more likely to be easy to understand and use, and will allow developers to collaborate on projects more efficiently.
+ • Ease of use: There are tools and libraries available to make programming more accessible, so developers should choose a language that can help them get started easier.
+ • Code readability: A clear and concise codebase should be easy to read and understand, especially when working with large projects.
+ • Tool and framework support: There are numerous libraries available for Python, Java, and JavaScript, along with tools like IDEs and static code analysis tools.
+ • Accessibility: Some languages and tools have features that make them more accessible to developers with disabilities, such as support for screen readers.
+ • Version control: As your projects grow and complexity increases, version control tools can be beneficial for tracking changes.
+
+
+ +## Expected dataset type + +ORPO requires a [preference dataset](dataset_formats#preference). The [`experimental.orpo.ORPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +Although the [`experimental.orpo.ORPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section. + +## Example script + +We provide an example script to train a model using the ORPO method. The script is available in [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) + +To test the ORPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command: + +```bash +accelerate launch examples/scripts/orpo.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --num_train_epochs 1 \ + --output_dir Qwen2-0.5B-ORPO +``` + +## Usage tips + +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config. + +## Logged metrics + +While training and evaluating, we record the following reward metrics: + +- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta +- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta +- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards +- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards +- `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses +- `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))` +- `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses + +## ORPOTrainer + +[[autodoc]] experimental.orpo.ORPOTrainer + - train + - save_model + - push_to_hub + +## ORPOConfig + +[[autodoc]] experimental.orpo.ORPOConfig diff --git a/ICL/RL/trl_source/docs/source/others.md b/ICL/RL/trl_source/docs/source/others.md new file mode 100644 index 0000000000000000000000000000000000000000..bd89447e7b877f5a24818099510de64ffa772aa0 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/others.md @@ -0,0 +1,9 @@ +# Other + +## profiling_decorator + +[[autodoc]] extras.profiling.profiling_decorator + +## profiling_context + +[[autodoc]] extras.profiling.profiling_context diff --git a/ICL/RL/trl_source/docs/source/paper_index.md b/ICL/RL/trl_source/docs/source/paper_index.md new file mode 100644 index 0000000000000000000000000000000000000000..b4a73346074e7c338826966d3a9a19b4fbbdf433 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/paper_index.md @@ -0,0 +1,1450 @@ +# Paper Index + +> [!WARNING] +> Section under construction. Feel free to contribute! See https://github.com/huggingface/trl/issues/4407. + +## Group Relative Policy Optimization + +Papers relating to the [`GRPOTrainer`]. + +### DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models + +**📜 Paper**: https://huggingface.co/papers/2402.03300 + +Introduces Group Relative Policy Optimization (GRPO) and shows strong math-reasoning gains from math-centric pretraining plus group-relative PPO-style optimization. Used in TRL via [`GRPOTrainer`]. + +```python +from trl import GRPOConfig, GRPOTrainer + +# The paper doesn't specify its hyperparameters, so here we provide hyperparameters from "DeepSeek-R1 incentivizes reasoning in LLMs through reinforcement learning" instead. +training_args = GRPOConfig( + loss_type="grpo", + beta=0.001, # "the KL coefficient to 0.001" + epsilon=10.0, # "the GRPO clip ratio ϵ to 10" + num_generations=16, # "For each question, we sample 16 outputs..." + max_completion_length=32_768, # "...with a maximum length of 32,768" + steps_per_generation=16, # "To accelerate training, each rollout generates 8,192 outputs, which are randomly split into 16 minibatches" + # "resulting in a training batch size of 512". One way to achieve this setting with 1 device is per_device_train_batch_size=4, gradient_accumulation_steps=128 + per_device_train_batch_size=4, + gradient_accumulation_steps=128, +) +trainer = GRPOTrainer( + ..., + args=training_args, +) +``` + +### DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning + +**📜 Paper**: https://huggingface.co/papers/2501.12948 + +DeepSeek-R1 achieves reasoning performance comparable to OpenAI-o1 through a multi-stage pipeline that transitions from pure reinforcement learning (RL) to a refined, human-aligned model. Unlike its predecessor, DeepSeek-R1-Zero, which used pure RL on a base model, R1 follows a structured four-stage evolution: +1. Cold Start: The base model is fine-tuned on a small set of high-quality, long Chain-of-Thought (CoT) data to provide a stable starting point. +2. Reasoning-Oriented RL: Large-scale RL is applied to enhance performance in math, coding, and logic, using rule-based rewards and a language consistency reward to reduce language mixing. +3. Rejection Sampling & SFT: The RL checkpoint generates 600k reasoning samples via rejection sampling, which are combined with 200k non-reasoning (general) samples to create a new dataset for a second round of Supervised Fine-Tuning. +4. RL for all Scenarios: A final RL stage aligns the model with human preferences (helpfulness and harmlessness) across all domains while maintaining reasoning strength. + +Distillation: Empowering Small Models + +A key contribution of the paper is demonstrating that reasoning patterns can be distilled from a large model (DeepSeek-R1) into smaller dense models (e.g., Qwen and Llama series). Distillation was found to be more effective for small models than training them with pure RL from scratch. + + +You can use the GRPOTrainer to replicate the reasoning-heavy stages of this pipeline. +```python +from trl import GRPOConfig, GRPOTrainer + +# Example configuration for a reasoning-oriented GRPO stage +# Based on the Open-R1 recipe for Qwen-7B +training_args = GRPOConfig( + learning_rate=4.0e-5, + max_prompt_length=4096, + max_completion_length=32768, # Support for long Chain-of-Thought + num_generations=16, # Sample 16 outputs per prompt for group relative advantage + beta=0.001, # KL coefficient + use_vllm=True, # Use vLLM backend for accelerated rollout generation +) + +trainer = GRPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + reward_funcs=[accuracy_reward, format_reward], # R1-Zero used rule-based rewards +) + +trainer.train() +``` + + +### Group Sequence Policy Optimization + +**📜 Paper**: https://huggingface.co/papers/2507.18071 + +GSPO is a GRPO variant that computes importance sampling weights at the sequence level instead of per-token. To reproduce the paper's setting, use this configuration: + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + importance_sampling_level="sequence", + loss_type="grpo", + beta=0.0, # GSPO set KL regularization to zero: https://github.com/volcengine/verl/pull/2775#issuecomment-3131807306 + epsilon=3e-4, # GSPO paper (v2), section 5.1 + epsilon_high=4e-4, # GSPO paper (v2), section 5.1 + gradient_accumulation_steps=1, + steps_per_generation=4, # partition rollout batch into 4 mini-batches. GSPO paper (v2), section 5.1. Must be 4 times gradient_accumulation_steps +) +``` + +Note that this method only has an effect when training goes slightly off-policy—for example, when `steps_per_generation > gradient_accumulation_steps` or `num_iterations > 1`. Otherwise, it is effectively equivalent to no modification. + +TRL also provide an experimental implementation of GSPO-token, see [Experimental - GSPO-Token](experimental#gspo-token). + +#### Policy ratio: GRPO vs. GSPO + +In GSPO, the policy ratio is defined at the sequence-level. In other words, it is the ratio between the probability of the current policy generating a sequence over the old policy generating that same sequence. + +The sequence likelihood is defined as: + +$$ +\pi_\theta (o_i | q) = \prod_{t=1}^{|o_i|} \pi_\theta (o_{i,t} | q, o_{i, < t} ), +$$ + +where \\( \pi_\theta \\) is the policy \\( \pi \\) with parameters \\(\theta\\), \\( o_i \\) is the \\( i \\)-th output sequence \\( o \\) and \\(o_{i,t}\\) is the \\( t \\)-th token in this sequence, \\( q \\) is the input query. The sequence likelihood ratio \\( s_i (\theta) \\) is defined as: + +$$ +s_i (\theta) = \left(\frac{\pi_\theta (o_i | q)}{\pi_{\theta_{old}} (o_i | q)} \right)^{\frac{1}{|o_i|}} +$$ + +The exponent \\( \frac{1}{|o_i|} \\) represents a sequence-length normalization, minimizing the influence of sequence length in sequence likelihood. In other terms, it computes the geometric mean of token probabilities, ensuring a fair comparison across sequences of varying lengths. + +While GSPO defines the policy ratio at the sequence level, GRPO operates at the token level. Specifically, GRPO computes an importance ratio for each token in the sequence: + +$$ +w_{i,t}(\theta) = \frac{\pi_\theta (o_{i,t} | q, o_{i,< t})}{\pi_{\theta_{\text{old}}} (o_{i,t} | q, o_{i,< t})} +$$ + +This token-level ratio is then combined with a shared advantage \\( \hat{A}_i \\), and the GRPO objective clips and optimizes each token independently across the sequence. + +### DAPO: An Open-Source LLM Reinforcement Learning System at Scale + +**📜 Paper**: https://huggingface.co/papers/2503.14476 + +The DAPO algorithm includes 5 key components: + +- Overlong Filtering +- Clip-Higher +- Soft Overlong Punishment +- Token-level Loss +- Dynamic Sampling (⚠️ Not supported in TRL) + +To reproduce the paper's setting, use this configuration: + +```python +from trl import GRPOConfig, GRPOTrainer + +training_args = GRPOConfig( + # Overlong Filtering + mask_truncated_completions=True, + # Token-level Loss + loss_type="dapo", + # Clip-Higher + epsilon_high=0.28, # DAPO paper: section 4.1 + epsilon=0.2, # DAPO paper: section 4.1 + # Other parameters used + per_device_train_batch_size=512, # mini-batch size for training in the paper, DAPO paper: section 4.1 + num_generations=16, # number of sample responses in the paper, DAPO paper: section 4.1 + max_completion_length=20480, # maximum number of tokens for generation in the paper, DAPO paper: section 4.1 + beta=0.0, # section 2.3, DAPO paper + +) +# Soft Overlong Punishment +sop_reward = get_soft_overlong_punishment(max_completion_len=20480, soft_punish_cache=4096) # DAPO paper: section 4.1 +trainer = GRPOTrainer( + ..., + args=training_args, + reward_funcs=[..., sop_reward], +) +``` + +### Beyond the 80/20 Rule: High-Entropy Minority Tokens Drive Effective Reinforcement Learning for LLM Reasoning + +**📜 Paper**: https://huggingface.co/papers/2506.01939 + +A minority of tokens with high entropy act as reasoning "forks" in the CoT path, driving exploration and performance gains for RLVR, while low-entropy majority tokens contribute little or even impede learning. RLVR mainly adjusts high-entropy tokens, largely preserving the base model’s overall entropy patterns. Thus landing on the 80/20 rule, training on only 20% of the tokens with the highest entropy is comparable or supasses full-gradient updates for Qwen3 models. + +The paper's main results use vanilla DAPO (⚠️ Dynamic Sampling is not supported in TRL). To replicate the main results, use the following configuration: + +```python +from trl import GRPOConfig, GRPOTrainer +from trl.rewards import get_soft_overlong_punishment + +training_args = GRPOConfig( + # --- vanilla DAPO parameters (80/20 rule: section 5.2) --- # + # Overlong Filtering + mask_truncated_completions=True, + # Token-level Loss + loss_type="dapo", + # Clip-Higher + epsilon_high=0.28, # DAPO paper: section 4.1 + epsilon=0.2, # DAPO paper: section 4.1 + # Other parameters used + per_device_train_batch_size=512, # mini-batch size for training in the paper, DAPO paper: section 4.1 + num_generations=16, # number of sample responses in the paper, DAPO paper: section 4.1 + max_completion_length=20480, # maximum number of tokens for generation in the paper, DAPO paper: section 4.1 + beta=0.0, # section 2.3, DAPO paper + # --- Gradients on the highest entropy tokens --- # + top_entropy_quantile=0.2 +) +# Soft Overlong Punishment +sop_reward = get_soft_overlong_punishment(max_completion_len=20480, soft_punish_cache=4096) # DAPO paper: section 4.1 +trainer = GRPOTrainer( + ..., + args=training_args, + reward_funcs=[..., sop_reward], +) +``` + +### Dr. GRPO: Understanding R1-Zero-Like Training: A Critical Perspective + +**📜 Paper**: https://huggingface.co/papers/2503.20783 + +A study of R1-Zero training identifies pretraining effects on RL performance and proffers Dr. GRPO to enhance token efficiency, achieving superior accuracy on AIME 2024. To reproduce the paper's setting, use this configuration: + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + loss_type="dr_grpo", + per_device_train_batch_size=1, # train_batch_size_per_device in the Training section of the repository + num_generations=8, # num_samples in the Training section of the repository + max_completion_length=3000, # generate_max_length in the Training section of the repository + beta=0.0, # beta in the Training section of the repository +) +``` + +### Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO) + +**📜 Paper**: https://huggingface.co/papers/2508.08221 + +The authors of this paper find that the combination of: + +1. scaling rewards by the standard deviation computed over the entire batch and +2. aggregating loss over the total number of tokens + +can unlock the learning capability of critic-free policies using vanilla PPO loss. Their results demonstrate that this simple combination consistently improves performance, surpassing strategies like GRPO and [DAPO](https://huggingface.co/papers/2503.14476). + +TRL supports using these learnings to train a GRPO model by: + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ... + scale_rewards="batch", + loss_type="dapo", + # Other parameters used + beta=0.0, # = init_kl_coef in the paper + top_p=0.99, + top_k=100, + temperature=0.99, + num_generations=8, # = num_return_sequences in the paper + num_iterations=1, # = ppo_epochs in the paper + per_device_train_batch_size=4, + gradient_accumulation_steps=32, + steps_per_generation=8, # (rollout_batch_size*num_return_sequences) / (per_device_train_batch_size*gradient_accumulation_steps) +) +``` + +Note that when using gradient accumulation, the loss is aggregated over the total number of tokens in the batch, but not over the accumulated batch. For more details, see the [GRPO Trainer - Loss types](grpo_trainer#loss_types). + +### Truncated Importance Sampling + +**📰 Blog**: https://fengyao.notion.site/off-policy-rl + +Online policy learning methods commonly use an optimized inference framework for rollout generation (e.g vLLM) that is separate from the training backend. This introduces a rollout-training mismatch, exemplified in the following PPO objective: + +$$ +\small{ +\mathbb{E}_{a\sim\textcolor{red}{\pi_{\text{inference}}}(\theta_{\mathrm{old}})} +\Bigl[ +\min\Bigl( +\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})}\,\hat A, +\;\mathrm{clip}\bigl(\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})},\,1-\epsilon,\,1+\epsilon\bigr)\,\hat A +\Bigr) +\Bigr] +} +$$ + +Despite \\( \textcolor{red}{\pi_{\text{inference}}} \\) and \\( \textcolor{blue}{\pi_{\text{training}}} \\) sharing the same model parameters \\( \theta \\), they can produce significantly different token probabilities. This unexpected behavior implicitly breaks the on-policy assumption, and silently turns training off-policy. + +Truncated Importance Sampling (TIS) addresses this issue by adapting the model update via importance-sampling correction. The gradient computation of the aforementioned PPO objective becomes + +$$ +\small{ +\mathbb{E}_{a\sim\textcolor{red}{\pi_{\text{inference}}}(\theta_{\mathrm{old}})} +\Bigl[ +\underbrace{\min(\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})}{\textcolor{red}{\pi_{\text{inference}}}(a, \theta_{\mathrm{old}})}, C)}_{\text{truncated importance ratio}} \cdot +\nabla_\theta +\min\Bigl( +\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})}\,\hat A, +\;\mathrm{clip}\bigl(\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})},\,1-\epsilon,\,1+\epsilon\bigr)\,\hat A +\Bigr) +\Bigr] +} +$$ + +where \\( C \\) is a hyper-parameter. TIS is implemented in GRPO, and is enabled by selecting a `vllm_importance_sampling_mode` variant that includes the term `truncate`, such as `"sequence_truncate"` or `"token_truncate"`. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ... + use_vllm=True, + vllm_importance_sampling_correction=True, # default True + vllm_importance_sampling_mode="sequence_truncate", # or "token_truncate" + vllm_importance_sampling_cap=2.0, # hyper-parameter C +) +``` + +### Masked Importance Sampling + +**📰 Blog**: https://ringtech.notion.site/icepop + +**📰 Blog**: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda + +Masked Importance Sampling (MIS) addresses the same issue as [Truncated Importance Sampling](#truncated-importance-sampling) but replaces clipping with masking. MIS takes a more decisive stance by discarding updates whose discrepancy exceeds a threshold \\( C \\). We apply upper-side masking, so any ratio above \\( C \\) is removed from the update. + + +$$ +\small{ +\mathbb{E}_{a\sim\textcolor{red}{\pi_{\text{inference}}}(\theta_{\mathrm{old}})} +\Bigl[ +\underbrace{\mathbf{1}\left[ +\frac{\pi_{\text{training}}(a, \theta_{\mathrm{old}})} +{\pi_{\text{inference}}(a, \theta_{\mathrm{old}})} +\le C +\right] +\cdot +\frac{\pi_{\text{training}}(a, \theta_{\mathrm{old}})} +{\pi_{\text{inference}}(a, \theta_{\mathrm{old}})}}_{\text{masked importance ratio}} \cdot +\nabla_\theta +\min\Bigl( +\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})}\,\hat A, +\;\mathrm{clip}\bigl(\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})},\,1-\epsilon,\,1+\epsilon\bigr)\,\hat A +\Bigr) +\Bigr] +} +$$ + +MIS is implemented for GRPO, and is enabled by selecting a `vllm_importance_sampling_mode` variant that includes the term `"mask"`, such as `"sequence_mask"` or `"token_mask"`. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ... + use_vllm=True, + vllm_importance_sampling_correction=True, # default True + vllm_importance_sampling_mode="sequence_mask", # or "token_mask" + vllm_importance_sampling_cap=2.0, # hyper-parameter C +) +``` + +### Sequence-level Importance Sampling + +**📰 Blog**: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda + +The theoretically principled way to correct for the training-inference distribution shift is importance sampling, as introduced in the two papers above [Truncated Importance Sampling](#truncated-importance-sampling) and [Masked Importance Sampling](#masked-importance-sampling). However, the choice of formulation is crucial for keeping the gradient unbiased and ensuring stable training. + +This work shows that sequence-level importance sampling is the sound approach for addressing the training–inference mismatch. Although token-level importance sampling achieves lower variance than a sequence-level ratio, it introduces bias and is therefore argued to be unsuitable for autoregressive models. The token-level gradient estimator is + +$$ +\mathbb{E}_{x\sim\mathcal{D},\, y\sim \pi^{\text{inference}}_\theta(\cdot|x)} +\Bigg[ + R(x,y)\,\cdot\, + \sum_{t=0}^{|y|-1} + \frac{\pi^{\text{training}}_\theta(y_t\,|\,x, y_{ 0`: + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ..., + beta=0.001, # the paper doesn't specify the value used, so we use the value from "DeepSeek-R1 incentivizes reasoning in LLMs through reinforcement learning" + use_bias_correction_kl=True, +) +``` + +- The **Off-Policy Masking**, which stabilizes training by ignoring sequences where the policy performs poorly (negative advantage) **and** has drifted significantly from the old policy (high KL divergence). + +The off-policy binary mask \\(\textcolor{red}{M_{i,t}}\\) is defined as: + +$$ +\textcolor{red}{M_{i,t}} = \begin{cases} +0 & \text{if } \hat{A}_{i,t} < 0 \quad \text{and} \quad \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \log \frac{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i, \textcolor{blue}{\delta} \\ +1 & \text{otherwise} +\end{cases} +$$ + +This mask is then applied to the GRPO loss as follows: + +$$ +\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) \textcolor{red}{M_{i,t}} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right] +$$ + +To enable this feature, use the `off_policy_mask_threshold` (corresponding to \\( \textcolor{blue}{\delta} \\)) in the [`GRPOConfig`]: + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ..., + off_policy_mask_threshold=0.5, +) +``` + +While the paper doesn't specify a \\( \textcolor{blue}{\delta} \\) value used, a good starting point could be \\( \textcolor{blue}{\delta} = 0.5 \\). If training seems too conservative or too many sequences are masked, you can increase the value. +For reference, \\( \textcolor{blue}{\delta} = 1.0 \\) corresponds to an average log-ratio divergence of 1 nat per token, i.e. on sequences where this threshold is exceeded, the old policy was on average \\( e^1 \approx 2.7 \\) times more likely to generate these tokens than the current policy. + +### GDPO: Group reward-Decoupled Normalization Policy Optimization for Multi-reward RL Optimization + +**📜 Paper**: https://huggingface.co/papers/2601.05242 + +GDPO is a reinforcement learning optimization method designed for multi-reward training. While existing approaches commonly apply Group Relative Policy Optimization (GRPO) in multi-reward settings, the authors show that this leads to reward advantages collapse, reducing training signal resolution and causing unstable or failed convergence. GDPO resolves this issue by decoupling reward normalization across individual rewards, preserving their relative differences and enabling more faithful preference optimization. To enable GDPO for multi-reward RL training, simply set: + +For a group of \\( N \\) rewards and \\( G \\) samples per group, GDPO normalizes each reward independently: + +$$ +A_n^{(i,j)} = \frac{r_n^{(i,j)} - \text{mean}\{r_n^{(i,1)}, \ldots, r_n^{(i,G)}\}}{\text{std}\{r_n^{(i,1)}, \ldots, r_n^{(i,G)}\} + \epsilon} +$$ + +The normalized group advantage is then aggregated across rewards: + +$$ +A^{(i,j)} = \sum_{n=1}^{N} w_n A_n^{(i,j)} +$$ + +The final per-batch normalization produces: + +$$ +\hat{A}^{(i,j)} = \frac{A^{(i,j)} - \text{mean}_{i',j'}\{A^{(i',j')}\}}{\text{std}_{i',j'}\{A^{(i',j')}\} + \epsilon} +$$ + +Here, \\( \text{mean}_{i',j'}\{A^{(i',j')}\} \\) and \\( \text{std}_{i',j'}\{A^{(i',j')}\} \\) denote statistics over all groups in the batch. + +```python +from trl import GRPOConfig + + +training_args = GRPOConfig( + ..., + multi_objective_aggregation="normalize_then_sum", +) +``` + +Note that this method only has an effect when training involve more than one reward function. + +The authors provide a easy-to-use, slurm-free training example that enable the community to quickly validate GDPO’s effectiveness over GRPO, see [Experiment-"Aha" moment](https://github.com/NVlabs/GDPO/tree/main/trl-GDPO). + +### Length-Unbiased Sequence Policy Optimization: Revealing and Controlling Response Length Variation in RLVR + +**📜 Paper**: https://huggingface.co/papers/2602.05261 + +Length-Unbiased Sequence Policy Optimization (LUSPO) modifies GSPO by scaling each sequence's loss by its length. This corrects GSPO's gradient bias that penalizes longer responses. To reproduce the paper's setting, use this configuration: + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + loss_type="luspo", + importance_sampling_level="sequence", + epsilon=2e-3, # section 5.1 of the paper + epsilon_high=2.5e-3, # section 5.1 of the paper +) +``` + +### INTELLECT-2: A Reasoning Model Trained Through Globally Decentralized Reinforcement Learning + +**📜 Paper**: https://huggingface.co/papers/2505.07291 + +INTELLECT-2 is the first globally distributed reinforcement learning training run of a 32 billion parameter language model using fully asynchronous RL across a dynamic, heterogeneous swarm of permissionless compute contributors. The authors propose modifications to the standard GRPO training recipe, including two-sided GRPO clipping for increased training stability. To reproduce the paper's setting, use this configuration: + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + delta=4, # δ in section 4.1 of the paper + epsilon=0.2, # ε in section 4.1 of the paper + beta=0.001, # KL divergence coefficient in section 4.1 of the paper + num_generations=16, # responses per prompt in section 4.1 of the paper + learning_rate=3e-7, # section 4.1 of the paper +) +``` + +## Direct Policy Optimization + +Papers relating to the [`DPOTrainer`] + +### Direct Preference Optimization: Your Language Model is Secretly a Reward Model + +**📜 Paper**: https://huggingface.co/papers/2305.18290 + +Direct Preference Optimization (DPO) fine-tunes language models more efficiently and with better performance compared to reinforcement learning from human feedback (RLHF), by directly optimizing policy training based on human preferences. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type="sigmoid", # losses in Appendix B of the paper + per_device_train_batch_size=64, # batch size in Appendix B of the paper + learning_rate=1e-6, # learning rate in Appendix B of the paper + beta=0.1, # beta in Appendix B of the paper +) +``` + +### A General Theoretical Paradigm to Understand Learning from Human Preferences + +**📜 Paper**: https://huggingface.co/papers/2310.12036 + +A new general objective, \\( \Psi \\)PO, bypasses both key approximations in reinforcement learning from human preferences, allowing for theoretical analysis and empirical superiority over DPO. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type="ipo", # Section 5.1 of the paper + per_device_train_batch_size=90, # mini-batch size in Section C.1 of the paper + learning_rate=1e-2, # learning rate in Section C.1 of the paper +) +``` + +These parameters only appear in the [published version](https://proceedings.mlr.press/v238/gheshlaghi-azar24a/gheshlaghi-azar24a.pdf) + +### SLiC-HF: Sequence Likelihood Calibration with Human Feedback + +**📜 Paper**: https://huggingface.co/papers/2305.10425 + +Sequence Likelihood Calibration (SLiC) is shown to be an effective and simpler alternative to Reinforcement Learning from Human Feedback (RLHF) for learning from human preferences in language models. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type="hinge", # Section 2 of the paper + per_device_train_batch_size=512, # batch size in Section 3.2 of the paper + learning_rate=1e-4, # learning rate in Section 3.2 of the paper +) +``` + +These parameters only appear in the [published version](https://openreview.net/pdf?id=0qSOodKmJaN) + +### Towards Efficient and Exact Optimization of Language Model Alignment + +**📜 Paper**: https://huggingface.co/papers/2402.00856 + +Efficient exact optimization (EXO) method is proposed to align language models with human preferences, providing a guaranteed and efficient alternative to reinforcement learning and direct preference optimization. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type="exo_pair", # Section 3.2 of the paper + per_device_train_batch_size=64, # batch size in Section B of the paper + learning_rate=1e-6, # learning rate in Section B of the paper + beta=0.1, # $\beta_r$ in Section B of the paper +) +``` + +### Noise Contrastive Alignment of Language Models with Explicit Rewards + +**📜 Paper**: https://huggingface.co/papers/2402.05369 + +A framework using Noise Contrastive Estimation enhances language model alignment with both scalar rewards and pairwise preferences, demonstrating advantages over Direct Preference Optimization. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type="nca_pair", # Section 4.1 of the paper + per_device_train_batch_size=32, # batch size in Section C of the paper + learning_rate=5e-6, # learning rate in Section C of the paper + beta=0.01, # $\alpha$ in Section C of the paper +) +``` + +### Provably Robust DPO: Aligning Language Models with Noisy Feedback + +**📜 Paper**: https://huggingface.co/papers/2403.00409 + +The paper introduces a robust direct preference optimization (rDPO) framework to address noise in preference-based feedback for language models, proving its sub-optimality gap and demonstrating its effectiveness through experiments. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type="robust", # Section 3.1 of the paper + per_device_train_batch_size=16, # batch size in Section B of the paper + learning_rate=1e-3, # learning rate in Section B of the paper + beta=0.01, # $\beta$ in Section B of the paper, + max_length=512, # max length in Section B of the paper + label_smoothing=0.1 # label smoothing $\epsilon$ in section 6 of the paper + +) +``` + +### Binary Classifier Optimization for Large Language Model Alignment + +**📜 Paper**: https://huggingface.co/papers/2404.04656 + +Theoretical analysis and a new algorithm, Binary Classifier Optimization, explain and enhance the alignment of large language models using binary feedback signals. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type="bco_pair", # Section 4 of the paper + per_device_train_batch_size=128, # batch size in Section C of the paper + learning_rate=5e-7, # learning rate in Section C of the paper + beta=0.01, # $\beta$ in Section C of the paper, +) +``` + +For the unpaired version, the user should utilize [`experimental.bco.BCOConfig`] and [`experimental.bco.BCOTrainer`]. + +### Learn Your Reference Model for Real Good Alignment + +**📜 Paper**: https://huggingface.co/papers/2404.09656 + +Trust Region DPO (TR-DPO) updates the reference policy during training, demonstrating effectiveness against DPO on the Anthropic HH and TLDR datasets, outperforming DPO by up to 19% measured by automatic evaluation with GPT-4, improving coherence, correctness, level of detail, helpfulness, and harmlessness. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + sync_ref_model=True, # enable TR-DPO (Section 3 of the paper) + ref_model_mixup_alpha=0.6, # α soft update weight (Table 1 of the paper) + ref_model_sync_steps=512, # τ update frequency in steps (Table 1 of the paper) + beta=0.05, # β temperature (Table 1 of the paper) + learning_rate=1e-6, # learning rate (Table 2 of the paper) + num_train_epochs=1, # Table 2 of the paper + max_length=1024, # max tokens length (Table 2 of the paper) + max_grad_norm=2, # max gradient norm (Table 2 of the paper) + warmup_steps=100, # warm-up steps (Table 2 of the paper) +) +``` + +### Self-Play Preference Optimization for Language Model Alignment + +**📜 Paper**: https://huggingface.co/papers/2405.00675 + +A self-play method called SPPO for language model alignment achieves state-of-the-art performance by approximating Nash equilibrium policy in a constant-sum game setting, outperforming other approaches with limited data. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type="sppo_hard", # Section 3 of the paper + per_device_train_batch_size=64, # batch size in Section C of the paper + learning_rate=5e-7, # learning rate in Section C of the paper +) +``` + +### Provably Mitigating Overoptimization in RLHF: Your SFT Loss is Implicitly an Adversarial Regularizer + +**📜 Paper**: https://huggingface.co/papers/2405.16436 + +Regularized Preference Optimization (RPO) mitigates overoptimization in RLHF by fusing the DPO loss with the SFT loss, provably preventing the policy from choosing actions with spurious high proxy rewards. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type=["sigmoid", "sft"], # RPO loss = DPO + SFT (Section 5 of the paper) + loss_weights=[1.0, 0.005], # η=0.005 SFT weight in Appendix E.1 of the paper + beta=0.01, # β in Appendix E.1 of the paper + learning_rate=5e-7, # learning rate in Appendix E.1 of the paper + num_train_epochs=1, # Appendix E.1 of the paper +) +``` + +### Distributional Preference Alignment of LLMs via Optimal Transport + +**📜 Paper**: https://huggingface.co/papers/2406.05882 + +Alignment via Optimal Transport (AOT) aligns large language models distributionally by penalizing violations of stochastic dominance between positive and negative sample distributions, achieving state-of-the-art performance on alignment benchmarks. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type="aot", # Section 3 of the paper +) +``` + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type="aot_unpaired", # Section 3 of the paper +) +``` + +There is no additional hyperparameter in the paper. + +### Discovering Preference Optimization Algorithms with and for Large Language Models + +**📜 Paper**: https://huggingface.co/papers/2406.08414 + +An LLM-driven method automatically discovers performant preference optimization algorithms, leading to a new algorithm called DiscoPOP that blends logistic and exponential losses. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type="discopop", # Section 3 of the paper + per_device_train_batch_size=64, # batch size in Section B.1 of the paper + learning_rate=5e-7, # learning rate in Section B.1 of the paper + beta=0.05, # $\beta$ in Section B.1 of the paper, + discopop_tau=0.05 # $\tau$ in Section E of the paper +) +``` + +### Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment + +**📜 Paper**: https://huggingface.co/papers/2408.06266 + +CLAIR and APO enhance LLM alignment through more contrastive preference pairs and controlled alignment objectives, improving model performance close to GPT4-turbo. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type="apo_zero", # Section 4 of the paper + per_device_train_batch_size=64, # batch size in Section B.1 of the paper + learning_rate=2e-7, # learning rate in Section 5.2 of the paper + beta=0.1, # $\beta$ in Section 5.2 of the paper, +) +``` + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type="apo_down", # Section 4 of the paper + per_device_train_batch_size=64, # batch size in Section B.1 of the paper + learning_rate=2e-7, # learning rate in Section 5.2 of the paper + beta=0.1, # $\beta$ in Section 5.2 of the paper, +) +``` + +These parameters only appear in the [published version](https://aclanthology.org/2025.tacl-1.22.pdf) + +### Statistical Rejection Sampling Improves Preference Optimization + +**📜 Paper**: https://huggingface.co/papers/2309.06657 + +Proposes **RSO**, selecting stronger preference pairs via statistical rejection sampling to boost offline preference optimization; complements DPO/SLiC. They also introduce a new loss defined as: + +$$ +\mathcal{L}_{\text{hinge-norm}}(\pi_\theta) += \mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}} +\left[ +\max\left(0,\; 1 - \left[\gamma \log \frac{\pi_\theta(y_w \mid x)}{\pi_\text{ref}(y_w \mid x)} - \gamma \log \frac{\pi_\theta(y_l \mid x)}{\pi_\text{ref}(y_l \mid x)}\right]\right) +\right] +$$ + +To train with RSO-filtered data and the hinge-norm loss, you can use the following code: + +```python +from trl import DPOConfig, DPOTrainer + +dataset = ... + +def rso_accept(example): # replace with your actual filter/score logic + return example["rso_keep"] + +train_dataset = train_dataset.filter(rso_accept) + +training_args = DPOConfig( + loss_type="hinge", + beta=0.05, # correspond to gamma in the paper +) + +trainer = DPOTrainer( + ..., + args=training_args, + train_dataset=train_dataset, +) +trainer.train() + +``` + +### Enhancing the Reasoning Ability of Multimodal Large Language Models via Mixed Preference Optimization + +**📜 Paper**: https://huggingface.co/papers/2411.10442 + +Introduces Mixed Preference Optimization (MPO) to improve multimodal reasoning in MLLMs, addressing distribution shift and weak Chain-of-Thought (CoT) after standard pre-training and SFT. The paper contributes (1) MMPR, an automated pipeline for high-quality multimodal preference data, and (2) MPO, a combined preference objective (pairwise + BCO-style + SFT) that boosts CoT. InternVL2-8B-MPO reaches 67.0 on MathVista (+8.7 over InternVL2-8B), comparable to the 10× larger InternVL2-76B. Used in TRL via [`DPOConfig`] with composite loss. To reproduce the paper's setting, use this configuration: + +```python +from trl import DPOConfig + +training_args = DPOConfig( + loss_type=["sigmoid", "bco_pair", "sft"], # ℒ = w_p·ℒ_p + w_q·ℒ_q + w_g·ℒ_g (Section 3.2 of the paper) + loss_weights=[0.8, 0.2, 1.0], # w_p, w_q, w_g loss weights (Section 7 of the paper) + learning_rate=5e-6, # learning rate (Section 7 of the paper) +) +``` + +## Kahneman–Tversky Optimization + +Papers relating to the [`experimental.kto.KTOTrainer`] + +### KTO: Model Alignment as Prospect Theoretic Optimization + +**📜 Paper**: https://huggingface.co/papers/2402.01306 + +KTO derives an alignment objective from prospect theory and learns directly from **binary** human feedback (liked/disliked), matching or surpassing DPO-style methods while handling imbalanced/noisy signals well. +To reproduce the paper's setting, you can use the default configuration of [`experimental.kto.KTOTrainer`]: + +```python +from trl.experimental.kto import KTOConfig, KTOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +trainer = KTOTrainer( + model=model, + processing_class=tokenizer, + args=KTOConfig(), + train_dataset=..., +) +trainer.train() +``` + +## Supervised Fine-Tuning + +Papers relating to the [`SFTTrainer`] + +### EMA Without the Lag: Bias-Corrected Iterate Averaging Schemes + +**📜 Paper**: https://huggingface.co/papers/2508.00180 + +Bias-Corrected Exponential Moving Average (BEMA) improves the stability and efficiency of language model fine-tuning by reducing stochasticity and eliminating bias. To use BEMA with SFT as described in the paper, you can use the [`BEMACallback`]: + +```python +from trl import BEMACallback, SFTTrainer + +trainer = SFTTrainer( + ... + callbacks=[BEMACallback()], +) +``` + +### On the Generalization of SFT: A Reinforcement Learning Perspective with Reward Rectification + +**📜 Paper**: https://huggingface.co/papers/2508.05629 + +Dynamic Fine-Tuning (DFT) improves the generalization of Large Language Models (LLMs) by dynamically rescaling gradients, outperforming standard Supervised Fine-Tuning (SFT) and showing competitive results in offline reinforcement learning. + +$$ +\mathcal{L}_{\text{DFT}}(\theta) = \mathbb{E}_{(x,y) \sim \mathcal{D}} \left[ - \sum_{t=1}^{|y|} \textcolor{red}{\text{sg}\big(\pi_\theta(y_t \mid y_{ 0 (optimism coefficient) and β > 0 (KL regularization) in Algorithm 1 but does not specify numerical values. The following configuration uses TRL defaults: + +```python +from trl.experimental.xpo import XPOConfig + +training_args = XPOConfig( + alpha=1e-5, # α exploration bonus weight, α ≥ 0 where α=0 reduces to online DPO (TRL default) + beta=0.1, # β KL regularization coefficient (TRL default) +) +``` + +## Distillation + +Papers relating to training a student model with the help of a teacher model. + +### On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes + +**📜 Paper**: https://huggingface.co/papers/2306.13649 + +Introduces Generalized Knowledge Distillation (GKD), which addresses distribution mismatch in KD for auto-regressive models by training the student on its own generated outputs with teacher feedback, instead of a fixed set of sequences. GKD supports flexible loss functions (e.g. beyond KL when the student cannot match the teacher) and integrates with RL fine-tuning (RLHF). The paper reports results on summarization, translation, arithmetic reasoning, and instruction-tuning. Used in TRL via [`experimental.gkd.GKDTrainer`]. To reproduce the paper's setting, use this configuration: + +```python +from trl.experimental.gkd import GKDConfig + +# XSum summarization task (Table A.1 of the paper) +training_args = GKDConfig( + lmbda=0.5, # λ student data fraction (Section 3 of the paper) + beta=0.5, # β Generalized JSD interpolation, 0=KL, 1=reverse KL (Section 3 of the paper) + temperature=1.0, # student training temperature (Appendix A of the paper) + max_steps=40000, # training steps (Table A.1 of the paper) + learning_rate=3e-4, # learning rate (Table A.1 of the paper) + per_device_train_batch_size=32, # batch size (Table A.1 of the paper) + warmup_steps=2000, # warm-up steps (Table A.1 of the paper) + max_new_tokens=64, # max output tokens (Table A.1 of the paper) +) +``` + +### On-Policy Distillation + +**📰 Blog**: https://thinkingmachines.ai/blog/on-policy-distillation/ + +On-Policy Distillation involves a student model generating rollouts for each batch of training data. We subsequently obtain the probability distributions for each token of the rollouts from both the student and teacher models. The student model is then optimized to minimize the negative Kullback-Leibler (KL) divergence between its own token distributions and those of the teacher model. + +| Method | Sampling | Reward signal | +|-------------------------|------------|---------------| +| Supervised finetuning | off-policy | dense | +| Reinforcement learning | on-policy | sparse | +| On-policy distillation | on-policy | dense | + +On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to restore generalization capabilities lost during SFT. + +Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data. + +To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`experimental.gkd.GKDTrainer`] and [`experimental.gkd.GKDConfig`]: + +```python +from trl.experimental.gkd import GKDConfig + +training_args = GKDConfig( + lmbda=1.0, # student produces rollouts for all batches + beta=1.0, # to ensure reverse-kl as the loss function + teacher_model_name_or_path="teacher-model", # specify the teacher model + +) +``` + +Alternatively, you can use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration: + +```python +from trl.experimental import GOLDConfig + +config = GOLDConfig( + lmbda=1.0, # student produces rollouts for all batches + beta=1.0, # to ensure reverse-kl as the loss function + teacher_model_name_or_path="teacher-model", # specify the teacher model + +) +``` + +### Knowledge Distillation of Large Language Models + +**📜 Paper**: https://huggingface.co/papers/2306.08543 + +MiniLLM is the first on-policy knowledge distillation method, which minimizes the sequence-level reverse KLD between the teacher and the student model and is optimized by reinforcement learning. + +It is a generalized version of [Think Machine Lab's On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/), with the option to add distribution-level single-step distillation signals (like GKD when `beta=1`) and long-context reverse KLD signals. + +Alternatively, you can use the [`experimental.MiniLLMTrainer`] and [`experimental.MiniLLMConfig`] to perform MiniLLM distillation as follows: + +```python +from datasets import load_dataset +from trl.experimental.minillm import MiniLLMTrainer + +dataset = load_dataset("trl-lib/tldr", split="train") + +trainer = MiniLLMTrainer( + model="Qwen/Qwen3-0.6B", + teacher_model="Qwen/Qwen3-1.7B", + train_dataset=dataset, +) +trainer.train() +``` + +For more details, see the [MiniLLM Trainer documentation](minillm) documentation. + +## Distributed Training + +### ZeRO: Memory Optimizations Toward Training Trillion Parameter Models + +**📜 Paper**: https://huggingface.co/papers/1910.02054 + +ZeRO (Zero Redundancy Optimizer) eliminates memory redundancies in data- and model-parallel training by partitioning optimizer states, gradients, and parameters across devices while retaining low communication volume and high computational granularity. This allows for the efficient training of large models that would otherwise not fit in GPU memory. + +TRL supports ZeRO via the [DeepSpeed integration](deepspeed_integration). To use it, provide a DeepSpeed configuration file with your desired settings, + +```yaml +# config.yaml +distributed_type: DEEPSPEED +num_processes: 2 +deepspeed_config: + zero_stage: 3 +``` + +and launch the training script using `accelerate launch --config_file config_file`. + +```sh +accelerate launch --config_file config.yaml train.py +``` + +## Proximal Policy Optimization + +Papers relating to the [`experimental.ppo.PPOTrainer`] + +### Proximal Policy Optimization Algorithms + +**📜 Paper**: https://huggingface.co/papers/1707.06347 + +Introduces Proximal Policy Optimization (PPO): policy gradient methods that alternate between collecting rollouts and optimizing a clipped surrogate objective over multiple minibatch epochs. PPO retains benefits of trust-region methods (e.g. TRPO) with simpler implementation and strong empirical sample efficiency, and was validated on robotics and Atari benchmarks. Used in TRL via [`experimental.ppo.PPOTrainer`]. To use PPO with TRL, use this configuration: + +```python +from trl.experimental.ppo import PPOConfig + +training_args = PPOConfig( + cliprange=0.2, # ε clipping range (Section 3 and Table 3 of the paper, Mujoco setting) + num_ppo_epochs=4, # K epochs of minibatch updates (TRL default; paper uses K=10 Mujoco, K=3 Atari) + gamma=1.0, # γ discount factor (TRL default for LLM tasks; paper uses γ=0.99) + lam=0.95, # λ GAE parameter (Table 3 of the paper, Mujoco setting) + kl_coef=0.05, # KL penalty coefficient (Section 4 of the paper discusses adaptive KL) + vf_coef=0.1, # c₁ value function loss weight (Equation 9 of the paper) +) +``` diff --git a/ICL/RL/trl_source/docs/source/papo_trainer.md b/ICL/RL/trl_source/docs/source/papo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..b9ec5aa66ed8e97de062b4e5b66a458f381c7e37 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/papo_trainer.md @@ -0,0 +1,20 @@ +# PAPO Trainer + +[![model badge](https://img.shields.io/badge/All_models-PAPO-blue)](https://huggingface.co/models?other=papo,trl) + +TRL supports the Perception-Aware Policy Optimization (PAPO) as described in the paper [Perception-Aware Policy Optimization for Multimodal Reasoning](https://huggingface.co/papers/2507.06448) by [Zhenhailong Wang](https://huggingface.co/mikewang), Xuehang Guo, Sofia Stoica, [Haiyang Xu](https://huggingface.co/xhyandwyy), Hongru Wang, Hyeonjeong Ha, Xiusi Chen, Yangyi Chen, Ming Yan, Fei Huang, Heng Ji + +The abstract from the paper is the following: + +> Reinforcement Learning with Verifiable Rewards (RLVR) has proven to be a highly effective strategy for endowing Large Language Models (LLMs) with robust multi-step reasoning abilities. However, its design and optimizations remain tailored to purely textual domains, resulting in suboptimal performance when applied to multimodal reasoning tasks. In particular, we observe that a major source of error in current multimodal reasoning lies in the perception of visual inputs. To address this bottleneck, we propose Perception-Aware Policy Optimization (PAPO), a simple yet effective extension of GRPO that encourages the model to learn to perceive while learning to reason, entirely from internal supervision signals. Notably, PAPO does not rely on additional data curation, external reward models, or proprietary models. Specifically, we introduce the Implicit Perception Loss in the form of a KL divergence term to the GRPO objective, which, despite its simplicity, yields significant overall improvements (4.4%) on diverse multimodal benchmarks. The improvements are more pronounced, approaching 8.0%, on tasks with high vision dependency. We also observe a substantial reduction (30.5%) in perception errors, indicating improved perceptual capabilities with PAPO. We conduct comprehensive analysis of PAPO and identify a unique loss hacking issue, which we rigorously analyze and mitigate through a Double Entropy Loss. Overall, our work introduces a deeper integration of perception-aware supervision into RLVR learning objectives and lays the groundwork for a new RL framework that encourages visually grounded reasoning. Project page: https://mikewangwzhl.github.io/PAPO. + +## PAPOTrainer + +[[autodoc]] experimental.papo.PAPOTrainer + - train + - save_model + - push_to_hub + +## PAPOConfig + +[[autodoc]] experimental.papo.PAPOConfig diff --git a/ICL/RL/trl_source/docs/source/peft_integration.md b/ICL/RL/trl_source/docs/source/peft_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..f0c7b262f17b970f1f26e8bbbcd39b6562786fc5 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/peft_integration.md @@ -0,0 +1,826 @@ +# PEFT Integration + +TRL supports [PEFT](https://github.com/huggingface/peft) (Parameter-Efficient Fine-Tuning) methods for memory-efficient model training. PEFT enables fine-tuning large language models by training only a small number of additional parameters while keeping the base model frozen, significantly reducing computational costs and memory requirements. + +This guide covers how to use PEFT with different TRL trainers, including LoRA, QLoRA, and prompt tuning techniques. + +For a complete working example, see the [SFT with LoRA/QLoRA notebook](https://github.com/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb). + +## Installation + +To use PEFT with TRL, install the required dependencies: + +```bash +pip install trl[peft] +``` + +For QLoRA support (4-bit and 8-bit quantization), also install: + +```bash +pip install bitsandbytes +``` + +## Quick Start + +All TRL trainers support PEFT through the `peft_config` argument. The simplest way to enable PEFT is by using the command-line interface with the `--use_peft` flag: + +```bash +python trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ + --output_dir Qwen2-0.5B-SFT-LoRA +``` + +Alternatively, you can pass a PEFT config directly in your Python code: + +```python +from peft import LoraConfig +from trl import SFTTrainer + +# Configure LoRA +peft_config = LoraConfig( + r=32, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +# Configure training - note the higher learning rate for LoRA (10x base rate) +training_args = SFTConfig( + learning_rate=2.0e-4, # 10x the base rate (2.0e-5) for LoRA + ... +) + +# Create trainer with PEFT +trainer = SFTTrainer( + model=model, + train_dataset=dataset, + peft_config=peft_config, +) +``` + +## Three Ways to Configure PEFT + +TRL provides three different methods to configure PEFT, each suited for different use cases: + +### 1. Using CLI Flags (Simplest) + +The easiest way to enable PEFT is to use the `--use_peft` flag with the command-line interface. This method is ideal for quick experiments and standard configurations: + +```bash +python trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ + --lora_dropout 0.05 \ + --output_dir Qwen2-0.5B-SFT-LoRA +``` + +**Pros**: Quick setup, no code required + +**Cons**: Limited to LoRA, fewer customization options + +### 2. Passing peft_config to Trainer (Recommended) + +For more control, pass a PEFT configuration directly to the trainer. This is the recommended approach for most use cases: + +```python +from peft import LoraConfig +from trl import SFTConfig, SFTTrainer + +peft_config = LoraConfig( + r=32, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], +) + +trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset, + peft_config=peft_config, # Pass config here +) +``` + +**Pros**: Full control, supports all PEFT methods (LoRA, Prompt Tuning, etc.) + +**Cons**: Requires Python code + +### 3. Applying PEFT to Model Directly (Advanced) + +For maximum flexibility, you can apply PEFT to your model before passing it to the trainer: + +```python +from peft import LoraConfig, get_peft_model +from transformers import AutoModelForCausalLM +from trl import SFTConfig, SFTTrainer + +# Load base model +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B") + +# Apply PEFT configuration +peft_config = LoraConfig( + r=32, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) +model = get_peft_model(model, peft_config) + +# Pass PEFT-wrapped model to trainer +trainer = SFTTrainer( + model=model, # Already has PEFT applied + args=training_args, + train_dataset=dataset, + # Note: no peft_config needed here +) +``` + +**Pros**: Maximum control, useful for custom model architectures or complex setups + +**Cons**: More verbose, requires understanding of PEFT internals + +## Learning Rate Considerations + +When using LoRA or other PEFT methods, you typically need to use a **higher learning rate** (approximately 10x) compared to full fine-tuning. This is because PEFT methods train only a small fraction of parameters, requiring a larger learning rate to achieve similar parameter updates. + +**Recommended learning rates:** + +| Trainer | Full Fine-Tuning | With LoRA (10x) | +|---------|------------------|-----------------| +| **SFT** | `2.0e-5` | `2.0e-4` | +| **DPO** | `5.0e-7` | `5.0e-6` | +| **GRPO** | `1.0e-6` | `1.0e-5` | +| **Prompt Tuning** | N/A | `1.0e-2` to `3.0e-2` | + +> **Why 10x?** LoRA adapters have significantly fewer trainable parameters than the full model. A higher learning rate compensates for this reduced parameter count, ensuring effective training. For detailed explanation, see [this blog post](https://thinkingmachines.ai/blog/lora/). + +For additional best practices on using LoRA effectively, refer to the [LoRA Without Regret](lora_without_regret) documentation. + +## PEFT with Different Trainers + +TRL's trainers support PEFT configurations for various training paradigms. Below are detailed examples for each major trainer. + + + + +### Supervised Fine-Tuning (SFT) + +The `SFTTrainer` is used for supervised fine-tuning on instruction datasets. + +#### With LoRA + +```bash +python trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --learning_rate 2.0e-4 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ + --output_dir Qwen2-0.5B-SFT-LoRA +``` + +#### Python Example + +```python +from peft import LoraConfig +from trl import SFTConfig, SFTTrainer + +# Configure LoRA +peft_config = LoraConfig( + r=32, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=["q_proj", "v_proj"], # Optional: specify target modules +) + +# Configure training with higher learning rate for LoRA +training_args = SFTConfig( + learning_rate=2.0e-4, # 10x the base rate for LoRA + ... +) + +# Create trainer with PEFT config +trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset, + peft_config=peft_config, # Pass PEFT config here +) + +trainer.train() +``` + + + + +### Direct Preference Optimization (DPO) + +The `DPOTrainer` implements preference learning from human feedback. + +#### With LoRA + +```bash +python trl/scripts/dpo.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --learning_rate 5.0e-6 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ + --output_dir Qwen2-0.5B-DPO-LoRA +``` + +#### Python Example + +```python +from peft import LoraConfig +from trl import DPOConfig, DPOTrainer + +# Configure LoRA +peft_config = LoraConfig( + r=32, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +# Configure training with higher learning rate for LoRA +training_args = DPOConfig( + learning_rate=5.0e-6, # 10x the base rate for DPO with LoRA + ... +) + +# Create trainer with PEFT config +trainer = DPOTrainer( + model=model, + ref_model=None, # Not needed when using PEFT + args=training_args, + train_dataset=dataset, + peft_config=peft_config, # Pass PEFT config here +) + +trainer.train() +``` + +**Note:** When using PEFT with DPO, you don't need to provide a separate reference model (`ref_model`). The trainer automatically uses the frozen base model as the reference. + + + + +### Group Relative Policy Optimization (GRPO) + +The `GRPOTrainer` optimizes policies using group-based rewards. + +#### With LoRA + +```bash +python trl/scripts/grpo.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/math-reasoning \ + --learning_rate 1.0e-5 \ + --per_device_train_batch_size 2 \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ + --output_dir Qwen2-0.5B-GRPO-LoRA +``` + +#### Python Example + +```python +from peft import LoraConfig +from trl import GRPOConfig, GRPOTrainer + +# Configure LoRA +peft_config = LoraConfig( + r=32, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +# Configure training with higher learning rate for LoRA +training_args = GRPOConfig( + learning_rate=1.0e-5, # 10x the base rate for GRPO with LoRA + ... +) + +# Create trainer with PEFT config +trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B", # Can pass model name or loaded model + args=training_args, + train_dataset=dataset, + peft_config=peft_config, # Pass PEFT config here +) + +trainer.train() +``` + + + + +### Proximal Policy Optimization (PPO) + +#### Multi-Adapter RL Training + +You can use a single base model with multiple PEFT adapters for the entire PPO algorithm - including retrieving reference logits, computing active logits, and calculating rewards. This approach is useful for memory-efficient RL training. + +> [!WARNING] +> This feature is experimental and convergence has not been extensively tested. We encourage the community to share feedback and report any issues. + +**Requirements** + +Install PEFT and optionally bitsandbytes for 8-bit models: + +```bash +pip install peft bitsandbytes +``` + +**Training Workflow** + +The multi-adapter approach requires three stages: + +1. **Supervised Fine-Tuning (SFT)**: Train a base model on your target domain (e.g., IMDB dataset) using `SFTTrainer` +2. **Reward Model Training**: Train a reward model adapter using PEFT and `RewardTrainer` (see [reward modeling example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py)) +3. **PPO Training**: Fine-tune new adapters using PPO with the reward adapter + +> [!IMPORTANT] +> Use the same base model (architecture and weights) for stages 2 & 3. + +**Basic Usage** + +After training your reward adapter and pushing it to the Hub: + +```python +from peft import LoraConfig +from trl.experimental.ppo import PPOTrainer, AutoModelForCausalLMWithValueHead + +model_name = "huggyllama/llama-7b" +rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter" + +# Configure PPO adapter +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +# Load model with reward adapter +model = AutoModelForCausalLMWithValueHead.from_pretrained( + model_name, + peft_config=lora_config, + reward_adapter=rm_adapter_id, +) + +trainer = PPOTrainer(model=model, ...) +``` + +In your training loop, compute rewards using: + +```python +rewards = trainer.model.compute_reward_score(**inputs) +``` + +**Advanced Features** + +**Quantized Base Models** + +For memory-efficient training, load the base model in 8-bit or 4-bit while keeping adapters in float32: + +```python +from transformers import BitsAndBytesConfig + +model = AutoModelForCausalLMWithValueHead.from_pretrained( + model_name, + peft_config=lora_config, + reward_adapter=rm_adapter_id, + quantization_config=BitsAndBytesConfig(load_in_8bit=True), +) +``` + +## QLoRA: Quantized Low-Rank Adaptation + +QLoRA combines 4-bit quantization with LoRA to enable fine-tuning of very large models on consumer hardware. This technique can reduce memory requirements by up to 4x compared to standard LoRA. + +### How QLoRA Works + +1. **4-bit Quantization**: The base model is loaded in 4-bit precision using `bitsandbytes` +2. **Frozen Weights**: The quantized model weights remain frozen during training +3. **LoRA Adapters**: Only the LoRA adapter parameters are trained in higher precision +4. **Memory Efficiency**: Enables fine-tuning of models like Llama-70B on a single consumer GPU + +### Using QLoRA with TRL + +Simply combine `load_in_4bit=True` with PEFT configuration: + +#### Command Line + +```bash +python trl/scripts/sft.py \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --dataset_name trl-lib/Capybara \ + --load_in_4bit \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --output_dir Llama-2-7b-QLoRA +``` + +#### Python Example + +```python +import torch + +from peft import LoraConfig +from transformers import AutoModelForCausalLM, BitsAndBytesConfig +from trl import SFTConfig, SFTTrainer + +# Configure 4-bit quantization +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, +) + +# Load model with quantization +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + quantization_config=bnb_config, + device_map="auto", +) + +# Configure LoRA +peft_config = LoraConfig( + r=32, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +# Configure training with higher learning rate for LoRA +training_args = SFTConfig( + learning_rate=2.0e-4, # 10x the base rate for QLoRA + ... +) + +# Create trainer with PEFT config +trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset, + peft_config=peft_config, +) + +trainer.train() +``` + +### QLoRA Configuration Options + +The `BitsAndBytesConfig` provides several options to optimize memory and performance: + +```python +import torch + +from transformers import BitsAndBytesConfig + +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", # or "fp4" + bnb_4bit_compute_dtype=torch.bfloat16, # Compute dtype for 4-bit base models + bnb_4bit_use_double_quant=True, # Nested quantization for additional memory savings +) +``` + +**Configuration Parameters:** +- `bnb_4bit_quant_type`: Quantization data type (`"nf4"` or `"fp4"`). NF4 is recommended. +- `bnb_4bit_compute_dtype`: The dtype used for computation. Use `bfloat16` for better training stability. +- `bnb_4bit_use_double_quant`: Enable nested quantization to save additional ~0.4 bits per parameter. + +### 8-bit Quantization + +For slightly higher precision with reduced memory savings, you can use 8-bit quantization: + +```python +from transformers import BitsAndBytesConfig, AutoModelForCausalLM + +bnb_config = BitsAndBytesConfig(load_in_8bit=True) + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + quantization_config=bnb_config, + device_map="auto", +) +``` + +Or via command line: + +```bash +python trl/scripts/sft.py \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --load_in_8bit \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 +``` + +## Prompt Tuning + +Prompt tuning is another PEFT technique that learns soft prompts (continuous embeddings) prepended to the input, while keeping the entire model frozen. This is particularly effective for large models. + +### How Prompt Tuning Works + +1. **Virtual Tokens**: Adds learnable continuous embeddings (virtual tokens) to the input +2. **Frozen Model**: The entire base model remains frozen +3. **Task-Specific Prompts**: Each task learns its own prompt embeddings +4. **Extreme Efficiency**: Only the prompt embeddings are trained (typically 8-20 tokens) + +### Using Prompt Tuning with TRL + +```python +from peft import PromptTuningConfig, PromptTuningInit, TaskType +from trl import SFTConfig, SFTTrainer + +# Configure Prompt Tuning +peft_config = PromptTuningConfig( + task_type=TaskType.CAUSAL_LM, + prompt_tuning_init=PromptTuningInit.TEXT, + num_virtual_tokens=8, + prompt_tuning_init_text="Classify if the tweet is a complaint or not:", + tokenizer_name_or_path="Qwen/Qwen2-0.5B", +) + +# Configure training with higher learning rate for Prompt Tuning +training_args = SFTConfig( + learning_rate=2.0e-2, # Prompt Tuning typically uses 1e-2 to 3e-2 + ... +) + +# Create trainer with PEFT config +trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset, + peft_config=peft_config, # Pass PEFT config here +) + +trainer.train() +``` + +### Prompt Tuning Configuration + +```python +from peft import PromptTuningConfig, PromptTuningInit, TaskType + +peft_config = PromptTuningConfig( + task_type=TaskType.CAUSAL_LM, # Task type + prompt_tuning_init=PromptTuningInit.TEXT, # Initialize from text + num_virtual_tokens=8, # Number of virtual tokens + prompt_tuning_init_text="Your initialization text here", + tokenizer_name_or_path="model_name", +) +``` + +**Configuration Parameters:** +- `task_type`: The task type (`TaskType.CAUSAL_LM` for language modeling) +- `prompt_tuning_init`: Initialization method (`TEXT`, `RANDOM`) +- `num_virtual_tokens`: Number of virtual tokens to prepend (typically 8-20) +- `prompt_tuning_init_text`: Text to initialize the virtual tokens (when using `TEXT` init) +- `tokenizer_name_or_path`: Tokenizer for initializing from text + +### Prompt Tuning vs LoRA + +| Feature | Prompt Tuning | LoRA | +|---------|---------------|------| +| **Parameters Trained** | ~0.001% | ~0.1-1% | +| **Memory Usage** | Minimal | Low | +| **Training Speed** | Fastest | Fast | +| **Model Modification** | None | Adapter layers | +| **Best For** | Large models, many tasks | General fine-tuning | +| **Learning Rate** | Higher (1e-2 to 3e-2) | Standard (1e-4 to 3e-4) | + +## Advanced PEFT Configurations + +### LoRA Configuration Parameters + +```python +from peft import LoraConfig + +peft_config = LoraConfig( + r=16, # LoRA rank + lora_alpha=32, # LoRA scaling factor + lora_dropout=0.05, # Dropout probability + bias="none", # Bias training strategy + task_type="CAUSAL_LM", # Task type + target_modules=["q_proj", "v_proj"], # Modules to apply LoRA + modules_to_save=None, # Additional modules to train +) +``` + +**Key Parameters:** +- `r`: LoRA rank (typical values: 8, 16, 32, 64). Higher rank = more parameters but potentially better performance. +- `lora_alpha`: Scaling factor (typically 2x the rank). Controls the magnitude of LoRA updates. +- `lora_dropout`: Dropout probability for LoRA layers (typical: 0.05-0.1). +- `target_modules`: Which modules to apply LoRA to. Common choices: + - `["q_proj", "v_proj"]`: Attention query and value (memory efficient) + - `["q_proj", "k_proj", "v_proj", "o_proj"]`: All attention projections + - `["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]`: All linear layers +- `modules_to_save`: Additional modules to fully train (e.g., `["embed_tokens", "lm_head"]`) + +### Target Module Selection + +You can specify which modules to apply LoRA to. Common patterns: + +```python +# Minimal (most memory efficient) +target_modules=["q_proj", "v_proj"] + +# Attention only +target_modules=["q_proj", "k_proj", "v_proj", "o_proj"] + +# All linear layers (best performance, more memory) +target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] +``` + +### Using Command-Line Arguments + +TRL scripts accept PEFT parameters via command line: + +```bash +python trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ + --lora_dropout 0.05 \ + --lora_target_modules q_proj v_proj \ + --output_dir output +``` + +Available flags: +- `--use_peft`: Enable PEFT +- `--lora_r`: LoRA rank (default: 16) +- `--lora_alpha`: LoRA alpha (default: 32) +- `--lora_dropout`: LoRA dropout (default: 0.05) +- `--lora_target_modules`: Target modules (space-separated) +- `--lora_modules_to_save`: Additional modules to train +- `--use_rslora`: Enable Rank-Stabilized LoRA +- `--use_dora`: Enable Weight-Decomposed LoRA (DoRA) +- `--load_in_4bit`: Enable 4-bit quantization (QLoRA) +- `--load_in_8bit`: Enable 8-bit quantization + +## Saving and Loading PEFT Models + +### Saving + +After training, save your PEFT adapters: + +```python +# Save the adapters +trainer.save_model("path/to/adapters") + +# Or manually +model.save_pretrained("path/to/adapters") +``` + +This saves only the adapter weights (~few MB) rather than the full model (~several GB). + +### Loading + +Load a PEFT model for inference: + +```python +from transformers import AutoModelForCausalLM +from peft import PeftModel + +# Load base model +base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B") + +# Load PEFT adapters +model = PeftModel.from_pretrained(base_model, "path/to/adapters") + +# Optionally merge adapters into base model for faster inference +model = model.merge_and_unload() +``` + +### Pushing to Hub + +You can easily share your PEFT adapters on the Hugging Face Hub: + +```python +# Push adapters to Hub +model.push_to_hub("username/model-name-lora") + +# Load from Hub +from peft import PeftModel +model = PeftModel.from_pretrained(base_model, "username/model-name-lora") +``` + +## Multi-GPU Training + +PEFT works seamlessly with TRL's multi-GPU support through `accelerate`: + +```bash +# Configure accelerate +accelerate config + +# Launch training +accelerate launch trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 +``` + +For QLoRA with multiple GPUs, the base model is automatically sharded: + +```bash +accelerate launch trl/scripts/sft.py \ + --model_name_or_path meta-llama/Llama-2-70b-hf \ + --load_in_4bit \ + --use_peft \ + --lora_r 32 +``` + +### Naive Pipeline Parallelism (NPP) for Large Models + +For very large models (>60B parameters), TRL supports Naive Pipeline Parallelism (NPP), which distributes the model and adapters across multiple GPUs. The activations and gradients are communicated across GPUs, supporting both `int8` and other data types. + +![NPP](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-npp.png) + +**How to Use NPP** + +Load your model with a custom `device_map` to split it across multiple devices: + +```python +from transformers import AutoModelForCausalLM +from peft import LoraConfig + +# Create custom device map (see accelerate documentation) +device_map = { + "model.embed_tokens": 0, + "model.layers.0": 0, + # ... distribute layers across GPUs + "lm_head": 0, # Must be on GPU 0 +} + +model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-70b-hf", + device_map=device_map, + peft_config=lora_config, +) +``` + +> [!IMPORTANT] +> - Keep the `lm_head` module on the first GPU (device 0) to avoid errors +> - See this [tutorial on device maps](https://github.com/huggingface/blog/blob/main/accelerate-large-models.md) for proper configuration +> - Run training scripts directly (not with `accelerate launch`): `python script.py` +> - Data Parallelism is not yet supported with NPP + +## Resources + +### TRL Examples and Notebooks + +- **[SFT with LoRA/QLoRA Notebook](https://github.com/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb)** - Complete working example showing both LoRA and QLoRA implementations +- **[TRL Examples Directory](https://github.com/huggingface/trl/tree/main/examples)** - Collection of training scripts demonstrating PEFT with different trainers +- **[TRL Cookbook Recipes](https://github.com/huggingface/cookbook/tree/main/notebooks/transformers)** - Step-by-step guides for common PEFT training scenarios + +### Documentation + +- [PEFT Documentation](https://huggingface.co/docs/peft) - Official PEFT library documentation +- [TRL Documentation](https://huggingface.co/docs/trl) - Complete TRL documentation with trainer guides +- [LoRA Without Regret](lora_without_regret) - Best practices for using LoRA effectively + +### Research Papers + +- [LoRA Paper](https://huggingface.co/papers/2106.09685) - Original LoRA methodology and results +- [QLoRA Paper](https://huggingface.co/papers/2305.14314) - Efficient finetuning with 4-bit quantization +- [Prompt Tuning Paper](https://huggingface.co/papers/2104.08691) - The Power of Scale for Parameter-Efficient Prompt Tuning diff --git a/ICL/RL/trl_source/docs/source/ppo_trainer.md b/ICL/RL/trl_source/docs/source/ppo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..32abf28e24b3e1f53a8e1b28275596b81182006d --- /dev/null +++ b/ICL/RL/trl_source/docs/source/ppo_trainer.md @@ -0,0 +1,258 @@ +# PPO Trainer + +[![model badge](https://img.shields.io/badge/All_models-PPO-blue)](https://huggingface.co/models?other=ppo,trl) + +TRL supports training LLMs with [Proximal Policy Optimization (PPO)](https://huggingface.co/papers/1707.06347). + +References: + +- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences) +- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback) +- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo) +- [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031) + +## Get started + +To just run a PPO script to make sure the trainer can run, you can run the following command to train a PPO model with a dummy reward model. + +```bash +python examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ + --learning_rate 3e-6 \ + --num_ppo_epochs 1 \ + --num_mini_batches 1 \ + --output_dir models/minimal/ppo \ + --per_device_train_batch_size 64 \ + --gradient_accumulation_steps 1 \ + --total_episodes 10000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path EleutherAI/pythia-1b-deduped \ + --reward_model_path EleutherAI/pythia-1b-deduped \ + --missing_eos_penalty 1.0 +``` + +## Explanation of the logged metrics + +The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35) + +- `eps`: Tracks the number of episodes per second. +- `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy. +- `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy. +- `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence. +- `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`. +- `objective/scores`: The mean scores returned by the reward model / environment. +- `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`. +- `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes. +- `loss/policy_avg`: The average policy loss, indicating how well the policy is performing. +- `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward. +- `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function. +- `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are. +- `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed. +- `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes. +- `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses. +- `lr`: lr: The current learning rate used by the optimizer. +- `episode`: episode: The current episode count in the training process. + +## Cookbook + +- Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up. +- Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it. +- Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint. +- Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`. +- Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions. + +## What is my model doing exactly? + +To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations. + +![ppov2_completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/ppov2_completions.gif) + +In the logs the sampled generations look like + +```txt +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ +┃ query ┃ model response ┃ score ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ +│ SUBREDDIT: r/AskReddit │ I'm in love with a friend, and │ 3.921875 │ +│ │ I don't know how to get rid of │ │ +│ TITLE: How do you get someone │ those feelings. I'm │ │ +│ out of your head? │ desperate.<|endoftext|>[PAD][P… │ │ +│ │ │ │ +│ POST: Hi, │ │ │ +│ I'm 22, and I have been with my │ │ │ +│ girlfriend for 5 years now. We │ │ │ +│ recently moved together. We've │ │ │ +│ always loved each other │ │ │ +│ intensely. │ │ │ +│ │ │ │ +│ Problem, I recently started to │ │ │ +│ have feelings for an other │ │ │ +│ person (a friend). This person │ │ │ +│ has had a boyfriend for now 3 │ │ │ +│ years, and has absolutely no │ │ │ +│ ideas. Those feelings were so │ │ │ +│ strong, it was hard to hide │ │ │ +│ them. After 2 months of me │ │ │ +│ being distant and really sad, │ │ │ +│ my girlfriend forced me to say │ │ │ +│ what was bothering me. I'm not │ │ │ +│ a good liar, and now she knows. │ │ │ +│ │ │ │ +│ We decided to give us a week │ │ │ +│ alone, I went to my parents. │ │ │ +│ │ │ │ +│ Now, I'm completely lost. I │ │ │ +│ keep on thinking about this │ │ │ +│ person, and I hate that. I │ │ │ +│ would like for those feelings │ │ │ +│ to go away, to leave me alone. │ │ │ +│ But I can't. │ │ │ +│ │ │ │ +│ What do I do? It's been 3 │ │ │ +│ months now, and I'm just │ │ │ +│ desperate. │ │ │ +│ │ │ │ +│ TL;DR: │ │ │ +├─────────────────────────────────┼─────────────────────────────────┼──────────┤ +│ SUBREDDIT: r/pettyrevenge │ My mom woke me up with a loud │ 6.84375 │ +│ │ TV. I blasted Gangnam Style on │ │ +│ TITLE: So, my mom woke me up │ repeat, with the bass cranked │ │ +│ with a loud TV. │ up as high as it could │ │ +│ │ go.<|endoftext|>[PAD][PAD][PAD… │ │ +│ POST: She was in her living │ │ │ +│ room, watching TV. This was at │ │ │ +│ about 8:30 in the morning, and │ │ │ +│ she was exercising. She turned │ │ │ +│ the TV up extra loud to hear it │ │ │ +│ over her excercycle, and woke │ │ │ +│ me up. I went in there asking │ │ │ +│ for her to turn it down. She │ │ │ +│ said she didn't have to; I │ │ │ +│ explained that I always used │ │ │ +│ headphones so she didn't have │ │ │ +│ to deal with my noise and that │ │ │ +│ she should give me a little │ │ │ +│ more respect, given that I paid │ │ │ +│ rent at the time. │ │ │ +│ │ │ │ +│ She disagreed. I went back to │ │ │ +│ my room, rather pissed off at │ │ │ +│ the lack of equality. I had no │ │ │ +│ lock on my door; but I had a │ │ │ +│ dresser right next to it, so I │ │ │ +│ pulled one of the drawers out │ │ │ +│ enough so that it caused the │ │ │ +│ door to not be openable. Then, │ │ │ +│ I turned my speakers up really │ │ │ +│ loud and blasted Gangnam Style │ │ │ +│ on repeat, with the bass │ │ │ +│ cranked up as high as it could │ │ │ +│ go. │ │ │ +│ │ │ │ +│ If you hate Gangnam Style for │ │ │ +│ being overplayed, you will see │ │ │ +│ why I chose that particular │ │ │ +│ song. I personally don't mind │ │ │ +│ it. But here's the thing about │ │ │ +│ my bass; it vibrates the walls, │ │ │ +│ making one hell of a lot of │ │ │ +│ noise. Needless to say, my mom │ │ │ +│ was not pleased and shut off │ │ │ +│ the internet. But it was oh so │ │ │ +│ worth it. │ │ │ +│ │ │ │ +│ TL;DR: │ │ │ +└─────────────────────────────────┴─────────────────────────────────┴──────────┘ +``` + +## Implementation details + +This PPO implementation is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). + +## Benchmark experiments + +To validate the PPO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). + +```shell +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/ppo/ppo_tldr.py \ + --output_dir models/minimal/ppo_tldr \ + --learning_rate 3e-6 \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 4 \ + --total_episodes 1000000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --local_rollout_forward_batch_size 16 \ + --missing_eos_penalty 1.0 \ + --stop_token eos +``` + +Checkpoints and experiment tracking are available at: + +- [🤗 Model checkpoint](https://huggingface.co/vwxyzjn/ppo_tldr) +- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/dd2o3g35) + +To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR. +For more information on how to use judges, see [Judges](judges). + +```bash +$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 33.00% +$ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 64.70% +``` + +The PPO checkpoint gets a 64.7% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the PPO training is working as intended. + +Metrics: + +![PPO v2](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/ppov2.png) + +```bash +# pip install openrlbenchmark==0.2.1a5 +# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation +# to use it, change `?we=huggingface&wpn=trl` to your own project and `?tag=pr-1540` to your own tag +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=train/episode&ceik=output_dir&cen=sft_model_path&metrics=train/objective/rlhf_reward&metrics=train/objective/scores&metrics=train/objective/kl&metrics=train/objective/non_score_reward&metrics=train/objective/entropy&metrics=train/policy/approxkl_avg&metrics=train/policy/clipfrac_avg&metrics=train/loss/policy_avg&metrics=train/loss/value_avg&metrics=train/val/clipfrac_avg&metrics=train/policy/entropy_avg&metrics=train/val/ratio&metrics=train/val/ratio_var&metrics=train/val/num_eos_tokens&metrics=train/lr&metrics=train/eps' \ + "cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540" \ + --env-ids models/minimal/ppo_tldr \ + --pc.ncols 4 \ + --pc.ncols-legend 1 \ + --pc.xlabel "Episode" \ + --output-filename benchmark/trl/pr-1540/ppo \ + --scan-history +``` + +## PPOTrainer + +[[autodoc]] experimental.ppo.PPOTrainer + - train + - save_model + - push_to_hub + +## PPOConfig + +[[autodoc]] experimental.ppo.PPOConfig + +## PreTrainedModelWrapper + +[[autodoc]] experimental.ppo.PreTrainedModelWrapper + +## AutoModelForCausalLMWithValueHead + +[[autodoc]] experimental.ppo.AutoModelForCausalLMWithValueHead + - __init__ + - forward + - generate + - _init_weights + +## AutoModelForSeq2SeqLMWithValueHead + +[[autodoc]] experimental.ppo.AutoModelForSeq2SeqLMWithValueHead + - __init__ + - forward + - generate + - _init_weights diff --git a/ICL/RL/trl_source/docs/source/prm_trainer.md b/ICL/RL/trl_source/docs/source/prm_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..29d73e819b377661191a93e495aeeb00dcf81bba --- /dev/null +++ b/ICL/RL/trl_source/docs/source/prm_trainer.md @@ -0,0 +1,122 @@ +# PRM Trainer + +[![model badge](https://img.shields.io/badge/All_models-PRM-blue)](https://huggingface.co/models?other=prm,trl) + +> [!WARNING] +> PRM Trainer is an experimental API which is subject to change at any time. + +## Overview + +Process-supervised Reward Models (PRM) were proposed in [Solving math word problems with process- and outcome-based feedback](https://huggingface.co/papers/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving, and Irina Higgins. + +The abstract from the paper is the following: + +> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions. + +This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Agustín Piqueres](https://huggingface.co/plaguss). + +## Quick start + +This example demonstrates how to train a model using the PRM method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model. We use the stepwise supervision data from the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd). You can view the data in the dataset here: + + + +Below is the script to train the model: + +```python +# train_prm.py +from datasets import load_dataset +from trl.experimental.prm import PRMConfig, PRMTrainer +from transformers import AutoModelForTokenClassification, AutoTokenizer + +model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B") +train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]") + +training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd") +trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_prm.py +``` + +Distributed across 8 GPUs, the training takes approximately 1 hour. + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward-Math-Sheperd) performs, you can use the following script. + +```python +from datasets import load_dataset +from transformers import pipeline + +pipe = pipeline("token-classification", model="trl-lib/Qwen2-0.5B-Reward-Math-Sheperd") +dataset = load_dataset("trl-lib/math_shepherd") +example = { + "prompt": "Musa is the class teacher of a class of 45 students. He wants to split them into three groups by age. If a third of the class is under 11 years, and two-fifths are above 11 but under 13, how many students will be in the third group (13 years and above)?", + "completions": [ + "Step 1: A third of the class is under 11 years because 11 - 1/3 = <<11-1/3=7>>7.", + "Step 2: Two-fifths of the class are above 11 but under 13 because 2/5 * 11 = <<2/5*11=8>>8.", + "Step 3: There are 45 students, so the third group will have 45 - 7 - 8 = <<45-7-8=20>>20 students. The answer is: 20", + ], + "labels": [True, False, False], +} + + +separator = "\n" # It's important to use the same separator as the one used during training + +for idx in range(1, len(example["completions"]) + 1): + steps = example["completions"][0:idx] + text = separator.join((example["prompt"], *steps)) + separator # Add a separator between the prompt and each steps + pred_entity = pipe(text)[-1]["entity"] + pred = {"LABEL_0": False, "LABEL_1": True}[pred_entity] + label = example["labels"][idx - 1] + print(f"Step {idx}\tPredicted: {pred} \tLabel: {label}") +``` + +```text +Step 1 Predicted: True Label: True +Step 2 Predicted: False Label: False +Step 3 Predicted: False Label: False +``` + +It's a win! + +## Expected dataset type + +PRM requires a [stepwise supervision](dataset_formats#stepwise-supervision). +The dataset should contain the following columns: `prompt`, `completions` and `labels`, where `completions` contains a list of reasoning steps and `labels` a list of booleans or floats indicating the correctness of each step. + +The [`experimental.prm.PRMTrainer`] only supports [standard](dataset_formats#standard) dataset format. + +## Example script + +We provide an example script to train a model using the PRM method. The script is available in [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) + +To use the PRM script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) on the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd), run the following command: + +```bash +accelerate launch examples/scripts/prm.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/math_shepherd \ + --num_train_epochs 1 \ + --output_dir Qwen2-0.5B-Reward-Math-Sheperd +``` + +## PRMTrainer + +[[autodoc]] experimental.prm.PRMTrainer + - train + - save_model + - push_to_hub + +## PRMConfig + +[[autodoc]] experimental.prm.PRMConfig diff --git a/ICL/RL/trl_source/docs/source/ptt_integration.md b/ICL/RL/trl_source/docs/source/ptt_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..3148215c14ec11801d63acd3e73dd175a3657b7a --- /dev/null +++ b/ICL/RL/trl_source/docs/source/ptt_integration.md @@ -0,0 +1,193 @@ +# Post-Training Toolkit Integration + +[Post-Training Toolkit](https://github.com/microsoft/post-training-toolkit) is a diagnostic and observability layer for RLHF training runs. Add one callback to any TRL trainer and get **auto-metrics**, **crash postmortems**, and **literature-backed heuristics**—without writing glue code. + +It was built to operationalize the debugging patterns we found most useful when running post-training at scale. + +## Usage + +1. First, install Post-Training Toolkit: + +```bash +pip install post-training-toolkit +``` + +2. Add one callback to your trainer. That's it! + + + + +```python +from post_training_toolkit import DiagnosticsCallback +from trl import DPOTrainer + +trainer = DPOTrainer( + model=model, + args=training_args, + callbacks=[DiagnosticsCallback()], # ← Just add this + ... +) +trainer.train() +``` + + + + +```python +from post_training_toolkit import DiagnosticsCallback +from trl.experimental.ppo import PPOTrainer + +trainer = PPOTrainer( + model=model, + args=training_args, + callbacks=[DiagnosticsCallback()], # ← Just add this + ... +) +trainer.train() +``` + + + + +```python +from post_training_toolkit import DiagnosticsCallback +from trl import SFTTrainer + +trainer = SFTTrainer( + model=model, + args=training_args, + callbacks=[DiagnosticsCallback()], # ← Just add this + ... +) +trainer.train() +``` + + + + +```python +from post_training_toolkit import DiagnosticsCallback +from trl.experimental.orpo import ORPOTrainer + +trainer = ORPOTrainer( + model=model, + args=training_args, + callbacks=[DiagnosticsCallback()], # ← Just add this + ... +) +trainer.train() +``` + + + + +```python +from post_training_toolkit import DiagnosticsCallback +from trl import KTOTrainer + +trainer = KTOTrainer( + model=model, + args=training_args, + callbacks=[DiagnosticsCallback()], # ← Just add this + ... +) +trainer.train() +``` + + + + +```python +from post_training_toolkit import DiagnosticsCallback +from trl.experimental.cpo import CPOTrainer + +trainer = CPOTrainer( + model=model, + args=training_args, + callbacks=[DiagnosticsCallback()], # ← Just add this + ... +) +trainer.train() +``` + + + + +```python +from post_training_toolkit import DiagnosticsCallback +from trl import GRPOTrainer + +trainer = GRPOTrainer( + model=model, + args=training_args, + callbacks=[DiagnosticsCallback()], # ← Just add this + ... +) +trainer.train() +``` + + + + +## What You Get + +**Example output:** +```text +[HIGH] DPO loss stuck at ~0.693 (random chance). Model may not be learning preferences. + Ref: Rafailov et al. (2023) 'DPO', Section 4.2 + +[RECOMMENDED] Increase learning rate 2-5x, check data quality, or reduce beta. +``` + +## Example Demo + +See a full working example with auto-stop in action: + +📂 **[demo/live_demo.ipynb](https://github.com/microsoft/post-training-toolkit/blob/main/demo/notebooks/demo_live_output.ipynb)** + +📂 **[demo/scripts/custom_heuristic.py](https://github.com/microsoft/post-training-toolkit/blob/main/demo/scripts/custom_heuristic_demo.py)** + + +### 1. Auto-Metrics +The callback automatically captures algorithm-specific metrics, backed by the latest research and industry push: + +| Trainer | Key Metrics Captured | +|---------|---------------------| +| **DPO** | loss, win_rate, reward_margin, logps_chosen/rejected | +| **PPO** | policy_loss, value_loss, entropy, clip_fraction, KL | +| **GRPO** | group rewards, advantages, policy loss, KL | +| **SFT** | loss, perplexity, accuracy | +| **ORPO** | sft_loss, odds_ratio_loss, log_odds_ratio | +| **KTO** | kl, logps for desirable/undesirable | + + +### 2. Crash Postmortems +If training crashes or gets interrupted, you get a `postmortem.json` with full context: + +```json +{ + "exit_reason": "exception", + "last_step": 847, + "timestamp": "2025-12-17T19:26:04Z", + "final_metrics": {"dpo_loss": 0.693, "win_rate": 0.52} +} +``` + +No more "what step did it die on?" + +### 3. Auto-Stop on Critical Issues + +Enable automatic training termination when critical issues are detected: + +```python +callback = DiagnosticsCallback(stop_on_critical=True) +``` + +## Distributed Training +Works automatically with multi-GPU setups. Zero configuration needed: + +```bash +accelerate launch --num_processes 8 train.py +``` + +Automatically detects stragglers, aggregates metrics across ranks, and tracks memory balance. diff --git a/ICL/RL/trl_source/docs/source/quickstart.md b/ICL/RL/trl_source/docs/source/quickstart.md new file mode 100644 index 0000000000000000000000000000000000000000..6661762af93325f944debd63da61ec2d8d5728ff --- /dev/null +++ b/ICL/RL/trl_source/docs/source/quickstart.md @@ -0,0 +1,140 @@ +# Quickstart + +TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO). + +## Quick Examples + +Get started instantly with TRL's most popular trainers. Each example uses compact models for quick experimentation. + +### Supervised Fine-Tuning + +```python +from trl import SFTTrainer +from datasets import load_dataset + +trainer = SFTTrainer( + model="Qwen/Qwen2.5-0.5B", + train_dataset=load_dataset("trl-lib/Capybara", split="train"), +) +trainer.train() +``` + +### Group Relative Policy Optimization + +```python +from trl import GRPOTrainer +from datasets import load_dataset +from trl.rewards import accuracy_reward + +trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", # Start from SFT model + train_dataset=load_dataset("trl-lib/DeepMath-103K", split="train"), + reward_funcs=accuracy_reward, +) +trainer.train() +``` + +### Direct Preference Optimization + +```python +from trl import DPOTrainer +from datasets import load_dataset + +trainer = DPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", # Use your SFT model + ref_model="Qwen/Qwen2.5-0.5B-Instruct", # Original base model + train_dataset=load_dataset("trl-lib/ultrafeedback_binarized", split="train"), +) +trainer.train() +``` + +### Reward Modeling + +```python +from trl import RewardTrainer +from datasets import load_dataset + +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +trainer = RewardTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", + train_dataset=dataset, +) +trainer.train() +``` + +## Command Line Interface + +Skip the code entirely - train directly from your terminal: + +```bash +# SFT: Fine-tune on instructions +trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/Capybara + +# DPO: Align with preferences +trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized + +# Reward: Train a reward model +trl reward --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized +``` + +## What's Next? + +### 📚 Learn More + +- [SFT Trainer](sft_trainer) - Complete SFT guide +- [DPO Trainer](dpo_trainer) - Preference alignment +- [GRPO Trainer](grpo_trainer) - Group relative policy optimization + +### 🚀 Scale Up + +- [Distributed Training](distributing_training) - Multi-GPU setups +- [Memory Optimization](reducing_memory_usage) - Efficient training +- [PEFT Integration](peft_integration) - LoRA and QLoRA + +### 💡 Examples + +- [Example Scripts](https://github.com/huggingface/trl/tree/main/examples) - Production-ready code +- [Community Tutorials](community_tutorials) - External guides + +## Troubleshooting + +### Out of Memory? + +Reduce batch size and enable optimizations: + + + + +```python +training_args = SFTConfig( + per_device_train_batch_size=1, # Start small + gradient_accumulation_steps=8, # Maintain effective batch size +) +``` + + + + +```python +training_args = DPOConfig( + per_device_train_batch_size=1, # Start small + gradient_accumulation_steps=8, # Maintain effective batch size +) +``` + + + + +### Loss not decreasing? + +Try adjusting the learning rate: + +```python +training_args = SFTConfig(learning_rate=2e-5) # Good starting point +``` + +For more help, open an [issue on GitHub](https://github.com/huggingface/trl/issues). diff --git a/ICL/RL/trl_source/docs/source/rapidfire_integration.md b/ICL/RL/trl_source/docs/source/rapidfire_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..711dd1b791c736ec22f8d3bd648212dd4a3552b0 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/rapidfire_integration.md @@ -0,0 +1,388 @@ +# RapidFire AI Integration + +RapidFire AI is an open-source experiment execution framework that enables concurrent training of multiple TRL configurations on the same GPU(s) through intelligent chunk-based scheduling. + +## Key Features + +- **16-24× higher experimentation throughput** compared to sequential training. +- **Almost no code changes** - drop-in configuration wrappers around TRL's and PEFT's existing configs. +- **Interactive Control Operations** - real-time control to stop, resume, clone, and modify training runs in flight +- **Automatic multi-GPU orchestration** with intelligent scheduling +- **Full compatibility** with transformers, PEFT, SFTTrainer, DPOTrainer, and GRPOTrainer +- **Full MLflow Integration**: Automatic experiment tracking and visualization +- **Production-Ready**: Already used in production environments with complete working examples. + +### Problem It Solves + +When fine-tuning or post-training with TRL, AI developers often need to: +- Try different hyperparameter configurations +- Compare different LoRA settings +- Test different prompt schemes +- Run ablation studies + + +**Current approach**: Train each config one after another → slow and inefficient process + +**With RapidFire AI**: Train all configs in one go even on a single GPU → 16-24× faster process + +### How It Works + +RapidFire AI employs **adaptive chunk-based scheduling**: + +``` +GPU Timeline (Single GPU): +Chunk 1: [Config A] → [Config B] → [Config C] → [Config D] +Chunk 2: [Config A] → [Config B] → [Config C] → [Config D] +Chunk 3: [Config A] → [Config B] → [Config C] → [Config D] +``` + +This enables: +- Early comparison of configurations on same data subsets incrementally +- Efficient GPU utilization and minimizing idle times +- Real-time and automated experiment metrics tracking +- Dynamic control over runs in flight to incentivize more experimentation + + +## Installation + +### Prerequisites + +- Python 3.12.x +- NVIDIA GPU with Compute Capability 7.x or 8.x +- CUDA Toolkit 11.8+ +- PyTorch 2.7.1+ + +### pip install + +```bash +pip install rapidfireai +``` + +Once installed, authenticate with Hugging Face and initialize RapidFire AI: + +```bash +# Authenticate with Hugging Face +huggingface-cli login --token YOUR_TOKEN + +# Workaround for current issue: https://github.com/huggingface/xet-core/issues/527 +pip uninstall -y hf-xet + +# Initialize RapidFire AI +rapidfireai init + +# Start the RapidFire AI server +rapidfireai start +``` + +The dashboard will be available at `http://0.0.0.0:3000` where you can monitor and control experiments in real-time. + +## Quick Start: SFT Training with Multiple Configs + +Here's a complete example showing how to train multiple SFT configurations concurrently: + +```python +from rapidfireai import Experiment +from rapidfireai.automl import List, RFGridSearch, RFModelConfig, RFLoraConfig, RFSFTConfig +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Load dataset +dataset = load_dataset("bitext/Bitext-customer-support-llm-chatbot-training-dataset") +train_dataset = dataset["train"].select(range(128)).shuffle(seed=42) +eval_dataset = dataset["train"].select(range(100, 124)).shuffle(seed=42) + +# Define data formatting function +def formatting_function(row): + return { + "prompt": [ + {"role": "system", "content": "You are a helpful customer support assistant."}, + {"role": "user", "content": row["instruction"]}, + ], + "completion": [ + {"role": "assistant", "content": row["response"]} + ] + } + +# Initialize experiment +experiment = Experiment(experiment_name="sft-customer-support") + +# Define multiple LoRA configurations to compare +peft_configs = List([ + RFLoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, + target_modules=["q_proj", "v_proj"], bias="none"), + RFLoraConfig(r=32, lora_alpha=64, lora_dropout=0.1, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], bias="none") +]) + +# Define multiple training configurations +# 2 base configs × 2 PEFT configs = 4 total training runs +config_set = List([ + RFModelConfig( + model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + peft_config=peft_configs, + training_args=RFSFTConfig( # Wraps TRL's SFTConfig + learning_rate=1e-3, + per_device_train_batch_size=4, + max_steps=128, + fp16=True, + ), + model_type="causal_lm", + model_kwargs={"device_map": "auto", "torch_dtype": "auto", "use_cache": False}, + formatting_func=formatting_function, + ), + RFModelConfig( + model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + peft_config=peft_configs, + training_args=RFSFTConfig( + learning_rate=1e-4, # Different learning rate + per_device_train_batch_size=4, + max_steps=128, + fp16=True, + ), + model_type="causal_lm", + model_kwargs={"device_map": "auto", "torch_dtype": "auto", "use_cache": False}, + formatting_func=formatting_function, + ) +]) + +# Define model creation function +def create_model(model_config): + model = AutoModelForCausalLM.from_pretrained( + model_config["model_name"], + **model_config["model_kwargs"] + ) + tokenizer = AutoTokenizer.from_pretrained(model_config["model_name"]) + return (model, tokenizer) + +# Create grid search over all configurations +config_group = RFGridSearch(configs=config_set, trainer_type="SFT") + +# Run all 4 configurations concurrently with chunk-based scheduling +experiment.run_fit(config_group, create_model, train_dataset, eval_dataset, + num_chunks=4, seed=42) + +# End experiment +experiment.end() +``` + +### What Happens During Execution + +When you run this example: + +1. **Config Expansion**: 2 base configurations × 2 PEFT configs = 4 total training runs +2. **Chunk-based Scheduling**: Training data is divided into chunks, and all 4 configs train concurrently +3. **GPU Swapping**: Models are swapped in/out of GPU memory based on chunk boundaries +4. **Real-time Tracking**: All metrics visible in the dashboard at `http://localhost:3000` +5. **Interactive Control**: Stop, resume, or clone any configuration from the dashboard + +This delivers **16-24× higher throughput** compared to training each configuration sequentially! + +## Supported TRL Trainers + +### SFTTrainer + +Use `RFSFTConfig` as a drop-in replacement for `SFTConfig`: + +```python +from rapidfireai.automl import RFSFTConfig + +training_args = RFSFTConfig( + learning_rate=5e-5, + per_device_train_batch_size=4, + num_train_epochs=3, + max_length = 512, + # ... all other SFTConfig parameters supported +) +``` + +**Example Notebook**: [SFT for Customer Support](https://github.com/RapidFireAI/rapidfireai/blob/main/tutorial_notebooks/rf-tutorial-sft-chatqa-lite.ipynb) + +### DPOTrainer + +Use `RFDPOConfig` as a drop-in replacement for `DPOConfig`: + +```python +from rapidfireai.automl import RFDPOConfig + +training_args = RFDPOConfig( + beta=0.1, + loss_type="sigmoid", + max_length=1024, + learning_rate=5e-4, + # ... all other DPOConfig parameters supported +) +``` + +**Example Notebook**: [DPO for Preference Alignment](https://github.com/RapidFireAI/rapidfireai/blob/main/tutorial_notebooks/rf-tutorial-dpo-alignment-lite.ipynb) + +### GRPOTrainer + +Use `RFGRPOConfig` as a drop-in replacement for `GRPOConfig`: + +```python +from rapidfireai.automl import RFGRPOConfig + +training_args = RFGRPOConfig( + learning_rate=5e-6, + num_generations=8, + max_completion_length=256, + # ... all other GRPOConfig parameters supported +) +``` + +**Example Notebook**: [GRPO for Math Reasoning](https://github.com/RapidFireAI/rapidfireai/blob/main/tutorial_notebooks/rf-tutorial-grpo-mathreasoning-lite.ipynb) + +## Core Concepts + +### Chunk-Based Concurrent Training + +RapidFire AI divides training data into chunks and alternates between configurations: + +``` +GPU Timeline (Single GPU): +Chunk 1: [Config A] → [Config B] → [Config C] → [Config D] +Chunk 2: [Config A] → [Config B] → [Config C] → [Config D] +Chunk 3: [Config A] → [Config B] → [Config C] → [Config D] +... +``` + +This approach maximizes GPU utilization and enables early comparison of configurations while maintaining training stability through automatic checkpointing. + +### Interactive Control Operations (IC Ops) + +Through the RapidFire AI dashboard, you can dynamically control running experiments: + +- **Stop**: Pause a configuration (checkpointed automatically) +- **Resume**: Continue from last checkpoint +- **Clone**: Duplicate a configuration with modifications +- **Clone & Warm Start**: Clone and initialize from parent's weights +- **Delete**: Remove failed or unwanted runs + +This enables adaptive experimentation where you can stop underperforming configs early and clone promising ones with tweaked hyperparameters. + +### Multi-Config Experimentation + +Use `RFGridSearch` or `RFRandomSearch` to automatically generate configuration combinations: + +```python +# Grid search: tests all combinations +config_group = RFGridSearch(configs=config_list, trainer_type="SFT") + +# Random search: samples N configurations +config_group = RFRandomSearch(configs=config_list, trainer_type="DPO", num_samples=10) +``` + +## Advanced Features + +### PEFT/LoRA Integration + +Full support for parameter-efficient fine-tuning: + +```python +from rapidfireai.automl import RFLoraConfig +from peft import TaskType + +lora_config = RFLoraConfig( + task_type=TaskType.CAUSAL_LM, + r=64, + lora_alpha=64, + lora_dropout=0.1, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + bias="none" +) +``` + +### Custom Reward Functions (GRPO) + +Define multiple reward functions for GRPO training: + +```python +def correctness_reward(prompts, completions, answer, **kwargs): + """Reward for correct answers""" + responses = [completion[0]['content'] for completion in completions] + extracted = [extract_answer(r) for r in responses] + return [2.0 if r == a else 0.0 for r, a in zip(extracted, answer)] + +def format_reward(completions, **kwargs): + """Reward for proper formatting""" + import re + pattern = r".*?\s*.*?" + responses = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, r) for r in responses] + return [0.5 if match else 0.0 for match in matches] + +# Use in model config +config = RFModelConfig( + reward_funcs=[correctness_reward, format_reward], + # ... other parameters +) +``` + +### Multi-GPU Support + +RapidFire AI automatically detects and utilizes all available GPUs. No special configuration needed - the scheduler automatically distributes configurations across GPUs. + +## Best Practices + +### Tuning Chunk Granularity + +The `num_chunks` parameter controls swap frequency: + +```python +# Fewer chunks = less overhead, less frequent comparison +experiment.run_fit(..., num_chunks=2) + +# More chunks = more overhead, more frequent comparison +experiment.run_fit(..., num_chunks=16) +``` + +**Rule of thumb**: Start with `num_chunks=4` and adjust based on dataset size and number of configurations. + +### Memory Management + +For large models, use quantization: + +```python +from transformers import BitsAndBytesConfig +import torch + +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", +) + +model_kwargs = { + "quantization_config": bnb_config, + "device_map": "auto", +} +``` + +## Performance Benchmarks + +Based on internal benchmarks comparing sequential vs. RapidFire AI concurrent training: + +| Scenario | Sequential Time | RapidFire AI Time | Speedup | +|----------|----------------|-------------------|---------| +| 4 configs, 1 GPU | 120 min | 7.5 min | 16× | +| 8 configs, 1 GPU | 240 min | 12 min | 20× | +| 4 configs, 2 GPUs | 60 min | 4 min | 15× | +| 8 configs, 4 GPUs | 60 min | 3 min | 20× | + +*Benchmarks performed on NVIDIA A100 40GB with TinyLlama-1.1B and Llama-3.2-1B models* + +## Troubleshooting + +For troubleshooting guidance, see the [RapidFire AI Troubleshooting Guide](https://oss-docs.rapidfire.ai/en/latest/troubleshooting.html). + +## Additional Resources +- **Colab Notebook**: [RapidFire AI in Google Colab](http://tinyurl.com/rapidfireai-colab) +- **Documentation**: [oss-docs.rapidfire.ai](https://oss-docs.rapidfire.ai) +- **GitHub**: [RapidFireAI/rapidfireai](https://github.com/RapidFireAI/rapidfireai) +- **PyPI**: [pypi.org/project/rapidfireai](https://pypi.org/project/rapidfireai/) +- **Discord**: [Join our Discord](https://discord.gg/6vSTtncKNN) +- **Tutorial Notebooks**: [GitHub Repository](https://github.com/RapidFireAI/rapidfireai/tree/main/tutorial_notebooks) + +Learn more about RapidFire AI in their [official repository](https://github.com/RapidFireAI/rapidfireai) and [documentation](https://oss-docs.rapidfire.ai). + diff --git a/ICL/RL/trl_source/docs/source/reducing_memory_usage.md b/ICL/RL/trl_source/docs/source/reducing_memory_usage.md new file mode 100644 index 0000000000000000000000000000000000000000..d50476ffbc0a2d040b1c952ea187fc63d272c936 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/reducing_memory_usage.md @@ -0,0 +1,330 @@ +# Reducing Memory Usage + +Training workflows can often be optimized to **reduce memory consumption**, and TRL provides several built-in features to help achieve this. + +Below, we outline these techniques and recommend experimenting with different combinations to figure out which configuration works best for your specific setup. + +Each method includes examples for the supported trainers. If you're unsure whether a technique is compatible with your trainer, please take a look at the corresponding trainer documentation. + +For additional strategies, such as **gradient checkpointing**, which is supported across all trainers, see the [`transformers` performance guide](https://huggingface.co/docs/transformers/perf_train_gpu_one#gradient-checkpointing). + +## Truncation + +Sequence lengths in the dataset can vary widely. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short. + +![Truncation prompt-completion](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/why_you_should_truncate.png) + +To reduce memory usage, it's important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case. + + + + +DPO truncation is controlled via `max_length`, which truncates the combined prompt+completion sequence. + +![DPO truncation](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_prompt_completion.png) + +To set the truncation parameter, use the following code snippet: + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., max_length=...) +``` + +> [!WARNING] +> The legacy `max_prompt_length` and `max_completion_length` parameters are deprecated and will be removed; instead, filter or pre-truncate overlong prompts/completions in your dataset before training. + + + + +SFT truncation is applied to the input sequence via the `max_length` parameter. + +![Truncation input ids](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_input_ids.png) + +To set the truncation parameter, use the following code snippet: + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., max_length=...) +``` + + + + +### How to choose the `max_length` value? + +If `max_length` is too small, a significant portion of your tokens will be discarded and won't contribute to training. If it's too large, memory usage can spike, potentially leading to out-of-memory (OOM) errors. Without packing or padding-free, a large `max_length` may also result in inefficient training, as many tokens will be padding. + +To help you choose an appropriate value, we provide a utility to visualize the sequence length distribution in your dataset. + + + +## Packing + +> [!TIP] +> This technique is available only for **SFT** training and setups that use **FlashAttention** (or its variants). + +[Truncation](#truncation) has several drawbacks: + +1. **Loss of information**: Key data at the end of a sequence may be discarded. +2. **Choosing truncation length**: Too short loses data; too long undermines efficiency. + +Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.10683), addresses these issues by grouping sequences instead of truncating. It concatenates and splits dataset sequences into the desired lengths. + +![Packing](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/packing_3.png) + +Packing reduces padding by merging several sequences in one row when possible. We use an advanced method to be near-optimal in the way we pack the dataset. To enable packing, use `packing=True` in the [`SFTConfig`]. + +> [!TIP] +> In TRL 0.18 and earlier, packing used a more aggressive method that reduced padding to almost nothing, but had the downside of breaking sequence continuity for a large fraction of the dataset. To revert to this strategy, use `packing_strategy="wrapped"` in [`SFTConfig`]. + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., packing=True, max_length=512) +``` + +## PEFT for parameter-efficient fine-tuning + +Parameter-Efficient Fine-Tuning (PEFT) methods like LoRA are among the most effective techniques for reducing memory usage during training. Instead of training all model parameters, PEFT methods train only a small number of adapter parameters, significantly reducing memory requirements and enabling fine-tuning of larger models on limited hardware. + +For comprehensive details on using PEFT with TRL, including various adapter methods, quantization options, and advanced configurations, see [PEFT Integration](peft_integration). + +To use PEFT for reducing memory usage: + +```python +from datasets import load_dataset +from peft import LoraConfig +from trl import SFTTrainer + +dataset = load_dataset("trl-lib/Capybara", split="train") + +peft_config = LoraConfig() + +trainer = SFTTrainer( + model="Qwen/Qwen2.5-0.5B", + train_dataset=dataset, + peft_config=peft_config, +) +``` + +PEFT can be combined with other memory reduction techniques such as quantization (4-bit or 8-bit) for even greater memory savings. See [PEFT Integration](peft_integration) for quantization examples. + +## Liger for reducing peak memory usage + +[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. + +For more information, see [Liger Kernel Integration](liger_kernel_integration). + +To use Liger for reducing peak memory usage, use the following code snippet: + + + + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl.experimental.kto import KTOConfig + +training_args = KTOConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl.experimental.gkd import GKDConfig + +training_args = GKDConfig(..., use_liger_kernel=True) +``` + + + + +## Padding-free + +Padding-free batching is an alternative approach for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact. + +![Padding-free](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/padding-free.png) + +> [!WARNING] +> It's highly recommended to use padding-free batching with **FlashAttention 2** or **FlashAttention 3**. Otherwise, you may encounter batch contamination issues. + + + + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) +``` + + + + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) +``` + + + + +## Activation offloading + +Activation offloading is a memory efficiency technique that reduces GPU VRAM usage by temporarily moving activation tensors to CPU RAM during the forward pass and bringing them back only when needed for the backward pass. This significantly reduces peak memory usage at the cost of slightly increased training time. + +To enable activation offloading in your SFT training configuration: + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., activation_offloading=True) +``` + +Under the hood, activation offloading implements PyTorch's [`saved_tensors_hooks`](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#hooks-for-autograd-saved-tensors) to intercept activations during the forward pass. It intelligently manages which tensors to offload based on size and context, avoiding offloading output tensors that would be inefficient. For performance optimization, it can, via a flag (which is true by default), use CUDA streams to overlap computation with CPU-GPU transfers. + +## Padding Sequences to a Multiple + +> [!TIP] +> This technique is supported for **SFT** and **Reward** trainers currently. + +When enabled, this option ensures that all sequences are **padded to a multiple** of the specified value. +This can improve computational efficiency on some hardware by aligning sequence lengths to memory-friendly boundaries. + + + + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., pad_to_multiple_of=2048) +``` + + + + +```python +from trl import RewardConfig + +training_args = RewardConfig(..., pad_to_multiple_of=2048) +``` + + + + +## Disabling model gathering for generation in online methods + +When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to OOM errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204). + +If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter: + + + + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., ds3_gather_for_generation=False) +``` + + + + +```python +from trl.experimental.online_dpo import OnlineDPOConfig + +training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False) +``` + + + + +```python +from trl.experimental.ppo import PPOConfig + +training_args = PPOConfig(..., ds3_gather_for_generation=False) +``` + + + + +```python +from trl import RLOOConfig + +training_args = RLOOConfig(..., ds3_gather_for_generation=False) +``` + + + + +This adjustment prevents model weights from being gathered, avoiding OOM errors, but it may result in slower generation speeds. + +## vLLM sleep mode + +When using **vLLM** as the generation backend for online training methods, you can enable _sleep mode_ to offload vLLM parameters and cache to CPU RAM during the optimization step and reload them back to GPU VRAM when needed for weight synchronization and generation. + + + + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., vllm_enable_sleep_mode=True) +``` + + + + +```python +from trl import RLOOConfig + +training_args = RLOOConfig(..., vllm_enable_sleep_mode=True) +``` + + + + +Offloading the vLLM weights and cache helps keep GPU memory usage low, which can be particularly beneficial when training large models or using limited GPU resources. However, waking the vLLM engine from sleep mode introduces some host–device transfer latency, which may slightly impact training speed. + +## Gradient checkpointing + +Gradient checkpointing trades compute for memory by not storing all intermediate activations during the forward pass, recomputing them during the backward pass instead. + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., gradient_checkpointing=True) +``` + +> [!NOTE] +> Gradient checkpointing is enabled by default in all trainers to optimize memory usage. You can disable it by setting `gradient_checkpointing=False` if needed. + +For more memory optimization techniques, see the [Transformers Performance Guide](https://huggingface.co/docs/transformers/perf_train_gpu_one#gradient-checkpointing). diff --git a/ICL/RL/trl_source/docs/source/reward_trainer.md b/ICL/RL/trl_source/docs/source/reward_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..e96563a07fbaadce6a67b60c97c889d070b09c31 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/reward_trainer.md @@ -0,0 +1,238 @@ +# Reward Modeling + +[![model badge](https://img.shields.io/badge/All_models-Reward_Trainer-blue)](https://huggingface.co/models?other=reward-trainer,trl) + +## Overview + +TRL supports the Outcome-supervised Reward Modeling (ORM) Trainer for training reward models. + +This post-training method was contributed by [Younes Belkada](https://huggingface.co/ybelkada). + +## Quick start + +This example demonstrates how to train a reward model using the [`RewardTrainer`] from TRL. We train a [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) model on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), large-scale, fine-grained, diverse preference dataset. + +```python +from trl import RewardTrainer +from datasets import load_dataset + +trainer = RewardTrainer( + model="Qwen/Qwen3-0.6B", + train_dataset=load_dataset("trl-lib/ultrafeedback_binarized", split="train"), +) +trainer.train() +``` + + + +## Expected dataset type and format + +[`RewardTrainer`] supports [preference](dataset_formats#preference) datasets type (both implicit and explicit prompt). The [`RewardTrainer`] is compatible with both [standard](dataset_formats#standard) and [conversational](dataset_formats#conversational) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +```python +# Standard preference (implicit prompt) +{"chosen": "The sky is blue.", + "rejected": "The sky is green."} + +# Conversational preference (implicit prompt) +{"chosen": [{"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is green."}]} + +# Standard preference (explicit prompt) +{"prompt": "The sky is", + "chosen": " blue.", + "rejected": " green."} + +# Conversational preference (explicit prompt) +{"prompt": [{"role": "user", "content": "What color is the sky?"}], + "chosen": [{"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "assistant", "content": "It is green."}]} +``` + +If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the [lmarena-ai/arena-human-preference-55k](https://huggingface.co/datasets/lmarena-ai/arena-human-preference-55k) dataset: + +```python +from datasets import load_dataset +import json + +dataset = load_dataset("lmarena-ai/arena-human-preference-55k") + +# Filter out ties +dataset = dataset.filter(lambda example: example["winner_tie"] == 0) + +# Create 'chosen' and 'rejected' fields based on the winner column +def response_a_b_to_chosen_rejected(example): + if example["winner_model_a"] == 1: + example["chosen"] = example["response_a"] + example["rejected"] = example["response_b"] + else: + example["chosen"] = example["response_b"] + example["rejected"] = example["response_a"] + return example + +dataset = dataset.map(response_a_b_to_chosen_rejected) + +# Convert to conversational format +def make_conversation(example): + prompt = json.loads(example["prompt"])[0] # '["What color is the sky?"]' -> "What color is the sky?" + chosen = json.loads(example["chosen"])[0] + rejected = json.loads(example["rejected"])[0] + return { + "chosen": [{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen}], + "rejected": [{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected}], + } + + +dataset = dataset.map(make_conversation) + +# Keep only necessary columns +dataset = dataset.select_columns(["chosen", "rejected"]) + +print(next(iter(dataset["train"]))) +``` + +```json +{ + "chosen": [ + {"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"}, + {"role": "assistant", "content": "The question of whether it is morally right to aim for a certain percentage of females..."}, + ], + "rejected": [ + {"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"}, + {"role": "assistant", "content": "As an AI, I don't have personal beliefs or opinions. However, ..."}, + ], +} +``` + +## Looking deeper into the training method + +Reward Models (RMs) are typically trained using supervised learning on datasets containing pairs of preferred and non-preferred responses. The goal is to learn a function that assigns higher scores to preferred responses, enabling the model to rank outputs based on preferences. + +This section breaks down how reward modeling works in practice, covering the key steps: **preprocessing** and **loss computation**. + +### Preprocessing and tokenization + +During training, each example is expected to contain a **chosen** and **rejected** field. For more details on the expected formats, see [Dataset formats - Preference](dataset_formats#preference). +The [`RewardTrainer`] tokenizes each input using the model's tokenizer. If prompts and completions (chosen and rejected) are provided separately (explicit prompt case), they are concatenated before tokenization. + +### Computing the loss + +Let \\( x \\) be the input sequence (prompt) and \\( y^+ \\) and \\( y^- \\) be the chosen and rejected sequences respectively. Under the Bradley-Terry model ([Bradley & Terry, 1952](https://www.jstor.org/stable/2334029)), the probability that \\( y^+ \\) is preferred over \\( y^- \\) given a reward function \\( r \\) is \\( p(y^+ ≻ y^- |x) = \sigma(r(x, y^+)−r(x, y^-)) \\), where \\( σ \\) is the sigmoid function. + +The reward model \\( r_\theta(x, y) \\) is trained to assign higher scores to preferred responses \\( y^+ \\) over non-preferred ones \\( y^- \\). The loss is then defined as the negative log-likelihood of the observed preferences: + +$$ +\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \right]. +$$ + +> [!TIP] +> The Bradley-Terry model is underdetermined, meaning that adding a constant to all rewards does not change the preference probabilities. To address this, [Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking](https://huggingface.co/papers/2312.09244) proposes adding an auxiliary loss term that encourages the rewards to be centered around zero. This is controlled by the `center_rewards_coefficient` parameter in the [`RewardConfig`]. The recommended value is `1e-2`. + +## Logged metrics + +While training and evaluating we record the following reward metrics: + +* `global_step`: The total number of optimizer steps taken so far. +* `epoch`: The current epoch number, based on dataset iteration. +* `num_tokens`: The total number of tokens processed so far. +* `loss`: The average loss over the last logging interval. +* `accuracy`: The proportion of correct predictions (i.e., the model assigned a higher score to the chosen response than to the rejected one) averaged over the last logging interval. +* `min_reward`: The minimum reward score assigned by the model. This value is averaged over the logging interval. +* `mean_reward`: The average reward score assigned by the model over the last logging interval. +* `max_reward`: The maximum reward score assigned by the model. This value is averaged over the logging interval. +* `margin`: The average margin (difference between chosen and rejected rewards) over the last logging interval. +* `learning_rate`: The current learning rate, which may change dynamically if a scheduler is used. +* `grad_norm`: The L2 norm of the gradients, computed before gradient clipping. + +## Customization + +### Model initialization + +You can directly pass the kwargs of the [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] method to the [`RewardConfig`]. For example, if you want to load a model in a different precision, analogous to + +```python +model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16) +``` + +you can do so by passing the `model_init_kwargs={"dtype": torch.bfloat16}` argument to the [`RewardConfig`]. + +```python +from trl import RewardConfig + +training_args = RewardConfig( + model_init_kwargs={"dtype": torch.bfloat16}, +) +``` + +Note that all keyword arguments of [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] are supported, except for `num_labels`, which is automatically set to 1. + +### Train adapters with PEFT + +We support tight integration with 🤗 PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model. + +```python +from datasets import load_dataset +from trl import RewardTrainer +from peft import LoraConfig + +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +trainer = RewardTrainer( + "Qwen/Qwen3-4B", + train_dataset=dataset, + peft_config=LoraConfig(modules_to_save=["score"]) # important to include the score head when base model is not a sequence classification model +) + +trainer.train() +``` + +You can also continue training your [`~peft.PeftModel`]. For that, first load a `PeftModel` outside [`RewardTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed. + +```python +from datasets import load_dataset +from trl import RewardTrainer +from peft import AutoPeftModelForCausalLM + +model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-Reward-LoRA", is_trainable=True) +dataset = load_dataset("trl-lib/Capybara", split="train") + +trainer = RewardTrainer( + model=model, + train_dataset=dataset, +) + +trainer.train() +``` + +> [!TIP] +> When training adapters, you typically use a higher learning rate (≈1e‑3) since only new parameters are being learned. +> +> ```python +> RewardConfig(learning_rate=1e-3, ...) +> ``` + +## Tool Calling with Reward Modeling + +The [`RewardTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include: + +* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages) +* The list of available tools in the `tools` column, typically provided as JSON schemas + +For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section. + +## RewardTrainer + +[[autodoc]] RewardTrainer + - train + - save_model + - push_to_hub + +## RewardConfig + +[[autodoc]] RewardConfig + +## DataCollatoForPreference + +[[autodoc]] trainer.reward_trainer.DataCollatorForPreference diff --git a/ICL/RL/trl_source/docs/source/rewards.md b/ICL/RL/trl_source/docs/source/rewards.md new file mode 100644 index 0000000000000000000000000000000000000000..52752205f377ff51064ccde868762879da6720e1 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/rewards.md @@ -0,0 +1,19 @@ +# Reward Functions + +This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`]. + +## accuracy_reward + +[[autodoc]] rewards.accuracy_reward + +## reasoning_accuracy_reward + +[[autodoc]] rewards.reasoning_accuracy_reward + +## think_format_reward + +[[autodoc]] rewards.think_format_reward + +## get_soft_overlong_punishment + +[[autodoc]] rewards.get_soft_overlong_punishment diff --git a/ICL/RL/trl_source/docs/source/rloo_trainer.md b/ICL/RL/trl_source/docs/source/rloo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..ef7db32d6a1fe77e9a02653dc9db5df84d6332ae --- /dev/null +++ b/ICL/RL/trl_source/docs/source/rloo_trainer.md @@ -0,0 +1,618 @@ +# RLOO Trainer + +[![model badge](https://img.shields.io/badge/All_models-RLOO-blue)](https://huggingface.co/models?other=rloo,trl) + +## Overview + +TRL supports the RLOO Trainer for training language models, as described in the paper [Back to Basics: Revisiting REINFORCE Style +Optimization for Learning from Human Feedback in LLMs](https://huggingface.co/papers/2402.14740) by [Arash Ahmadian](https://huggingface.co/ArashAhmadian), Chris Cremer, [Matthias Gallé](https://huggingface.co/mgalle), [Marzieh Fadaee](https://huggingface.co/MarziehFadaee), [Julia Kreutzer](https://huggingface.co/JuliaKreutzerCohere), [Ahmet Üstün](https://huggingface.co/ahmetu) and [Sara Hooker](https://huggingface.co/sarahooker). + +The abstract from the paper is the following: + +> AI alignment in the shape of Reinforcement Learning from Human Feedback (RLHF) is increasingly treated as a crucial ingredient for high performance large language models. Proximal Policy Optimization (PPO) has been positioned by recent literature as the canonical method for the RL part of RLHF However, it involves both high computational cost and sensitive hyperparameter tuning. We posit that most of the motivational principles that led to the development of PPO are less of a practical concern in RLHF and advocate for a less computationally expensive method that preserves and even increases performance. We revisit the formulation of alignment from human preferences in the context of RL. Keeping simplicity as a guiding principle, we show that many components of PPO are unnecessary in an RLHF context and that far simpler REINFORCE-style optimization variants outperform both PPO and newly proposed “RL-free” methods such as DPO and RAFT. Our work suggests that careful adaptation to LLMs alignment characteristics enables benefiting from online RL optimization at low cost. + +This post-training method was contributed by [Costa Huang](https://github.com/vwxyzjn) and later refactored by [Shirin Yamani](https://huggingface.co/ShirinYamani). + +## Quick start + +This example demonstrates how to train a model using the RLOO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [DeepMath-103K dataset](https://huggingface.co/datasets/trl-lib/DeepMath-103K). You can view the data in the dataset here: + + + +Below is the script to train the model. + +```python +# train_rloo.py +from datasets import load_dataset +from trl import RLOOTrainer +from trl.rewards import accuracy_reward + +dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + +trainer = RLOOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=accuracy_reward, + train_dataset=dataset, +) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_rloo.py +``` + +## Looking deeper into the RLOO method + +RLOO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind RLOO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how RLOO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**. + +![RLOO](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/rloo.png) + +### Generating completions + +At each training step, we sample a batch of prompts and generate a set of \\( G \\) completions for each prompt (denoted as \\( o_i \\)). + +### Computing the reward + +In RLOO, the reward consists of two components: the reward provided by the reward model (or reward function) and a KL penalty that discourages the policy from deviating too far from a fixed reference policy + +1. For each of the \\( G \\) generated sequences \\( o_i = (o_{i,1}, \dots, o_{i,T}) \\) conditioned on a query \\( q \\), we compute a scalar reward using a reward model \\( R(o_i, q) \\). +2. Concurrently, we estimate the KL divergence between the current policy \\( \pi_\theta \\) and the fixed reference policy \\( \pi_{\text{ref}} \\) over the sequence. The KL estimate for sequence \\( o_i \\) is: + +$$ +\mathbb{D}_{\mathrm{KL}}\!\left[\pi_\theta\|\pi_{\mathrm{ref}}\right] = \sum_{t=1}^T \log \frac{\pi_\theta(o_{i,t} \mid q, o_{i, 0 \\) controls the strength of the KL penalty. + +> [!TIP] +> In a purely online setting (`num_iterations = 1`, default), the data are generated by the current policy. In this case, the KL penalty is computed directly using the current policy. +> +> In the more general setting (e.g., multiple gradient steps per batch), the data are instead generated by an earlier snapshot \\( \pi_{\text{old}} \\). To keep the penalty consistent with the sampling distribution, the KL is defined with respect to this policy: +> +> $$ +> \mathbb{D}_{\mathrm{KL}}\!\left[\pi_{\text{old}} \,\|\, \pi_{\text{ref}}\right]. +> $$ +> +> Equivalently, for a sampled sequence $o$, the Monte Carlo estimate is +> +> $$ +> \mathbb{D}_{\mathrm{KL}}\!\left[\pi_{\text{old}} \|\pi_{\mathrm{ref}}\right] = \sum_{t=1}^T \log \frac{\pi_{\text{old}}(o_{i,t} \mid q, o_{i, $$ + +### Computing the advantage + +Once the rewards for each completion have been computed, we calculate a baseline as the average reward of all other samples in the same batch, excluding the current sample. This baseline is used to reduce the variance of the policy gradient estimate. The advantage for each completion is then obtained as the difference between its own reward and this leave-one-out baseline. + +Formally, for a batch of G completions, the baseline for completion is: +$$ +b_i = \frac{1}{G-1} \sum_{j \neq i} r_j +$$ + +and then the advantage for each completion is computed as the difference between its reward and the baseline: + +$$ +A_i = r_i - b_i +$$ + +### Computing the loss + +The REINFORCE loss is simply defined as: + +$$ +\mathcal{L}_{\text{RLOO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \hat{A}_i \, \log \pi_\theta(o_i \mid q) +$$ + +In practice, performing multiple gradient steps on the same batch makes the actions effectively off-policy relative to the current parameters. To correct for this, we introduce the importance sampling ratio. To prevent excessively large updates when the policy changes between sampling and gradient steps, we clip this ratio: + +$$ +\mathcal{L}_{\text{RLOO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \min \left( \frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_\text{old}}(o_i \mid q)} \hat{A}_i, \, \text{clip}\left(\frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_\text{old}}(o_i \mid q)}, 1-\epsilon, 1+\epsilon\right) \hat{A}_i \right) +$$ + +In a fully online, single-step setting (default), \\( \frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_\text{old}}(o_i \mid q)} = 1 \\) and this reduces to standard REINFORCE. + +## Logged metrics + +While training and evaluating, we record the following reward metrics: + +- `num_tokens`: The total number of tokens processed so far, including both prompts and completions. +- `step_time`: The average time (in seconds) taken per training step (including generation). +- `completions/mean_length`: The average length of generated completions. +- `completions/min_length`: The minimum length of generated completions. +- `completions/max_length`: The maximum length of generated completions. +- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS. +- `completions/min_terminated_length`: The minimum length of generated completions that terminate with EOS. +- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS. +- `completions/clipped_ratio`: The ratio of truncated (clipped) completions. +- `reward/{reward_func_name}/mean`: The average reward from a specific reward function. +- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function. +- `reward`: The overall average reward after summing rewards across functions (unweighted). +- `reward_std`: The standard deviation of summed rewards across functions (unweighted), computed over the full batch. +- `frac_reward_zero_std`: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect). +- `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.) +- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero. +- `clip_ratio/region_mean`: The ratio of sequence probabilities where the RLOO objective is clipped to stay within the trust region: \\( \text{clip}\left( r_{i}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \quad r_{i}(\theta) = \frac{\pi_\theta(o_{i} \mid q)}{\pi_{\theta_{\text{old}}}(o_{i} \mid q)} \\). A higher value means more samples are clipped, which constrains how much the policy $\pi_\theta$ can change. +- `clip_ratio/low_mean`: The average ratio of sequence probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\). +- `clip_ratio/low_min`: The minimum ratio of sequence probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\). +- `clip_ratio/high_mean`: The average ratio of sequence probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\). +- `clip_ratio/high_max`: The maximum ratio of sequence probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\). + +## Customization + +### Speed up training with vLLM-powered generation + +Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with + +```shell +pip install trl[vllm] +``` + +We support two ways of using vLLM during training: **server mode** and **colocate mode**. + +#### 🔌 Option 1: Server mode + +In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference. + +1. **Start the vLLM server**: + + ```bash + trl vllm-serve --model + ``` + +2. **Enable server mode in your training script**: + + ```python + from trl import RLOOConfig + + training_args = RLOOConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted + ) + ``` + +> [!WARNING] +> Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable. + +#### 🧩 Option 2: Colocate mode + +In this mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs. + +```python +from trl import RLOOConfig + +training_args = RLOOConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + +> [!TIP] +> Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`RLOOConfig`] to avoid underutilization or out-of-memory errors. +> +> We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-memory) to help estimate the recommended GPU memory utilization based on your model configuration and experiment settings. Simply use it as follows to get `vllm_gpu_memory_utilization` recommendation: +> +> +> +> If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability. +> +> If you still find you are getting out-of-memory errors set `vllm_enable_sleep_mode` to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see [Reducing Memory Usage with vLLM Sleep Mode](reducing_memory_usage#vllm-sleep-mode). + +> [!TIP] +> By default, RLOO uses `MASTER_ADDR=localhost` and `MASTER_PORT=12345` for vLLM, but you can override these values by setting the environment variables accordingly. + +For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods). + +### RLOO at scale: train a 70B+ Model on multiple nodes + +When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include: + +- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such models. For more details, see [DeepSpeed Integration](deepspeed_integration). +- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training). +- **vLLM**: See the previous section on how to use vLLM to speed up generation. + +Below is an example SLURM script to train a 70B model with RLOO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation. + +```sh +#!/bin/bash +#SBATCH --nodes=5 +#SBATCH --gres=gpu:8 + +# Get the list of allocated nodes +NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) + +# Assign the first 4 nodes for training and the 5th node for vLLM +TRAIN_NODES="${NODELIST[@]:0:4}" # Nodes 0, 1, 2, 3 for training +VLLM_NODE="${NODELIST[4]}" # Node 4 for vLLM + +# Run training on the first 4 nodes (Group 1) +srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \ + --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ + --num_processes 32 \ + --num_machines 4 \ + --main_process_ip ${NODELIST[0]} \ + --machine_rank $SLURM_PROCID \ + --rdzv_backend c10d \ + train_rloo.py \ + --server_ip $VLLM_NODE & + +# Run vLLM server on the 5th node (Group 2) +srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 & + +wait +``` + +```python +import argparse + +from datasets import load_dataset +from trl import RLOOTrainer, RLOOConfig + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP") + args = parser.parse_args() + + # Example dataset from TLDR + dataset = load_dataset("trl-lib/tldr", split="train") + + # Dummy reward function: count the number of unique characters in the completions + def reward_num_unique_chars(completions, **kwargs): + return [len(set(c)) for c in completions] + + training_args = RLOOConfig( + output_dir="Qwen2.5-72B-RLOO", + per_device_train_batch_size=4, + bf16=True, + use_vllm=True, + vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."), # from ip-X-X-X-X to X.X.X.X + ) + + trainer = RLOOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset) + trainer.train() + +if __name__=="__main__": + main() +``` + +### Using a custom reward function + +The [`RLOOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements: + +Reward functions can be either synchronous Python callables or asynchronous `async def` coroutines. When you provide multiple asynchronous reward functions, they are awaited concurrently (run in parallel via `asyncio.gather`) so their latency overlaps. + +1. **Input arguments**: + - The function must accept the following as keyword arguments: + - `prompts` (contains the prompts), + - `completions` (contains the generated completions), + - `completion_ids` (contains the tokenized completions), + - `trainer_state` ([`~transformers.TrainerState`]): The current state of the trainer. This can be used to implement dynamic reward functions, such as curriculum learning, where the reward is adjusted based on the training progress. + - All column names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument. + + The easiest way to comply with this requirement is to use `**kwargs` in the function signature. + - Depending on the dataset format, the input will vary: + - For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings. + - For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries. + +2. **Return value**: The function must return a list of floats. Each float represents the reward corresponding to a single completion. + +#### Example 1: Reward longer completions + +Below is an example of a reward function for a standard format that rewards longer completions: + +```python +def reward_func(completion_ids, **kwargs): + """Reward function that assigns higher scores to longer completions (in terms of token count).""" + return [float(len(ids)) for ids in completion_ids] +``` + +You can test it as follows: + +```python +>>> prompts = ["The sky is", "The sun is"] # not used in the reward function, but the trainer will pass it +>>> completions = [" blue.", " in the sky."] # not used in the reward function, but the trainer will pass it +>>> completion_ids = [[6303, 13], [304, 279, 12884, 13]] +>>> reward_func(prompts=prompts, completions=completions, completion_ids=completion_ids) +[2.0, 4.0] +``` + +#### Example 1.1: Reward longer completions (based on the number of characters) + +Same as the previous example, but this time the reward function is based on the number of characters instead of tokens. + +```python +def reward_func(completions, **kwargs): + """Reward function that assigns higher scores to longer completions (in terms of character count).""" + return [float(len(completion)) for completion in completions] +``` + +You can test it as follows: + +```python +>>> prompts = ["The sky is", "The sun is"] +>>> completions = [" blue.", " in the sky."] +>>> completion_ids = [[6303, 13], [304, 279, 12884, 13]] # not used in the reward function, but the trainer will pass it +>>> reward_func(prompts=prompts, completions=completions, completion_ids=completion_ids) +[6.0, 12.0] +``` + +#### Example 2: Reward completions with a specific format + +Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). +It is designed for a conversational format, where prompts and completions consist of structured messages. + +```python +import re + +def format_reward_func(completions, **kwargs): + """Reward function that checks if the completion has a specific format.""" + pattern = r"^.*?.*?$" + completion_contents = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, content) for content in completion_contents] + return [1.0 if match else 0.0 for match in matches] +``` + +You can test this function as follows: + +```python +>>> prompts = [ +... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}], +... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}], +... ] +>>> completions = [ +... [{"role": "assistant", "content": "The sum of 1 and 2 is 3, which we multiply by 4 to get 12.(1 + 2) * 4 = 12"}], +... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}], +... ] +>>> format_reward_func(prompts=prompts, completions=completions) +[1.0, 0.0] +``` + +#### Example 3: Reward completions based on a reference + +Below is an example of a reward function that checks if the completion is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). +This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`. + +```python +import re + +def reward_func(completions, ground_truth, **kwargs): + # Regular expression to capture content inside \boxed{} + matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions] + contents = [match.group(1) if match else "" for match in matches] + # Reward 1 if the content is the same as the ground truth, 0 otherwise + return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)] +``` + +You can test this function as follows: + +```python +>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."] +>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."] +>>> ground_truth = ["2", "5"] +>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth) +[1.0, 0.0] +``` + +#### Example 4: Multi-task reward functions + +Below is an example of using multiple reward functions in the [`RLOOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works. + +```python +from datasets import Dataset +from trl import RLOOTrainer + +# Define a dataset that contains both math and coding problems +dataset = Dataset.from_list( + [ + {"prompt": "What is 2+2?", "task": "math"}, + {"prompt": "Write a function that returns the sum of two numbers.", "task": "code"}, + {"prompt": "What is 3*4?", "task": "math"}, + {"prompt": "Write a function that returns the product of two numbers.", "task": "code"}, + ] +) + +# Math-specific reward function +def math_reward_func(prompts, completions, task, **kwargs): + rewards = [] + for prompt, completion, t in zip(prompts, completions, task): + if t == "math": + # Calculate math-specific reward + correct = check_math_solution(prompt, completion) + reward = 1.0 if correct else -1.0 + rewards.append(reward) + else: + # Return None for non-math tasks + rewards.append(None) + return rewards + +# Coding-specific reward function +def coding_reward_func(prompts, completions, task, **kwargs): + rewards = [] + for prompt, completion, t in zip(prompts, completions, task): + if t == "coding": + # Calculate coding-specific reward + works = test_code_solution(prompt, completion) + reward = 1.0 if works else -1.0 + rewards.append(reward) + else: + # Return None for non-coding tasks + rewards.append(None) + return rewards + +# Use both task-specific reward functions +trainer = RLOOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=[math_reward_func, coding_reward_func], + train_dataset=dataset, +) + +trainer.train() +``` + +In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None`, and the [`RLOOTrainer`] will continue with the valid functions and tasks. This allows the [`RLOOTrainer`] to handle multiple reward functions with different applicability. + +Note that the [`RLOOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function. + +#### Example 5: Asynchronous reward functions + +Custom reward functions can also be defined as `async def` coroutines. This is useful if your reward depends on slow I/O (for example, calling a remote service). When you pass multiple async reward functions, [`RLOOTrainer`] executes them concurrently so their latency overlaps. + +Below is a minimal example of an async reward function that simulates an I/O-bound operation: + +```python +import asyncio + +async def async_reward_func(prompts, completions, **kwargs): + # Simulate an I/O-bound call (e.g., HTTP request, database lookup) + await asyncio.sleep(0.01) + # Simple toy reward: 1.0 if the completion is non-empty, else 0.0 + return [1.0 if completion else 0.0 for completion in completions] +``` + +#### Passing the reward function to the trainer + +To use your custom reward function, pass it to the [`RLOOTrainer`] as follows: + +```python +from trl import RLOOTrainer + +trainer = RLOOTrainer( + reward_funcs=reward_func, + ..., +) +``` + +You can pass several reward functions as a list; this list may include both synchronous and asynchronous functions: + +```python +from trl import RLOOTrainer + +trainer = RLOOTrainer( + reward_funcs=[reward_func, async_reward_func1, async_reward_func2], + ..., +) +``` + +and the reward will be computed as the sum of the rewards from each function, or the weighted sum if `reward_weights` is provided in the config. + +Note that [`RLOOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details. + +## Vision-Language Model (VLM) Training + +RLOO supports training Vision-Language Models (VLMs) on multimodal datasets containing both text and images. + +### Supported Models + +Tested with: + +- **Gemma3** — e.g., `google/gemma-3-4b-it` +- **LLaVA-NeXT** — e.g., `llava-hf/llava-v1.6-mistral-7b-hf` +- **Qwen2-VL** — e.g., `Qwen/Qwen2-VL-2B-Instruct` +- **Qwen2.5-VL** — e.g., `Qwen/Qwen2.5-VL-3B-Instruct` +- **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct` + +> [!TIP] +> Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes. + +### Quick Start + +Use [rloo\_vlm.py](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo_vlm.py) to fine-tune a VLM. Example command for training on [`lmms-lab/multimodal-open-r1-8k-verified`](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified): + +```bash +accelerate launch \ + --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/rloo_vlm.py \ + --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \ + --output_dir rloo-Qwen2.5-VL-3B-Instruct \ + --learning_rate 1e-5 \ + --dtype bfloat16 \ + --max_completion_length 1024 \ + --use_vllm \ + --vllm_mode colocate \ + --use_peft \ + --lora_target_modules "q_proj", "v_proj" \ + --log_completions +``` + +### Configuration Tips + +- Use LoRA on vision-language projection layers +- Enable 4-bit quantization to reduce memory usage +- VLMs are memory-intensive — start with smaller batch sizes +- Most models are compatible with vLLM (`server` and `colocate` modes) + +### Dataset Format + +Each training sample should include: + +- `prompt`: Text formatted via the processor's chat template +- `image`/`images`: PIL Image or list of PIL Images + +The trainer automatically handles image-to-tensor conversion via the model’s image processor. + +## RLOOTrainer + +[[autodoc]] RLOOTrainer + - train + - save_model + - push_to_hub + +## RLOOConfig + +[[autodoc]] RLOOConfig + +## References + +1. [RLOO Paper](https://openreview.net/pdf?id=r1lgTGL5DE) +2. [Paper Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs](https://huggingface.co/papers/2402.14740) +3. [Paper - REINFORCE++: A Simple and Efficient Approach for Aligning Large Language Models](https://huggingface.co/papers/2501.03262) +4. [Blog Post - Putting RL back in RLHF](https://huggingface.co/blog/putting_rl_back_in_rlhf_with_rloo) +5. [Blog Post - Unraveling RLHF and Its Variants: Progress and Practical Engineering Insights](https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05) +6. [Youtube - RLOO: A Cost-Efficient Optimization for Learning from Human Feedback in LLMs](https://www.youtube.com/watch?v=86asXGPK6RU&ab_channel=BuzzRobot) + +## Migration Guide from the old implementation (0.21 and below) + +With the release of version 0.22.0, we have revamped the [`RLOOTrainer`] to be more aligned with other online trainers in the library, like [`GRPOTrainer`]. This new implementation introduces several changes to the configuration parameters and overall structure of the trainer. +Below is a summary of the key changes for [`RLOOConfig`]: + +| TRL ≤ 0.21.x | TRL ≥ 0.22.0 | +| --- | --- | +| `rloo_k` | renamed to `num_generations` | +| `cliprange` | renamed to `epsilon` | +| `kl_coef` | renamed to `beta` | +| `exp_name` | renamed to `run_name`. Use `run_name = f"{exp_name}__{seed}__{int(time.time())}"` to replicate old behavior | +| `normalize_reward` | renamed to `normalize_advantages`. Note: this always normalized advantages (despite the old name) | +| `num_ppo_epochs` | renamed to `num_iterations` (default: `1`) | +| `token_level_kl` | **removed** – KL is now computed only at the sequence level | +| `dataset_num_proc` | **removed** – it was unused | +| `num_mini_batches` | renamed to `steps_per_generation` | +| `total_episodes` | use `max_steps=total_episodes / gradient_accumulation_steps` instead | +| `local_rollout_forward_batch_size` | **removed** – now automatically set to `per_device_train_batch_size` (or `per_device_eval_batch_size` during evaluation) | +| `num_sample_generations` | **removed** – use `logging_steps` to control generation logging frequency | +| `response_length` | renamed to `max_completion_length` (default: `256`) | +| `stop_token` | **removed** | +| `stop_token_id` | **removed** – use `processing_class.eos_token_id` instead | +| `missing_eos_penalty` | **removed** – replicate with a custom reward function checking if `eos_token_id` is in `completion_ids` | + +Below is a summary of the key changes for [`RLOOTrainer`]: + +| TRL ≤ 0.21.x | TRL ≥ 0.22.0 | +| --- | --- | +| `config` | renamed to `args` | +| `reward_model` | renamed to `reward_funcs`, which now supports both reward models and custom reward functions | +| `policy` | renamed to `model` | +| `ref_policy` | **removed** – the reference model is now created automatically from `model` | +| `data_collator` | **removed** | diff --git a/ICL/RL/trl_source/docs/source/script_utils.md b/ICL/RL/trl_source/docs/source/script_utils.md new file mode 100644 index 0000000000000000000000000000000000000000..1ecb73756530d0fcdd87f668dd8385ebfbce1536 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/script_utils.md @@ -0,0 +1,24 @@ +# Scripts Utilities + +## ScriptArguments + +[[autodoc]] ScriptArguments + +## TrlParser + +[[autodoc]] TrlParser + - parse_args_and_config + - parse_args_into_dataclasses + - set_defaults_with_config + +## get_dataset + +[[autodoc]] get_dataset + +## DatasetConfig + +[[autodoc]] scripts.utils.DatasetConfig + +## DatasetMixtureConfig + +[[autodoc]] DatasetMixtureConfig diff --git a/ICL/RL/trl_source/docs/source/sft_trainer.md b/ICL/RL/trl_source/docs/source/sft_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..d98f7828b9fd199087db006d7935a4c64a1e588a --- /dev/null +++ b/ICL/RL/trl_source/docs/source/sft_trainer.md @@ -0,0 +1,339 @@ +# SFT Trainer + +[![All_models-SFT-blue](https://img.shields.io/badge/All_models-SFT-blue)](https://huggingface.co/models?other=sft,trl) [![smol_course-Chapter_1-yellow](https://img.shields.io/badge/smol_course-Chapter_1-yellow)](https://github.com/huggingface/smol-course/tree/main/1_instruction_tuning) + +## Overview + +TRL supports the Supervised Fine-Tuning (SFT) Trainer for training language models. + +This post-training method was contributed by [Younes Belkada](https://huggingface.co/ybelkada). + +## Quick start + +This example demonstrates how to train a language model using the [`SFTTrainer`] from TRL. We train a [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) model on the [Capybara dataset](https://huggingface.co/datasets/trl-lib/Capybara), a compact, diverse multi-turn dataset to benchmark reasoning and generalization. + +```python +from trl import SFTTrainer +from datasets import load_dataset + +trainer = SFTTrainer( + model="Qwen/Qwen3-0.6B", + train_dataset=load_dataset("trl-lib/Capybara", split="train"), +) +trainer.train() +``` + + + +## Expected dataset type and format + +SFT supports both [language modeling](dataset_formats#language-modeling) and [prompt-completion](dataset_formats#prompt-completion) datasets. The [`SFTTrainer`] is compatible with both [standard](dataset_formats#standard) and [conversational](dataset_formats#conversational) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +```python +# Standard language modeling +{"text": "The sky is blue."} + +# Conversational language modeling +{"messages": [{"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}]} + +# Standard prompt-completion +{"prompt": "The sky is", + "completion": " blue."} + +# Conversational prompt-completion +{"prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}]} +``` + +If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the [FreedomIntelligence/medical-o1-reasoning-SFT](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT) dataset: + +```python +from datasets import load_dataset + +dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en") + +def preprocess_function(example): + return { + "prompt": [{"role": "user", "content": example["Question"]}], + "completion": [ + {"role": "assistant", "content": f"{example['Complex_CoT']}{example['Response']}"} + ], + } + +dataset = dataset.map(preprocess_function, remove_columns=["Question", "Response", "Complex_CoT"]) +print(next(iter(dataset["train"]))) +``` + +```json +{ + "prompt": [ + { + "content": "Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?", + "role": "user", + } + ], + "completion": [ + { + "content": "Okay, let's see what's going on here. We've got sudden weakness [...] clicks into place!The specific cardiac abnormality most likely to be found in [...] the presence of a PFO facilitating a paradoxical embolism.", + "role": "assistant", + } + ], +} +``` + +## Looking deeper into the SFT method + +Supervised Fine-Tuning (SFT) is the simplest and most commonly used method to adapt a language model to a target dataset. The model is trained in a fully supervised fashion using pairs of input and output sequences. The goal is to minimize the negative log-likelihood (NLL) of the target sequence, conditioning on the input. + +This section breaks down how SFT works in practice, covering the key steps: **preprocessing**, **tokenization** and **loss computation**. + +### Preprocessing and tokenization + +During training, each example is expected to contain a **text field** or a **(prompt, completion)** pair, depending on the dataset format. For more details on the expected formats, see [Dataset formats](dataset_formats). +The [`SFTTrainer`] tokenizes each input using the model's tokenizer. If both prompt and completion are provided separately, they are concatenated before tokenization. + +### Computing the loss + +![sft_figure](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft_figure.png) + +The loss used in SFT is the **token-level cross-entropy loss**, defined as: + +$$ +\mathcal{L}_{\text{SFT}}(\theta) = - \sum_{t=1}^{T} \log p_\theta(y_t \mid y_{ [!TIP] +> The paper [On the Generalization of SFT: A Reinforcement Learning Perspective with Reward Rectification](https://huggingface.co/papers/2508.05629) proposes an alternative loss function, called **Dynamic Fine-Tuning (DFT)**, which aims to improve generalization by rectifying the reward signal. This method can be enabled by setting `loss_type="dft"` in the [`SFTConfig`]. For more details, see [Paper Index - Dynamic Fine-Tuning](paper_index#on-the-generalization-of-sft-a-reinforcement-learning-perspective-with-reward-rectification). + +### Label shifting and masking + +During training, the loss is computed using a **one-token shift**: the model is trained to predict each token in the sequence based on all previous tokens. Specifically, the input sequence is shifted right by one position to form the target labels. +Padding tokens (if present) are ignored in the loss computation by applying an ignore index (default: `-100`) to the corresponding positions. This ensures that the loss focuses only on meaningful, non-padding tokens. + +## Logged metrics + +While training and evaluating we record the following reward metrics: + +* `global_step`: The total number of optimizer steps taken so far. +* `epoch`: The current epoch number, based on dataset iteration. +* `num_tokens`: The total number of tokens processed so far. +* `loss`: The average cross-entropy loss computed over non-masked tokens in the current logging interval. +* `entropy`: The average entropy of the model's predicted token distribution over non-masked tokens. +* `mean_token_accuracy`: The proportion of non-masked tokens for which the model’s top-1 prediction matches the ground truth token. +* `learning_rate`: The current learning rate, which may change dynamically if a scheduler is used. +* `grad_norm`: The L2 norm of the gradients, computed before gradient clipping. + +## Customization + +### Model initialization + +You can directly pass the kwargs of the [`~transformers.AutoModelForCausalLM.from_pretrained()`] method to the [`SFTConfig`]. For example, if you want to load a model in a different precision, analogous to + +```python +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16) +``` + +you can do so by passing the `model_init_kwargs={"dtype": torch.bfloat16}` argument to the [`SFTConfig`]. + +```python +from trl import SFTConfig + +training_args = SFTConfig( + model_init_kwargs={"dtype": torch.bfloat16}, +) +``` + +Note that all keyword arguments of [`~transformers.AutoModelForCausalLM.from_pretrained()`] are supported. + +### Packing + +[`SFTTrainer`] supports _example packing_, where multiple examples are packed in the same input sequence to increase training efficiency. To enable packing, simply pass `packing=True` to the [`SFTConfig`] constructor. + +```python +training_args = SFTConfig(packing=True) +``` + +For more details on packing, see [Packing](reducing_memory_usage#packing). + +### Train on assistant messages only + +To train on assistant messages only, use a [conversational](dataset_formats#conversational) dataset and set `assistant_only_loss=True` in the [`SFTConfig`]. This setting ensures that loss is computed **only** on the assistant responses, ignoring user or system messages. + +```python +training_args = SFTConfig(assistant_only_loss=True) +``` + +![train_on_assistant](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/train_on_assistant.png) + +> [!WARNING] +> This functionality is only available for chat templates that support returning the assistant tokens mask via the `{% generation %}` and `{% endgeneration %}` keywords. For an example of such a template, see [HugggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82). + +### Train on completion only + +To train on completion only, use a [prompt-completion](dataset_formats#prompt-completion) dataset. By default, the trainer computes the loss on the completion tokens only, ignoring the prompt tokens. If you want to train on the full sequence, set `completion_only_loss=False` in the [`SFTConfig`]. + +![train_on_completion](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/train_on_completion.png) + +> [!TIP] +> Training on completion only is compatible with training on assistant messages only. In this case, use a [conversational](dataset_formats#conversational) [prompt-completion](dataset_formats#prompt-completion) dataset and set `assistant_only_loss=True` in the [`SFTConfig`]. + +### Train adapters with PEFT + +We support tight integration with 🤗 PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model. + +```python +from datasets import load_dataset +from trl import SFTTrainer +from peft import LoraConfig + +dataset = load_dataset("trl-lib/Capybara", split="train") + +trainer = SFTTrainer( + "Qwen/Qwen3-0.6B", + train_dataset=dataset, + peft_config=LoraConfig() +) + +trainer.train() +``` + +You can also continue training your [`~peft.PeftModel`]. For that, first load a `PeftModel` outside [`SFTTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed. + +```python +from datasets import load_dataset +from trl import SFTTrainer +from peft import AutoPeftModelForCausalLM + +model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-LoRA", is_trainable=True) +dataset = load_dataset("trl-lib/Capybara", split="train") + +trainer = SFTTrainer( + model=model, + train_dataset=dataset, +) + +trainer.train() +``` + +> [!TIP] +> When training adapters, you typically use a higher learning rate (≈1e‑4) since only new parameters are being learned. +> +> ```python +> SFTConfig(learning_rate=1e-4, ...) +> ``` + +### Train with Liger Kernel + +Liger Kernel is a collection of Triton kernels for LLM training that boosts multi-GPU throughput by 20%, cuts memory use by 60% (enabling up to 4× longer context), and works seamlessly with tools like FlashAttention, PyTorch FSDP, and DeepSpeed. For more information, see [Liger Kernel Integration](liger_kernel_integration). + +### Rapid Experimentation for SFT + +RapidFire AI is an open-source experimentation engine that sits on top of TRL and lets you launch multiple SFT configurations at once, even on a single GPU. Instead of trying configurations sequentially, RapidFire lets you **see all their learning curves earlier, stop underperforming runs, and clone promising ones with new settings in flight** without restarting. For more information, see [RapidFire AI Integration](rapidfire_integration). + +### Train with Unsloth + +Unsloth is an open‑source framework for fine‑tuning and reinforcement learning that trains LLMs (like Llama, Mistral, Gemma, DeepSeek, and more) up to 2× faster with up to 70% less VRAM, while providing a streamlined, Hugging Face–compatible workflow for training, evaluation, and deployment. For more information, see [Unsloth Integration](unsloth_integration). + +## Instruction tuning example + +**Instruction tuning** teaches a base language model to follow user instructions and engage in conversations. This requires: + +1. **Chat template**: Defines how to structure conversations into text sequences, including role markers (user/assistant), special tokens, and turn boundaries. Read more about chat templates in [Chat templates](https://huggingface.co/docs/transformers/chat_templating#templates). +2. **Conversational dataset**: Contains instruction-response pairs + +This example shows how to transform the [Qwen 3 0.6B Base](https://huggingface.co/Qwen/Qwen3-0.6B-Base) model into an instruction-following model using the [Capybara dataset](https://huggingface.co/datasets/trl-lib/Capybara) and a chat template from [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B). The SFT Trainer automatically handles tokenizer updates and special token configuration. + +```python +from trl import SFTConfig, SFTTrainer +from datasets import load_dataset + +trainer = SFTTrainer( + model="Qwen/Qwen3-0.6B-Base", + args=SFTConfig( + output_dir="Qwen3-0.6B-Instruct", + chat_template_path="HuggingFaceTB/SmolLM3-3B", + ), + train_dataset=load_dataset("trl-lib/Capybara", split="train"), +) +trainer.train() +``` + +> [!WARNING] +> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply [`clone_chat_template()`], as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in [`SFTConfig`]; for example, for `Qwen/Qwen2.5-1.5B`, one should set `eos_token="<|im_end|>"`. + +Once trained, your model can now follow instructions and engage in conversations using its new chat template. + +```python +>>> from transformers import pipeline +>>> pipe = pipeline("text-generation", model="Qwen3-0.6B-Instruct/checkpoint-5000") +>>> prompt = "<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\n" +>>> response = pipe(prompt) +>>> response[0]["generated_text"] +'<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\nThe capital of France is Paris.' +``` + +Alternatively, use the structured conversation format (recommended): + +```python +>>> prompt = [{"role": "user", "content": "What is the capital of France? Answer in one word."}] +>>> response = pipe(prompt) +>>> response[0]["generated_text"] +[{'role': 'user', 'content': 'What is the capital of France? Answer in one word.'}, {'role': 'assistant', 'content': 'The capital of France is Paris.'}] +``` + +## Tool Calling with SFT + +The [`SFTTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include: + +* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages) +* The list of available tools in the `tools` column, typically provided as JSON schemas + +For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section. + +## Training Vision Language Models + +[`SFTTrainer`] fully supports training Vision-Language Models (VLMs). To train a VLM, you need to provide a dataset with an additional `images` column containing the images to be processed. For more information on the expected dataset structure, see the [Dataset Format — Vision Dataset](dataset_formats#vision-dataset) section. +An example of such a dataset is the [LLaVA Instruct Mix](https://huggingface.co/datasets/trl-lib/llava-instruct-mix). + +```python +from trl import SFTConfig, SFTTrainer +from datasets import load_dataset + +trainer = SFTTrainer( + model="Qwen/Qwen2.5-VL-3B-Instruct", + args=SFTConfig(max_length=None), + train_dataset=load_dataset("trl-lib/llava-instruct-mix", split="train"), +) +trainer.train() +``` + +> [!TIP] +> For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set `max_length=None` in the [`SFTConfig`]. This allows the model to process the full sequence length without truncating image tokens. +> +> ```python +> SFTConfig(max_length=None, ...) +> ``` +> +> Only use `max_length` when you've verified that truncation won't remove image tokens for the entire dataset. + +## SFTTrainer + +[[autodoc]] SFTTrainer + - train + - save_model + - push_to_hub + +## SFTConfig + +[[autodoc]] SFTConfig + +## DataCollatorForLanguageModeling + +[[autodoc]] trainer.sft_trainer.DataCollatorForLanguageModeling + +## DataCollatorForVisionLanguageModeling + +[[autodoc]] trainer.sft_trainer.DataCollatorForVisionLanguageModeling diff --git a/ICL/RL/trl_source/docs/source/speeding_up_training.md b/ICL/RL/trl_source/docs/source/speeding_up_training.md new file mode 100644 index 0000000000000000000000000000000000000000..ec34a27454195b273d9db2cb28bd96f8a73594c0 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/speeding_up_training.md @@ -0,0 +1,204 @@ +# Speeding Up Training + +This guide covers various methods to accelerate training in TRL. Each technique includes minimal examples with links to more comprehensive documentation. + +## vLLM for fast generation in online methods + +[Online methods](index#online-methods) such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time. +To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed. For more details, see [vLLM Integration](vllm_integration). + +To use [vLLM](https://github.com/vllm-project/vllm), first install it using: + +```bash +pip install trl[vllm] +``` + + + + +First, start a vLLM server by running: + +```bash +trl vllm-serve --model +``` + +Then, run the training script and pass `use_vllm=True` in the training arguments. + +```python +from trl.experimental.online_dpo import OnlineDPOConfig + +training_args = OnlineDPOConfig(..., use_vllm=True) +``` + + + + +First, start a vLLM server by running: + +```bash +trl vllm-serve --model +``` + +Then, run the training script and pass `use_vllm=True` in the training arguments. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., use_vllm=True) +``` + +You can customize the server configuration by passing additional arguments. For more information, see [vLLM integration](vllm_integration). + +> [!WARNING] +> When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`. +> +> Set GPUs **0-3** for vLLM generation: +> +> ```sh +> CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model +> ``` +> +> And GPUs **4-7** for training: +> +> ```sh +> CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py +> ``` + + + + +First, start a vLLM server by running: + +```bash +trl vllm-serve --model +``` + +Then, run the training script and pass `use_vllm=True` in the training arguments. + +```python +from trl import RLOOConfig + +training_args = RLOOConfig(..., use_vllm=True) +``` + +You can customize the server configuration by passing additional arguments. For more information, see [vLLM integration](vllm_integration). + +> [!WARNING] +> When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`. +> +> Set GPUs **0-3** for vLLM generation: +> +> ```sh +> CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model +> ``` +> +> And GPUs **4-7** for training: +> +> ```sh +> CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py +> ``` + + + + +## Optimized attention implementations + +TRL supports various optimized attention implementations that can significantly speed up training while reducing memory usage. You can use either a pre-optimized kernels directly from the [Kernels Hub](kernels_hub) or a manually built attention backend. + + + + +You can use pre-optimized attention kernels from the Hub without manual compilation: + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"}) +``` + +Other options include `kernels-community/vllm-flash-attn3` and `kernels-community/paged-attention`. + +Optimized attention works across all TRL trainers. For more details, see [Kernels Hub Integration](kernels_hub). + + + + +> [!WARNING] +> Manually building optimized attention backends is complex and time-consuming. It's never recommended unless absolutely necessary. Consider using Kernels from the Hub instead, as described in the previous section. + +If you have manually installed an optimized attention backend like Flash Attention 2, you can specify it in the training arguments: + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., model_init_kwargs={"attn_implementation": "flash_attention_2"}) +``` + + + + +## Liger Kernel for memory optimization + +Liger Kernel is a collection of Triton kernels designed for LLM training that can increase throughput by 20% and reduce memory usage by 60%. + + + + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl.experimental.kto import KTOConfig + +training_args = KTOConfig(..., use_liger_kernel=True) +``` + + + + +```python +from trl.experimental.gkd import GKDConfig + +training_args = GKDConfig(..., use_liger_kernel=True) +``` + + + + +For more information, see [Liger Kernel Integration](liger_kernel_integration). + +## Mixed precision training + +Mixed precision training using bf16 or fp16 can speed up training and reduce memory usage with minimal impact on model quality. + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., bf16=True) # or fp16=True for older GPUs +``` + +Use `bf16=True` for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True` for older GPUs. Mixed precision training is supported across all TRL trainers. diff --git a/ICL/RL/trl_source/docs/source/trackio_integration.md b/ICL/RL/trl_source/docs/source/trackio_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..4e93120fe19a8dfe3443e791bff8d60350908b8b --- /dev/null +++ b/ICL/RL/trl_source/docs/source/trackio_integration.md @@ -0,0 +1,67 @@ +# Trackio Integration + +[Trackio](https://huggingface.co/docs/trackio) is a lightweight, free experiment tracking library built on top of **🤗 Datasets** and **🤗 Spaces**. It is the **recommended tracking solution for TRL** and comes natively integrated with all trainers. + +To enable logging, simply set `report_to="trackio"` in your training config: + +```python +from trl import SFTConfig # works with any trainer config (e.g. DPOConfig, GRPOConfig, etc.) + +training_args = SFTConfig( + ..., + report_to="trackio", # enable Trackio logging +) +``` + +## Organizing Your Experiments with Run Names and Projects + +By default, Trackio will generate a name to identify each run. However, we highly recommend setting a descriptive `run_name` to make it easier to organize experiments. For example: + +```python +from trl import SFTConfig + +training_args = SFTConfig( + ..., + report_to="trackio", + run_name="sft_qwen3-4b_lr2e-5_bs128", # descriptive run name +) +``` + +You can also group related experiments by project by setting the following environment variable: + +```bash +export TRACKIO_PROJECT="my_project" +``` + +## Hosting Your Logs on 🤗 Spaces + +Trackio has local-first design, meaning your logs stay on your machine. If you’d like to host them and deploy a dashboard on **🤗 Spaces**, set: + +```bash +export TRACKIO_SPACE_ID="username/space_id" +``` + +Running the following example: + +```python +import os +from trl import SFTConfig, SFTTrainer +from datasets import load_dataset + +os.environ["TRACKIO_SPACE_ID"] = "trl-lib/trackio" +os.environ["TRACKIO_PROJECT"] = "trl-documentation" + +trainer = SFTTrainer( + model="Qwen/Qwen3-0.6B", + train_dataset=load_dataset("trl-lib/Capybara", split="train"), + args=SFTConfig( + report_to="trackio", + run_name="sft_qwen3-0.6b_capybara", + ), +) +trainer.train() +``` + +will give you a hosted dashboard at https://huggingface.co/spaces/trl-lib/trackio. + + diff --git a/ICL/RL/trl_source/docs/source/unsloth_integration.md b/ICL/RL/trl_source/docs/source/unsloth_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..0350bc21612580f6d333860872289e8f1b4b201c --- /dev/null +++ b/ICL/RL/trl_source/docs/source/unsloth_integration.md @@ -0,0 +1,125 @@ +# Unsloth Integration + +Unsloth is an open‑source framework for fine‑tuning and reinforcement learning that trains LLMs (like Llama, OpenAI gpt-oss, Mistral, Gemma, DeepSeek, and more) up to 2× faster with up to 80% less VRAM. Unsloth allows [training](https://huggingface.co/docs/trl/en/unsloth_integration#Training), evaluation, running and [deployment](https://huggingface.co/docs/trl/en/unsloth_integration#Saving-the-model) with other inference engines like llama.cpp, Ollama and vLLM. + +The library provides a streamlined, Hugging Face compatible workflow for training, evaluation, inference and deployment and is fully compatible with [`SFTTrainer`]. + +## Key Features + +- Training support for all transformer compatible models: Text-to-speech (TTS), multimodal, BERT, RL and more +- Supports full fine-tuning, pretraining, LoRA, QLoRA, 8-bit training & more +- Works on Linux, Windows, Colab, Kaggle; NVIDIA GPUs, soon AMD & Intel setups +- Supports most features TRL supports, including RLHF (GSPO, GRPO, DPO etc.) +- Hand-written Triton kernels and a manual backprop engine ensure no accuracy degradation (0% approximation error) + +## Installation + +### pip install + +Local Installation (Linux recommended): + +```sh +pip install unsloth +``` + +You can also install `unsloth` according to the [official documentation](https://docs.unsloth.ai/get-started/installing-+-updating). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading [`~transformers.AutoModelForCausalLM`], you just need to load a `FastLanguageModel` as follows: + +```python +import torch +from trl import SFTConfig, SFTTrainer +from unsloth import FastLanguageModel + +max_length = 2048 # Supports automatic RoPE Scaling, so choose any number + +# Load model +model, tokenizer = FastLanguageModel.from_pretrained( + model_name="unsloth/mistral-7b", + max_seq_length=max_length, + dtype="auto", # For auto-detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ + load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False +) + +# Do model patching and add fast LoRA weights +model = FastLanguageModel.get_peft_model( + model, + r=16, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + lora_alpha=16, + lora_dropout=0, # Dropout = 0 is currently optimized + bias="none", # Bias = "none" is currently optimized + use_gradient_checkpointing=True, + random_state=3407, +) + +training_args = SFTConfig(output_dir="./output", max_length=max_length) + +trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset, +) +trainer.train() +``` + +The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth). + +### Docker Install + +```sh +docker run -d -e JUPYTER_PASSWORD="mypassword" \ + -p 8888:8888 -p 2222:22 \ + -v $(pwd)/work:/workspace/work \ + --gpus all \ + unsloth/unsloth +``` + +Access Jupyter Lab at ```http://localhost:8888``` and start fine-tuning! + +## Training + +These are some core settings you can toggle before training: + +- ```max_seq_length = 2048``` – Controls context length. While Llama-3 supports 8192, we recommend 2048 for testing. Unsloth enables 4× longer context fine-tuning. +- ```dtype = "auto"``` – For auto-detection; use torch.float16 or torch.bfloat16 for newer GPUs. +- ```load_in_4bit = True``` – Enables 4-bit quantization, reducing memory use 4× for fine-tuning. Disabling it allows for LoRA 16-bit fine-tuning to be enabled. +- To enable full fine-tuning (FFT), set ```full_finetuning = True```. For 8-bit fine-tuning, set ```load_in_8bit = True```. Note: Only one training method can be set to True at a time. + +For more information on configuring Unsloth's hyperparameters and features, read their [documentation guide here](https://docs.unsloth.ai/get-started/fine-tuning-llms-guide). + +## Saving the model + +Unsloth allows you to directly save the finetuned model as a small file called a LoRA adapter. You can instead push to the Hugging Face hub as well if you want to upload your model! Remember to get a [Hugging Face token](https://huggingface.co/settings/tokens) and add your token! + +### Saving to GGUF + +To save to GGUF, Unsloth uses llama.cpp. To save locally: + +```python +model.save_pretrained_gguf("directory", tokenizer, quantization_method = "q4_k_m") +model.save_pretrained_gguf("directory", tokenizer, quantization_method = "q8_0") +model.save_pretrained_gguf("directory", tokenizer, quantization_method = "f16") +``` + +To push to the hub: + +```python +model.push_to_hub_gguf("hf_username/directory", tokenizer, quantization_method = "q4_k_m") +model.push_to_hub_gguf("hf_username/directory", tokenizer, quantization_method = "q8_0") +``` + +### Saving to vLLM + +To save to 16-bit for vLLM, use: + +```python +model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",) +model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "") +``` diff --git a/ICL/RL/trl_source/docs/source/use_model.md b/ICL/RL/trl_source/docs/source/use_model.md new file mode 100644 index 0000000000000000000000000000000000000000..058f20d18e3c211d0693bc9bd4f22f84898aa496 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/use_model.md @@ -0,0 +1,58 @@ +# Use model after training + +Once you have trained a model using either the SFTTrainer, PPOTrainer, or DPOTrainer, you will have a fine-tuned model that can be used for text generation. In this section, we'll walk through the process of loading the fine-tuned model and generating text. If you need to run an inference server with the trained model, you can explore libraries such as [`text-generation-inference`](https://github.com/huggingface/text-generation-inference). + +## Load and Generate + +If you have fine-tuned a model fully, meaning without the use of PEFT you can simply load it like any other language model in transformers. E.g. the value head that was trained during the PPO training is no longer needed and if you load the model with the original transformer class it will be ignored: + +```python +from transformers import AutoTokenizer, AutoModelForCausalLM + +model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub +device = "cpu" # or "cuda" if you have a GPU + +model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device) +tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + +inputs = tokenizer.encode("This movie was really", return_tensors="pt").to(device) +outputs = model.generate(inputs) +print(tokenizer.decode(outputs[0])) +``` + +Alternatively you can also use the pipeline: + +```python +from transformers import pipeline + +model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub +pipe = pipeline("text-generation", model=model_name_or_path) +print(pipe("This movie was really")[0]["generated_text"]) +``` + +## Use Adapters PEFT + +```python +from peft import PeftConfig, PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer + +base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub +adapter_model_name = "path/to/my/adapter" + +model = AutoModelForCausalLM.from_pretrained(base_model_name) +model = PeftModel.from_pretrained(model, adapter_model_name) + +tokenizer = AutoTokenizer.from_pretrained(base_model_name) +``` + +You can also merge the adapters into the base model so you can use the model like a normal transformers model, however the checkpoint will be significantly bigger: + +```python +model = AutoModelForCausalLM.from_pretrained(base_model_name) +model = PeftModel.from_pretrained(model, adapter_model_name) + +model = model.merge_and_unload() +model.save_pretrained("merged_adapters") +``` + +Once you have the model loaded and either merged the adapters or keep them separately on top you can run generation as with a normal model outlined above. diff --git a/ICL/RL/trl_source/docs/source/vllm_integration.md b/ICL/RL/trl_source/docs/source/vllm_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..16ff5b722f42826dd472622256eadc7a8151279d --- /dev/null +++ b/ICL/RL/trl_source/docs/source/vllm_integration.md @@ -0,0 +1,449 @@ +# vLLM Integration + +This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. + +> [!WARNING] +> TRL currently only supports vLLM versions `0.10.2`, `0.11.0`, `0.11.1`, `0.11.2` and `0.12.0`. Please ensure you have one of these versions installed to avoid compatibility issues. + +> [!TIP] +> The following trainers currently support generation with vLLM: +> +> - [`GRPOTrainer`] +> - [`RLOOTrainer`] +> - [`experimental.nash_md.NashMDTrainer`] +> - [`experimental.online_dpo.OnlineDPOTrainer`] +> - [`experimental.xpo.XPOTrainer`] + +## 🚀 How can I use vLLM with TRL to speed up training? + +💡 **Note**: Resources required for this specific example: a single node with 8 GPUs. + +> [!WARNING] +> When using vLLM with TRL, the **vLLM server** and the **trainer** must run on **separate CUDA devices** to prevent conflicts. +> For guidance on configuring this properly, see [Modes of using vLLM during training](#modes-of-using-vllm-during-training). + +First, install vLLM using the following command: + +```bash +pip install "trl[vllm]" +``` + +Then run the server on specific GPUs (e.g., GPUs 0-3): + +```sh +CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2 +``` + +Once the server is running, you can use it to generate completions for training. In the example below, we are using the different supported trainers using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs. + +In this example, we are sharding two copies of the model across 4 GPUs. Increasing data parallelism increases throughput, while increasing tensor parallelism allows for serving larger models. Then, run the training script on different GPUs (e.g., GPUs 4-7) by passing `use_vllm=True` in the training arguments as follows: + +Sample of a simple `train.py` script: + + + + +```python +from datasets import load_dataset +from trl import GRPOTrainer, GRPOConfig +from trl.rewards import accuracy_reward + +dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + +trainer = GRPOTrainer( + model="Qwen/Qwen2.5-7B", + args=GRPOConfig(use_vllm=True), + reward_funcs=accuracy_reward, + train_dataset=dataset, +) + +trainer.train() +``` + + + + +```python +from datasets import load_dataset +from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer +from trl.rewards import accuracy_reward + +dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + +trainer = OnlineDPOTrainer( + model="Qwen/Qwen2.5-7B", + args=OnlineDPOConfig(use_vllm=True), + reward_funcs=accuracy_reward, + train_dataset=dataset, +) + +trainer.train() +``` + + + + +```python +from datasets import load_dataset +from trl.experimental.nash_md import NashMDConfig, NashMDTrainer +from trl.rewards import accuracy_reward + +dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + +trainer = NashMDTrainer( + model="Qwen/Qwen2.5-7B", + args=NashMDConfig(use_vllm=True), + reward_funcs=accuracy_reward, + train_dataset=dataset, +) + +trainer.train() +``` + + + + +```python +from datasets import load_dataset +from trl.experimental.xpo import XPOTrainer, XPOConfig +from trl.rewards import accuracy_reward + +dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + +trainer = XPOTrainer( + model="Qwen/Qwen2.5-7B", + args=XPOConfig(use_vllm=True), + reward_funcs=accuracy_reward, + train_dataset=dataset, +) + +trainer.train() +``` + + + + +```python +from datasets import load_dataset +from trl import RLOOTrainer, RLOOConfig +from trl.rewards import accuracy_reward + +dataset = load_dataset("trl-lib/DeepMath-103K", split="train") + +trainer = RLOOTrainer( + model="Qwen/Qwen2.5-7B", + args=RLOOConfig(use_vllm=True), + reward_funcs=accuracy_reward, + train_dataset=dataset, +) + +trainer.train() +``` + + + + +And the train command on separate GPUs from the server: + +```sh +CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py +``` + +## Why using vLLM? + +### 🎬 Flashback: Why do we need to use vLLM in online methods? + +Online methods like GRPO or Online DPO require the model to generate completions during training, which are then used to compute reward signals. However, generation can be extremely time-consuming, especially with large or reasoning models. In the default setup (without vLLM), completions are generated using the [(unwrapped) model's `generate` method](https://github.com/huggingface/trl/blob/f3e8c2304428ef16e9ae5de9e5741ed84d533b7b/trl/trainer/grpo_trainer.py#L965C39-L965C66). This approach quickly becomes a major bottleneck — generation is slow and inefficient, particularly for large batches or models. As a result, training times increase significantly, and overall efficiency drops. To address this, we turn to vLLM, which enables much faster and more scalable generation, helping eliminate this bottleneck in online methods. + +### 🤔 How does vLLM solve the slow generation issue? + +If you've ever done autoregressive decoder training, you know all the input tokens to the LLM produce their attention key and value tensors, and these tensors are kept in GPU memory to later generate subsequent tokens based on them. These cached key and value tensors are often referred to as the KV cache. However, storing the KV cache occupies a lot of memory, so vLLM uses a technique called **PagedAttention** to solve this problem. PagedAttention, which is inspired by the OS’s virtual memory concept, stores continuous keys and values in **non-contiguous memory space**, which is much more efficient. The details of this are beyond the scope of this document, but in short, it allows the model to store the keys and values in a more efficient way, reducing the memory footprint and speeding up the generation process. If you are interested, make sure to check out the [vLLM PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html) for more details. + +## How vLLM Works (Under the Hood) 🔍 + +### 🤔 What exactly happens when you run `trl vllm-serve --model `? + +When you run for example + +```sh +CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 1 --data-parallel-size 4 +``` + +the following happens: + +![vllm](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/vllm-doc.png) + +1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (1 × 4). +Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation. A key point to understand is that these 4 workers are running in parallel, and each one is responsible for handling a subset of the total incoming load. + +2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the model’s weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled. Meanwhile, data parallelism (controlled by `--data-parallel-size`) ensures that different sets of requests are processed independently across the workers. In short: tensor parallelism splits the model across GPUs, and data parallelism splits the batch of requests across different model replicas. + +3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--data-parallel-size=4`, each GPU processes 2 prompts). +This GPU-to-GPU communication is managed efficiently by NVIDIA’s NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — it’s lightweight and doesn’t interfere with generation itself. +Separately, the number of completions to generate per prompt is controlled by the `num_generations` setting in the GRPO config. For instance, if you set `num_generations=2` (like in the picture above), each prompt will have 2 completions. So, with 8 prompts and `num_generations=2`, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings. + +### 🥸 More detail on what happens under the hood when running the server + +- The vLLM server starts by running the command: `trl vllm-serve --model Qwen/Qwen2.5-7B`. +- Once the server is running, it generates completions based on requests from the client (trainer) using `vllm_client.generate` [these lines](https://github.com/huggingface/trl/blob/cc044e35b285be7dc062764b3364e1e684db4c7c/trl/trainer/grpo_trainer.py#L1025-L1035). +- The client (trainer) then requests these completions from the server. +- These completions are used to compute the reward signal. +- Based on the reward signal and the model’s output, the loss is computed, and the backward pass is performed to update the model’s weights. +- **Note**: The server only handles completion generation — it doesn’t train the model. Therefore, the model’s weights aren’t updated on the server. Once the backward pass is complete, the client sends the updated weights to the server using `vllm_client.update_named_param(name, param.data)`. + +When using vLLM, ensure the GPUs assigned for training and generation are separate to avoid NCCL communication conflicts. If you do not set the `CUDA_VISIBLE_DEVICES` environment variable, the training script will use all available GPUs by default, which may lead to device conflicts. Starting from TRL next release after v0.19.1, the code automatically detects and prevents same-device usage, raising a error at the vllm server process: + +```log +RuntimeError: Attempting to use the same CUDA device for multiple distinct roles/ranks within the same communicator. +Ensure that trainer is using different devices than vLLM server. +``` + +For example, if you want to use GPUs 4–7 for training while the server runs on GPUs 0-3, set: + +```sh +CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py +``` + +## Advanced usage + +### 🍷 More customization options with vLLM? + +You can customize the server configuration by passing additional arguments. + +```txt +$ trl vllm-serve --help +usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE] [--data_parallel_size DATA_PARALLEL_SIZE] [--host HOST] + [--port PORT] [--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE] [--max_model_len MAX_MODEL_LEN] + [--enable_prefix_caching ENABLE_PREFIX_CACHING] [--enforce_eager [ENFORCE_EAGER]] [--kv_cache_dtype KV_CACHE_DTYPE] + [--trust_remote_code [TRUST_REMOTE_CODE]] [--log_level LOG_LEVEL] [--vllm_model_impl VLLM_MODEL_IMPL] + +options: + -h, --help show this help message and exit + --model MODEL Model name or path to load the model from. (default: None) + --revision REVISION Revision to use for the model. If not specified, the default branch will be used. (default: None) + --tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE + Number of tensor parallel workers to use. (default: 1) + --data_parallel_size DATA_PARALLEL_SIZE, --data-parallel-size DATA_PARALLEL_SIZE + Number of data parallel workers to use. (default: 1) + --host HOST Host address to run the server on. (default: 0.0.0.0) + --port PORT Port to run the server on. (default: 8000) + --gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION + Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device dedicated to generation + powered by vLLM. Higher values will increase the KV cache size and thus improve the model's throughput. However, if the value is too high, + it may cause out-of-memory (OOM) errors during initialization. (default: 0.9) + --dtype DTYPE Data type to use for vLLM generation. If set to 'auto', the data type will be automatically determined based on the model configuration. + Find the supported values in the vLLM documentation. (default: auto) + --max_model_len MAX_MODEL_LEN, --max-model-len MAX_MODEL_LEN + If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced `vllm_gpu_memory_utilization`, leading to a + reduced KV cache size. If not set, vLLM will use the model context size, which might be much larger than the KV cache, leading to + inefficiencies. (default: None) + --enable_prefix_caching ENABLE_PREFIX_CACHING, --enable-prefix-caching ENABLE_PREFIX_CACHING + Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this feature. (default: None) + --enforce_eager [ENFORCE_EAGER], --enforce-eager [ENFORCE_EAGER] + Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the model in eager mode. If `False` + (default behavior), we will use CUDA graph and eager execution in hybrid. (default: False) + --kv_cache_dtype KV_CACHE_DTYPE, --kv-cache-dtype KV_CACHE_DTYPE + Data type to use for KV cache. If set to 'auto', the dtype will default to the model data type. (default: auto) + --trust_remote_code [TRUST_REMOTE_CODE], --trust-remote-code [TRUST_REMOTE_CODE] + Whether to trust remote code when loading models. Set to True to allow executing code from model repositories. This is required for some + custom models but introduces security risks. (default: False) + --log_level LOG_LEVEL, --log-level LOG_LEVEL + Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', 'trace'. (default: info) + --vllm_model_impl VLLM_MODEL_IMPL, --vllm-model-impl VLLM_MODEL_IMPL + Model implementation to use for vLLM. Must be one of `transformers` or `vllm`. `transformers`: Use the `transformers` backend for model + implementation. `vllm`: Use the `vllm` library for model implementation. (default: vllm) +``` + +### 💆🏻‍♀️ What's the best distributed setup? + +![tp dp throughput 8 gpus](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_8_gpus.png) +![tp dp throughput 4 gpus](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_4_gpus.png) + +First and foremost, always remember that the optimal setup depends on: + +- The model size +- The number of GPUs you have +- The GPU memory size +- The batch size you are using +- The number of requests you are sending to the server (prompts) +- The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size) +- The number of completions you are generating for each request (`num_generations`) + +Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that: + +- For reasonable-sized models (3B–14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results. +- For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window. + +### vLLM with Transformers Backend + +vLLM can use the **Transformers backend** for model implementations, which works for both LLMs and VLMs. +To enable this, set `vllm_model_impl="transformers"` in your configuration or pass it via the command-line argument. + +For more details, check out [vLLM Transformers Backend](https://blog.vllm.ai/2025/04/11/transformers-backend.html). + +Example: + +```sh +CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen +2.5-VL-3B-Instruct --tensor-parallel-size 1 --port 8000 --enforce_eager --vllm_model_impl transformers +``` + +### Modes of Using vLLM During Training + +TRL supports **two modes** for integrating vLLM during training: **server mode** and **colocate mode**. + +#### Server Mode + +In **server mode**, vLLM runs as a separate process on dedicated GPUs and communicates with the trainer via HTTP. +This setup is ideal if you have GPUs dedicated to inference. + +Example configuration: + + + + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted +) +``` + + + + +```python +from trl.experimental.online_dpo import OnlineDPOConfig + +training_args = OnlineDPOConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted +) +``` + + + + +```python +from trl.experimental.nash_md import NashMDConfig + +training_args = NashMDConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted +) +``` + + + + +```python +from trl.experimental.xpo import XPOConfig + +training_args = XPOConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted +) +``` + + + + +```python +from trl import RLOOConfig + +training_args = RLOOConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted +) +``` + + + + +#### Colocate Mode + +In **colocate mode**, vLLM runs inside the trainer process and shares GPU memory with the training model. +This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs. + +Example configuration: + + + + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + + + + +```python +from trl.experimental.online_dpo import OnlineDPOConfig + +training_args = OnlineDPOConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + + + + +```python +from trl.experimental.nash_md import NashMDConfig + +training_args = NashMDConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + + + + +```python +from trl.experimental.xpo import XPOConfig + +training_args = XPOConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + + + + +```python +from trl import RLOOConfig + +training_args = RLOOConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + + + + +> [!WARNING] +> Check the documentation of the trainer you are using for specific details on vLLM usage and parameters. + +> [!WARNING] +> To reduce GPU memory usage when running vLLM, consider [enabling vLLM sleep mode](reducing_memory_usage#vllm-sleep-mode). diff --git a/ICL/RL/trl_source/docs/source/winrate_callback.md b/ICL/RL/trl_source/docs/source/winrate_callback.md new file mode 100644 index 0000000000000000000000000000000000000000..c139174620281cc721d9d41eda52ef1490a4195c --- /dev/null +++ b/ICL/RL/trl_source/docs/source/winrate_callback.md @@ -0,0 +1,3 @@ +# WinRateCallback + +[[autodoc]] experimental.winrate_callback.WinRateCallback diff --git a/ICL/RL/trl_source/docs/source/xpo_trainer.md b/ICL/RL/trl_source/docs/source/xpo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..abc22f1538b0fe6dd04e5b1f2c93b8698f9c5505 --- /dev/null +++ b/ICL/RL/trl_source/docs/source/xpo_trainer.md @@ -0,0 +1,164 @@ +# XPO Trainer + +[![model badge](https://img.shields.io/badge/All_models-XPO-blue)](https://huggingface.co/models?other=xpo,trl) + +## Overview + +Exploratory Preference Optimization (XPO) was proposed in the paper [Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF](https://huggingface.co/papers/2405.21046) by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, [Corby Rosset](https://huggingface.co/corbyrosset), [Ahmed Awadallah](https://huggingface.co/AhmedAwadallah), and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the initial model and human feedback data. + +The abstract from the paper is the following: + +> Reinforcement learning from human feedback (RLHF) has emerged as a central tool for language model alignment. We consider online exploration in RLHF, which exploits interactive access to human or AI feedback by deliberately encouraging the model to produce diverse, maximally informative responses. By allowing RLHF to confidently stray from the pre-trained model, online exploration offers the possibility of novel, potentially super-human capabilities, but its full potential as a paradigm for language model training has yet to be realized, owing to computational and statistical bottlenecks in directly adapting existing reinforcement learning techniques. We propose a new algorithm for online exploration in RLHF, Exploratory Preference Optimization (XPO), which is simple and practical -- a one-line change to (online) Direct Preference Optimization (DPO; Rafailov et al., 2023) -- yet enjoys the strongest known provable guarantees and promising empirical performance. XPO augments the DPO objective with a novel and principled exploration bonus, empowering the algorithm to explore outside the support of the initial model and human feedback data. In theory, we show that XPO is provably sample-efficient and converges to a near-optimal language model policy under natural exploration conditions, irrespective of whether the initial model has good coverage. Our analysis, which builds on the observation that DPO implicitly performs a form of Q*-approximation (or, Bellman error minimization), combines previously disparate techniques from language modeling and theoretical reinforcement learning in a serendipitous fashion through the perspective of KL-regularized Markov decision processes. Empirically, we find that XPO is more sample-efficient than non-exploratory DPO variants in a preliminary evaluation. + +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Lewis Tunstall](https://huggingface.co/lewtun). + +> [!NOTE] +> XPO is currently experimental. The API may change without notice while the feature is iterated on. + +## Quick start + +This example demonstrates how to train a model using the XPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`experimental.judges.PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here: + + +Below is the script to train the model: + +```python +# train_xpo.py +from datasets import load_dataset +from trl.experimental.judges import PairRMJudge +from trl.experimental.xpo import XPOConfig, XPOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +judge = PairRMJudge() +train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train") + +training_args = XPOConfig(output_dir="Qwen2-0.5B-XPO") +trainer = XPOTrainer( + model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset +) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_xpo.py +``` + +Distributed across 8 GPUs, the training takes approximately 1 hour. + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models). + +
$ transformers chat trl-lib/Qwen2-0.5B-XPO
+<quentin_gallouedec>:
+What is the best programming language?
+
+<trl-lib/Qwen2-0.5B-XPO>:
+The best programming language depends on individual preferences and familiarity with coding concepts. Some popular languages include Python, Java, C++, and JavaScript.
+
+ +## Expected dataset type + +XPO requires a [prompt-only dataset](dataset_formats#prompt-only). The [`experimental.xpo.XPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +## Usage tips + +### Use a reward model + +Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model: + +```diff +- from trl.experimental.judges import PairRMJudge ++ from transformers import AutoModelForSequenceClassification + +- judge = PairRMJudge() ++ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1) + + trainer = XPOTrainer( + ... +- judge=judge, ++ reward_funcs=reward_model, + ) +``` + +> [!WARNING] +> Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training. + +### Encourage EOS token generation + +When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`experimental.xpo.XPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`experimental.xpo.XPOConfig`]: + +```python +training_args = XPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0) +``` + +### Logging Completions + +To better understand your model’s behavior during training, you can log sample completions periodically using the [`LogCompletionsCallback`]. + +```python +trainer = XPOTrainer(..., eval_dataset=eval_dataset) +completions_callback = LogCompletionsCallback(trainer, num_prompts=8) +trainer.add_callback(completions_callback) +``` + +This callback logs the model's generated completions directly to Weights & Biases. + +![Logged Completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/wandb_completions.png) + +## Example script + +We provide an example script to train a model using the XPO method. The script is available in [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) + +To test the XPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command: + +```bash +python examples/scripts/xpo.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --judge pair_rm \ + --dataset_name trl-lib/ultrafeedback-prompt \ + --learning_rate 5.0e-7 \ + --output_dir Qwen2.5-0.5B-XPO-PairRM \ + --warmup_steps 0.1 \ + --push_to_hub +``` + +## Logged metrics + +While training and evaluating we record the following reward metrics: + +* `loss/xpo`: The mean xpo part of the full loss. +* `loss/dpo`: The mean dpo part of the full loss. +* `objective/kl`: The mean KL divergence between the model and reference data. +* `objective/entropy`: The mean entropy of the model and reference data. +* `objective/model_scores`: The mean scores (according to the reward model) of the model completions. +* `objective/ref_scores`: The mean scores (according to the reward model) of the reference completions. +* `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions. +* `rewards/chosen`: The mean reward (according to XPO's DPO implicit reward model) of the chosen completions. +* `rewards/rejected`: The mean reward (according to XPO's DPO implicit reward model) of the rejected completions. +* `rewards/accuracies`: The accuracies of the XPO's implicit reward model. +* `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions. +* `logps/chosen`: The mean log probabilities of the chosen completions. +* `logps/rejected`: The mean log probabilities of the rejected completions. +* `val/model_contain_eos_token`: The amount of times the model's output contains the eos token. +* `val/ref_contain_eos_token`: The amount of times the reference's output contains the eos token. +* `alpha`: The weight of the XPO loss term. Typically fixed, but can be made dynamic by passing a list to [`experimental.xpo.XPOConfig`]. +* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`experimental.xpo.XPOConfig`]. + +## XPOTrainer + +[[autodoc]] experimental.xpo.XPOTrainer + - train + - save_model + - push_to_hub + +## XPOConfig + +[[autodoc]] experimental.xpo.XPOConfig diff --git a/ICL/RL/trl_source/tests/data/template.jinja b/ICL/RL/trl_source/tests/data/template.jinja new file mode 100644 index 0000000000000000000000000000000000000000..01be9b307daa2d425f7c168c9fb145a286e0afb4 --- /dev/null +++ b/ICL/RL/trl_source/tests/data/template.jinja @@ -0,0 +1,89 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/ICL/RL/trl_source/tests/distributed/__init__.py b/ICL/RL/trl_source/tests/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d26f4482fe3969b6de2b51bd1f85fc16dad6a65 --- /dev/null +++ b/ICL/RL/trl_source/tests/distributed/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/ICL/RL/trl_source/tests/distributed/data/accelerate_configs/ddp.yaml b/ICL/RL/trl_source/tests/distributed/data/accelerate_configs/ddp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5363d79686c57ac0e2b1a1e6f3f53a3b1eec102 --- /dev/null +++ b/ICL/RL/trl_source/tests/distributed/data/accelerate_configs/ddp.yaml @@ -0,0 +1,2 @@ +distributed_type: MULTI_GPU +num_processes: 2 \ No newline at end of file diff --git a/ICL/RL/trl_source/tests/distributed/data/accelerate_configs/fsdp2.yaml b/ICL/RL/trl_source/tests/distributed/data/accelerate_configs/fsdp2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..274141f25833636ba2781193232d9c13d7d846c1 --- /dev/null +++ b/ICL/RL/trl_source/tests/distributed/data/accelerate_configs/fsdp2.yaml @@ -0,0 +1,4 @@ +distributed_type: FSDP +fsdp_config: + fsdp_version: 2 +num_processes: 2 \ No newline at end of file diff --git a/ICL/RL/trl_source/tests/distributed/data/accelerate_configs/zero2.yaml b/ICL/RL/trl_source/tests/distributed/data/accelerate_configs/zero2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2c51f5754cb3a51a96b9a94bc3f95ab69418ead0 --- /dev/null +++ b/ICL/RL/trl_source/tests/distributed/data/accelerate_configs/zero2.yaml @@ -0,0 +1,4 @@ +distributed_type: DEEPSPEED +deepspeed_config: + zero_stage: 2 +num_processes: 2 \ No newline at end of file diff --git a/ICL/RL/trl_source/tests/distributed/data/accelerate_configs/zero3.yaml b/ICL/RL/trl_source/tests/distributed/data/accelerate_configs/zero3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6ca9200eadb557b626d238c49f0dc153a54f1709 --- /dev/null +++ b/ICL/RL/trl_source/tests/distributed/data/accelerate_configs/zero3.yaml @@ -0,0 +1,4 @@ +distributed_type: DEEPSPEED +deepspeed_config: + zero_stage: 3 +num_processes: 2 \ No newline at end of file diff --git a/ICL/RL/trl_source/tests/distributed/test_distributed.py b/ICL/RL/trl_source/tests/distributed/test_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..a01928ec77ffc001330c92179c116ca160d43ed1 --- /dev/null +++ b/ICL/RL/trl_source/tests/distributed/test_distributed.py @@ -0,0 +1,283 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +from pathlib import Path + +import pytest +import torch +import transformers +from packaging.version import Version + +from ..testing_utils import TrlTestCase, require_torch_multi_accelerator + + +ROOT = Path(__file__).resolve().parents[2] + + +def run_command(command: list[str], env: dict[str, str]) -> None: + result = subprocess.run(command, env=env, cwd=ROOT) + assert result.returncode == 0 + + +@pytest.fixture +def get_config_path(lazy_shared_datadir): + def _get_config_path(config_name): + return lazy_shared_datadir / "accelerate_configs" / f"{config_name}.yaml" + + return _get_config_path + + +@require_torch_multi_accelerator +class TestDistributed( + TrlTestCase +): # pytest.param("zero3", marks=pytest.mark.xfail(reason="ZeRO 3 is currently failing, see #4899")) + @pytest.mark.parametrize( + "config", + [ + "ddp", + pytest.param( + "zero2", + marks=pytest.mark.xfail( + Version(transformers.__version__) == Version("5.1.0"), + reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)", + ), + ), + pytest.param( + "zero3", + marks=pytest.mark.xfail( + Version(transformers.__version__) == Version("5.1.0"), + reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)", + ), + ), + "fsdp2", + ], + ) + def test_sft(self, config, get_config_path): + # fmt: off + run_command( + [ + "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/sft.py", + "--output_dir", self.tmp_dir, + "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "--dataset_name", "trl-internal-testing/zen", + "--dataset_config", "standard_language_modeling", + ], + os.environ.copy(), + ) + # fmt: on + + @pytest.mark.parametrize( + "config", + [ + "ddp", + pytest.param( + "zero2", + marks=pytest.mark.xfail( + Version(transformers.__version__) == Version("5.1.0"), + reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)", + ), + ), + pytest.param( + "zero3", + marks=pytest.mark.xfail( + Version(transformers.__version__) == Version("5.1.0"), + reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)", + ), + ), + pytest.param("fsdp2", marks=pytest.mark.xfail(reason="FSDP2 DPO is currently failing, see #4812")), + ], + ) + def test_dpo(self, config, get_config_path): + # fmt: off + run_command( + [ + "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/dpo.py", + "--output_dir", self.tmp_dir, + "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "--dataset_name", "trl-internal-testing/zen", + "--dataset_config", "standard_preference", + ], + os.environ.copy(), + ) + # fmt: on + + @pytest.mark.parametrize( + "config", + [ + "ddp", + pytest.param( + "zero2", + marks=pytest.mark.xfail( + Version(transformers.__version__) == Version("5.1.0"), + reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)", + ), + ), + pytest.param( + "zero3", + marks=pytest.mark.xfail( + Version(transformers.__version__) == Version("5.1.0"), + reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)", + ), + ), + "fsdp2", + ], + ) + def test_sft_dataset_streaming(self, config, get_config_path): + # fmt: off + run_command( + [ + "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/sft.py", + "--output_dir", self.tmp_dir, + "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "--dataset_name", "trl-internal-testing/zen", + "--dataset_config", "standard_language_modeling", + "--dataset_streaming", + "--max_steps", "3", + ], + os.environ.copy(), + ) + # fmt: on + + @pytest.mark.parametrize( + "config", + [ + "ddp", + pytest.param( + "zero2", + marks=pytest.mark.xfail( + condition=Version("2.10") <= Version(torch.__version__), + reason="ZeRO 2 + PEFT is failing on torch 2.10; see #4884", + ), + ), + pytest.param( + "zero3", + marks=pytest.mark.xfail( + condition=Version("2.10") <= Version(torch.__version__), + reason="ZeRO 3 + PEFT is failing on torch 2.10; see #4884", + ), + ), + "fsdp2", + ], + ) + def test_sft_peft(self, config, get_config_path): + # fmt: off + run_command( + [ + "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/sft.py", + "--output_dir", self.tmp_dir, + "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "--dataset_name", "trl-internal-testing/zen", + "--dataset_config", "standard_language_modeling", + "--use_peft", + ], + os.environ.copy(), + ) + # fmt: on + + @pytest.mark.parametrize( + "config", + [ + "ddp", + pytest.param( + "zero2", + marks=pytest.mark.xfail( + Version(transformers.__version__) == Version("5.1.0"), + reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)", + ), + ), + pytest.param( + "zero3", + marks=pytest.mark.xfail( + Version(transformers.__version__) == Version("5.1.0"), + reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)", + ), + ), + "fsdp2", + ], + ) + def test_reward(self, config, get_config_path): + # fmt: off + run_command( + [ + "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/reward.py", + "--output_dir", self.tmp_dir, + "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "--dataset_name", "trl-internal-testing/zen", + "--dataset_config", "conversational_implicit_prompt_preference", + ], + os.environ.copy(), + ) + # fmt: on + + @pytest.mark.parametrize( + "config", + [ + "ddp", + pytest.param( + "zero2", + marks=pytest.mark.xfail( + Version(transformers.__version__) == Version("5.1.0"), + reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)", + ), + ), + pytest.param("zero3", marks=pytest.mark.xfail(reason="ZeRO 3 is currently failing, see #4899")), + pytest.param("fsdp2", marks=pytest.mark.xfail(reason="FSDP2 RLOO is currently failing, see #4854")), + ], + ) + def test_rloo(self, config, get_config_path): + # fmt: off + run_command( + [ + "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/rloo.py", + "--output_dir", self.tmp_dir, + "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "--dataset_name", "trl-internal-testing/zen", + "--dataset_config", "conversational_prompt_only", + "--reward_model_name_or_path", "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + ], + os.environ.copy(), + ) + # fmt: on + + @pytest.mark.parametrize( + "config", + [ + "ddp", + pytest.param( + "zero2", + marks=pytest.mark.xfail( + Version(transformers.__version__) == Version("5.1.0"), + reason="Upstream incompatibility: deepspeed and transformers==5.1.0 (see transformers#43780)", + ), + ), + pytest.param("zero3", marks=pytest.mark.xfail(reason="ZeRO 3 is currently failing, see #4899")), + "fsdp2", + ], + ) + def test_grpo(self, config, get_config_path): + # fmt: off + run_command( + [ + "accelerate", "launch", "--config_file", get_config_path(config), "trl/scripts/grpo.py", + "--output_dir", self.tmp_dir, + "--model_name_or_path", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "--dataset_name", "trl-internal-testing/zen", + "--dataset_config", "conversational_prompt_only", + "--reward_model_name_or_path", "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + ], + os.environ.copy(), + ) + # fmt: on diff --git a/ICL/RL/trl_source/tests/experimental/__init__.py b/ICL/RL/trl_source/tests/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2777dd0eb21a0ea67dce8775337a27ea8499da2 --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ICL/RL/trl_source/tests/experimental/test_bco_trainer.py b/ICL/RL/trl_source/tests/experimental/test_bco_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..34f42967c3352d1c27329e9d2eedba3863285b33 --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_bco_trainer.py @@ -0,0 +1,442 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import pytest +import torch +from accelerate import Accelerator +from datasets import load_dataset +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers.utils import is_peft_available + +from trl.experimental.bco import BCOConfig, BCOTrainer +from trl.experimental.bco.bco_trainer import _process_tokens, _tokenize + +from ..testing_utils import TrlTestCase, require_no_wandb, require_peft, require_sklearn + + +if is_peft_available(): + from peft import LoraConfig + + +@pytest.mark.low_priority +class TestBCOTrainer(TrlTestCase): + @pytest.mark.parametrize( + "config_name", + [ + "standard_preference", + "standard_implicit_prompt_preference", + "standard_unpaired_preference", + "conversational_preference", + "conversational_implicit_prompt_preference", + "conversational_unpaired_preference", + ], + ) + @require_sklearn + def test_train(self, config_name): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", config_name, split="train") + + training_args = BCOConfig( + output_dir=self.tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.equal(param.cpu(), new_param.cpu()) + + @require_sklearn + def test_train_with_precompute(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + + training_args = BCOConfig( + output_dir=self.tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + precompute_ref_log_probs=True, + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.equal(param.cpu(), new_param.cpu()) + + @require_sklearn + def test_train_eval(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + training_args = BCOConfig( + output_dir=self.tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + eval_strategy="steps", + eval_steps=3, + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + trainer.train() + + @require_sklearn + def test_init_with_ref_model_is_model(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + + training_args = BCOConfig( + output_dir=self.tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + report_to="none", + ) + + with pytest.raises(ValueError): + BCOTrainer( + model=model, + ref_model=model, # ref_model can't be the same as model + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + @require_sklearn + def test_tokenize_and_process_tokens(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + + training_args = BCOConfig( + output_dir=self.tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + tokenized_dataset = dataset.map( + _tokenize, + fn_kwargs={"tokenizer": trainer.processing_class}, + batched=True, + batch_size=2, + ) + assert tokenized_dataset["prompt"][:] == dataset["prompt"][:] + assert tokenized_dataset["completion"][:] == dataset["completion"][:] + assert tokenized_dataset["label"][:] == dataset["label"][:] + assert tokenized_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] + assert tokenized_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] + assert tokenized_dataset["answer_input_ids"][0] == [27261, 13] + assert tokenized_dataset["answer_attention_mask"][0] == [1, 1] + + fn_kwargs = { + "prefix": "", + "is_encoder_decoder": trainer.is_encoder_decoder, + "tokenizer": trainer.processing_class, + "max_length": trainer.max_length, + "truncation_mode": trainer.truncation_mode, + } + processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs) + assert processed_dataset["prompt"][:] == dataset["prompt"][:] + assert processed_dataset["completion"][:] == dataset["completion"][:] + assert processed_dataset["label"][:] == dataset["label"][:] + assert processed_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] + assert processed_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] + assert processed_dataset["completion_input_ids"][0] == [46518, 374, 2664, 1091, 27261, 13, 151645] + assert processed_dataset["completion_attention_mask"][0] == [1, 1, 1, 1, 1, 1, 1] + assert processed_dataset["completion_labels"][0] == [-100, -100, -100, -100, 27261, 13, 151645] + + @require_sklearn + def test_train_without_providing_ref_model(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + + training_args = BCOConfig( + output_dir=self.tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.equal(param.cpu(), new_param.cpu()) + + @require_sklearn + def test_train_udm(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # Get embedding model + embedding_model_id = "trl-internal-testing/tiny-BartModel" + embedding_model = AutoModel.from_pretrained(embedding_model_id) + embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_id) + + def embed_prompt(input_ids, attention_mask, model): + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + + return outputs.last_hidden_state.mean(dim=1) + + embedding_model = Accelerator().prepare_model(embedding_model) + embedding_func = partial(embed_prompt, model=embedding_model) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + + training_args = BCOConfig( + output_dir=self.tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + embedding_func=embedding_func, + embedding_tokenizer=embedding_tokenizer, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.equal(param.cpu(), new_param.cpu()) + + @require_sklearn + @require_peft + def test_train_without_providing_ref_model_with_lora(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, task_type="CAUSAL_LM") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + + training_args = BCOConfig( + output_dir=self.tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.equal(param.cpu(), new_param.cpu()) + + @require_sklearn + @require_no_wandb + def test_generate_during_eval_no_wandb(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + training_args = BCOConfig( + output_dir=self.tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + eval_strategy="steps", + eval_steps=3, + generate_during_eval=True, + report_to="none", + ) + + with pytest.raises( + ValueError, + match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve.", + ): + BCOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + @require_sklearn + @require_peft + def test_lora_train_and_save(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, task_type="CAUSAL_LM") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + training_args = BCOConfig( + output_dir=self.tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset["train"], + peft_config=lora_config, + ) + + # train the model + trainer.train() + + # save peft adapter + trainer.save_model() + + # assert that the model is loaded without giving OSError + AutoModelForCausalLM.from_pretrained(self.tmp_dir) + + @require_sklearn + def test_compute_metrics(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + def dummy_compute_metrics(*args, **kwargs): + return {"test": 0.0} + + training_args = BCOConfig( + output_dir=self.tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + eval_strategy="steps", + eval_steps=3, + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + compute_metrics=dummy_compute_metrics, + ) + + trainer.train() + + assert trainer.state.log_history[-2]["eval_test"] == 0.0 diff --git a/ICL/RL/trl_source/tests/experimental/test_cpo_trainer.py b/ICL/RL/trl_source/tests/experimental/test_cpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..4bad5e604f4c54ea28f537f82b65dd5ba050fe6a --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_cpo_trainer.py @@ -0,0 +1,216 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer + +from trl.experimental.cpo import CPOConfig, CPOTrainer + +from ..testing_utils import TrlTestCase, require_peft + + +class TestCPOTrainer(TrlTestCase): + def setup_method(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id, dtype="float32") + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + # get t5 as seq2seq example: + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration" + self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, dtype="float32") + self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + + @pytest.mark.parametrize( + "name, loss_type, config_name", + [ + ("qwen", "sigmoid", "standard_preference"), + ("t5", "hinge", "standard_implicit_prompt_preference"), + ("qwen", "ipo", "conversational_preference"), + ("qwen", "simpo", "standard_preference"), + ("t5", "simpo", "standard_implicit_prompt_preference"), + ("qwen", "hinge", "conversational_preference"), + ], + ) + def test_cpo_trainer(self, name, loss_type, config_name): + training_args = CPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + loss_type=loss_type, + cpo_alpha=1.0, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + if name == "qwen": + model = self.model + tokenizer = self.tokenizer + elif name == "t5": + model = self.t5_model + tokenizer = self.t5_tokenizer + training_args.is_encoder_decoder = True + + trainer = CPOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.equal(param, new_param) + + @pytest.mark.parametrize( + "config_name", + [ + "standard_preference", + "standard_implicit_prompt_preference", + "conversational_preference", + "conversational_implicit_prompt_preference", + ], + ) + @require_peft + def test_cpo_trainer_with_lora(self, config_name): + from peft import LoraConfig + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + training_args = CPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + cpo_alpha=1.0, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = CPOTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.equal(param, new_param) + + def test_compute_metrics(self): + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + def dummy_compute_metrics(*args, **kwargs): + return {"test": 0.0} + + training_args = CPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + remove_unused_columns=False, + do_eval=True, + eval_strategy="steps", + eval_steps=1, + per_device_eval_batch_size=2, + report_to="none", + ) + + trainer = CPOTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + compute_metrics=dummy_compute_metrics, + ) + + trainer.train() + + assert trainer.state.log_history[-2]["eval_test"] == 0.0 + + def test_alphapo_trainer(self): + training_args = CPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + loss_type="alphapo", + alpha=0.5, + simpo_gamma=0.5, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = CPOTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: + assert not torch.equal(param, new_param) diff --git a/ICL/RL/trl_source/tests/experimental/test_gkd_trainer.py b/ICL/RL/trl_source/tests/experimental/test_gkd_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..b4589ef14538de462d1f7f592bd8cc2a9d343386 --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_gkd_trainer.py @@ -0,0 +1,282 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch +import torch.nn.functional as F +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from trl.experimental.gkd import GKDConfig, GKDTrainer + +from ..testing_utils import TrlTestCase, require_liger_kernel + + +class TestGKDTrainerGenerateOnPolicy(TrlTestCase): + @classmethod + def setup_class(cls): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + cls.device = "cuda" if torch.cuda.is_available() else "cpu" + cls.tokenizer = AutoTokenizer.from_pretrained(model_id) + cls.tokenizer.pad_token = cls.tokenizer.eos_token + cls.model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32").to(cls.device) + cls.generation_config = GenerationConfig( + max_new_tokens=20, + num_return_sequences=1, + pad_token_id=cls.tokenizer.pad_token_id, + eos_token_id=cls.tokenizer.eos_token_id, + ) + + def test_generate_on_policy_outputs_deterministic(self): + prompts = ["Hello, how are you?", "What's the weather like today?"] + tokenized_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True) + + inputs = { + "prompts": tokenized_prompts["input_ids"].to(self.device), + "prompt_attention_mask": tokenized_prompts["attention_mask"].to(self.device), + } + + # Set temperature to 0 for deterministic output + deterministic_generation_config = GenerationConfig( + max_new_tokens=30, + num_return_sequences=1, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + do_sample=False, + temperature=0.0, + ) + + outputs = GKDTrainer.generate_on_policy_outputs( + self.model, inputs, deterministic_generation_config, self.tokenizer.pad_token_id + ) + + new_input_ids, new_attention_mask, new_labels = outputs + + # Decode the generated outputs + generated_texts = self.tokenizer.batch_decode(new_input_ids, skip_special_tokens=True) + + # Check if the generated texts start with the original prompts + for prompt, generated_text in zip(prompts, generated_texts, strict=True): + assert generated_text.startswith(prompt), ( + f"Generated text '{generated_text}' does not start with prompt '{prompt}'" + ) + + # Run the generation twice and check if the outputs are identical + outputs2 = GKDTrainer.generate_on_policy_outputs( + self.model, inputs, deterministic_generation_config, self.tokenizer.pad_token_id + ) + + new_input_ids2, new_attention_mask2, new_labels2 = outputs2 + + # Check if the two generations are identical + assert torch.all(new_input_ids.eq(new_input_ids2)), "Deterministic generations are not identical" + assert torch.all(new_attention_mask.eq(new_attention_mask2)), ( + "Attention masks for deterministic generations are not identical" + ) + assert torch.all(new_labels.eq(new_labels2)), "Labels for deterministic generations are not identical" + + def test_generate_on_policy_outputs(self): + prompts = ["Hello, how are you?", "What's the weather like today?"] + tokenized_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True) + + inputs = { + "prompts": tokenized_prompts["input_ids"].to(self.device), + "attention_mask": tokenized_prompts["attention_mask"].to(self.device), + } + + outputs = GKDTrainer.generate_on_policy_outputs( + self.model, inputs, self.generation_config, self.tokenizer.pad_token_id + ) + + # Check that outputs is a tuple of three tensors + assert isinstance(outputs, tuple) + assert len(outputs) == 3 + + new_input_ids, new_attention_mask, new_labels = outputs + + # Check shapes + batch_size = len(prompts) + assert new_input_ids.shape[0] == batch_size + assert new_attention_mask.shape[0] == batch_size + assert new_labels.shape[0] == batch_size + + # Check types + assert isinstance(new_input_ids, torch.Tensor) + assert isinstance(new_attention_mask, torch.Tensor) + assert isinstance(new_labels, torch.Tensor) + + # Check that new_input_ids and new_attention_mask have the same shape + assert new_input_ids.shape == new_attention_mask.shape + assert new_labels.shape == new_attention_mask.shape + + +class TestGeneralizedJSDLoss(TrlTestCase): + def setup_method(self): + self.batch_size = 2 + self.seq_length = 3 + self.vocab_size = 5 + self.student_logits = torch.randn(self.batch_size, self.seq_length, self.vocab_size) + self.teacher_logits = torch.randn(self.batch_size, self.seq_length, self.vocab_size) + + def test_uniform_distribution(self): + logits = torch.ones(1, 1, self.vocab_size) + loss = GKDTrainer.generalized_jsd_loss(logits, logits) + assert round(abs(loss.item() - 0), 5) == 0 + + def test_generalized_jsd_loss_edge_cases(self): + # Setup + student_logits = torch.log(torch.tensor([[0.1, 0.9]])).unsqueeze(0) + teacher_logits = torch.log(torch.tensor([[0.9, 0.1]])).unsqueeze(0) + + # Case 1: beta = 1 (should be equivalent to KL(student || teacher)) + loss_beta_1 = GKDTrainer.generalized_jsd_loss(student_logits, teacher_logits, beta=1) + expected_loss_beta_1 = F.kl_div( + F.log_softmax(teacher_logits, dim=-1), F.softmax(student_logits, dim=-1), reduction="batchmean" + ) + assert round(abs(loss_beta_1.item() - expected_loss_beta_1.item()), 5) == 0 + + # Case 2: beta = 0 (should be equivalent to KL(teacher || student)) + loss_beta_0 = GKDTrainer.generalized_jsd_loss(student_logits, teacher_logits, beta=0) + expected_loss_beta_0 = F.kl_div( + F.log_softmax(student_logits, dim=-1), F.softmax(teacher_logits, dim=-1), reduction="batchmean" + ) + assert round(abs(loss_beta_0.item() - expected_loss_beta_0.item()), 5) == 0 + + def test_output_shape(self): + loss = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits) + assert torch.is_tensor(loss) + assert loss.shape == torch.Size([]) + + def test_beta_values(self): + loss_beta_0 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0) + loss_beta_1 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=1) + assert loss_beta_0 != loss_beta_1 + + def test_temperature_scaling(self): + loss_temp_1 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, temperature=1) + loss_temp_2 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, temperature=2) + assert loss_temp_1 != loss_temp_2 + + def test_reduction_methods(self): + loss_batchmean = GKDTrainer.generalized_jsd_loss( + self.student_logits, self.teacher_logits, reduction="batchmean" + ) + loss_sum = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, reduction="sum") + loss_mean = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, reduction="mean") + loss_none = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, reduction="none") + + assert loss_batchmean.shape == torch.Size([]) + assert loss_sum.shape == torch.Size([]) + assert loss_mean.shape == torch.Size([]) + assert loss_none.shape == self.student_logits.shape + + def test_symmetry(self): + student_teacher = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0.1) + teacher_student = GKDTrainer.generalized_jsd_loss(self.teacher_logits, self.student_logits, beta=0.1) + assert student_teacher != teacher_student + + student_teacher = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0.5) + teacher_student = GKDTrainer.generalized_jsd_loss(self.teacher_logits, self.student_logits, beta=0.5) + assert student_teacher == teacher_student + + def test_zero_loss_for_identical_inputs(self): + identical_logits = torch.randn(self.batch_size, self.seq_length, self.vocab_size) + loss = GKDTrainer.generalized_jsd_loss(identical_logits, identical_logits) + assert round(abs(loss.item() - 0), 6) == 0 + + +class TestGKDTrainer(TrlTestCase): + def setup_method(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id, dtype="float32") + self.teacher_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + def test_gkd_trainer(self): + training_args = GKDConfig( + output_dir=self.tmp_dir, + dataloader_drop_last=True, + eval_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + per_device_eval_batch_size=2, + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + trainer = GKDTrainer( + model=self.model_id, + teacher_model=self.model_id, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + ) + + trainer.train() + + assert trainer.state.log_history[(-1)]["train_loss"] is not None + assert trainer.state.log_history[0]["eval_loss"] is not None + assert "model.safetensors" in os.listdir(self.tmp_dir + "/checkpoint-2") + + @require_liger_kernel + @pytest.mark.xfail(reason="Computing the Liger loss spikes GPU memory usage, causing the test to run OOM.") + def test_gkd_trainer_with_liger(self): + training_args = GKDConfig( + output_dir=self.tmp_dir, + report_to="none", + use_liger_kernel=True, + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + trainer = GKDTrainer( + model=self.model_id, + teacher_model=self.model_id, + args=training_args, + train_dataset=dummy_dataset["train"], + processing_class=self.tokenizer, + ) + + # Ensure liger fused JSD path is enabled; if not, skip (runtime may lack system libs) + if not getattr(trainer, "use_liger_gkd_loss", False): + pytest.skip("Liger fused JSD not enabled at runtime; skipping fused-loss assertion") + + trainer.train() + + # Check we logged a train loss + assert trainer.state.log_history[-1]["train_loss"] is not None + + def test_generation_config_init(self): + training_args = GKDConfig(output_dir=self.tmp_dir) + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + trainer = GKDTrainer( + model=self.model_id, + teacher_model=self.model_id, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + ) + + assert trainer.generation_config.pad_token_id == self.tokenizer.eos_token_id + assert trainer.generation_config.eos_token_id == self.model.generation_config.eos_token_id + assert trainer.generation_config.max_new_tokens == training_args.max_new_tokens + assert trainer.generation_config.temperature == training_args.temperature + assert trainer.generation_config.top_k == 0 diff --git a/ICL/RL/trl_source/tests/experimental/test_gold_trainer.py b/ICL/RL/trl_source/tests/experimental/test_gold_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f48f3464654adf9b919dc1c57ef396433c245545 --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_gold_trainer.py @@ -0,0 +1,643 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +import pytest +import torch +from datasets import load_dataset +from transformers import AutoTokenizer + +from trl.experimental.gold.gold_trainer import GOLDTrainer, ULDLoss, build_teacher_inputs_from_texts +from trl.experimental.utils import DataCollatorForChatML + + +@pytest.fixture(scope="module") +def openr1_examples(): + try: + dataset = load_dataset( + "HuggingFaceTB/OpenR1-Math-220k-default-verified", + "all", + split="train[:3]", + ) + except Exception as exc: # pragma: no cover - network/environment dependent + pytest.skip(f"OpenR1 dataset unavailable: {exc}") + return [{"messages": row["messages"]} for row in dataset] + + +@pytest.fixture(scope="module") +def countdown_examples(): + try: + dataset = load_dataset( + "HuggingFaceTB/Countdown-Tasks-3to4", + "gkd_verified_Qwen2.5-7B-Instruct", + split="train[:3]", + ) + except Exception as exc: # pragma: no cover - network/environment dependent + pytest.skip(f"Countdown dataset unavailable: {exc}") + return [{"messages": row["messages"]} for row in dataset] + + +def _teacher_inputs_from_collator(student_tok, teacher_tok, batch): + prompt_texts = [] + completion_texts = [] + + pad_token_id = student_tok.pad_token_id + for prompt_ids_tensor, input_ids_tensor, labels_tensor in zip( + batch["prompts"], batch["input_ids"], batch["labels"], strict=True + ): + prompt_ids = prompt_ids_tensor.tolist() + if pad_token_id is not None: + prompt_ids = [tok for tok in prompt_ids if tok != pad_token_id] + prompt_texts.append(student_tok.decode(prompt_ids, skip_special_tokens=False)) + + input_ids = input_ids_tensor.tolist() + labels = labels_tensor.tolist() + completion_token_ids = [tok for tok, label in zip(input_ids, labels, strict=True) if label != -100] + completion_texts.append(student_tok.decode(completion_token_ids, skip_special_tokens=False)) + + teacher_input_ids, teacher_labels, _, _ = build_teacher_inputs_from_texts( + teacher_tok, prompt_texts, completion_texts + ) + return teacher_input_ids, teacher_labels, completion_texts + + +def _assert_alignment_covers_completion(loss_fn, batch, teacher_input_ids, teacher_labels): + for idx in range(batch["input_ids"].shape[0]): + student_mask = batch["attention_mask"][idx].bool() + student_ids = batch["input_ids"][idx][student_mask] + student_labels = batch["labels"][idx][student_mask] + student_answer_ids = student_ids[student_labels != -100].tolist() + + teacher_answer_mask = teacher_labels[idx] != -100 + teacher_answer_ids = teacher_input_ids[idx][teacher_answer_mask].tolist() + + student_groups, teacher_groups = loss_fn._build_alignment_groups_from_ids( + student_answer_ids, teacher_answer_ids + ) + + assert student_groups, "Student alignment groups must not be empty" + assert teacher_groups, "Teacher alignment groups must not be empty" + assert sorted(idx for group in student_groups for idx in group) == list(range(len(student_answer_ids))) + assert sorted(idx for group in teacher_groups for idx in group) == list(range(len(teacher_answer_ids))) + + +@pytest.mark.slow +def test_chatml_collator_preserves_completion_llama(llama_tokenizer, qwen_tokenizer, openr1_examples): + collator = DataCollatorForChatML(tokenizer=llama_tokenizer, max_length=512) + batch = collator(openr1_examples) + + assistant_texts = [example["messages"][-1]["content"] for example in openr1_examples] + decoded_batch = llama_tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False) + for decoded, assistant in zip(decoded_batch, assistant_texts, strict=True): + assert assistant.strip() in decoded + + teacher_input_ids, teacher_labels, completion_texts = _teacher_inputs_from_collator( + llama_tokenizer, qwen_tokenizer, batch + ) + for completion, assistant in zip(completion_texts, assistant_texts, strict=True): + assert assistant.strip() in completion + assert completion.strip() + + config = build_config( + uld_use_hybrid_loss=True, + uld_hybrid_matched_weight=0.6, + uld_hybrid_unmatched_weight=0.4, + ) + loss_fn = ULDLoss(config, student_tokenizer=llama_tokenizer, teacher_tokenizer=qwen_tokenizer) + + _assert_alignment_covers_completion(loss_fn, batch, teacher_input_ids, teacher_labels) + + torch.manual_seed(0) + student_vocab = len(llama_tokenizer) + teacher_vocab = len(qwen_tokenizer) + batch_size, seq_len = batch["input_ids"].shape + student_logits = torch.randn(batch_size, seq_len, student_vocab) + teacher_logits = torch.randn(batch_size, teacher_input_ids.shape[1], teacher_vocab) + + loss = loss_fn( + student_logits=student_logits, + teacher_logits=teacher_logits, + student_labels=batch["labels"], + teacher_labels=teacher_labels, + student_input_ids=batch["input_ids"], + teacher_input_ids=teacher_input_ids, + ) + + assert torch.isfinite(loss) + + +@pytest.mark.slow +def test_chatml_collator_preserves_completion_llama_countdown(llama_tokenizer, qwen_tokenizer, countdown_examples): + collator = DataCollatorForChatML(tokenizer=llama_tokenizer, max_length=512) + batch = collator(countdown_examples) + + assistant_texts = [example["messages"][-1]["content"] for example in countdown_examples] + decoded_batch = llama_tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False) + for decoded, assistant in zip(decoded_batch, assistant_texts, strict=True): + assert assistant.strip() in decoded + + teacher_input_ids, teacher_labels, completion_texts = _teacher_inputs_from_collator( + llama_tokenizer, qwen_tokenizer, batch + ) + for completion, assistant in zip(completion_texts, assistant_texts, strict=True): + assert assistant.strip() in completion + assert completion.strip() + + config = build_config( + uld_use_hybrid_loss=True, + uld_hybrid_matched_weight=0.6, + uld_hybrid_unmatched_weight=0.4, + ) + loss_fn = ULDLoss(config, student_tokenizer=llama_tokenizer, teacher_tokenizer=qwen_tokenizer) + + _assert_alignment_covers_completion(loss_fn, batch, teacher_input_ids, teacher_labels) + + torch.manual_seed(2) + student_vocab = len(llama_tokenizer) + teacher_vocab = len(qwen_tokenizer) + batch_size, seq_len = batch["input_ids"].shape + student_logits = torch.randn(batch_size, seq_len, student_vocab) + teacher_logits = torch.randn(batch_size, teacher_input_ids.shape[1], teacher_vocab) + + loss = loss_fn( + student_logits=student_logits, + teacher_logits=teacher_logits, + student_labels=batch["labels"], + teacher_labels=teacher_labels, + student_input_ids=batch["input_ids"], + teacher_input_ids=teacher_input_ids, + ) + + assert torch.isfinite(loss) + + +@pytest.mark.slow +def test_chatml_collator_preserves_completion_smollm(smollm_tokenizer, qwen_tokenizer, openr1_examples): + collator = DataCollatorForChatML(tokenizer=smollm_tokenizer, max_length=512) + batch = collator(openr1_examples) + + assistant_texts = [example["messages"][-1]["content"] for example in openr1_examples] + decoded_batch = smollm_tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False) + for decoded, assistant in zip(decoded_batch, assistant_texts, strict=True): + assert assistant.strip() in decoded + + teacher_input_ids, teacher_labels, completion_texts = _teacher_inputs_from_collator( + smollm_tokenizer, qwen_tokenizer, batch + ) + for completion, assistant in zip(completion_texts, assistant_texts, strict=True): + assert assistant.strip() in completion + assert completion.strip() + + config = build_config( + uld_use_hybrid_loss=True, + uld_hybrid_matched_weight=0.5, + uld_hybrid_unmatched_weight=0.5, + ) + loss_fn = ULDLoss(config, student_tokenizer=smollm_tokenizer, teacher_tokenizer=qwen_tokenizer) + + _assert_alignment_covers_completion(loss_fn, batch, teacher_input_ids, teacher_labels) + + torch.manual_seed(1) + student_vocab = len(smollm_tokenizer) + teacher_vocab = len(qwen_tokenizer) + batch_size, seq_len = batch["input_ids"].shape + student_logits = torch.randn(batch_size, seq_len, student_vocab) + teacher_logits = torch.randn(batch_size, teacher_input_ids.shape[1], teacher_vocab) + + loss = loss_fn( + student_logits=student_logits, + teacher_logits=teacher_logits, + student_labels=batch["labels"], + teacher_labels=teacher_labels, + student_input_ids=batch["input_ids"], + teacher_input_ids=teacher_input_ids, + ) + + assert torch.isfinite(loss) + + +def build_config(**overrides): + base = dict( + uld_crossentropy_weight=0.0, + uld_distillation_weight=1.0, + uld_student_temperature=1.0, + uld_teacher_temperature=1.0, + uld_skip_student_eos=False, + uld_skip_teacher_eos=False, + use_extended_uld=True, + uld_use_hybrid_loss=False, + uld_hybrid_matched_weight=None, + uld_hybrid_unmatched_weight=None, + beta=0.5, + ) + base.update(overrides) + return SimpleNamespace(**base) + + +@pytest.fixture(scope="session") +def llama_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +@pytest.fixture(scope="session") +def qwen_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +@pytest.fixture(scope="session") +def smollm_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B") + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +def encode_prompt_completion(tokenizer, prompt, completion): + prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"] + completion_ids = tokenizer(completion, add_special_tokens=False)["input_ids"] + eos_id = tokenizer.eos_token_id + if eos_id is not None: + completion_ids = completion_ids + [eos_id] + input_ids = prompt_ids + completion_ids + labels = [-100] * len(prompt_ids) + completion_ids + return input_ids, labels + + +def pad_tokens(ids, pad_id, target_length): + return ids + [pad_id] * (target_length - len(ids)) + + +def pad_labels(labels, target_length): + return labels + [-100] * (target_length - len(labels)) + + +def test_alignment_groups_cover_all_tokens(llama_tokenizer, qwen_tokenizer): + config = build_config() + loss = ULDLoss(config, student_tokenizer=llama_tokenizer, teacher_tokenizer=qwen_tokenizer) + + text = "SmolLM3-3B is smaller than Llama 3.2 but still capable." + student_ids = llama_tokenizer(text, add_special_tokens=False)["input_ids"] + teacher_ids = qwen_tokenizer(text, add_special_tokens=False)["input_ids"] + + student_groups, teacher_groups = loss._build_alignment_groups_from_ids(student_ids, teacher_ids) + + assert len(student_groups) == len(teacher_groups) + assert sorted(idx for group in student_groups for idx in group) == list(range(len(student_ids))) + assert sorted(idx for group in teacher_groups for idx in group) == list(range(len(teacher_ids))) + + +def test_merge_probabilities_multiplies_split_tokens(): + config = build_config() + # Use simple 3-token vocabulary to validate merging behaviour + # probs[0] = P(token | context) at position 0 for all vocab tokens + # probs[1] = P(token | context) at position 1 for all vocab tokens + probs = torch.tensor([[0.6, 0.3, 0.1], [0.2, 0.5, 0.3]]) + loss = ULDLoss(config, student_tokenizer=None, teacher_tokenizer=None) + + # token_ids[1] = 1 means the actual token at position 1 is token ID 1 + # So we should extract P(token_id=1 | ...) = probs[1, 1] = 0.5 + token_ids = [0, 1] # Actual generated tokens + + merged = loss._merge_probabilities_with_alignment_groups(probs, [[0, 1]], token_ids=token_ids) + + # Expected: P_merged(y) = P(y | context_0) × P(token_1=1 | context_1) + # For each vocab token y, multiply marginal prob at pos 0 by scalar conditional prob of actual token at pos 1 + expected = probs[0] * probs[1, 1] # probs[1, 1] = 0.5 + # Expected unnormalized: [0.6 * 0.5, 0.3 * 0.5, 0.1 * 0.5] = [0.3, 0.15, 0.05] + + torch.testing.assert_close(merged[0], expected) + + +def test_initialize_vocabulary_mapping_contains_common_tokens(llama_tokenizer, qwen_tokenizer): + config = build_config( + uld_use_hybrid_loss=True, + uld_hybrid_matched_weight=1.0, + uld_hybrid_unmatched_weight=0.0, + ) + loss = ULDLoss(config, student_tokenizer=llama_tokenizer, teacher_tokenizer=qwen_tokenizer) + + common_tokens = ["Hello", "world", "-", "ol", "LM", "3", "B"] + for token in common_tokens: + student_id = llama_tokenizer.convert_tokens_to_ids(token) + teacher_id = qwen_tokenizer.convert_tokens_to_ids(token) + assert student_id is not None + assert teacher_id is not None + assert teacher_id in loss._vocab_mapping + assert loss._vocab_mapping[teacher_id] == student_id + assert teacher_id in loss._teacher_matched_ids + assert student_id in loss._student_matched_ids + + +def test_get_start_and_size_answers_skips_prompt_tokens(): + trainer = ULDLoss.__new__(ULDLoss) + trainer.ignore_index = -100 + + answers = torch.tensor( + [ + [-100, -100, -100, 10, 20, 30, -100, -100], + [-100, 5, 6, 7, -100, -100, -100, -100], + [-100, -100, -100, -100, -100, -100, -100, -100], + ] + ) + + starts, sizes = trainer._get_start_and_size_answers(answers) + + assert starts == [3, 1, 0] + assert sizes == [3, 3, 0] + + +@pytest.mark.slow +def test_generate_on_policy_outputs_masks_prompt(llama_tokenizer): + trainer = GOLDTrainer.__new__(GOLDTrainer) + trainer.use_transformers_paged = False + trainer.processing_class = llama_tokenizer + + prompt_text = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\nHello?<|eot_id|>" + completion_text = "<|start_header_id|>assistant<|end_header_id|>\nHi there!" + + prompt_ids = llama_tokenizer(prompt_text, add_special_tokens=False)["input_ids"] + completion_ids = llama_tokenizer(completion_text, add_special_tokens=False)["input_ids"] + + pad_id = llama_tokenizer.pad_token_id + pad_width = 3 + prompt_tensor = torch.full((1, len(prompt_ids) + pad_width), pad_id, dtype=torch.long) + prompt_tensor[0, pad_width:] = torch.tensor(prompt_ids, dtype=torch.long) + prompt_mask = (prompt_tensor != pad_id).long() + + generated_sequence = torch.tensor(prompt_ids + completion_ids, dtype=torch.long).unsqueeze(0) + + class DummyModel: + def generate(self, input_ids, attention_mask, generation_config, return_dict_in_generate): + assert torch.equal(input_ids, prompt_tensor) + assert torch.equal(attention_mask, prompt_mask) + return SimpleNamespace(sequences=generated_sequence) + + generation_config = SimpleNamespace(max_completion_length=None, temperature=None, top_k=None, top_p=None) + new_ids, new_mask, new_labels, prompt_texts, completion_texts = GOLDTrainer.generate_on_policy_outputs( + trainer, + DummyModel(), + {"prompts": prompt_tensor, "prompt_attention_mask": prompt_mask}, + generation_config, + pad_id, + ) + + assert torch.equal(new_ids, generated_sequence) + if pad_id is not None: + expected_mask = (generated_sequence != pad_id).long() + assert torch.equal(new_mask, expected_mask) + else: + assert torch.all(new_mask == 1) + + prompt_len = len(prompt_ids) + assert torch.all(new_labels[0, :prompt_len] == -100) + assert torch.equal(new_labels[0, prompt_len:], torch.tensor(completion_ids, dtype=torch.long)) + + assert prompt_texts[0] == llama_tokenizer.decode(prompt_ids, skip_special_tokens=False) + assert completion_texts[0] == llama_tokenizer.decode(completion_ids, skip_special_tokens=False) + + +@pytest.mark.slow +def test_generate_on_policy_outputs_masks_prompt_smollm(smollm_tokenizer, openr1_examples): + trainer = GOLDTrainer.__new__(GOLDTrainer) + trainer.use_transformers_paged = False + trainer.processing_class = smollm_tokenizer + + collator = DataCollatorForChatML(tokenizer=smollm_tokenizer) + batch = collator([openr1_examples[0]]) + batch = {k: v.cpu() for k, v in batch.items()} + + class DummyModel: + def generate(self, input_ids, attention_mask, generation_config, return_dict_in_generate): + assert torch.equal(input_ids, batch["prompts"]) + assert torch.equal(attention_mask, batch["prompt_attention_mask"]) + return SimpleNamespace(sequences=batch["input_ids"]) + + generation_config = SimpleNamespace(max_completion_length=None, temperature=None, top_k=None, top_p=None) + pad_id = smollm_tokenizer.pad_token_id + new_ids, new_mask, new_labels, prompt_texts, completion_texts = GOLDTrainer.generate_on_policy_outputs( + trainer, + DummyModel(), + {"prompts": batch["prompts"], "prompt_attention_mask": batch["prompt_attention_mask"]}, + generation_config, + pad_id, + ) + + assert torch.equal(new_ids, batch["input_ids"]) + if pad_id is not None: + expected_mask = (batch["input_ids"] != pad_id).long() + assert torch.equal(new_mask, expected_mask) + else: + assert torch.all(new_mask == 1) + + prompt_len = int(batch["prompt_attention_mask"].sum().item()) + tail_labels = new_labels[0, prompt_len:] + expected_tail = batch["input_ids"][0, prompt_len:] + active_mask = tail_labels != -100 + assert torch.all(new_labels[0, :prompt_len] == -100) + assert torch.equal(tail_labels[active_mask], expected_tail[active_mask]) + assert torch.all(tail_labels[~active_mask] == -100) + + prompt_tokens = batch["prompts"][0, batch["prompt_attention_mask"][0].bool()] + decoded_prompt = smollm_tokenizer.decode(prompt_tokens.tolist(), skip_special_tokens=False) + assert prompt_texts[0] == decoded_prompt + + assistant_completion = openr1_examples[0]["messages"][-1]["content"].strip() + assert assistant_completion in completion_texts[0] + + +def test_generalized_jsd_loss_accepts_probability_inputs(): + student_probs = torch.tensor([[[0.6, 0.3, 0.1]]]) + teacher_probs = torch.tensor([[[0.5, 0.4, 0.1]]]) + mixture = 0.5 * (student_probs + teacher_probs) + expected = 0.5 * ( + torch.sum(student_probs * (torch.log(student_probs) - torch.log(mixture))) + + torch.sum(teacher_probs * (torch.log(teacher_probs) - torch.log(mixture))) + ) + + loss = GOLDTrainer.generalized_jsd_loss( + student_probs, + teacher_probs, + beta=0.5, + reduction="batchmean", + logits_are_probs=True, + ) + + torch.testing.assert_close(loss, expected) + + +def test_uldloss_handles_llama_student_qwen_teacher_sequence(llama_tokenizer, qwen_tokenizer): + config = build_config( + uld_use_hybrid_loss=True, + uld_hybrid_matched_weight=0.6, + uld_hybrid_unmatched_weight=0.4, + ) + loss_fn = ULDLoss(config, student_tokenizer=llama_tokenizer, teacher_tokenizer=qwen_tokenizer) + + prompt = "User: Summarize the difference between llamas and alpacas." + completion = "Assistant: Llamas are taller while alpacas have softer wool." + + student_ids, student_labels = encode_prompt_completion(llama_tokenizer, prompt, completion) + teacher_ids, teacher_labels = encode_prompt_completion(qwen_tokenizer, prompt, completion) + + pad_id_student = llama_tokenizer.pad_token_id + pad_id_teacher = qwen_tokenizer.pad_token_id + max_length = max(len(student_ids), len(teacher_ids)) + + student_ids = pad_tokens(student_ids, pad_id_student, max_length) + teacher_ids = pad_tokens(teacher_ids, pad_id_teacher, max_length) + student_labels = pad_labels(student_labels, max_length) + teacher_labels = pad_labels(teacher_labels, max_length) + + student_input_ids = torch.tensor([student_ids]) + teacher_input_ids = torch.tensor([teacher_ids]) + student_labels = torch.tensor([student_labels]) + teacher_labels = torch.tensor([teacher_labels]) + + student_vocab = len(llama_tokenizer) + teacher_vocab = len(qwen_tokenizer) + + student_logits = torch.randn(1, max_length, student_vocab) + teacher_logits = torch.randn(1, max_length, teacher_vocab) + + loss = loss_fn( + student_logits=student_logits, + teacher_logits=teacher_logits, + student_labels=student_labels, + teacher_labels=teacher_labels, + student_input_ids=student_input_ids, + teacher_input_ids=teacher_input_ids, + ) + + assert torch.isfinite(loss) + assert loss.dim() == 0 + assert loss_fn.last_matched_loss is not None + assert loss_fn.last_unmatched_loss is not None + + +def test_uldloss_handles_smollm_student_qwen_teacher_sequence(smollm_tokenizer, qwen_tokenizer): + config = build_config( + uld_use_hybrid_loss=True, + uld_hybrid_matched_weight=0.5, + uld_hybrid_unmatched_weight=0.5, + ) + loss_fn = ULDLoss(config, student_tokenizer=smollm_tokenizer, teacher_tokenizer=qwen_tokenizer) + + prompt = "User: Describe SmolLM3 in a sentence." + completion = "Assistant: SmolLM3 is a compact yet capable language model." + + student_ids, student_labels = encode_prompt_completion(smollm_tokenizer, prompt, completion) + teacher_ids, teacher_labels = encode_prompt_completion(qwen_tokenizer, prompt, completion) + + pad_id_student = smollm_tokenizer.pad_token_id + pad_id_teacher = qwen_tokenizer.pad_token_id + max_length = max(len(student_ids), len(teacher_ids)) + + student_ids = pad_tokens(student_ids, pad_id_student, max_length) + teacher_ids = pad_tokens(teacher_ids, pad_id_teacher, max_length) + student_labels = pad_labels(student_labels, max_length) + teacher_labels = pad_labels(teacher_labels, max_length) + + student_input_ids = torch.tensor([student_ids]) + teacher_input_ids = torch.tensor([teacher_ids]) + student_labels = torch.tensor([student_labels]) + teacher_labels = torch.tensor([teacher_labels]) + + student_vocab = len(smollm_tokenizer) + teacher_vocab = len(qwen_tokenizer) + + student_logits = torch.randn(1, max_length, student_vocab) + teacher_logits = torch.randn(1, max_length, teacher_vocab) + + loss = loss_fn( + student_logits=student_logits, + teacher_logits=teacher_logits, + student_labels=student_labels, + teacher_labels=teacher_labels, + student_input_ids=student_input_ids, + teacher_input_ids=teacher_input_ids, + ) + + assert torch.isfinite(loss) + assert loss.dim() == 0 + assert loss_fn.last_matched_loss is not None + assert loss_fn.last_unmatched_loss is not None + + +def test_uldloss_hybrid_config_beta_zero(llama_tokenizer, qwen_tokenizer): + config = build_config( + uld_use_hybrid_loss=True, + uld_hybrid_matched_weight=0.0, + uld_hybrid_unmatched_weight=1.0, + use_extended_uld=True, + uld_crossentropy_weight=0.0, + uld_distillation_weight=1.0, + uld_student_temperature=1.0, + uld_teacher_temperature=1.0, + temperature=1.0, + top_p=0.95, + top_k=0, + lmbda=1.0, + beta=0.0, + ) + loss_fn = ULDLoss(config, student_tokenizer=llama_tokenizer, teacher_tokenizer=qwen_tokenizer) + + prompt = "User: Explain how GOLD handles tokenizer mismatches." + completion = "Assistant: GOLD merges aligned subwords and applies hybrid ULD loss." + + student_ids, student_labels = encode_prompt_completion(llama_tokenizer, prompt, completion) + teacher_ids, teacher_labels = encode_prompt_completion(qwen_tokenizer, prompt, completion) + + pad_id_student = llama_tokenizer.pad_token_id + pad_id_teacher = qwen_tokenizer.pad_token_id + max_length = max(len(student_ids), len(teacher_ids)) + + student_ids = pad_tokens(student_ids, pad_id_student, max_length) + teacher_ids = pad_tokens(teacher_ids, pad_id_teacher, max_length) + student_labels = pad_labels(student_labels, max_length) + teacher_labels = pad_labels(teacher_labels, max_length) + + student_input_ids = torch.tensor([student_ids]) + teacher_input_ids = torch.tensor([teacher_ids]) + student_labels = torch.tensor([student_labels]) + teacher_labels = torch.tensor([teacher_labels]) + + student_vocab = len(llama_tokenizer) + teacher_vocab = len(qwen_tokenizer) + torch.manual_seed(0) + student_logits = torch.randn(1, max_length, student_vocab) + teacher_logits = torch.randn(1, max_length, teacher_vocab) + + loss = loss_fn( + student_logits=student_logits, + teacher_logits=teacher_logits, + student_labels=student_labels, + teacher_labels=teacher_labels, + student_input_ids=student_input_ids, + teacher_input_ids=teacher_input_ids, + ) + + assert torch.isfinite(loss) + assert loss.dim() == 0 + assert loss_fn.last_matched_loss is not None + assert loss_fn.last_unmatched_loss is not None + + expected = config.uld_hybrid_unmatched_weight * loss_fn.last_unmatched_loss + torch.testing.assert_close(loss, expected, atol=1e-6, rtol=1e-5) diff --git a/ICL/RL/trl_source/tests/experimental/test_grpo_with_replay_buffer_trainer.py b/ICL/RL/trl_source/tests/experimental/test_grpo_with_replay_buffer_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..d1031d2c3a8535f0d730df1ea88deb05a10baf48 --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_grpo_with_replay_buffer_trainer.py @@ -0,0 +1,291 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from datasets import load_dataset + +from trl.experimental.grpo_with_replay_buffer import ( + GRPOWithReplayBufferConfig, + GRPOWithReplayBufferTrainer, + ReplayBuffer, +) + +from ..testing_utils import TrlTestCase + + +@pytest.mark.low_priority +class TestReplayBuffer: + def setup_method(self): + self.replay_buffer = ReplayBuffer(max_size=5) + + def test_add(self): + # Add elements to the replay buffer + scores = [0.5, 0.8, 0.3, 0.9, 0.7] + data = [ + {"id": 1}, + {"id": 2}, + {"id": 3}, + {"id": 4}, + {"id": 5}, + ] + self.replay_buffer.add(scores, data) + + # Check if the buffer contains the correct number of elements + assert len(self.replay_buffer.heap) == 5 + + # Check if the buffer maintains the min-heap property + heap_scores = [item[0] for item in self.replay_buffer.heap] + assert heap_scores[0] == min(heap_scores) + assert heap_scores[0] == 0.3 + + def test_add_more_than_maxlen(self): + # Add elements to the replay buffer + scores = [0.5, 0.8, 0.3, 0.9, 0.7, 0.6, 0.4] + data = [ + {"id": 1}, + {"id": 2}, + {"id": 3}, + {"id": 4}, + {"id": 5}, + {"id": 6}, + {"id": 7}, + ] + self.replay_buffer.add(scores, data) + + # Check if the buffer contains the correct number of elements + assert len(self.replay_buffer.heap) == 5 + + # Check if the buffer maintains the min-heap property + heap_scores = [item[0] for item in self.replay_buffer.heap] + assert heap_scores[0] == min(heap_scores) + assert heap_scores[0] == 0.5 # 0.3 and 0.4 should be removed + + def test_sample(self): + # Add elements to the replay buffer + scores = [0.5, 0.8, 0.3, 0.9, 0.7] + data = [ + {"id": 1}, + {"id": 2}, + {"id": 3}, + {"id": 4}, + {"id": 5}, + ] + self.replay_buffer.add(scores, data) + + # Sample elements from the buffer + sampled = self.replay_buffer.sample(num_samples=3) + + # Check if the sampled elements are from the buffer + assert len(sampled) == 3 + for item in sampled: + assert item in [entry[1] for entry in self.replay_buffer.heap] + + +@pytest.mark.low_priority +class TestUpdateWithReplayBuffer: + def setup_method(self): + config = GRPOWithReplayBufferConfig( + replay_buffer_size=5, + ) + self.trainer = GRPOWithReplayBufferTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=config, + train_dataset=None, + ) + self.trainer.replay_buffer = ReplayBuffer(max_size=5) + self.trainer.num_generations = 2 + + def _prepopulate_buffer(self, with_pixels=False, with_logprobs=False): + scores = [0.1, 0.9] + data = [ + { + "prompt_ids": torch.tensor([[100, 101], [102, 103]]), + "prompt_mask": torch.ones(2, 2, dtype=torch.long), + "completion_ids": torch.tensor([[5, 6], [7, 8]]), + "completion_mask": torch.ones(2, 2, dtype=torch.long), + "advantages": torch.tensor([[0.5, 0.6]]), + **({"pixel_values": torch.randn(2, 3, 224, 224)} if with_pixels else {}), + **({"old_per_token_logps": torch.randn(2, 2)} if with_logprobs else {}), + }, + { + "prompt_ids": torch.tensor([[104, 105], [106, 107]]), + "prompt_mask": torch.ones(2, 2, dtype=torch.long), + "completion_ids": torch.tensor([[13, 14], [15, 16]]), + "completion_mask": torch.ones(2, 2, dtype=torch.long), + "advantages": torch.tensor([[0.8, 0.85]]), + **({"pixel_values": torch.randn(2, 3, 224, 224)} if with_pixels else {}), + **({"old_per_token_logps": torch.randn(2, 2)} if with_logprobs else {}), + }, + ] + self.trainer.replay_buffer.add(scores, data) + + def _make_inputs(self, group_advantages, with_pixels=False, with_logprobs=False): + inputs = { + "group_advantages": group_advantages, + "prompt_ids": torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]), + "prompt_mask": torch.ones(4, 2, dtype=torch.long), + "completion_ids": torch.tensor([[9, 10], [11, 12], [13, 14], [15, 16]]), + "completion_mask": torch.ones(4, 2, dtype=torch.long), + "forward_kwargs": {"pixel_values": torch.randn(4, 3, 224, 224)} if with_pixels else {}, + "old_per_token_logps": torch.randn(4, 2) if with_logprobs else None, + } + inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages) + return inputs + + def test_update_with_replay_buffer_no_variance(self): + self._prepopulate_buffer(with_pixels=True, with_logprobs=True) + group_advantages = torch.tensor([[0.5, 0.5], [0.8, 0.8]]) # no variance + inputs = self._make_inputs(group_advantages, with_pixels=True, with_logprobs=True) + original_prompt_ids = inputs["prompt_ids"].clone() + + outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) + + assert outputs is not None + assert "pixel_values" in outputs + assert "old_per_token_logps" in outputs + assert len(self.trainer.replay_buffer.heap) == 2 + for pid in outputs["prompt_ids"]: + assert pid.tolist() not in original_prompt_ids.tolist() + + def test_update_with_replay_buffer_with_variance(self): + self._prepopulate_buffer() + group_advantages = torch.tensor([[0.6, 0.4], [0.7, 1.2]]) # has variance + inputs = self._make_inputs(group_advantages) + + sampled = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) + + assert len(self.trainer.replay_buffer.heap) == 4 # grew + assert sampled is None + + def test_update_with_mixed_variance(self): + self._prepopulate_buffer() + group_advantages = torch.tensor([[0.6, 0.6], [0.3, 0.45]]) # one no-variance, one variance + inputs = self._make_inputs(group_advantages) + original_prompt_ids = inputs["prompt_ids"].clone().view(-1, self.trainer.num_generations, 2).tolist() + + outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) + + assert len(self.trainer.replay_buffer.heap) == 3 # grew by 1 + output_prompt_ids = outputs["prompt_ids"].view(-1, self.trainer.num_generations, 2).tolist() + + buffer_ids = [item[1]["prompt_ids"].tolist() for item in self.trainer.replay_buffer.heap] + found_from_buffer = any(pid in buffer_ids for pid in output_prompt_ids) + found_from_original = any(pid in original_prompt_ids for pid in output_prompt_ids) + + assert found_from_buffer + assert found_from_original + assert [[1, 2], [3, 4]] not in output_prompt_ids # excluded no-variance group + + def test_update_with_inputs_different_seq_len(self): + """ + Test with inputs where the sequence lengths are different from the prepopulated buffer. + """ + self._prepopulate_buffer() + pad_token_id = self.trainer.processing_class.pad_token_id + group_advantages = torch.tensor([[0.6, 0.6], [0.3, 0.45]]) # one no-variance, one variance + inputs = { + "group_advantages": group_advantages, + "prompt_ids": torch.tensor( + [ + [1, 2, pad_token_id], + [1, 2, pad_token_id], + [3, 4, 5], + [3, 4, 5], + ] + ), + "prompt_mask": torch.tensor([[1, 1, 0], [1, 1, 0], [1, 1, 1], [1, 1, 1]], dtype=torch.long), + "completion_ids": torch.tensor( + [ + [1009, 1010, pad_token_id], + [1011, 1012, 1013], + [1013, 1014, pad_token_id], + [1015, 1016, 1017], + ] + ), + "completion_mask": torch.tensor([[1, 1, 0], [1, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.long), + "forward_kwargs": {}, + } + inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages) + + outputs_after_sampling = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) + # Seq length of current batch should be preserved + assert outputs_after_sampling["prompt_ids"].shape[-1] == 3 + assert len(self.trainer.replay_buffer.heap) == 3 + output_prompt_ids = outputs_after_sampling["prompt_ids"].view(-1, self.trainer.num_generations, 3).tolist() + + buffered_prompt_completion_ids = [ + (item[1]["prompt_ids"].tolist(), item[1]["completion_ids"].tolist()) + for item in self.trainer.replay_buffer.heap + ] + buffered_prompt_ids, buffered_completion_ids = zip(*buffered_prompt_completion_ids, strict=True) + + # Check for new entry with seq len 3 in buffer + assert [[3, 4, 5], [3, 4, 5]] in buffered_prompt_ids # excluded no-variance group + assert [ + [1013, 1014, pad_token_id], + [1015, 1016, 1017], + ] in buffered_completion_ids # excluded no-variance group + + # Check that sampled outputs contain one group with prompt_ids starting with a pad token + assert [ + [pad_token_id, 101, 102], + [pad_token_id, 102, 103], + ] in output_prompt_ids or [ + [pad_token_id, 104, 105], + [pad_token_id, 106, 107], + ] in output_prompt_ids + + +@pytest.mark.low_priority +@pytest.mark.parametrize("scale_rewards", ["batch", "group"]) +class TestGRPOWithReplayBufferTrainer(TrlTestCase): + def test_training_with_replay_buffer(self, scale_rewards): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + # Guarantee that some rewards have 0 std + def custom_reward_func(completions, **kwargs): + if torch.rand(1).item() < 0.25: + return [0] * len(completions) # simulate some None rewards + else: + return torch.rand(len(completions)).tolist() + + training_args = GRPOWithReplayBufferConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=4, # reduce the batch size to reduce memory usage + num_generations=4, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + replay_buffer_size=8, + report_to="none", + scale_rewards=scale_rewards, + ) + trainer = GRPOWithReplayBufferTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[custom_reward_func], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." diff --git a/ICL/RL/trl_source/tests/experimental/test_gspo_token_trainer.py b/ICL/RL/trl_source/tests/experimental/test_gspo_token_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..19391d55a188a90ad071f54f97db74f12e90dfeb --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_gspo_token_trainer.py @@ -0,0 +1,60 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from datasets import load_dataset +from transformers.utils import is_peft_available + +from trl import GRPOConfig +from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer + +from ..testing_utils import TrlTestCase + + +if is_peft_available(): + pass + + +class TestGSPOTokenTrainer(TrlTestCase): + def test_training(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + num_iterations=2, # the importance sampling weights won't be 0 in this case + importance_sampling_level="sequence_token", + report_to="none", + ) + trainer = GSPOTokenTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." diff --git a/ICL/RL/trl_source/tests/experimental/test_judges.py b/ICL/RL/trl_source/tests/experimental/test_judges.py new file mode 100644 index 0000000000000000000000000000000000000000..7714a0caa1b5e252c5cd9b9c496852a2a9161383 --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_judges.py @@ -0,0 +1,106 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import sys +import time + +import pytest +import transformers +from packaging.version import Version + +from trl.experimental.judges import AllTrueJudge, BaseBinaryJudge, HfPairwiseJudge, PairRMJudge + +from ..testing_utils import TrlTestCase, require_llm_blender + + +class RandomBinaryJudge(BaseBinaryJudge): + """ + Random binary judge, for testing purposes. + """ + + def judge(self, prompts, completions, gold_completions=None, shuffle_order=True): + return [random.choice([0, 1, -1]) for _ in range(len(prompts))] + + +class TestJudges(TrlTestCase): + def _get_prompts_and_pairwise_completions(self): + prompts = ["The capital of France is", "The biggest planet in the solar system is"] + completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]] + return prompts, completions + + def _get_prompts_and_single_completions(self): + prompts = ["What's the capital of France?", "What's the color of the sky?"] + completions = ["Marseille", "blue"] + return prompts, completions + + def test_all_true_judge(self): + judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()]) + prompts, completions = self._get_prompts_and_single_completions() + judgements = judge.judge(prompts=prompts, completions=completions) + assert len(judgements) == 2 + assert all(judgement in {0, 1, -1} for judgement in judgements) + + @pytest.mark.skip(reason="This test needs to be run manually since it requires a valid Hugging Face API key.") + def test_hugging_face_judge(self): + judge = HfPairwiseJudge() + prompts, completions = self._get_prompts_and_pairwise_completions() + ranks = judge.judge(prompts=prompts, completions=completions) + assert len(ranks) == 2 + assert all(isinstance(rank, int) for rank in ranks) + assert ranks == [0, 1] + + def load_pair_rm_judge(self): + # When using concurrent tests, PairRM may fail to load the model while another job is still downloading. + # This is a workaround to retry loading the model a few times. + for _ in range(5): + try: + return PairRMJudge() + except ValueError: + time.sleep(5) + raise ValueError("Failed to load PairRMJudge") + + @require_llm_blender + @pytest.mark.skipif( + sys.version_info[:3] == (3, 13, 8), reason="Python 3.13.8 has a bug in inspect.BlockFinder (cpython GH-139783)" + ) + @pytest.mark.xfail( + Version(transformers.__version__) >= Version("5.0.0"), + reason="Known incompatibility between llm-blender and transformers >= 5.0.0 (GH-4918)", + strict=True, + ) + def test_pair_rm_judge(self): + judge = self.load_pair_rm_judge() + prompts, completions = self._get_prompts_and_pairwise_completions() + ranks = judge.judge(prompts=prompts, completions=completions) + assert len(ranks) == 2 + assert all(isinstance(rank, int) for rank in ranks) + assert ranks == [0, 1] + + @require_llm_blender + @pytest.mark.skipif( + sys.version_info[:3] == (3, 13, 8), reason="Python 3.13.8 has a bug in inspect.BlockFinder (cpython GH-139783)" + ) + @pytest.mark.xfail( + Version(transformers.__version__) >= Version("5.0.0"), + reason="Known incompatibility between llm-blender and transformers >= 5.0.0 (GH-4918)", + strict=True, + ) + def test_pair_rm_judge_return_scores(self): + judge = self.load_pair_rm_judge() + prompts, completions = self._get_prompts_and_pairwise_completions() + probs = judge.judge(prompts=prompts, completions=completions, return_scores=True) + assert len(probs) == 2 + assert all(isinstance(prob, float) for prob in probs) + assert all(0 <= prob <= 1 for prob in probs) diff --git a/ICL/RL/trl_source/tests/experimental/test_kto_trainer.py b/ICL/RL/trl_source/tests/experimental/test_kto_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..c0fb27da2c7adc785a4d2a7e1221dc1da0295dfe --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_kto_trainer.py @@ -0,0 +1,350 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl.experimental.kto import KTOConfig, KTOTrainer +from trl.experimental.kto.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize + +from ..testing_utils import TrlTestCase, require_liger_kernel, require_no_wandb, require_peft + + +class TestKTOTrainer(TrlTestCase): + def setup_method(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id, dtype="float32") + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + @pytest.mark.parametrize( + "config_name, loss_type, pre_compute, eval_dataset", + [ + ("standard_preference", "kto", True, True), + ("standard_unpaired_preference", "kto", False, True), + ("conversational_implicit_prompt_preference", "apo_zero_unpaired", True, True), + ("standard_unpaired_preference", "apo_zero_unpaired", False, True), + ], + ) + def test_kto_trainer(self, config_name, loss_type, pre_compute, eval_dataset): + training_args = KTOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps" if eval_dataset else "no", + beta=0.1, + precompute_ref_log_probs=pre_compute, + loss_type=loss_type, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = KTOTrainer( + model=self.model, + ref_model=self.ref_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"] if eval_dataset else None, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.equal(param, new_param) + + def test_kto_trainer_with_ref_model_is_model(self): + training_args = KTOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + with pytest.raises(ValueError): + KTOTrainer( + model=self.model, + ref_model=self.model, # ref_model can't be the same as model + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + ) + + def test_tokenize_and_process_tokens(self): + training_args = KTOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + trainer = KTOTrainer( + model=self.model, + ref_model=self.ref_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + train_dataset = dummy_dataset["train"] + tokenized_dataset = train_dataset.map( + _tokenize, + fn_kwargs={"tokenizer": trainer.processing_class}, + batched=True, + batch_size=2, + ) + assert tokenized_dataset["prompt"][:] == train_dataset["prompt"][:] + assert tokenized_dataset["completion"][:] == train_dataset["completion"][:] + assert tokenized_dataset["label"][:] == train_dataset["label"][:] + assert tokenized_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] + assert tokenized_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] + assert tokenized_dataset["answer_input_ids"][0] == [27261, 13] + assert tokenized_dataset["answer_attention_mask"][0] == [1, 1] + + # Test corruption of (prompt, completion) pairs for KL dataset + for batch_size in [2, 3]: + tokenized_kl_dataset = tokenized_dataset.map(_get_kl_dataset, batched=True, batch_size=batch_size) + + # Verify that the "answer_input_ids" have been modified, meaning the new "answer_input_ids" differ + # from the original ones. However, when the length of the dataset modulo batch_size equals 1, + # the last batch remains unaltered. This is a rare scenario that does not impact the training + # process, so we exclude it from testing by iterating only up to len - 1. + for i in range(len(tokenized_kl_dataset["answer_input_ids"]) - 1): + assert tokenized_dataset["prompt_input_ids"][i] == tokenized_kl_dataset["prompt_input_ids"][i] + assert ( + tokenized_dataset["prompt_attention_mask"][i] == tokenized_kl_dataset["prompt_attention_mask"][i] + ) + assert tokenized_dataset["answer_input_ids"][i] != tokenized_kl_dataset["answer_input_ids"][i] + + fn_kwargs = { + "prefix": "", + "tokenizer": trainer.processing_class, + "max_length": trainer.max_length, + } + processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2) + assert processed_dataset["prompt"][:] == train_dataset["prompt"][:] + assert processed_dataset["completion"][:] == train_dataset["completion"][:] + assert processed_dataset["label"][:] == train_dataset["label"][:] + assert processed_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] + assert processed_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] + assert processed_dataset["completion_input_ids"][0] == [46518, 374, 2664, 1091, 27261, 13, 151645] + assert processed_dataset["completion_attention_mask"][0] == [1, 1, 1, 1, 1, 1, 1] + assert processed_dataset["completion_labels"][0] == [-100, -100, -100, -100, 27261, 13, 151645] + + def test_kto_trainer_without_providing_ref_model(self): + training_args = KTOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + trainer = KTOTrainer( + model=self.model, + ref_model=None, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.equal(param, new_param) + + @require_peft + def test_kto_trainer_without_providing_ref_model_with_lora(self): + from peft import LoraConfig + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + training_args = KTOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + trainer = KTOTrainer( + model=self.model, + ref_model=None, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.equal(param, new_param) + + @require_no_wandb + def test_kto_trainer_generate_during_eval_no_wandb(self): + training_args = KTOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + generate_during_eval=True, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + with pytest.raises( + ValueError, + match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve.", + ): + KTOTrainer( + model=self.model, + ref_model=None, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + @require_liger_kernel + def test_kto_trainer_with_liger(self): + """Test KTO trainer with Liger kernel enabled.""" + training_args = KTOConfig( + output_dir=self.tmp_dir, + report_to="none", + use_liger_kernel=True, # Enable Liger kernel + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + trainer = KTOTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + assert not torch.equal(param, new_param) + + def test_compute_metrics(self): + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32") + ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + tokenizer.pad_token = tokenizer.eos_token + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + def dummy_compute_metrics(*args, **kwargs): + return {"test": 0.0} + + training_args = KTOConfig( + output_dir=self.tmp_dir, + remove_unused_columns=False, + per_device_train_batch_size=2, + do_eval=True, + eval_strategy="steps", + eval_steps=1, + per_device_eval_batch_size=2, + report_to="none", + ) + + trainer = KTOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + compute_metrics=dummy_compute_metrics, + ) + + trainer.train() + + assert trainer.state.log_history[-2]["eval_test"] == 0.0 diff --git a/ICL/RL/trl_source/tests/experimental/test_merge_model_callback.py b/ICL/RL/trl_source/tests/experimental/test_merge_model_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..fb63ac40c448dc560f132933984d38ed77eb9fc3 --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_merge_model_callback.py @@ -0,0 +1,84 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.trainer_utils import get_last_checkpoint + +from trl import DPOConfig, DPOTrainer +from trl.experimental.merge_model_callback import MergeConfig, MergeModelCallback + +from ..testing_utils import TrlTestCase, require_mergekit + + +@require_mergekit +class TestMergeModelCallback(TrlTestCase): + def setup_method(self): + self.model = AutoModelForCausalLM.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32" + ) + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + + def test_callback(self): + training_args = DPOConfig( + output_dir=self.tmp_dir, + num_train_epochs=1, + report_to="none", + save_strategy="steps", + save_steps=1, + ) + config = MergeConfig() + merge_callback = MergeModelCallback(config) + trainer = DPOTrainer( + model=self.model, + args=training_args, + train_dataset=self.dataset, + processing_class=self.tokenizer, + callbacks=[merge_callback], + ) + trainer.train() + last_checkpoint = get_last_checkpoint(self.tmp_dir) + merged_path = os.path.join(last_checkpoint, "merged") + assert os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint." + + def test_every_checkpoint(self): + training_args = DPOConfig( + output_dir=self.tmp_dir, + num_train_epochs=1, + report_to="none", + save_strategy="steps", + save_steps=1, + ) + config = MergeConfig() + merge_callback = MergeModelCallback(config, merge_at_every_checkpoint=True) + trainer = DPOTrainer( + model=self.model, + args=training_args, + train_dataset=self.dataset, + processing_class=self.tokenizer, + callbacks=[merge_callback], + ) + trainer.train() + + checkpoints = sorted( + [os.path.join(self.tmp_dir, cp) for cp in os.listdir(self.tmp_dir) if cp.startswith("checkpoint-")] + ) + + for checkpoint in checkpoints: + merged_path = os.path.join(checkpoint, "merged") + assert os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}." diff --git a/ICL/RL/trl_source/tests/experimental/test_minillm_trainer.py b/ICL/RL/trl_source/tests/experimental/test_minillm_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..fb9f3726f942ac39e85d641301592a60fec967df --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_minillm_trainer.py @@ -0,0 +1,57 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from datasets import load_dataset + +from trl.experimental.minillm import MiniLLMConfig, MiniLLMTrainer + +from ..testing_utils import TrlTestCase + + +@pytest.mark.low_priority +class TestMiniLLMTrainer(TrlTestCase): + def test_train(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + # Initialize the trainer + training_args = MiniLLMConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = MiniLLMTrainer( + model="trl-internal-testing/small-Qwen3ForCausalLM", + teacher_model="trl-internal-testing/tiny-Qwen3ForCausalLM", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" diff --git a/ICL/RL/trl_source/tests/experimental/test_modeling_value_head.py b/ICL/RL/trl_source/tests/experimental/test_modeling_value_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa6d43118471d41289cf7053274320b038d3d51 --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_modeling_value_head.py @@ -0,0 +1,99 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from trl import create_reference_model +from trl.experimental.ppo import AutoModelForCausalLMWithValueHead + +from ..testing_utils import TrlTestCase + + +class TestReferenceModel(TrlTestCase): + def setup_method(self): + self.model = AutoModelForCausalLMWithValueHead.from_pretrained("trl-internal-testing/tiny-GPT2LMHeadModel") + self.test_input = torch.tensor([[0, 1, 2, 3]]) + self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1) + self.layer_format = "pretrained_model.transformer.h.{layer}.attn.c_attn.weight" + + def test_independent_reference(self): + layer_0 = self.layer_format.format(layer=0) + layer_1 = self.layer_format.format(layer=1) + + ref_model = create_reference_model(self.model) + + first_layer_before = self.model.get_parameter(layer_0).data.clone() + last_layer_before = self.model.get_parameter(layer_1).data.clone() # the model only has 2 layers + + first_ref_layer_before = ref_model.get_parameter(layer_0).data.clone() + last_ref_layer_before = ref_model.get_parameter(layer_1).data.clone() + + output = self.model(input_ids=self.test_input, labels=self.test_input) + output[1].backward() + self.optimizer.step() + + first_layer_after = self.model.get_parameter(layer_0).data.clone() + last_layer_after = self.model.get_parameter(layer_1).data.clone() + + first_ref_layer_after = ref_model.get_parameter(layer_0).data.clone() + last_ref_layer_after = ref_model.get_parameter(layer_1).data.clone() + + # before optimization ref and model are identical + assert (first_layer_before == first_ref_layer_before).all() + assert (last_layer_before == last_ref_layer_before).all() + + # ref model stays identical after optimization + assert (first_ref_layer_before == first_ref_layer_after).all() + assert (last_ref_layer_before == last_ref_layer_after).all() + + # optimized model changes + assert not (first_layer_before == first_layer_after).all() + assert not (last_layer_before == last_layer_after).all() + + def test_shared_layers(self): + layer_0 = self.layer_format.format(layer=0) + layer_1 = self.layer_format.format(layer=1) + + ref_model = create_reference_model(self.model, num_shared_layers=1) + + first_layer_before = self.model.get_parameter(layer_0).data.clone() + second_layer_before = self.model.get_parameter(layer_1).data.clone() + + first_ref_layer_before = ref_model.get_parameter(layer_0).data.clone() + second_ref_layer_before = ref_model.get_parameter(layer_1).data.clone() + + output = self.model(input_ids=self.test_input, labels=self.test_input) + output[1].backward() + self.optimizer.step() + + first_layer_after = self.model.get_parameter(layer_0).data.clone() + second_layer_after = self.model.get_parameter(layer_1).data.clone() + + first_ref_layer_after = ref_model.get_parameter(layer_0).data.clone() + second_ref_layer_after = ref_model.get_parameter(layer_1).data.clone() + + # before optimization ref and model are identical + assert (first_layer_before == first_ref_layer_before).all() + assert (second_layer_before == second_ref_layer_before).all() + + # ref model stays identical after optimization + assert (first_ref_layer_before == first_ref_layer_after).all() + assert (second_ref_layer_before == second_ref_layer_after).all() + + # first layer of optimized model stays the same + assert (first_layer_before == first_layer_after).all() + + # other layers in optimized model change + assert not (second_layer_before == second_layer_after).all() diff --git a/ICL/RL/trl_source/tests/experimental/test_nash_md_trainer.py b/ICL/RL/trl_source/tests/experimental/test_nash_md_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ca99bc1961b3fb43de32e1a732bd739655c5fc59 --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_nash_md_trainer.py @@ -0,0 +1,236 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig +from transformers.utils import is_peft_available + +from trl.experimental.nash_md import NashMDConfig, NashMDTrainer +from trl.experimental.nash_md.nash_md_trainer import GeometricMixtureWrapper +from trl.models.utils import create_reference_model + +from ..testing_utils import TrlTestCase, require_llm_blender, require_peft +from .testing_utils import RandomPairwiseJudge + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + +class TestGeometricMixtureWrapper(TrlTestCase): + def setup_method(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32").to(self.device) + self.ref_model = create_reference_model(self.model).to(self.device) + self.generation_config = GenerationConfig.from_pretrained(model_id) + self.mixture_coef = 0.5 + self.wrapper = GeometricMixtureWrapper( + self.model, self.ref_model, self.generation_config, mixture_coef=self.mixture_coef + ) + + def test_forward(self): + input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=self.device) + attention_mask = torch.ones_like(input_ids) + + output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask) + + assert output is not None + assert hasattr(output, "logits") + assert output.logits.shape == (1, 5, self.model.config.vocab_size) + + def test_mixture_coefficient(self): + input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=self.device) + attention_mask = torch.ones_like(input_ids) + + with torch.no_grad(): + model_output = self.model(input_ids=input_ids, attention_mask=attention_mask) + ref_model_output = self.ref_model(input_ids=input_ids, attention_mask=attention_mask) + wrapper_output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask) + + expected_logits = torch.nn.functional.log_softmax( + self.mixture_coef * ref_model_output.logits + (1 - self.mixture_coef) * model_output.logits, dim=-1 + ) + + torch.testing.assert_close(wrapper_output.logits, expected_logits) + + def test_prepare_inputs_for_generation(self): + input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=self.device) + attention_mask = torch.ones_like(input_ids) + + inputs = self.wrapper.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask, use_cache=True) + + assert "input_ids" in inputs + assert "attention_mask" in inputs + assert not inputs.get("use_cache", False) + + +class TestNashMDTrainer(TrlTestCase): + def setup_method(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id, dtype="float32") + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.model_id, num_labels=1) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + @pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"]) + def test_nash_md_trainer_training(self, config_name): + training_args = NashMDConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = NashMDTrainer( + model=self.model, + ref_model=self.ref_model, + reward_funcs=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + @require_peft + def test_training_with_peft(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + training_args = NashMDConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = NashMDTrainer( + model=self.model, + reward_funcs=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + @require_peft + def test_training_with_peft_and_ref_model(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + training_args = NashMDConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = NashMDTrainer( + model=self.model, + ref_model=self.ref_model, + reward_funcs=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + @require_peft + def test_training_pre_pefted_model_implicit_ref_with_reward_model(self): + lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM") + # self.model from setUp is a base AutoModelForCausalLM + peft_model_instance = get_peft_model(self.model, lora_config) + + training_args = NashMDConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, # Keep small for quick test + max_steps=2, # Few steps + learning_rate=5.0e-7, + eval_strategy="no", + report_to="none", + remove_unused_columns=False, # Important for the dummy dataset + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")["train"] + + trainer = NashMDTrainer( + model=peft_model_instance, # Pass the already PEFT model + ref_model=None, # Implicit reference from peft_model_instance's base + reward_funcs=self.reward_model, # To trigger GeometricMixtureWrapper path + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + # peft_config is not passed, as model is already PEFT + ) + + trainer.train() + + assert "train_loss" in trainer.state.log_history[-1] + + @pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"]) + @require_llm_blender + def test_nash_md_trainer_judge_training(self, config_name): + training_args = NashMDConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + judge = RandomPairwiseJudge() + + trainer = NashMDTrainer( + model=self.model, + ref_model=self.ref_model, + judge=judge, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] diff --git a/ICL/RL/trl_source/tests/experimental/test_online_dpo_trainer.py b/ICL/RL/trl_source/tests/experimental/test_online_dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..c3bd23cf414c855c7d72be619aab857f4bd6d9c0 --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_online_dpo_trainer.py @@ -0,0 +1,520 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import transformers +from datasets import Dataset, features, load_dataset +from packaging.version import Version +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers.utils import is_peft_available, is_vision_available + +from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer + +from ..testing_utils import ( + TrlTestCase, + require_llm_blender, + require_peft, + require_torch_accelerator, + require_vision, + require_vllm, +) +from .testing_utils import RandomPairwiseJudge + + +if is_peft_available(): + from peft import LoraConfig + +if is_vision_available(): + import numpy as np + from PIL import Image + from transformers import AutoModelForImageTextToText, AutoProcessor + + +class TestOnlineDPOTrainer(TrlTestCase): + def setup_method(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id, dtype="float32") + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.reward_model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" + self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.reward_model_id, num_labels=1) + self.reward_tokenizer = AutoTokenizer.from_pretrained(self.reward_model_id) + self.reward_tokenizer.pad_token = self.reward_tokenizer.eos_token + + @pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"]) + def test_training(self, config_name): + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = OnlineDPOTrainer( + model=self.model, + reward_funcs=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_classes=self.reward_tokenizer, + ) + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + def test_training_model_str(self): + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = OnlineDPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_classes=self.reward_tokenizer, + ) + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + def test_training_with_ref_model(self): + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = OnlineDPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_funcs=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_classes=self.reward_tokenizer, + ) + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + def test_ref_model_is_model(self): + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + with pytest.raises(ValueError): + OnlineDPOTrainer( + model=self.model, + ref_model=self.model, # ref_model can't be the same as model + reward_funcs=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + processing_class=self.tokenizer, + reward_processing_classes=self.reward_tokenizer, + ) + + @require_peft + def test_training_with_peft(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = OnlineDPOTrainer( + model=self.model, + reward_funcs=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_classes=self.reward_tokenizer, + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + @require_peft + def test_training_with_peft_and_ref_model(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = OnlineDPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_funcs=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_classes=self.reward_tokenizer, + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + @pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"]) + @require_llm_blender + def test_training_with_judge(self, config_name): + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = OnlineDPOTrainer( + model=self.model, + judge=RandomPairwiseJudge(), + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + ) + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + @pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"]) + @require_torch_accelerator + @require_vllm + @pytest.mark.slow + def test_training_with_vllm(self, config_name): + def cleanup_vllm_communicator(trainer): + """Clean up vLLM communicator to avoid conflicts between test runs""" + try: + if hasattr(trainer, "vllm_client") and trainer.vllm_client is not None: + trainer.vllm_client.close_communicator() + except Exception: + pass # Continue if cleanup fails + + model_id = "trl-internal-testing/small-Qwen2ForCausalLM-2.5" # We need a bigger model + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + use_vllm=True, + vllm_gpu_memory_utilization=0.2, + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = OnlineDPOTrainer( + model=model, + reward_funcs=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + processing_class=tokenizer, + reward_processing_classes=self.reward_tokenizer, + ) + + # Ensure cleanup of vLLM communicator after the test + try: + trainer.train() + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + finally: + cleanup_vllm_communicator(trainer) + + @require_vllm + def test_training_with_vllm_colocate(self): + """Test vLLM colocate mode with our refactored implementation""" + model_id = "trl-internal-testing/small-Qwen2ForCausalLM-2.5" # We need a bigger model + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + use_vllm=True, + vllm_mode="colocate", + vllm_gpu_memory_utilization=0.2, + per_device_train_batch_size=1, + max_steps=2, + report_to="none", + # Test generation parameters + temperature=0.9, + top_p=0.95, + top_k=50, + repetition_penalty=1.1, + max_new_tokens=32, + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = OnlineDPOTrainer( + model=model, + reward_funcs=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + processing_class=tokenizer, + reward_processing_classes=self.reward_tokenizer, + ) + + # Verify vLLM setup + assert trainer.use_vllm + assert trainer.vllm_mode == "colocate" + assert trainer.llm is not None + # self.assertIsNone(trainer.vllm_client) + # self.assertEqual(trainer.vllm_gpu_memory_utilization, 0.2) + + # Verify generation parameters + assert trainer.temperature == 0.9 + assert trainer.top_p == 0.95 + assert trainer.top_k == 50 + assert trainer.repetition_penalty == 1.1 + + # Verify generation config + assert trainer.generation_config is not None + assert trainer.generation_config.temperature == 0.9 + assert trainer.generation_config.top_p == 0.95 + assert trainer.generation_config.top_k == 50 + assert trainer.generation_config.repetition_penalty == 1.1 + assert trainer.generation_config.max_tokens == 32 + + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + def test_vllm_config_validation(self): + """Test vLLM configuration validation""" + # Test valid vllm_mode values + config = OnlineDPOConfig(use_vllm=True, vllm_mode="server") + assert config.vllm_mode == "server" + + config = OnlineDPOConfig(use_vllm=True, vllm_mode="colocate") + assert config.vllm_mode == "colocate" + + # Test default values + config = OnlineDPOConfig() + assert config.vllm_mode == "server" + assert config.vllm_server_base_url is None + assert config.vllm_server_host == "0.0.0.0" + assert config.vllm_server_port == 8000 + assert config.vllm_server_timeout == 240.0 + assert config.vllm_gpu_memory_utilization == 0.55 + + # Test generation parameters + assert config.top_p == 1.0 + assert config.top_k == 0 + assert config.min_p is None + assert config.repetition_penalty == 1.0 + assert not config.use_transformers_paged + assert config.cache_implementation is None + assert config.generation_kwargs is None + + def test_generation_config_setup(self): + """Test that generation configuration is properly set up for both vLLM and transformers""" + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + use_vllm=False, + temperature=0.8, + top_p=0.9, + top_k=40, + repetition_penalty=1.2, + max_new_tokens=64, + generation_kwargs={"do_sample": False}, + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = OnlineDPOTrainer( + model=self.model, + reward_funcs=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + processing_class=self.tokenizer, + reward_processing_classes=self.reward_tokenizer, + ) + + # Verify transformers generation config + assert not trainer.use_vllm + # When not using vLLM, these attributes should not be set + assert not (hasattr(trainer, "llm") and trainer.llm is not None) + assert not (hasattr(trainer, "vllm_client") and trainer.vllm_client is not None) + assert trainer.generation_config is not None + assert trainer.generation_config.temperature == 0.8 + assert trainer.generation_config.top_p == 0.9 + assert trainer.generation_config.top_k == 40 + assert trainer.generation_config.repetition_penalty == 1.2 + assert trainer.generation_config.max_new_tokens == 64 + assert not trainer.generation_config.do_sample # From generation_kwargs + + @pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"]) + @require_torch_accelerator + def test_training_with_transformers_paged(self, config_name): + if Version(transformers.__version__) < Version("4.57.0"): + pytest.xfail("Bug in transformers solved in GH#40692, released in 4.57.0.") + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + use_transformers_paged=True, + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = OnlineDPOTrainer( + model=self.model, + reward_funcs=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_classes=self.reward_tokenizer, + ) + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + @pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"]) + def test_training_with_reward_funcs(self, config_name): + def simple_reward_func(prompts, completions, completion_ids, **kwargs): + return [0.5 for _ in prompts] + + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + reward_weights=[0.7, 0.3], + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = OnlineDPOTrainer( + model=self.model, + reward_funcs=[simple_reward_func, simple_reward_func], + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + ) + trainer.train() + + assert "train_loss" in trainer.state.log_history[-1] + assert len(trainer.reward_funcs) == 2 + assert trainer.reward_weights is not None + assert round(abs(trainer.reward_weights[0].item() - 0.7), 5) == 0 + assert round(abs(trainer.reward_weights[1].item() - 0.3), 5) == 0 + + +@require_vision +class TestOnlineDPOVisionTrainer(TrlTestCase): + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", + "trl-internal-testing/tiny-LlavaForConditionalGeneration", + ], + ) + def test_online_dpo_vlm_trainer(self, model_id): + dataset_dict = { + "prompt": [ + [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Describe the image."}]}], + [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What do you see?"}]}], + ], + "images": [ + [Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8))], + [Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8))], + ], + } + dataset = Dataset.from_dict(dataset_dict) + dataset = dataset.cast_column("images", features.Sequence(features.Image())) + + model = AutoModelForImageTextToText.from_pretrained(model_id, dtype="float32") + reward_model = AutoModelForSequenceClassification.from_pretrained( + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", num_labels=1 + ) + processor = AutoProcessor.from_pretrained(model_id) + reward_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2") + reward_tokenizer.pad_token = reward_tokenizer.eos_token + + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + max_steps=2, + learning_rate=0.01, + report_to="none", + ) + trainer = OnlineDPOTrainer( + model=model, + reward_funcs=reward_model, + args=training_args, + processing_class=processor, + train_dataset=dataset, + eval_dataset=dataset, + reward_processing_classes=reward_tokenizer, + ) + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None diff --git a/ICL/RL/trl_source/tests/experimental/test_orpo_trainer.py b/ICL/RL/trl_source/tests/experimental/test_orpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e7c6bbb5917bc4092b32ab333770a45eafef5b09 --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_orpo_trainer.py @@ -0,0 +1,177 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer + +from trl.experimental.orpo import ORPOConfig, ORPOTrainer + +from ..testing_utils import TrlTestCase, require_peft + + +class TestORPOTrainer(TrlTestCase): + def setup_method(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id, dtype="float32") + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + # get t5 as seq2seq example: + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration" + self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, dtype="float32") + self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + + @pytest.mark.parametrize( + "name, config_name", + [ + ("qwen", "standard_preference"), + ("t5", "standard_implicit_prompt_preference"), + ("qwen", "conversational_preference"), + ], + ) + def test_orpo_trainer(self, name, config_name): + training_args = ORPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + if name == "qwen": + model = self.model + tokenizer = self.tokenizer + elif name == "t5": + model = self.t5_model + tokenizer = self.t5_tokenizer + training_args.is_encoder_decoder = True + + trainer = ORPOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.equal(param, new_param) + + @pytest.mark.parametrize( + "config_name", + [ + "standard_preference", + "standard_implicit_prompt_preference", + "conversational_preference", + "conversational_implicit_prompt_preference", + ], + ) + @require_peft + def test_orpo_trainer_with_lora(self, config_name): + from peft import LoraConfig + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + training_args = ORPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = ORPOTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.equal(param, new_param) + + def test_compute_metrics(self): + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32") + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + tokenizer.pad_token = tokenizer.eos_token + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + def dummy_compute_metrics(*args, **kwargs): + return {"test": 0.0} + + training_args = ORPOConfig( + output_dir=self.tmp_dir, + remove_unused_columns=False, + per_device_train_batch_size=2, + do_eval=True, + eval_strategy="steps", + eval_steps=1, + per_device_eval_batch_size=2, + report_to="none", + ) + + trainer = ORPOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + compute_metrics=dummy_compute_metrics, + ) + + trainer.train() + + assert trainer.state.log_history[-2]["eval_test"] == 0.0 diff --git a/ICL/RL/trl_source/tests/experimental/test_ppo_trainer.py b/ICL/RL/trl_source/tests/experimental/test_ppo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..32d6726ff96228dc7914802e8db5dc7f41dd004b --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_ppo_trainer.py @@ -0,0 +1,828 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os + +import pytest +import torch +from datasets import load_dataset +from transformers import ( + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoTokenizer, + GenerationConfig, +) +from transformers.utils import is_peft_available + +from trl.experimental.ppo import ( + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, + PPOConfig, + PPOTrainer, +) +from trl.experimental.ppo.ppo_trainer import batch_generation, masked_mean, masked_var, masked_whiten + +from ..testing_utils import ( + TrlTestCase, + require_bitsandbytes, + require_peft, + require_torch_gpu_if_bnb_not_multi_backend_enabled, +) + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + +ALL_CAUSAL_LM_MODELS = [ + "trl-internal-testing/tiny-BloomForCausalLM", + "trl-internal-testing/tiny-CohereForCausalLM", + # "trl-internal-testing/tiny-FalconMambaForCausalLM", # FalconMambaForCausalLM modeling seems to be broken for now + "trl-internal-testing/tiny-Gemma2ForCausalLM", + "trl-internal-testing/tiny-GemmaForCausalLM", + "trl-internal-testing/tiny-GPT2LMHeadModel", + "trl-internal-testing/tiny-GPTNeoXForCausalLM", + "trl-internal-testing/tiny-LlamaForCausalLM-3.1", + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-LlamaForCausalLM-3", + "trl-internal-testing/tiny-MistralForCausalLM-0.1", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + "trl-internal-testing/tiny-OPTForCausalLM", + "trl-internal-testing/tiny-Phi3ForCausalLM", + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", +] + +ALL_SEQ2SEQ_MODELS = [ + "trl-internal-testing/tiny-T5ForConditionalGeneration", + "trl-internal-testing/tiny-BartModel", +] + + +class TestBatchGeneration(TrlTestCase): + def setup_method(self): + # Initialize the tokenizer + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id, dtype="float32").to(self.device) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + + self.generation_config = GenerationConfig( + max_new_tokens=128, + temperature=0.5, + do_sample=True, + top_k=0, + pad_token_id=self.tokenizer.pad_token_id, + ) + + # Example input + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + self.examples = dataset["messages"] + self.mini_batch_size = 3 + + def test_mini_batch_generation(self): + batch = [ + self.tokenizer.apply_chat_template(example[:-1], add_generation_prompt=True, tokenize=False) + for example in self.examples + ] + queries = self.tokenizer(batch, padding=True, return_tensors="pt")["input_ids"].to(self.device) + bs, context_length = queries.shape + + query_responses, logits = batch_generation( + self.model, queries, self.mini_batch_size, self.tokenizer.pad_token_id, self.generation_config + ) + + max_length_query = query_responses.shape[1] + max_length_logits = max_length_query - context_length + + assert max_length_query > context_length + assert query_responses.shape == (bs, max_length_query) + assert logits.shape == (bs, max_length_logits, self.model.config.vocab_size) + + def test_single_batch_generation(self): + batch = [ + self.tokenizer.apply_chat_template(example[:-1], add_generation_prompt=True, tokenize=False) + for example in self.examples + ] + queries = self.tokenizer(batch, padding=True, return_tensors="pt")["input_ids"].to(self.device) + bs, context_length = queries.shape + + query_responses, logits = batch_generation( + self.model, queries, bs, self.tokenizer.pad_token_id, self.generation_config + ) + + max_length_query = query_responses.shape[1] + max_length_logits = max_length_query - context_length + + assert max_length_query > context_length + assert query_responses.shape == (bs, max_length_query) + assert logits.shape == (bs, max_length_logits, self.model.config.vocab_size) + + +class BaseTester: + class VHeadModelTester(TrlTestCase): + all_model_names = None + trl_model_class = None + transformers_model_class = None + + def setup_method(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + def test_value_head(self): + r""" + Test if the v-head is added to the model successfully + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + assert hasattr(model, "v_head") + + def test_value_head_shape(self): + r""" + Test if the v-head has the correct shape + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + assert model.v_head.summary.weight.shape[0] == 1 + + def test_value_head_init_random(self): + r""" + Test if the v-head has been randomly initialized. We can check that by making sure the bias is different + than zeros by default. + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + assert not torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias)) + + def test_value_head_not_str(self): + r""" + Test if the v-head is added to the model successfully, by passing a non `PretrainedModel` as an argument to + `from_pretrained`. + """ + for model_name in self.all_model_names: + pretrained_model = self.transformers_model_class.from_pretrained(model_name) + model = self.trl_model_class.from_pretrained(pretrained_model) + assert hasattr(model, "v_head") + + def test_from_save_trl(self): + """ + Test if the model can be saved and loaded from a directory and get the same weights, including the + additional modules (e.g. v_head) + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + + model.save_pretrained(self.tmp_dir) + + model_from_save = self.trl_model_class.from_pretrained(self.tmp_dir) + + # Check if the weights are the same + for key in model_from_save.state_dict(): + torch.testing.assert_close(model_from_save.state_dict()[key], model.state_dict()[key]) + + def test_from_save_trl_sharded(self): + """ + Test if the model can be saved and loaded from a directory and get the same weights - sharded case + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + + model.save_pretrained(self.tmp_dir) + + model_from_save = self.trl_model_class.from_pretrained(self.tmp_dir) + + # Check if the weights are the same + for key in model_from_save.state_dict(): + torch.testing.assert_close(model_from_save.state_dict()[key], model.state_dict()[key]) + + def test_from_save_transformers_sharded(self): + """ + Test if the model can be saved and loaded using transformers and get the same weights - sharded case + """ + for model_name in self.all_model_names: + transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name) + + trl_model = self.trl_model_class.from_pretrained(model_name) + + trl_model.save_pretrained(self.tmp_dir, max_shard_size="1MB") + transformers_model_from_save = self.trl_model_class.transformers_parent_class.from_pretrained( + self.tmp_dir + ) + + # Check if the weights are the same + for key in transformers_model.state_dict(): + torch.testing.assert_close( + transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] + ) + + def test_from_save_transformers(self): + """ + Test if the model can be saved and loaded using transformers and get the same weights. We override the test + of the super class to check if the weights are the same. + """ + for model_name in self.all_model_names: + transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name) + + trl_model = self.trl_model_class.from_pretrained(model_name) + + trl_model.save_pretrained(self.tmp_dir) + transformers_model_from_save = self.trl_model_class.transformers_parent_class.from_pretrained( + self.tmp_dir + ) + + # Check if the weights are the same + for key in transformers_model.state_dict(): + torch.testing.assert_close( + transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] + ) + + # Check if the trl model has the same keys as the transformers model + # except the v_head + for key in trl_model.state_dict(): + if "v_head" not in key: + assert key in transformers_model.state_dict() + # check if the weights are the same + torch.testing.assert_close(trl_model.state_dict()[key], transformers_model.state_dict()[key]) + + # check if they have the same modules + assert set(transformers_model_from_save.state_dict().keys()) == set( + transformers_model.state_dict().keys() + ) + + +class TestCausalLMValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase): + """ + Testing suite for v-head models. + """ + + all_model_names = ALL_CAUSAL_LM_MODELS + trl_model_class = AutoModelForCausalLMWithValueHead + transformers_model_class = AutoModelForCausalLM + + def teardown_method(self): + # free memory + gc.collect() + + def test_inference(self): + r""" + Test if the model can be used for inference and outputs 3 values + - logits, loss, and value states + """ + EXPECTED_OUTPUT_SIZE = 3 + + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name).to(self.device) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device) + outputs = model(input_ids) + + # Check if the outputs are of the right size - here + # we always output 3 values - logits, loss, and value states + assert len(outputs) == EXPECTED_OUTPUT_SIZE + + def test_dropout_config(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config it will be added to the v_head + """ + for model_name in self.all_model_names: + pretrained_model = self.transformers_model_class.from_pretrained(model_name) + pretrained_model.config.summary_dropout_prob = 0.5 + model = self.trl_model_class.from_pretrained(pretrained_model) + + # Check if v head of the model has the same dropout as the config + assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob + + def test_dropout_kwargs(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config it will be added to the v_head + """ + for model_name in self.all_model_names: + v_head_kwargs = {"summary_dropout_prob": 0.5} + + model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs) + + # Check if v head of the model has the same dropout as the config + assert model.v_head.dropout.p == 0.5 + + model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5) + + # Check if v head of the model has the same dropout as the config + assert model.v_head.dropout.p == 0.5 + + @pytest.mark.parametrize("model_name", ALL_CAUSAL_LM_MODELS) + def test_generate(self, model_name): + r""" + Test if `generate` works for every model + """ + generation_config = GenerationConfig(max_new_tokens=9) + model = self.trl_model_class.from_pretrained(model_name).to(self.device) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device) + + # Just check if the generation works + _ = model.generate(input_ids, generation_config=generation_config) + + def test_transformers_bf16_kwargs(self): + r""" + Test if the transformers kwargs are correctly passed. Here we check that loading a model in half precision + works as expected, i.e. the weights of the `pretrained_model` attribute is loaded in half precision and you can + run a dummy forward pass without any issue. + """ + for model_name in self.all_model_names: + trl_model = self.trl_model_class.from_pretrained(model_name, dtype=torch.bfloat16).to(self.device) + + lm_head_namings = ["lm_head", "embed_out", "output_layer"] + + assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), ( + "Can't test the model because it doesn't have any of the expected lm_head namings" + ) + + for lm_head_naming in lm_head_namings: + if hasattr(trl_model.pretrained_model, lm_head_naming): + assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16 + + dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(self.device) + + # check dummy forward pass works in half precision + _ = trl_model(dummy_input) + + @pytest.mark.skip(reason="This test needs to be run manually due to HF token issue.") + def test_push_to_hub(self): + for model_name in self.all_model_names: + model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name) + if "sharded" in model_name: + model.push_to_hub(model_name + "-ppo", use_auth_token=True, max_shard_size="1MB") + else: + model.push_to_hub(model_name + "-ppo", use_auth_token=True) + + model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(model_name + "-ppo") + # check all keys + assert model.state_dict().keys() == model_from_pretrained.state_dict().keys() + + for name, param in model.state_dict().items(): + ( + torch.testing.assert_close(param, model_from_pretrained.state_dict()[name]), + (f"Parameter {name} is not the same after push_to_hub and from_pretrained"), + ) + + +class TestSeq2SeqValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase): + """ + Testing suite for v-head models. + """ + + all_model_names = ALL_SEQ2SEQ_MODELS + trl_model_class = AutoModelForSeq2SeqLMWithValueHead + transformers_model_class = AutoModelForSeq2SeqLM + + def teardown_method(self): + # free memory + gc.collect() + + def test_inference(self): + r""" + Test if the model can be used for inference and outputs 3 values + - logits, loss, and value states + """ + EXPECTED_OUTPUT_SIZE = 3 + + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name).to(self.device) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device) + decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device) + outputs = model(input_ids, decoder_input_ids=decoder_input_ids) + + # Check if the outputs are of the right size - here + # we always output 3 values - logits, loss, and value states + assert len(outputs) == EXPECTED_OUTPUT_SIZE + + def test_dropout_config(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config it will be added to the v_head + """ + for model_name in self.all_model_names: + pretrained_model = self.transformers_model_class.from_pretrained(model_name) + pretrained_model.config.summary_dropout_prob = 0.5 + model = self.trl_model_class.from_pretrained(pretrained_model) + + # Check if v head of the model has the same dropout as the config + assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob + + def test_dropout_kwargs(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config it will be added to the v_head + """ + for model_name in self.all_model_names: + v_head_kwargs = {"summary_dropout_prob": 0.5} + + model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs) + + # Check if v head of the model has the same dropout as the config + assert model.v_head.dropout.p == 0.5 + + model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5) + + # Check if v head of the model has the same dropout as the config + assert model.v_head.dropout.p == 0.5 + + @pytest.mark.parametrize("model_name", ALL_SEQ2SEQ_MODELS) + def test_generate(self, model_name): + r""" + Test if `generate` works for every model + """ + generation_config = GenerationConfig(max_new_tokens=9) + model = self.trl_model_class.from_pretrained(model_name).to(self.device) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device) + decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], device=self.device) + + # Just check if the generation works + _ = model.generate(input_ids, decoder_input_ids=decoder_input_ids, generation_config=generation_config) + + @pytest.mark.skip(reason="This test needs to be run manually due to HF token issue.") + def test_push_to_hub(self): + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + if "sharded" in model_name: + model.push_to_hub(model_name + "-ppo", use_auth_token=True, max_shard_size="1MB") + else: + model.push_to_hub(model_name + "-ppo", use_auth_token=True) + + model_from_pretrained = self.trl_model_class.from_pretrained(model_name + "-ppo") + # check all keys + assert model.state_dict().keys() == model_from_pretrained.state_dict().keys() + + for name, param in model.state_dict().items(): + ( + torch.testing.assert_close(param, model_from_pretrained.state_dict()[name]), + (f"Parameter {name} is not the same after push_to_hub and from_pretrained"), + ) + + def test_transformers_bf16_kwargs(self): + r""" + Test if the transformers kwargs are correctly passed. Here we check that loading a model in half precision + works as expected, i.e. the weights of the `pretrained_model` attribute is loaded in half precision and you can + run a dummy forward pass without any issue. + """ + for model_name in self.all_model_names: + trl_model = self.trl_model_class.from_pretrained(model_name, dtype=torch.bfloat16).to(self.device) + + lm_head_namings = self.trl_model_class.lm_head_namings + + assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings) + + for lm_head_naming in lm_head_namings: + if hasattr(trl_model.pretrained_model, lm_head_naming): + assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16 + + dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(self.device) + + # check dummy forward pass works in half precision + _ = trl_model(input_ids=dummy_input, decoder_input_ids=dummy_input) + + +@require_peft +class TestPeftModel(TrlTestCase): + def setup_method(self): + self.causal_lm_model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + def test_create_peft_model(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + _ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + def test_peft_requires_grad(self): + r""" + Check that the value head of the returned model has requires_grad=True. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + # Check that the value head has requires_grad=True + assert model.v_head.summary.weight.requires_grad + + def test_check_peft_model_nb_trainable_params(self): + r""" + Check that the number of trainable parameters is correct. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + assert nb_trainable_params == 905 + + # Check that the number of trainable param for the non-peft model is correct + non_peft_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) + nb_trainable_params = sum(p.numel() for p in non_peft_model.parameters() if p.requires_grad) + assert nb_trainable_params == 2428641 + + def test_create_peft_model_from_config(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained( + self.causal_lm_model_id, peft_config=self.lora_config + ) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + assert nb_trainable_params == 905 + + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + assert nb_trainable_params == 905 + + @require_bitsandbytes + @require_torch_gpu_if_bnb_not_multi_backend_enabled + def test_create_bnb_peft_model_from_config(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + from bitsandbytes.nn import Linear8bitLt + from transformers import BitsAndBytesConfig + + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained( + self.causal_lm_model_id, + peft_config=self.lora_config, + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + ) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + assert nb_trainable_params == 905 + assert isinstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt) + + causal_lm_model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map="auto" + ) + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + assert nb_trainable_params == 905 + assert isinstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt) + + def test_save_pretrained_peft(self): + r""" + Check that the model can be saved and loaded properly. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + model.save_pretrained(self.tmp_dir) + + # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory + assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), ( + f"{self.tmp_dir}/adapter_model.safetensors does not exist" + ) + assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), ( + f"{self.tmp_dir}/adapter_config.json does not exist" + ) + + # check also for `pytorch_model.bin` and make sure it only contains `v_head` weights + assert os.path.exists(f"{self.tmp_dir}/pytorch_model.bin"), f"{self.tmp_dir}/pytorch_model.bin does not exist" + + # check that only keys that starts with `v_head` are in the dict + maybe_v_head = torch.load(f"{self.tmp_dir}/pytorch_model.bin", weights_only=True) + assert all(k.startswith("v_head") for k in maybe_v_head.keys()), ( + f"keys in {self.tmp_dir}/pytorch_model.bin do not start with `v_head`" + ) + + model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir) + + # check all the weights are the same + for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters(), strict=True): + torch.testing.assert_close(p1[1], p2[1]), f"{p1[0]} != {p2[0]}" + + def test_load_pretrained_peft(self): + r""" + Check that the model saved with peft class interface can be loaded properly. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + pretrained_model.save_pretrained(self.tmp_dir) + model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir) + + # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory + assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), ( + f"{self.tmp_dir}/adapter_model.safetensors does not exist" + ) + assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), ( + f"{self.tmp_dir}/adapter_config.json does not exist" + ) + + # check all the weights are the same + for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters(), strict=True): + if p1[0] not in ["v_head.summary.weight", "v_head.summary.bias"]: + torch.testing.assert_close(p1[1], p2[1]), f"{p1[0]} != {p2[0]}" + + def test_continue_training_peft_model(self): + r""" + Load peft and checks that it can continue training. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + pretrained_model.save_pretrained(self.tmp_dir) + # set is_trainable to True + model = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir, is_trainable=True) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + assert nb_trainable_params == 905 + + +class TestCore(TrlTestCase): + """ + A wrapper class for testing core utils functions + """ + + def setup_method(self): + self.test_input = torch.Tensor([1, 2, 3, 4]) + self.test_mask = torch.Tensor([0, 1, 1, 0]) + self.test_input_unmasked = self.test_input[1:3] + + def test_masked_mean(self): + assert torch.mean(self.test_input_unmasked) == masked_mean(self.test_input, self.test_mask) + + def test_masked_var(self): + assert torch.var(self.test_input_unmasked) == masked_var(self.test_input, self.test_mask) + + def test_masked_whiten(self): + def whiten(values: torch.Tensor) -> torch.Tensor: + mean, var = torch.mean(values), torch.var(values) + return (values - mean) * torch.rsqrt(var + 1e-8) + + whiten_unmasked = whiten(self.test_input_unmasked) + whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3] + diffs = (whiten_unmasked - whiten_masked).sum() + assert abs(diffs.item()) < 0.00001 + + +class TestPPOTrainer(TrlTestCase): + def setup_method(self): + # Set up the models and tokenizer using the test model + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id, dtype="float32") + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left") + self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + # Add reward and value models as in ppo.py + reward_model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + self.value_model = AutoModelForSequenceClassification.from_pretrained(reward_model_id, num_labels=1) + self.reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_id, num_labels=1) + + # Load dataset + raw_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + def tokenize(example, tokenizer): + tokenized = tokenizer(text=example["prompt"]) + if tokenizer.eos_token_id is not None and tokenized["input_ids"][-1] != tokenizer.eos_token_id: + tokenized["input_ids"] = tokenized["input_ids"] + [tokenizer.eos_token_id] + tokenized["attention_mask"] = tokenized["attention_mask"] + [1] + return tokenized + + self.raw_dataset = raw_dataset.map(tokenize, fn_kwargs={"tokenizer": self.tokenizer}, remove_columns="prompt") + + def test_basic_training(self): + """Test basic PPO training configuration and verify model updates.""" + # Capture initial weights + initial_critic_weights = {} + initial_policy_weights = {} + for name, param in self.value_model.named_parameters(): + initial_critic_weights[name] = param.clone().detach() + for name, param in self.model.named_parameters(): + initial_policy_weights[name] = param.clone().detach() + + # Configure training args similar to example script + training_args = PPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=4, + per_device_eval_batch_size=2, + num_ppo_epochs=2, # Decrease number of PPO epochs to speed up test + report_to="none", + ) + + # Create trainer + trainer = PPOTrainer( + args=training_args, + processing_class=self.tokenizer, + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + value_model=self.value_model, + train_dataset=self.raw_dataset["train"], + eval_dataset=self.raw_dataset["test"], + ) + + # Train + trainer.train() + + # Check if critic weights have been updated + critic_weights_updated = False + for name, param in trainer.model.value_model.named_parameters(): + if not torch.allclose(initial_critic_weights[name], param.to("cpu")): + critic_weights_updated = True + break + + # Check if policy weights have been updated + policy_weights_updated = False + for name, param in trainer.model.policy.named_parameters(): + if not torch.allclose(initial_policy_weights[name], param.to("cpu")): + policy_weights_updated = True + break + + assert critic_weights_updated, "Critic weights were not updated during training" + assert policy_weights_updated, "Policy weights were not updated during training" + + @require_peft + def test_peft_training(self): + """Test PPO training with PEFT configuration and verify model updates.""" + # Capture initial weights + initial_critic_weights = {} + initial_policy_weights = {} + for name, param in self.value_model.named_parameters(): + initial_critic_weights[name] = param.clone().detach() + for name, param in self.model.named_parameters(): + initial_policy_weights[name] = param.clone().detach() + + # Configure training args + training_args = PPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=4, + per_device_eval_batch_size=2, + num_ppo_epochs=2, # Decrease number of PPO epochs to speed up test + report_to="none", + ) + + # Configure PEFT + peft_config = LoraConfig( + r=32, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + # Create trainer with PEFT + trainer = PPOTrainer( + args=training_args, + processing_class=self.tokenizer, + model=self.model, + ref_model=None, + reward_model=self.reward_model, + value_model=self.value_model, + train_dataset=self.raw_dataset["train"], + eval_dataset=self.raw_dataset["test"], + peft_config=peft_config, + ) + + # Train + trainer.train() + + # Check if critic weights have been updated + critic_weights_updated = False + for name, param in trainer.model.value_model.named_parameters(): + if name in initial_critic_weights and not torch.allclose(initial_critic_weights[name], param.to("cpu")): + critic_weights_updated = True + break + + # Check if policy weights have been updated - for PEFT we check the LoRA weights + policy_weights_updated = False + for name, param in trainer.model.policy.named_parameters(): + if "lora" in name.lower() and param.requires_grad: # Only check LoRA weights + # New weights should be non-zero if they've been updated + if not torch.allclose(param, torch.zeros_like(param)): + policy_weights_updated = True + break + + assert critic_weights_updated, "Critic weights were not updated during training" + assert policy_weights_updated, "Policy LoRA weights were not updated during training" diff --git a/ICL/RL/trl_source/tests/experimental/test_prm_trainer.py b/ICL/RL/trl_source/tests/experimental/test_prm_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..28c5adcbaa5b9eea28949ea77ecd174ec0905218 --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_prm_trainer.py @@ -0,0 +1,376 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch +from datasets import Dataset, load_dataset +from transformers import AutoModelForTokenClassification, AutoTokenizer, PreTrainedTokenizerBase +from transformers.utils import is_peft_available + +from trl.experimental.prm import PRMConfig, PRMTrainer +from trl.experimental.prm.prm_trainer import compute_accuracy + +from ..testing_utils import TrlTestCase, require_peft + + +if is_peft_available(): + from peft import LoraConfig, TaskType + + +class TestComputeAccuracy(TrlTestCase): + def test_token_classification_task(self): + eval_pred = ( + np.array( + [ + [[0.1, 0.9], [0.8, 0.2]], # Batch 1 + [[0.3, 0.7], [0.6, 0.4]], # Batch 2 + ] + ), + np.array([[0, 1], [1, 0]]), + ) + expected_accuracy = 0.5 # 2 matches, 2 mismatches + result = compute_accuracy(eval_pred) + assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 + + def test_token_classification_task_with_ignored_tokens_0(self): + eval_pred = ( + np.array( + [ + [[0.1, 0.9], [0.8, 0.2]], # Batch 1 + [[0.3, 0.7], [0.6, 0.4]], # Batch 2 + ] + ), + np.array([[1, 0], [1, -100]]), + ) + expected_accuracy = 1.0 # All non-ignored tokens match + result = compute_accuracy(eval_pred) + assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 + + def test_token_classification_task_with_ignored_tokens_1(self): + eval_pred = ( + np.array( + [ + [[0.1, 0.9], [0.8, 0.2]], # Batch 1 + [[0.3, 0.7], [0.6, 0.4]], # Batch 2 + ] + ), + np.array([[1, 1], [0, -100]]), + ) + expected_accuracy = 1 / 3 # 1 match, 2 mismatch, 1 ignored + result = compute_accuracy(eval_pred) + assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 + + def test_rewards_comparison_task(self, caplog): + eval_pred = ( + np.array( + [ + [0.9, 0.1], # Batch 1 + [0.6, 0.4], # Batch 2 + [0.5, 0.5], # Batch 3 (equal) + ] + ), + np.array([0, 1, 1]), + ) + expected_accuracy = 0.5 # 1 match, 1 mismatch, 1 equal (ignored) + + with caplog.at_level("WARNING", logger="trl.trainer.utils"): + result = compute_accuracy(eval_pred) + + assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 + expected_warning = ( + "There are 1 out of 3 instances where the predictions for both options are equal. " + "These instances are ignored in the accuracy computation." + ) + assert expected_warning in caplog.text + + +class TestTokenizeRow(TrlTestCase): + def setup_method(self): + # Set up the mock tokenizer with specific behaviors + self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + self.tokenizer.bos_token_id = 0 + self.tokenizer.eos_token_id = 2 + + def mock_encode(text, add_special_tokens): + token_map = { + "Which number is larger, 9.8 or 9.11?": [465, 6766, 318, 298], + "11 is greater than 8.": [4, 322, 12], + "Hence, 9.11 > 9.8.": [4995, 11, 22], + "\n": [1030], + "\n\n": [1030, 1030], + } + + return token_map[text] + + def mock_tokenizer_call(text, add_special_tokens): + return {"input_ids": mock_encode(text, add_special_tokens)} + + self.tokenizer.encode.side_effect = mock_encode + self.tokenizer.side_effect = mock_tokenizer_call + + def test_tokenize_row_no_truncation(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + # Call the method with no truncation + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n", + max_length=None, + max_completion_length=None, + train_on_last_step_only=False, + is_eval=False, + ) + + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], + } + + def test_tokenize_row_train_on_last_step_only(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n", + max_length=None, + max_completion_length=None, + train_on_last_step_only=True, + is_eval=False, + ) + + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0], + } + + def test_tokenize_row_completion_truncation(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + # Call the method with truncation on the completion + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n", + max_length=None, + max_completion_length=6, + train_on_last_step_only=False, + is_eval=False, + ) + + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100], + } + + def test_tokenize_row_prompt_completion_truncation(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + # Call the method with truncation on the prompt and completion + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n", + max_length=9, + max_completion_length=None, + train_on_last_step_only=False, + is_eval=False, + ) + + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1], + } + + def test_tokenize_row_multi_token_separator(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + # Call the method using multiple tokens as step_separator + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n\n", + max_length=None, + max_completion_length=None, + train_on_last_step_only=False, + is_eval=False, + ) + + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 1030, 4995, 11, 22, 1030, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, 0], + } + + +class TestPRMTrainer(TrlTestCase): + def setup_method(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForTokenClassification.from_pretrained(model_id, dtype="float32") + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + @pytest.mark.parametrize("train_on_last_step_only", [True, False]) + def test_train_full(self, train_on_last_step_only): + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train") + training_args = PRMConfig( + output_dir=self.tmp_dir, + report_to="none", + train_on_last_step_only=train_on_last_step_only, + ) + trainer = PRMTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + + def test_train_full_pretokenized(self): + dummy_dataset = Dataset.from_dict( + { + "labels": [ + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 1, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, 1, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, 0, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0], + ], + "input_ids": [ + [46518, 374, 2664, 1091, 11, 1077, 752, 1744, 1112, 198, 27261, 13, 198], + [98923, 374, 2664, 1091, 11, 315, 3308, 11, 198, 17995, 13, 198, 1576, 31273, 12850, 13, 198], + [16374, 374, 2664, 1091, 1112, 1077, 594, 2506, 432, 6770, 11, 198, 6351, 13, 198], + [31137, 374, 2664, 1091, 979, 4362, 11, 198, 16965, 13, 198], + [31019, 374, 2664, 1091, 304, 3793, 315, 5944, 11, 198, 24034, 13, 198], + [98491, 374, 2664, 1091, 1112, 5310, 369, 91494, 13, 198], + [4418, 2897, 14579, 5310, 979, 3800, 1349, 432, 13, 198], + [20366, 5048, 7629, 944, 3281, 3322, 11, 7241, 1112, 198, 807, 1795, 279, 5601, 13, 198], + [15802, 14976, 487, 33327, 1045, 31787, 63443, 11, 198, 52400, 13, 198], + [13877, 1265, 2581, 1494, 49394, 11, 198, 7241, 20975, 91681, 13, 198], + [641, 279, 3579, 315, 71768, 11, 25066, 279, 61361, 311, 7942, 13, 198], + [7039, 374, 2664, 1091, 2937, 13, 198], + [26155, 374, 3545, 2664, 1091, 34933, 26537, 13, 198], + [2679, 279, 8129, 374, 4135, 311, 10339, 11, 432, 2578, 387, 264, 1661, 2884, 13, 198], + ], + } + ) + + training_args = PRMConfig(output_dir=self.tmp_dir, report_to="none") + trainer = PRMTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + + @require_peft + def test_train_lora(self): + peft_config = LoraConfig( + task_type=TaskType.TOKEN_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train") + training_args = PRMConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") + trainer = PRMTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + peft_config=peft_config, + ) + previous_trainable_params = {} + previous_non_trainable_params = {} + + # due to a change in the way the modules to save are dealt in PEFT. + trainable_params_name = ["lora", "modules_to_save"] + + # check gradients are not None + for n, param in trainer.model.named_parameters(): + if any(t in n for t in trainable_params_name): + previous_trainable_params[n] = param.clone() + else: + previous_non_trainable_params[n] = param.clone() + + trainer.train() + + assert trainer.state.log_history[(-1)]["train_loss"] is not None + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param, atol=1e-12, rtol=1e-12) + + # Check that the non trainable parameters have not changed + for n, param in previous_non_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + torch.testing.assert_close(param, new_param, atol=1e-12, rtol=1e-12) + + def test_tags(self): + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train") + training_args = PRMConfig(output_dir=self.tmp_dir, report_to="none") + trainer = PRMTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + assert trainer.model.model_tags == trainer._tag_names diff --git a/ICL/RL/trl_source/tests/experimental/test_utils.py b/ICL/RL/trl_source/tests/experimental/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..85bde671494c9b6e92af9d442954dcff376b6d2a --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_utils.py @@ -0,0 +1,107 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datasets import load_dataset +from transformers import AutoTokenizer + +from trl.experimental.utils import DataCollatorForChatML + +from ..testing_utils import TrlTestCase + + +class TestDataCollatorForChatML(TrlTestCase): + def setup_method(self): + # Initialize the tokenizer + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Define token IDs + self.bos_token_id = self.tokenizer.bos_token_id if self.tokenizer.bos_token_id is not None else 1 + self.eos_token_id = self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else 2 + # Token ID for "true", the last assistant's response in the example: + self.ignore_index = -100 + self.max_length = 1024 + self.messages_key = "messages" + + # Example input + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + self.examples = dataset.to_list() + + # Initialize the data collator + self.collator = DataCollatorForChatML( + tokenizer=self.tokenizer, + max_length=self.max_length, + ignore_index=self.ignore_index, + ) + + def test_data_collator_for_chatml(self): + # Process the data + data = self.collator(self.examples) + + # Verify basic shapes and types + assert "input_ids" in data + assert "attention_mask" in data + assert "labels" in data + assert "prompts" in data + assert "prompt_attention_mask" in data + + # Decode input_ids and labels for verification + input_ids = data["input_ids"][0].tolist() + labels = data["labels"][0].tolist() + prompt_only = data["prompts"][0].tolist() + + # Get the last assistant's response for comparison + last_message = self.examples[0][self.messages_key][-1] + assert last_message["role"] == "assistant", "Last message should be from assistant" + last_assistant_response = last_message["content"] + + # Verify that input_ids contain both prompt and response + decoded_input = self.tokenizer.decode(input_ids) + assert last_assistant_response in decoded_input, "Input should contain assistant's response" + + # Verify that prompts only contain the conversation up to the last response + decoded_prompt = self.tokenizer.decode(prompt_only) + assert last_assistant_response not in decoded_prompt, "Prompt should not contain assistant's response" + + # Verify labels are -100 for non-assistant parts + prompt_length = len(prompt_only) + assert all(label == self.ignore_index for label in labels[:prompt_length]), ( + "Labels should be ignore_index for prompt tokens" + ) + + # Verify labels match assistant response after prompt + # Add a filter to remove any trailing tokens after the first <|im_end|> + last_assistant_response_with_end = last_assistant_response + self.tokenizer.eos_token + last_assistant_response_tokens = self.tokenizer.encode( + last_assistant_response_with_end, add_special_tokens=False + ) + + response_labels = [] + for label in labels[prompt_length:]: + if label == self.ignore_index: + continue + response_labels.append(label) + if label == self.tokenizer.convert_tokens_to_ids("<|im_end|>"): + break + assert response_labels == last_assistant_response_tokens, "Labels should match assistant response tokens" + + # Verify there isn't a generation prompt at the end + generation_prompt = "<|im_start|>assistant" + assert not decoded_input.strip().endswith(generation_prompt), ( + f"Input should not end with generation prompt '{generation_prompt}'" + ) + + assert response_labels == last_assistant_response_tokens, "Labels should match assistant response tokens" diff --git a/ICL/RL/trl_source/tests/experimental/test_winrate_callback.py b/ICL/RL/trl_source/tests/experimental/test_winrate_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..5e137246b4bcb73a11a2b59583afd7c737f82a0b --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_winrate_callback.py @@ -0,0 +1,213 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments +from transformers.utils import is_peft_available + +from trl.experimental.judges import BasePairwiseJudge +from trl.experimental.winrate_callback import WinRateCallback + +from ..testing_utils import TrlTestCase, require_peft + + +if is_peft_available(): + from peft import LoraConfig + + +class HalfPairwiseJudge(BasePairwiseJudge): + """Naive pairwise judge that always returns [1, 0] for two prompts""" + + def judge(self, prompts, completions, shuffle_order=True, return_scores=False): + # just check that the batch size is 2 + assert len(prompts) == 2 + if return_scores: + return [0.3, 0.9] + return [1, 0] + + +class TrainerWithRefModel(Trainer): + # This is a dummy class to test the callback. Compared to the Trainer class, it only has an additional + # ref_model attribute + def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processing_class): + super().__init__( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + ) + # Prepare ref_model like TRL trainers do (DPOTrainer, GRPOTrainer, etc.) + self.ref_model = self.accelerator.prepare_model(ref_model, evaluation_mode=True) + + +class TestWinRateCallback(TrlTestCase): + def setup_method(self): + self.model = AutoModelForCausalLM.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32" + ) + self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + self.tokenizer.pad_token = self.tokenizer.eos_token + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + dataset["train"] = dataset["train"].select(range(8)) + self.expected_winrates = [ + {"eval_win_rate": 0.5, "epoch": 0.0, "step": 0}, + {"eval_win_rate": 0.5, "epoch": 0.5, "step": 2}, + {"eval_win_rate": 0.5, "epoch": 1.0, "step": 4}, + {"eval_win_rate": 0.5, "epoch": 1.5, "step": 6}, + {"eval_win_rate": 0.5, "epoch": 2.0, "step": 8}, + {"eval_win_rate": 0.5, "epoch": 2.5, "step": 10}, + {"eval_win_rate": 0.5, "epoch": 3.0, "step": 12}, + ] + + def tokenize_function(examples): + out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True) + out["labels"] = out["input_ids"].copy() + return out + + self.dataset = dataset.map(tokenize_function, batched=True) + + self.generation_config = GenerationConfig(max_length=32) + self.judge = HalfPairwiseJudge() + + def test_basic(self): + training_args = TrainingArguments( + output_dir=self.tmp_dir, + eval_strategy="steps", + eval_steps=2, # evaluate every 2 steps + per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch + per_device_eval_batch_size=2, + report_to="none", + ) + trainer = TrainerWithRefModel( + model=self.model, + ref_model=self.ref_model, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=self.tokenizer, + ) + win_rate_callback = WinRateCallback( + judge=self.judge, trainer=trainer, generation_config=self.generation_config + ) + trainer.add_callback(win_rate_callback) + trainer.train() + winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] + for history_row, expected_row in zip(winrate_history, self.expected_winrates, strict=True): + assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row) + + def test_without_ref_model(self): + # Same as before, but without the ref_model attribute. It should use the model attribute instead + training_args = TrainingArguments( + output_dir=self.tmp_dir, + eval_strategy="steps", + eval_steps=2, # evaluate every 2 steps + per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch + per_device_eval_batch_size=2, + report_to="none", + ) + trainer = Trainer( + model=self.model, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=self.tokenizer, + ) + win_rate_callback = WinRateCallback( + judge=self.judge, trainer=trainer, generation_config=self.generation_config + ) + trainer.add_callback(win_rate_callback) + trainer.train() + winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] + for history_row, expected_row in zip(winrate_history, self.expected_winrates, strict=True): + assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row) + + def test_soft_judge(self): + """Test that the soft judge functionality works correctly""" + training_args = TrainingArguments( + output_dir=self.tmp_dir, + eval_strategy="steps", + eval_steps=2, # evaluate every 2 steps + per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch + per_device_eval_batch_size=2, + report_to="none", + ) + trainer = TrainerWithRefModel( + model=self.model, + ref_model=self.ref_model, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=self.tokenizer, + ) + win_rate_callback = WinRateCallback( + judge=self.judge, trainer=trainer, generation_config=self.generation_config, use_soft_judge=True + ) + trainer.add_callback(win_rate_callback) + trainer.train() + + # Expected values based on judge returning [0.3, 0.9] for each pair + expected_soft_winrates = [ + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.0, "step": 0}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.5, "step": 2}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.0, "step": 4}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.5, "step": 6}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.0, "step": 8}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.5, "step": 10}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 3.0, "step": 12}, + ] + + winrate_history = [ + {k: h[k] for k in ["eval_avg_win_prob", "eval_win_rate", "epoch", "step"]} + for h in trainer.state.log_history + if "eval_avg_win_prob" in h + ] + for history_row, expected_row in zip(winrate_history, expected_soft_winrates, strict=True): + assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row) + + @require_peft + def test_lora(self): + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + self.model.add_adapter(peft_config) + training_args = TrainingArguments( + output_dir=self.tmp_dir, + eval_strategy="steps", + eval_steps=2, # evaluate every 2 steps + per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch + per_device_eval_batch_size=2, + report_to="none", + ) + trainer = Trainer( + model=self.model, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=self.tokenizer, + ) + win_rate_callback = WinRateCallback( + judge=self.judge, trainer=trainer, generation_config=self.generation_config + ) + trainer.add_callback(win_rate_callback) + trainer.train() + winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] + for history_row, expected_row in zip(winrate_history, self.expected_winrates, strict=True): + assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row) diff --git a/ICL/RL/trl_source/tests/experimental/test_xpo_trainer.py b/ICL/RL/trl_source/tests/experimental/test_xpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..d5f92ada9c7f8de9ab160a34ef01fae13934da99 --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/test_xpo_trainer.py @@ -0,0 +1,184 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers.utils import is_peft_available + +from trl.experimental.xpo import XPOConfig, XPOTrainer + +from ..testing_utils import TrlTestCase, require_llm_blender, require_peft +from .testing_utils import RandomPairwiseJudge + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + +@pytest.mark.low_priority +class TestXPOTrainer(TrlTestCase): + def setup_method(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id, dtype="float32") + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.model_id, num_labels=1) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + @pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"]) + def test_xpo_trainer_training(self, config_name): + training_args = XPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = XPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_funcs=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + @require_peft + def test_training_with_peft(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + training_args = XPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = XPOTrainer( + model=self.model, + reward_funcs=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + @require_peft + def test_training_with_peft_and_ref_model(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + training_args = XPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = XPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_funcs=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] + + @require_peft + def test_training_pre_pefted_model_implicit_ref(self): + lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM") + peft_model_instance = get_peft_model(self.model, lora_config) + + training_args = XPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=1, + max_steps=2, + learning_rate=5.0e-7, + eval_strategy="no", + report_to="none", + remove_unused_columns=False, + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")["train"] + + trainer = XPOTrainer( + model=peft_model_instance, + ref_model=None, + reward_funcs=self.reward_model, # Using reward_model to ensure _generate_completions is used as expected + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + ) + + trainer.train() + + assert "train_loss" in trainer.state.log_history[-1] + + @pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"]) + @require_llm_blender + def test_xpo_trainer_judge_training(self, config_name): + training_args = XPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + judge = RandomPairwiseJudge() + + trainer = XPOTrainer( + model=self.model, + ref_model=self.ref_model, + judge=judge, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + trainer.train() + + # Check if training loss is available + assert "train_loss" in trainer.state.log_history[-1] diff --git a/ICL/RL/trl_source/tests/experimental/testing_utils.py b/ICL/RL/trl_source/tests/experimental/testing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4caae79737d17c38be79d675c909155788dc6e7c --- /dev/null +++ b/ICL/RL/trl_source/tests/experimental/testing_utils.py @@ -0,0 +1,28 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +from trl.experimental.judges import BasePairwiseJudge + + +class RandomPairwiseJudge(BasePairwiseJudge): + """ + Random pairwise judge, for testing purposes. + """ + + def judge(self, prompts, completions, shuffle_order=True, return_scores=False): + if not return_scores: + return [random.randint(0, len(completion) - 1) for completion in completions] + else: + return [random.random() for _ in range(len(prompts))] diff --git a/ICL/RL/trl_source/trl/experimental/bco/bco_config.py b/ICL/RL/trl_source/trl/experimental/bco/bco_config.py new file mode 100644 index 0000000000000000000000000000000000000000..64f0779ef6a274a179eb8a21d3736a605490eb45 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/bco/bco_config.py @@ -0,0 +1,189 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from transformers import TrainingArguments + + +@dataclass +class BCOConfig(TrainingArguments): + r""" + Configuration class for the [`experimental.bco.BCOTrainer`]. + + This class includes only the parameters that are specific to BCO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want + to use the default data collator. + max_completion_length (`int`, *optional*): + Maximum length of the completion. This argument is required if you want to use the default data collator + and your model is an encoder-decoder. + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher β means less deviation from the + reference model. + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. + This argument is required if you want to use the default data collator. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model and reference model. + generate_during_eval (`bool`, *optional*, defaults to `False`): + If `True`, generates and logs completions from both the model and the reference model to W&B or Comet + during evaluation. + is_encoder_decoder (`bool`, *optional*): + When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, + you need to specify if the model returned by the callable is an encoder-decoder model. + precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): + Whether to precompute reference model log probabilities for training and evaluation datasets. This is + useful when training without the reference model to reduce the total GPU memory needed. + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model and + reference model from strings. + dataset_num_proc (`int`, *optional*): + Number of processes to use for processing the dataset. + prompt_sample_size (`int`, *optional*, defaults to `1024`): + Number of prompts that are fed to density ratio classifier. + min_density_ratio (`float`, *optional*, defaults to `0.5`): + Minimum value of the density ratio. The estimated density ratio is clamped to this value. + max_density_ratio (`float`, *optional*, defaults to `10.0`): + Maximum value of the density ratio. The estimated density ratio is clamped to this value. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + logging_steps: float = field( + default=10, + metadata={ + "help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, " + "will be interpreted as ratio of total training steps." + }, + ) + gradient_checkpointing: bool = field( + default=True, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) + bf16: bool | None = field( + default=None, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if " + "`fp16` is not set." + }, + ) + # Transformers 4.57.0 introduced a bug that caused the dtype of `lr_scheduler_kwargs` to be unparsable. This issue + # was fixed in https://github.com/huggingface/transformers/pull/41322 and released in 4.57.5. We add a temporary + # workaround here, which can be removed once we drop support for versions older than 4.57.5. + lr_scheduler_kwargs: dict | str | None = field( + default=None, + metadata={ + "help": "Additional parameters for the lr_scheduler, such as {'num_cycles': 1} for cosine with hard " + "restarts." + }, + ) + + max_length: int | None = field( + default=1024, + metadata={ + "help": "Maximum length of the sequences (prompt + completion) in the batch. " + "This argument is required if you want to use the default data collator." + }, + ) + max_completion_length: int | None = field( + default=None, + metadata={ + "help": "Maximum length of the completion. This argument is required if you want to use the " + "default data collator and your model is an encoder-decoder." + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the deviation from the reference model. " + "Higher β means less deviation from the reference model." + }, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the prompt is too long. Possible values are " + "`keep_end` or `keep_start`. This argument is required if you want to use the " + "default data collator." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model and reference model."}, + ) + generate_during_eval: bool = field( + default=False, + metadata={ + "help": "If `True`, generates and logs completions from both the model and the reference model " + "to W&B during evaluation." + }, + ) + is_encoder_decoder: bool | None = field( + default=None, + metadata={ + "help": "When using the `model_init` argument (callable) to instantiate the model instead of the " + "`model` argument, you need to specify if the model returned by the callable is an " + "encoder-decoder model." + }, + ) + precompute_ref_log_probs: bool = field( + default=False, + metadata={ + "help": "Whether to precompute reference model log probabilities for training and evaluation datasets. " + "This is useful when training without the reference model to reduce the total GPU memory " + "needed." + }, + ) + model_init_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " + "model from a string." + }, + ) + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + prompt_sample_size: int = field( + default=1024, + metadata={"help": "Number of prompts that are fed to density ratio classifier."}, + ) + min_density_ratio: float = field( + default=0.5, + metadata={"help": "Minimum value of the density ratio. The estimated density ratio is clamped to this value."}, + ) + max_density_ratio: float = field( + default=10.0, + metadata={"help": "Maximum value of the density ratio. The estimated density ratio is clamped to this value."}, + ) + + def __post_init__(self): + self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 + + super().__post_init__() diff --git a/ICL/RL/trl_source/trl/experimental/bco/bco_trainer.py b/ICL/RL/trl_source/trl/experimental/bco/bco_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a2d3da651a2bc4cc5c87d4c53a5d329420ffc4 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/bco/bco_trainer.py @@ -0,0 +1,1468 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import random +import textwrap +from collections import defaultdict +from collections.abc import Callable +from contextlib import contextmanager, nullcontext +from operator import itemgetter +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal, Optional + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformers +from accelerate import PartialState, logging +from accelerate.utils import tqdm +from datasets import Dataset +from packaging.version import Version +from torch import autocast +from torch.utils.data import DataLoader, SequentialSampler +from transformers import ( + AutoModelForCausalLM, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, + TrainingArguments, + is_comet_available, + is_sklearn_available, + is_wandb_available, +) +from transformers.trainer_utils import EvalLoopOutput, has_length +from transformers.utils import is_peft_available + +from ...data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset +from ...import_utils import is_joblib_available +from ...models.utils import create_reference_model, peft_module_casting_to_bf16, prepare_deepspeed +from ...trainer.base_trainer import BaseTrainer +from ...trainer.utils import ( + RunningMoments, + disable_dropout_in_model, + log_table_to_comet_experiment, + pad_to_length, + selective_log_softmax, +) +from ..utils import DPODataCollatorWithPadding +from .bco_config import BCOConfig + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + +if is_wandb_available(): + import wandb + +if is_sklearn_available(): + from sklearn.linear_model import LogisticRegression + +if is_joblib_available(): + import joblib + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + +logger = logging.get_logger(__name__) + +RUNNING_NAME = "running.json" +CLF_NAME = "clf.pkl" + + +def _tokenize( + batch: dict[str, list[Any]], + tokenizer: "PreTrainedTokenizer", + embedding_tokenizer: Optional["PreTrainedTokenizer"] = None, +) -> dict[str, list[Any]]: + """Tokenize a batch from a BCO specific dataset.""" + prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False) + prompt_input_ids = prompt_tokenized["input_ids"] + prompt_attention_mask = prompt_tokenized["attention_mask"] + prompt_and_completion = [ + prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"], strict=True) + ] + full_tokenized = tokenizer(prompt_and_completion, add_special_tokens=False) + full_input_ids = full_tokenized["input_ids"] + full_attention_mask = full_tokenized["attention_mask"] + + answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids, strict=True)] + answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask, strict=True)] + + # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` + full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids, strict=True)] + # Prepare input tokens for token by token comparison + full_input_ids = [np.array(f) for f in full_input_ids] + for full, concat in zip(full_input_ids, full_concat_input_ids, strict=True): + if len(full) != len(concat): + raise ValueError( + "The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length." + ) + + # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens + # can be merged together when tokenizing prompt+answer. This could result + # on the last token from the prompt being different when tokenized on its own + # vs when done as prompt+answer. + response_token_ids_start_idx = [len(p) for p in prompt_input_ids] + + # If tokenized prompt is different than both prompt+answer, then it means the + # last token has changed due to merging. + for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx, strict=True)): + if not np.array_equal(p, f[:r]): + response_token_ids_start_idx[idx] -= 1 + + prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)] + prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)] + + for p, m in zip(prompt_input_ids, prompt_attention_mask, strict=True): + if len(p) != len(m): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)] + answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)] + + output = dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + answer_input_ids=answer_input_ids, + answer_attention_mask=answer_attention_mask, + ) + + if embedding_tokenizer is not None: + embedding_tokenized = embedding_tokenizer(batch["prompt"], add_special_tokens=False) + + output.update( + { + "embedding_input_ids": embedding_tokenized["input_ids"], + "embedding_attention_mask": embedding_tokenized["attention_mask"], + } + ) + + return output + + +def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **kwargs) -> dict: + """Process tokens of a BCO specific dataset. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + + completion responses is/are too long. First we truncate the prompt; if we're still too long, we truncate the + completion. + + We also create the labels for the completion responses, which are of length equal to the sum of the length of the + prompt and the completion response, with `-100` for the prompt tokens. + """ + prompt = example["prompt"] + completion = example["completion"] + + batch = { + f"{kwargs['prefix']}prompt": prompt, + f"{kwargs['prefix']}completion": completion, + f"{kwargs['prefix']}label": example["label"], + } + + if not kwargs["is_encoder_decoder"]: + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + + if not isinstance(completion, str): + raise ValueError(f"completion should be an str but got {type(completion)}") + + # keys of format prompt_* refers to just the prompt and answer_* refers to just the answer + all_tokens = { + "prompt_input_ids": example["prompt_input_ids"], + "prompt_attention_mask": example["prompt_attention_mask"], + "answer_input_ids": example["answer_input_ids"], + "answer_attention_mask": example["answer_attention_mask"], + } + + # calculate max length by checking if BOS/EOS is already there + max_length = kwargs["max_length"] + bos_token_id = kwargs["tokenizer"].bos_token_id + eos_token_id = kwargs["tokenizer"].eos_token_id + if bos_token_id != all_tokens["prompt_input_ids"][0]: + max_length -= 1 + if eos_token_id != all_tokens["answer_input_ids"][-1]: + max_length -= 1 + + # if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the response + if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: + for k in ["answer_input_ids", "answer_attention_mask"]: + all_tokens[k] = all_tokens[k][: max_length - len(all_tokens["prompt_input_ids"])] + + # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens + batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"] + batch[f"{kwargs['prefix']}completion_input_ids"] = ( + all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"] + ) + batch[f"{kwargs['prefix']}completion_attention_mask"] = ( + all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"] + ) + + # add BOS, which affects both prompt and the full completion + if bos_token_id is not None: + if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: + batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}prompt_input_ids" + ] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}prompt_attention_mask" + ] + batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}completion_input_ids" + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] + # add EOS, which affects only the full completion + if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: + batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ + eos_token_id + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] + [1] + + batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:] + batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [ + -100 + ] * len(batch[f"{kwargs['prefix']}prompt_input_ids"]) + else: + completion_tokens = kwargs["tokenizer"]( + completion, truncation=True, max_length=kwargs["max_completion_length"], add_special_tokens=True + ) + prompt_tokens = kwargs["tokenizer"](prompt, add_special_tokens=True) + + batch[f"{kwargs['prefix']}prompt_input_ids"] = prompt_tokens["input_ids"] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = prompt_tokens["attention_mask"] + + batch[f"{kwargs['prefix']}completion_labels"] = completion_tokens["input_ids"] + batch[f"{kwargs['prefix']}completion_attention_mask"] = completion_tokens["attention_mask"] + if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): + batch[f"{kwargs['prefix']}completion_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["completion_labels"]) + ) + + return batch + + +class BCOTrainer(BaseTrainer): + r""" + Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper. + + Args: + model ([`~transformers.PreTrainedModel`]): + The model to train, preferably an [`~transformers.AutoModelForSequenceClassification`]. + ref_model ([`~transformers.PreTrainedModel`]): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation + and loss. If no reference model is provided, the trainer will create a reference model with the same + architecture as the model to be optimized. + args ([`experimental.bco.BCOConfig`]): + The arguments to use for training. + train_dataset ([`~datasets.Dataset`]): + The dataset to use for training. + eval_dataset ([`~datasets.Dataset`]): + The dataset to use for evaluation. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + data_collator ([`~transformers.DataCollator`], *optional*): + The data collator to use for training. If None is specified, the default data collator + ([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the + maximum length of the sequences in the batch, given a dataset of paired sequences. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be + used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in + a PEFT model. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to + metric values. + model_adapter_name (`str`, defaults to `None`): + Name of the train target PEFT adapter, when using LoRA with multiple adapters. + ref_adapter_name (`str`, defaults to `None`): + Name of the reference PEFT adapter, when using LoRA with multiple adapters. + """ + + _tag_names = ["trl", "bco"] + _name = "BCO" + _paper = { + "title": "Binary Classifier Optimization for Large Language Model Alignment", + "id": "2404.04656", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @article{jung2024binary, + title = {{Binary Classifier Optimization for Large Language Model Alignment}}, + author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On}, + year = 2024, + eprint = {arXiv:2404.04656} + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module | str = None, + ref_model: PreTrainedModel | nn.Module | str | None = None, + args: BCOConfig = None, + train_dataset: Dataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + data_collator: DataCollator | None = None, + model_init: Callable[[], PreTrainedModel] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: dict | None = None, + compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, + model_adapter_name: str | None = None, + ref_adapter_name: str | None = None, + embedding_func: Callable | None = None, + embedding_tokenizer: PreTrainedTokenizerBase | None = None, + ): + if embedding_func is not None and not (is_sklearn_available() and is_joblib_available()): + raise ImportError( + "BCOTrainer with UDM requires the scikit-learn and joblib libraries. Please install it with `pip install scikit-learn joblib`." + ) + + if type(args) is TrainingArguments: + raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.") + + if not isinstance(model, str) and model is not None and ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, you must mass a copy of it, or `None` if you use peft." + ) + + if args.model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.") + else: + model_init_kwargs = args.model_init_kwargs + dtype = model_init_kwargs.get("dtype", "auto") + if dtype is not None: + # Convert to `torch.dtype` if an str is passed + if isinstance(dtype, str) and dtype != "auto": + dtype = getattr(torch, dtype) + if dtype != "auto" and not isinstance(dtype, torch.dtype): + raise ValueError( + f"Invalid `dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {dtype}." + ) + model_init_kwargs["dtype"] = dtype + model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto") + + if isinstance(model, str): + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if isinstance(ref_model, str): + ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **model_init_kwargs) + + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + if isinstance(model, PeftModel): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " + "merge and unload the existing adapter, save the resulting base model, and then pass that base " + "model along with the new `peft_config` to the trainer." + ) + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True + + # For models that use gradient_checkpointing, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif args.is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = args.is_encoder_decoder + + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + self.model_adapter_name = model_adapter_name + self.ref_adapter_name = ref_adapter_name + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model or args.precompute_ref_log_probs: + # The `model` with adapters turned off will be used as the reference model + self.ref_model = None + else: + self.ref_model = create_reference_model(model) + + if processing_class is None: + raise ValueError( + "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding" + ) + if args.max_length is None: + logger.warning( + "When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. " + "It will be set to `512` by default, but you should do it yourself in the future.", + ) + max_length = 512 + if args.max_length is not None: + max_length = args.max_length + + max_completion_length = None + if args.max_completion_length is None and self.is_encoder_decoder: + logger.warning( + "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + ) + max_completion_length = 128 + if args.max_completion_length is not None and self.is_encoder_decoder: + max_completion_length = args.max_completion_length + + if data_collator is None: + data_collator = DPODataCollatorWithPadding( + pad_token_id=processing_class.pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + logger.warning( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig" + " we have set it for you, but you should do it yourself in the future.", + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + # Disable dropout in the model and reference model + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.max_length = max_length + self.generate_during_eval = args.generate_during_eval + self.truncation_mode = args.truncation_mode + self.max_completion_length = max_completion_length + self.precompute_ref_log_probs = args.precompute_ref_log_probs + + # Since ref_logs are precomputed on the first call to get_train/eval_dataloader + # keep track of first called to avoid computation of future calls + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + + # metric + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # BCO parameter + self.beta = args.beta + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) + if self.aux_loss_enabled and self.aux_loss_coef == 0.0: + logger.warning( + "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " + "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " + "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " + "loss.", + ) + + # Underlying Distribution Matching argument + self.embedding_func = embedding_func + self.embedding_tokenizer = embedding_tokenizer + + with PartialState().main_process_first(): + # Extract the prompt if needed + train_dataset = train_dataset.map( + maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset" + ) + # Unpair the dataset if needed + train_dataset = maybe_unpair_preference_dataset( + train_dataset, args.dataset_num_proc, desc="Unpairing train dataset" + ) + # Apply the chat template if needed + train_dataset = train_dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc + ) + if eval_dataset is not None: + # Extract the prompt if needed + eval_dataset = eval_dataset.map( + maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset" + ) + # Unpair the dataset if needed + eval_dataset = maybe_unpair_preference_dataset( + eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset" + ) + eval_dataset = eval_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + ) + + # Tokenize and prepare the training datasets + train_dataset = train_dataset.map( + _tokenize, + batched=True, + fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer}, + num_proc=args.dataset_num_proc, + desc="Tokenizing train dataset", + ) + + # Prepare the datasets + fn_kwargs = { + "prefix": "", + "is_encoder_decoder": self.is_encoder_decoder, + "tokenizer": processing_class, + "max_length": self.max_length, + "truncation_mode": self.truncation_mode, + "max_completion_length": self.max_completion_length, + } + train_dataset = train_dataset.map( + _process_tokens, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + desc="Processing tokenized train dataset", + ) + + if eval_dataset is not None: + # Tokenize + eval_dataset = eval_dataset.map( + _tokenize, + fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer}, + batched=True, + num_proc=args.dataset_num_proc, + desc="Tokenizing eval dataset", + ) + + # Process + fn_kwargs = { + "prefix": "", + "is_encoder_decoder": self.is_encoder_decoder, + "tokenizer": processing_class, + "max_length": self.max_length, + "truncation_mode": self.truncation_mode, + "max_completion_length": self.max_completion_length, + } + eval_dataset = eval_dataset.map( + _process_tokens, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + desc="Processing tokenized eval dataset", + ) + + desirable = train_dataset.filter( + lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples" + ) + undesirable = train_dataset.filter( + lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples" + ) + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + # Deepspeed Zero-3 does not support precompute_ref_log_probs + if self.is_deepspeed_enabled: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." + ) + + if self.ref_model is None: + if not (self.is_peft_model or self.precompute_ref_log_probs): + raise ValueError( + "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" + ) + else: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + self.running = RunningMoments(accelerator=self.accelerator) + + if self.embedding_func is None or args.resume_from_checkpoint: + return + + chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size) + rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size) + + embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0) + labels = torch.cat( + (torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0 + ) + + self.clf = LogisticRegression(class_weight="balanced").fit( + embeddings.cpu().float().numpy(), labels.cpu().numpy() + ) + chosen_mean = self.clf.score( + chosen_embeddings.cpu().float().numpy(), torch.ones_like(chosen_embeddings[:, 0]).cpu().numpy() + ) + rejected_mean = self.clf.score( + rejected_embeddings.cpu().float().numpy(), torch.zeros_like(rejected_embeddings[:, 0]).cpu().numpy() + ) + logger.info(f"UDM classifier training scores: chosen: {chosen_mean}, rejected: {rejected_mean}") + + @property + def match_underlying_distribution(self): + return self.embedding_func is not None and self.embedding_tokenizer is not None + + def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor: + """ + Calculates the probability if the given prompt embedding is from desirable dataset. This function calculates + the probability in the process and ensemble across processes. + """ + dtype = prompt_embeddings.dtype + device = prompt_embeddings.device + rank = self.accelerator.process_index + + padded_prompt_embeddings = self.accelerator.pad_across_processes( + prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id + ) + sample_size = padded_prompt_embeddings.shape[0] + nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id + prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings) + + # cannot predict for all empty values + if prompt_embeddings.shape[0] == 0: + return torch.tensor([], device=device, dtype=dtype) + + prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1] + prob = torch.as_tensor(prob, dtype=dtype, device=device) + prob = self.accelerator.reduce(prob, reduction="mean") + + prob = prob[sample_size * rank : sample_size * (rank + 1)] + prob = prob[nonzero] + + return prob + + def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor: + """ + Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id and applies self.embedding_func + """ + input_ids = torch.where( + input_ids == self.processing_class.pad_token_id, + self.embedding_tokenizer.pad_token_id, + input_ids, + ) + + with torch.no_grad(): + embeddings = self.embedding_func( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + return embeddings + + def _get_prompt_embeddings( + self, batch: dict[str, list | torch.LongTensor] + ) -> tuple[torch.FloatTensor, torch.FloatTensor]: + """Extract embeddings from frozen embedding model""" + + if not self.match_underlying_distribution: + return None, None + + embeddings = self._vectorize_prompt( + input_ids=batch["embedding_input_ids"], + attention_mask=batch["embedding_attention_mask"], + ) + + labels = torch.tensor(batch["label"], dtype=torch.bool, device=embeddings.device) + chosen_idx = torch.where(labels)[0] + rejected_idx = torch.where(~labels)[0] + + chosen_embeddings = embeddings[chosen_idx, ...] + rejected_embeddings = embeddings[rejected_idx, ...] + + return (chosen_embeddings, rejected_embeddings) + + def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor: + """ + Sample instances from dataset and get prompt embeddings. Used for density ratio classifier training. + """ + n_samples = min(len(dataset), sample_size) + rand_indices = np.random.choice(len(dataset), size=(n_samples,)) + + embedding_dataset = dataset.select(rand_indices) + + dataloader_params = { + "batch_size": self.args.per_device_train_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params)) + + with torch.no_grad(): + all_embeddings = torch.empty(0) + for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"): + embeddings = self._vectorize_prompt( + input_ids=padded_batch["embedding_input_ids"], + attention_mask=padded_batch["embedding_attention_mask"], + ) + embeddings = self.accelerator.gather_for_metrics(embeddings) + all_embeddings = torch.cat((all_embeddings, embeddings.cpu())) + + return all_embeddings + + def _save_optimizer_and_scheduler(self, output_dir): + output_dir = output_dir if output_dir is not None else self.args.output_dir + super()._save_optimizer_and_scheduler(output_dir) + + if self.accelerator.is_main_process: + # When saving optimizer and scheduler to checkpoint, save also the running delta object. + self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME)) + + if self.match_underlying_distribution: + joblib.dump(self.clf, os.path.join(output_dir, CLF_NAME), compress=True) + + def _load_optimizer_and_scheduler(self, checkpoint): + if checkpoint is None: + logger.warning_once(f"Missing Checkpoint {checkpoint}") + return + + super()._load_optimizer_and_scheduler(checkpoint) + + # when loading optimizer and scheduler from checkpoint, also load the running delta object. + running_file = os.path.join(checkpoint, RUNNING_NAME) + if os.path.isfile(running_file): + self.running = RunningMoments.load_from_json(self.accelerator, running_file) + + if self.match_underlying_distribution: + clf_file = os.path.join(checkpoint, CLF_NAME) + if os.path.isfile(clf_file): + self.clf = joblib.load(clf_file) + + @contextmanager + def null_ref_context(self): + """Context manager for handling null reference model (that is, peft adapter manipulation).""" + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): + if self.ref_adapter_name: + self.model.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.set_adapter(self.model_adapter_name or "default") + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. + """ + + if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_train_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) + reference_completion_logps = [] + + for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): + reference_completion_logp = self.compute_reference_log_probs(padded_batch) + + reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) + reference_completion_logps.append(reference_completion_logp.cpu()) + + self.train_dataset = self.train_dataset.add_column( + name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() + ) + + self._precomputed_train_ref_log_probs = True + + return super().get_train_dataloader() + + def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_eval_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) + + reference_completion_logps = [] + + for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): + reference_completion_logp = self.compute_reference_log_probs(padded_batch) + + reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) + reference_completion_logps.append(reference_completion_logp.cpu()) + + eval_dataset = eval_dataset.add_column( + name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() + ) + + # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs + if self.eval_dataset is not None: + self.eval_dataset = eval_dataset + self._precomputed_eval_ref_log_probs = True + + return super().get_eval_dataloader(eval_dataset=eval_dataset) + + def compute_reference_log_probs(self, padded_batch: dict) -> dict: + """Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset.""" + with torch.no_grad(): + if self.ref_model is None: + with self.null_ref_context(): + if self.is_encoder_decoder: + completion_logits = self.model( + padded_batch["prompt_input_ids"], + attention_mask=padded_batch["prompt_attention_mask"], + decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), + labels=padded_batch["completion_labels"], + ).logits + + else: + completion_logits = self.model( + padded_batch["completion_input_ids"], + attention_mask=padded_batch["completion_attention_mask"], + ).logits + + else: + if self.is_encoder_decoder: + completion_logits = self.ref_model( + padded_batch["prompt_input_ids"], + attention_mask=padded_batch["prompt_attention_mask"], + decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), + labels=padded_batch["completion_labels"], + ).logits + + else: + completion_logits = self.ref_model( + padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"] + ).logits + + completion_logps = self.get_batch_logps( + completion_logits, + padded_batch["completion_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + ) + + return completion_logps + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + is_encoder_decoder: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: + Labels for which to compute the log probabilities. Label tokens with a value of `-100` are ignored. + Shape: (batch_size, sequence_length) + average_log_prob: + If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the + log probabilities of the (non-masked) tokens. + is_encoder_decoder: + Whether the model is an encoder-decoder model. If True, the labels are not shifted, and the logits are + assumed to already be aligned with the labels. If False, the labels are shifted to the right by one + position, and the logits are assumed to be aligned with the shifted labels. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the + given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + else: + # Fixes end-dec RuntimeError + labels = labels.clone() + + loss_mask = labels != -100 + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == -100] = 0 + + per_token_logps = selective_log_softmax(logits, labels) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def forward( + self, model: nn.Module, batch: dict[str, list | torch.LongTensor] + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + model_kwargs = ( + { + "labels": batch["completion_labels"], + "decoder_input_ids": batch.get("completion_decoder_input_ids"), + } + if self.is_encoder_decoder + else {} + ) + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + outputs = model( + batch["completion_input_ids"], + attention_mask=batch["completion_attention_mask"], + **model_kwargs, + ) + completion_logits = outputs.logits + + completion_logps = self.get_batch_logps( + completion_logits, + batch["completion_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if completion_logps.shape[0] != len(batch["label"]): + raise ValueError( + "There is a mismatch between the number of examples in this batch and the number of " + "examples for which an output sequence was predicted." + ) + + chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True] + rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False] + + chosen_logps = completion_logps[chosen_idx, ...] + rejected_logps = completion_logps[rejected_idx, ...] + + chosen_logits = completion_logits[chosen_idx, ...] + rejected_logits = completion_logits[rejected_idx, ...] + + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss) + else: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) + + def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor: + prob_desirable = self._get_chosen_prob(rejected_embeddings) + min_ratio = self.args.min_density_ratio + max_ratio = self.args.max_density_ratio + + weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio) + + return weight + + def bco_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + chosen_embeddings: torch.FloatTensor | None, + rejected_embeddings: torch.FloatTensor | None, + do_train: bool = True, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the BCO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: + Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,) + policy_rejected_logps: + Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,) + reference_chosen_logps: + Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,) + reference_rejected_logps: + Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in + batch_size,) + chosen_embeddings: embeddings of desirable prompts + rejected_embeddings: embeddings of undesirable prompts + do_train: whether to update the running delta value. Default is True. + + Returns: + A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta). The losses tensor contains the + BCO loss for each example in the batch. The chosen_rewards and rejected_rewards tensors contain the rewards + for the chosen and rejected responses, respectively. The delta value contains the moving average of all + implicit rewards. + """ + + chosen_logratios = policy_chosen_logps - reference_chosen_logps + chosen_rewards = self.beta * chosen_logratios + + rejected_logratios = policy_rejected_logps - reference_rejected_logps + rejected_rewards = self.beta * rejected_logratios + + if do_train: + self.running.update(torch.cat((chosen_rewards, rejected_rewards), 0).detach()) + delta = torch.as_tensor(self.running.mean, device=chosen_rewards.device) + + chosen_losses = -F.logsigmoid(chosen_rewards - delta) + rejected_losses = -F.logsigmoid(-(rejected_rewards - delta)) + + if self.match_underlying_distribution: + chosen_weight = torch.ones_like(chosen_losses) + rejected_weight = self._get_udm_weight(rejected_embeddings) + + losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0) + else: + losses = torch.cat((chosen_losses, rejected_losses), dim=0) + + return losses, chosen_rewards, rejected_rewards, delta + + def get_batch_loss_metrics( + self, + model, + batch: dict[str, list | torch.LongTensor], + do_train: bool = True, + ): + """Compute the BCO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + + forward_output = self.forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + ) = forward_output[:4] + if self.aux_loss_enabled: + aux_loss = forward_output[4] + + # if reference_logps in batch use them, otherwise use the reference model + if "reference_logps" in batch: + chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True] + rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False] + + reference_chosen_logps = batch["reference_logps"][chosen_idx, ...] + reference_rejected_logps = batch["reference_logps"][rejected_idx, ...] + else: + with torch.no_grad(): + if self.ref_model is None: + with self.null_ref_context(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.forward(self.model, batch)[:4] + else: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.forward(self.ref_model, batch)[:4] + + chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch) + + losses, chosen_rewards, rejected_rewards, delta = self.bco_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + chosen_embeddings, + rejected_embeddings, + do_train=do_train, + ) + metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item() + + num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) + num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) + + all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item() + all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item() + + if all_num_chosen > 0: + metrics["rewards/chosen_sum"] = ( + self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item() + ) + metrics["logps/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item() + ) + metrics["logits/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item() + ) + metrics["count/chosen"] = all_num_chosen + + if all_num_rejected > 0: + metrics["rewards/rejected_sum"] = ( + self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item() + ) + metrics["logps/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item() + ) + metrics["logits/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item() + ) + metrics["count/rejected"] = all_num_rejected + + loss = losses.nanmean() + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss + + return loss, metrics + + def compute_loss( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + return_outputs=False, + num_items_in_batch=None, + ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: + compute_loss_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with compute_loss_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs) + + # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: + loss = loss.to(self.args.device) + # force log the metrics + if self.accelerator.is_main_process: + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def _get_train_sampler(self, dataset: Dataset | None = None) -> torch.utils.data.Sampler | None: + if dataset is None: + dataset = self.train_dataset + if dataset is None or not has_length(dataset): + return None + return SequentialSampler(dataset) + + def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + with generate_context_manager: + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + # if reference_output in batch use that otherwise use the reference model + if "reference_output" in batch: + reference_output = batch["reference_output"] + else: + if self.ref_model is None: + with self.null_ref_context(): + reference_output = self.model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + else: + reference_output = self.ref_model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id) + reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True) + + return policy_output_decoded, reference_output_decoded + + def prediction_step( + self, + model: PreTrainedModel | nn.Module, + inputs: dict[str, torch.Tensor | Any], + prediction_loss_only: bool, + ignore_keys: list[str] | None = None, + ): + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + with torch.no_grad(), prediction_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, do_train=False) + + # force log the metrics + if self.accelerator.is_main_process: + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = {} + if "logits/chosen_sum" in metrics: + logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"] + if "logits/rejected_sum" in metrics: + logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"] + logits = [v for k, v in logits_dict.items() if k not in ignore_keys] + logits = torch.tensor(logits, device=self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: bool | None = None, + ignore_keys: list[str] | None = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. Prediction/evaluation loop, shared by + `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + target_labels = torch.tensor(random_batch["label"], dtype=torch.bool, device=self.accelerator.device) + target_indices = torch.where(~target_labels)[0] + target_batch = { + "prompt_input_ids": random_batch["prompt_input_ids"][target_indices], + "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indices], + "prompt": itemgetter(*target_indices)(random_batch["prompt"]), + } + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) + + table = pd.DataFrame( + columns=["Prompt", "Policy", "Ref Model"], + data=[ + [prompt, pol[len(prompt) :], ref[len(prompt) :]] + for prompt, pol, ref in zip( + target_batch["prompt"], policy_output_decoded, ref_output_decoded, strict=True + ) + ], + ) + if "wandb" in self.args.report_to: + wandb.log({"game_log": wandb.Table(data=table)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="game_log.csv", + table=table, + ) + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`dict[str, float]`): + The values to log. + start_time (`float`, *optional*): + Start time of the training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # train metrics should have no prefix, eval should have 'eval_' + prefix = "eval_" if train_eval == "eval" else "" + # accumulate average metrics from sums and lengths + for split in ["chosen", "rejected"]: + if f"count/{split}" in self._stored_metrics[train_eval]: + count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item() + for metric in ["rewards", "logps", "logits"]: + logs[f"{prefix}{metric}/{split}"] = ( + torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item() + / count_sum + ) + # delete obsolete metric + del self._stored_metrics[train_eval][f"{metric}/{split}_sum"] + del self._stored_metrics[train_eval][f"count/{split}"] + # calculate reward margin + if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: + logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/ICL/RL/trl_source/trl/experimental/bema_for_ref_model/callback.py b/ICL/RL/trl_source/trl/experimental/bema_for_ref_model/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..d030be9fb51c428fe82e6391c2ec64a05345d5d5 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/bema_for_ref_model/callback.py @@ -0,0 +1,221 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +from transformers import PreTrainedModel, TrainerControl, TrainerState, TrainingArguments +from transformers.trainer_callback import CallbackHandler + +from ...trainer.callbacks import BEMACallback as _BEMACallback + + +# Logger for module-level logging +logger = logging.getLogger(__name__) + + +class CallbackHandlerWithRefModel(CallbackHandler): + """ + A [`~transformers.CallbackHandler`] that supports passing a reference model to callbacks. + """ + + def __init__(self, callbacks, model, ref_model, processing_class, optimizer, lr_scheduler): + super().__init__(callbacks, model, processing_class, optimizer, lr_scheduler) + self.ref_model = ref_model + + # Copied from CallbackHandler.call_event with the addition of `ref_model` to the callback call. + def call_event(self, event, args, state, control, **kwargs): + for callback in self.callbacks: + result = getattr(callback, event)( + args, + state, + control, + model=self.model, + ref_model=self.ref_model, # <- Added ref_model to the callback call + processing_class=self.processing_class, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler, + train_dataloader=self.train_dataloader, + eval_dataloader=self.eval_dataloader, + **kwargs, + ) + # A Callback can skip the return of `control` if it doesn't change it. + if result is not None: + control = result + return control + + +class BEMACallback(_BEMACallback): + # docstyle-ignore + r""" + A [`~transformers.TrainerCallback`] that implements [BEMA](https://huggingface.co/papers/2508.00180) + (Bias-Corrected Exponential Moving Average) by [Adam Block](https://huggingface.co/abblock) and [Cyril + Zhang](https://huggingface.co/cyrilzhang). Code from https://github.com/abblock/bema under MIT license. + + BEMA computes model weights that scale like: + + $$ + \theta_t' = \alpha_t \cdot (\theta_t - \theta_0) + \text{EMA}_t + $$ + + where \\( \theta_t \\) is the current model weights, \\( \theta_0 \\) is a snapshot of the model weights at the + first `update_after` step, \\( \text{EMA}_t \\) is the exponential moving average of the model weights, and + \\( \alpha_t \\) is a scaling factor that decays with the number of steps \\( t \\) as + + $$ + \alpha_t = (\rho + \gamma \cdot t)^{-\eta}. + $$ + + The EMA is computed as: + + $$ + \text{EMA}_t = (1 - \beta_t) \cdot \text{EMA}_{t-1} + \beta_t \cdot \theta_t + $$ + + where \\( \beta_t \\) is a decay factor that decays with the number of steps \\( t \\) as + + $$ + \beta_t = (\rho + \gamma \cdot t)^{-\kappa}. + $$ + + Args: + update_freq (`int`, *optional*, defaults to `400`): + Update the BEMA weights every X steps. Denoted this as \\( \phi \\) in the paper. + ema_power (`float`, *optional*, defaults to `0.5`): + Power for the EMA decay factor. Denoted \\( \kappa \\) in the paper. To disable EMA, set this to `0.0`. + bias_power (`float`, *optional*, defaults to `0.2`): + Power for the BEMA scaling factor. Denoted \\( \eta \\) in the paper. To disable BEMA, set this to `0.0`. + lag (`int`, *optional*, defaults to `10`): + Initial offset in the weight decay schedule that controls early-stage smoothness by acting as a virtual + starting age for the updates. Denoted as \\( \rho \\) in the paper. + update_after (`int`, *optional*, defaults to `0`): + Burn-in time before starting to update the BEMA weights. Denoted \\( \tau \\) in the paper. + multiplier (`float`, *optional*, defaults to `1.0`): + Initial value for the EMA decay factor. Denoted as \\( \gamma \\) in the paper. + min_ema_multiplier (`float`, *optional*, defaults to `0.0`): + Minimum value for the EMA decay factor. + device (`str`, *optional*, defaults to `"cpu"`): + Device to use for the BEMA buffers, e.g. `"cpu"` or `"cuda"`. Note that in most cases, this device SHOULD + BE DIFFERENT from the device used for training in order to avoid OOM. + update_ref_model (`bool`, *optional*, defaults to `False`): + Whether to update the reference model with BEMA weights. This creates a lagged, smoothed version of the + main model as the reference model. + ref_model_update_freq (`int`, *optional*, defaults to `400`): + Update the reference model with BEMA weights every this many steps. + ref_model_update_after (`int`, *optional*, defaults to `0`): + Number of steps to wait before starting to update the reference model. + + Example: + + ```python + from trl import BEMACallback + + trainer = Trainer(..., callbacks=[BEMACallback()]) + ``` + """ + + def __init__( + self, + update_freq: int = 400, + ema_power: float = 0.5, + bias_power: float = 0.2, + lag: int = 10, + update_after: int = 0, + multiplier: float = 1.0, + min_ema_multiplier: float = 0.0, + device: str = "cpu", + update_ref_model: bool = False, + ref_model_update_freq: int = 400, + ref_model_update_after: int = 0, + ): + super().__init__( + update_freq, + ema_power, + bias_power, + lag, + update_after, + multiplier, + min_ema_multiplier, + device, + ) + # Reference model update parameters + self.update_ref_model = update_ref_model + self.ref_model_update_freq = ref_model_update_freq + self.ref_model_update_after = ref_model_update_after + + @torch.no_grad() + def on_step_end( + self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs + ): + super().on_step_end(args, state, control, model, **kwargs) + + step = state.global_step + # Update reference model if enabled + if ( + self.update_ref_model + and step >= self.ref_model_update_after + and (step - self.ref_model_update_after) % self.ref_model_update_freq == 0 + ): + if "ref_model" not in kwargs: + raise ValueError("'ref_model' not found in kwargs.") + + ref_model = kwargs["ref_model"] + + # Get the current BEMA state dict + bema_state_dict = self.running_model.state_dict() + + # Handle the case where ref_model is None (PEFT case) + if ref_model is None: + # In PEFT case, ref_model is None and we need to update the base model of the main model + main_model = self._unwrap_model(model) + if hasattr(main_model, "get_base_model"): + # This is a PEFT model, update the base model + base_model = main_model.get_base_model() + self._update_model_with_bema_weights(base_model, bema_state_dict, is_peft_base=True) + else: + # Regular model, update directly + self._update_model_with_bema_weights(main_model, bema_state_dict, is_peft_base=False) + else: + # ref_model is provided, unwrap it and update + ref_model = self._unwrap_model(ref_model) + if hasattr(ref_model, "get_base_model"): + # This is a PEFT model, update the base model + base_model = ref_model.get_base_model() + self._update_model_with_bema_weights(base_model, bema_state_dict, is_peft_base=True) + else: + # Regular model, update directly + self._update_model_with_bema_weights(ref_model, bema_state_dict, is_peft_base=False) + + logger.info("BEMACallback: Updated reference model with BEMA weights") + + def _update_model_with_bema_weights(self, model, bema_state_dict, is_peft_base=False): + """Helper method to update a model with BEMA weights, handling PEFT and distributed scenarios.""" + if is_peft_base: + # For PEFT base models, filter out adapter parameters + filtered_state_dict = {} + for key, value in bema_state_dict.items(): + # Skip adapter parameters + if not key.startswith("lora_") and not key.startswith("adapter_"): + # Remove 'base_model.' prefix if it exists + if key.startswith("base_model."): + base_key = key[len("base_model.") :] + else: + base_key = key + filtered_state_dict[base_key] = value + + # Update the base model + model.load_state_dict(filtered_state_dict, strict=False) + else: + # Regular model, update directly + model.load_state_dict(bema_state_dict, strict=False) diff --git a/ICL/RL/trl_source/trl/experimental/gfpo/__init__.py b/ICL/RL/trl_source/trl/experimental/gfpo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f91462a911daaff39038ef4cb53f18ffc845bea8 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/gfpo/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .gfpo_config import GFPOConfig +from .gfpo_trainer import GFPOTrainer diff --git a/ICL/RL/trl_source/trl/experimental/gfpo/gfpo_trainer.py b/ICL/RL/trl_source/trl/experimental/gfpo/gfpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..eb97d7d6c97cc5e59f39d2b010be86d2c8cc64db --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/gfpo/gfpo_trainer.py @@ -0,0 +1,397 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections.abc import Callable +from typing import Any + +import torch +from accelerate.utils import gather_object + +from ...data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages +from ...models.utils import disable_gradient_checkpointing +from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer +from ...trainer.utils import nanmax, nanmin, nanstd, pad + + +logger = logging.getLogger(__name__) + +GroupFilterFunc = Callable[[list[list[Any]], list[list[Any]]], list[list[float]]] + + +class GFPOTrainer(_GRPOTrainer): + def __init__( + self, + model, + reward_funcs, + args=None, + train_dataset=None, + eval_dataset=None, + processing_class=None, + reward_processing_classes=None, + group_filter_func=None, + callbacks=None, + optimizers=(None, None), + peft_config=None, + ): + super().__init__( + model=model, + reward_funcs=reward_funcs, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + reward_processing_classes=reward_processing_classes, + callbacks=callbacks, + optimizers=optimizers, + peft_config=peft_config, + ) + self.group_filter_func = group_filter_func + self.num_remains_in_group = args.num_remains_in_group + if self.group_filter_func is None and self.num_remains_in_group is not None: + raise ValueError( + f"Group filter function must not be None when num_remains_in_group ({self.num_remains_in_group}) is given." + ) + if self.group_filter_func is not None and self.num_remains_in_group is None: + logger.warning("Group filter function is not activated since num_remains_in_group is not set") + + def _generate_and_score_completions(self, inputs): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + # Transformers requires at least one image in the batch, otherwise it throws an error + if images is not None and all(img_list == [] for img_list in images): + images = None + + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What color is the sky?"}]}] + if images is not None: + prompts = [ + prepare_multimodal_messages(prompt, image_list) + for prompt, image_list in zip(prompts, images, strict=True) + ] + + prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = ( + self._generate(prompts) + ) + + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + if sampling_per_token_logps_list is not None: + sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] + sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") + else: + sampling_per_token_logps = None + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() + + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + num_images = [len(img_list) for img_list in images] if images is not None else None + + # Get forward_kwargs for models with multimodal inputs + if images is not None: + prompts_text = [ + apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] + prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + + # If token_type_ids are used, extend them with zeros for the completion part + if "token_type_ids" in forward_kwargs: + token_type_ids = forward_kwargs["token_type_ids"] + forward_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 + ) + + # When gradient checkpointing is enabled with use_reentrant=True (non default), calling the model inside a + # torch.no_grad() block triggers a harmless PyTorch warning ("None of the inputs have requires_grad=True"). + # Temporarily disable checkpointing to avoid this warning during inference. + with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): + # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of + # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the + # samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps + # for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set + # old_per_token_logps to None. + # When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the + # distribution mismatch between vLLM and the training model can be large and harm the training. + generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency + if self.args.gradient_accumulation_steps % generate_every != 0 or ( + self.use_vllm and self.vllm_importance_sampling_correction + ): + old_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + old_per_token_logps = None + + # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch + if self.use_vllm and self.vllm_importance_sampling_correction: + importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) + importance_sampling_ratio = torch.clamp( + importance_sampling_ratio, max=self.vllm_importance_sampling_cap + ) + + # Compute the per-token log probabilities for the reference model + if self.beta != 0.0: + if self.ref_model is not None: + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size=batch_size, + num_images=num_images, + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes + ) + else: + ref_per_token_logps = None + + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + if is_conversational(inputs[0]): + completions = [] + for prompt, completion in zip(prompts, completions_text, strict=True): + bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + if isinstance(bootstrap, list): # for VLM, the format might be [{"type": "text", "text": "..."}] + assert len(bootstrap) == 1 and bootstrap[0]["type"] == "text" + bootstrap = bootstrap[0]["text"] + completions.append([{"role": "assistant", "content": bootstrap + completion}]) + else: + completions = completions_text + + # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is + # important because rewards will be normalized per group, and completions are distributed. We will later slice + # rewards_per_func to extract each process's subset. + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) + + # Apply weights to each reward function's output and sum + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + num_in_group = self.num_generations + num_inputs_in_device = len(prompts) + + if self.num_remains_in_group is not None and mode == "train": + num_in_group = self.num_remains_in_group + + all_completions = gather_object(completions) + + group_filter_scores = self.group_filter_func( + group_completions=[ + all_completions[i : i + 1 * self.num_generations] + for i in range(len(all_completions) // self.num_generations) + ], + group_rewards=rewards.view(-1, self.num_generations).tolist(), + ) + group_filter_scores = torch.tensor(group_filter_scores, device=device) + + _, group_local_indices = torch.topk(group_filter_scores, self.num_remains_in_group, dim=-1) + group_row_offsets = torch.arange(0, len(all_completions), self.num_generations, device=device).unsqueeze(1) + group_global_indices = group_row_offsets + group_local_indices + group_global_indices = group_global_indices.flatten() + + rewards = rewards[group_global_indices].contiguous() + rewards_per_func = rewards_per_func[group_global_indices, :].contiguous() + + num_inputs_in_device = int(len(prompts) / self.num_generations * self.num_remains_in_group) + + # Compute grouped-wise rewards + mean_grouped_rewards = rewards.view(-1, num_in_group).mean(dim=1) + + # Normalize the rewards to compute the advantages + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(num_in_group, dim=0) + advantages = rewards - mean_grouped_rewards + + if self.scale_rewards in ["group", "none"]: + # If self.scale_rewards = "none", we'll still log group level std + std_rewards = rewards.view(-1, num_in_group).std(dim=1) + std_rewards = std_rewards.repeat_interleave(num_in_group, dim=0) + elif self.scale_rewards == "batch": + # Compute global std + std_rewards = rewards.std().expand_as(rewards) + else: + raise ValueError( + f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'." + ) + + is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) + if self.scale_rewards != "none": + advantages = advantages / (std_rewards + 1e-4) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * num_inputs_in_device, + (self.accelerator.process_index + 1) * num_inputs_in_device, + ) + all_process_advantages = advantages.clone() # keep the aggregated advantages for logging + advantages = advantages[process_slice] + + if self.num_remains_in_group is not None and mode == "train": + local_input_indices_to_keep = group_global_indices[process_slice] - self.accelerator.process_index * len( + prompts + ) # step is length of prompts + + prompt_ids = prompt_ids[local_input_indices_to_keep].contiguous() + prompt_mask = prompt_mask[local_input_indices_to_keep].contiguous() + completion_ids = completion_ids[local_input_indices_to_keep].contiguous() + completion_mask = completion_mask[local_input_indices_to_keep].contiguous() + attention_mask = attention_mask[local_input_indices_to_keep].contiguous() + completion_lengths = completion_mask.sum(1) + agg_completion_lengths = self.accelerator.gather(completion_lengths) + num_items_in_batch = agg_completion_lengths.sum() + + if sampling_per_token_logps is not None: + sampling_per_token_logps = sampling_per_token_logps[local_input_indices_to_keep].contiguous() + if old_per_token_logps is not None: + old_per_token_logps = old_per_token_logps[local_input_indices_to_keep].contiguous() + if ref_per_token_logps is not None: + ref_per_token_logps = ref_per_token_logps[local_input_indices_to_keep].contiguous() + if self.use_vllm and self.vllm_importance_sampling_correction: + importance_sampling_ratio = importance_sampling_ratio[local_input_indices_to_keep].contiguous() + + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) + std_func_rewards = nanstd(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards) + self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item()) + self._metrics[mode]["reward_std"].append(std_rewards.mean().item()) + self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) + + # Log prompt and completion texts + all_prompts_text = gather_object(prompts_text) + all_completions_text = gather_object(completions_text) + all_images = gather_object(images) if images is not None else None + if self.num_remains_in_group is not None and mode == "train": + group_global_indices_list = group_global_indices.tolist() + all_prompts_text = [all_prompts_text[i] for i in group_global_indices_list] + all_completions_text = [all_completions_text[i] for i in group_global_indices_list] + if images is not None: + all_images = [all_images[i] for i in group_global_indices_list] + + self._logs["prompt"].extend(all_prompts_text) + self._logs["completion"].extend(all_completions_text) + for i, name in enumerate(self.reward_func_names): + self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) + self._logs["advantages"].extend(all_process_advantages.tolist()) + + if images is not None: + self._logs["images"].extend(all_images) + + if self.use_vllm and self.vllm_importance_sampling_correction: + delta = torch.abs(old_per_token_logps - sampling_per_token_logps) + delta = delta[completion_mask.bool()] + mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) + self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( + self.accelerator.gather(mean_delta).mean().item() + ) + self._metrics[mode]["sampling/sampling_logp_difference/max"].append( + self.accelerator.gather(max_delta).max().item() + ) + + flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] + min_importance_sampling_ratio = ( + torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + mean_importance_sampling_ratio = ( + torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + max_importance_sampling_ratio = ( + torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) + ) + self._metrics[mode]["sampling/importance_sampling_ratio/min"].append( + nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append( + self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item() + ) + self._metrics[mode]["sampling/importance_sampling_ratio/max"].append( + nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item() + ) + + output = { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "advantages": advantages, + "num_items_in_batch": num_items_in_batch, + } + if old_per_token_logps is not None: + output["old_per_token_logps"] = old_per_token_logps + if self.use_vllm and self.vllm_importance_sampling_correction: + output["importance_sampling_ratio"] = importance_sampling_ratio + if ref_per_token_logps is not None: + output["ref_per_token_logps"] = ref_per_token_logps + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] + if "token_type_ids" in forward_kwargs: + output["token_type_ids"] = forward_kwargs["token_type_ids"] + if images is not None: + output["num_images"] = num_images + return output diff --git a/ICL/RL/trl_source/trl/experimental/gkd/gkd_trainer.py b/ICL/RL/trl_source/trl/experimental/gkd/gkd_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..94bf520cda5e04054b920f2fc02a7e7c31fd2288 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/gkd/gkd_trainer.py @@ -0,0 +1,456 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import textwrap +from collections.abc import Callable +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import Dataset +from transformers import ( + AutoModelForCausalLM, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_liger_kernel_available, is_peft_available + +from ...models import prepare_deepspeed +from ...models.utils import unwrap_model_for_generation +from ...trainer.sft_trainer import SFTTrainer +from ...trainer.utils import disable_dropout_in_model, empty_cache +from ..utils import DataCollatorForChatML +from .gkd_config import GKDConfig + + +if is_peft_available(): + from peft import PeftConfig + +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss + + +class GKDTrainer(SFTTrainer): + """Trainer for Generalized Knowledge Distillation (GKD) of language models. + + For details on GKD, see the paper: [On-Policy Distillation of Language Models: Learning from Self-Generated + Mistakes](https://huggingface.co/papers/2306.13649). + + Args: + model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `str`, *optional*): + Model to be trained, or the string identifier of the model to be instantiated from a pretrained model. + teacher_model ([`~transformers.PreTrainedModel`] or `torch.nn.Module` or `str`, *optional*): + Teacher model for knowledge distillation, or the string identifier of the model to be instantiated from a + pretrained model. + args ([`experimental.gkd.GKDConfig`], *optional*): + Training arguments. + data_collator ([`~transformers.DataCollator`], *optional*): + Data collator to batch samples from the dataset. It defaults to a + [`experimental.utils.DataCollatorForChatML`] using the `processing_class`. + train_dataset ([`~datasets.Dataset`], *optional*): + Dataset for training. + eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*): + Dataset for evaluation. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): + Class to process the data. + compute_metrics (`Callable`, *optional*): + Function to compute metrics at evaluation. Must take in an [`~transformers.EvalPrediction`] and return a + dictionary string to float. + callbacks (`list` of [`~transformers.TrainerCallback`], *optional*): + Callbacks to use during training. + optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`): + Tuple containing the optimizer and the learning rate scheduler to use for training. + preprocess_logits_for_metrics (`Callable`, *optional*): + Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and + return the logits to be used for metrics computation. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be + wrapped with the specified PEFT adapter. + formatting_func (`Callable`, *optional*): + Function to format the dataset. Must take in an example and return an example. + """ + + _tag_names = ["trl", "gkd"] + _name = "GKD" + _paper = { + "title": "On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes", + "id": "2306.13649", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @inproceedings{agarwal2024on-policy, + title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}}, + author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem}, + year = 2024, + booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=3zKtaqxLhW}, + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module | str | None = None, + teacher_model: PreTrainedModel | nn.Module | str = None, + args: GKDConfig | None = None, + data_collator: DataCollator | None = None, # type: ignore + train_dataset: Dataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + compute_metrics: Callable[[EvalPrediction], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: "PeftConfig | None" = None, + formatting_func: Callable | None = None, + ): + # Ensure Trainer does not drop non-signature columns used by the collator (e.g., "prompts") + args.remove_unused_columns = False + # Respect a user-provided data_collator; otherwise, provide a ChatML collator that + if data_collator is None: + data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length) + + # Ensure SFTTrainer does not pre-process the dataset when using a ChatML collator, + # so that raw conversational fields (e.g., "messages") remain available to the collator. + if args.dataset_kwargs is None: + args.dataset_kwargs = {"skip_prepare_dataset": True} + else: + args.dataset_kwargs["skip_prepare_dataset"] = True + + # Liger fused GKD loss (JSD) + self.use_liger_gkd_loss = False + if args.use_liger_kernel: + self.liger_jsd_loss = LigerFusedLinearJSDLoss( + beta=args.beta, + ignore_index=-100, + temperature=args.temperature, + compiled=False, + ) + self.use_liger_gkd_loss = True + + super().__init__( + model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + peft_config=peft_config, + formatting_func=formatting_func, + ) + + if args.teacher_model_init_kwargs is None: + teacher_model_init_kwargs = {} + elif not isinstance(teacher_model, str): + raise ValueError( + "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated." + ) + else: + teacher_model_init_kwargs = args.teacher_model_init_kwargs + teacher_model_init_kwargs["dtype"] = ( + teacher_model_init_kwargs["dtype"] + if teacher_model_init_kwargs["dtype"] in ["auto", None] + else getattr(torch, teacher_model_init_kwargs["dtype"]) + ) + + if isinstance(teacher_model, str): + teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs) + + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(self.model) + + if self.is_deepspeed_enabled: + self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator) + else: + self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) + + self.lmbda = args.lmbda + self.beta = args.beta + self.temperature = args.temperature + self.seq_kd = args.seq_kd + + generation_kwargs = { + "max_new_tokens": args.max_new_tokens, + "temperature": args.temperature, + "do_sample": True, + "top_k": 0, + "use_cache": False if args.gradient_checkpointing else True, + "pad_token_id": self.processing_class.pad_token_id, + } + self.generation_config = GenerationConfig(**generation_kwargs) + # Keep training-specific generation kwargs to overwrite model's original generation config + self.generation_kwargs = generation_kwargs + # Set custom EOS tokens if they are specified by the model's generation + # config. This is important for models with the Llama 3 chat template, + # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of + # turns or messages. + if ( + hasattr(self.model.generation_config, "eos_token_id") + and self.model.generation_config.eos_token_id is not None + ): + self.generation_config.eos_token_id = self.model.generation_config.eos_token_id + + @staticmethod + def generalized_jsd_loss( + student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean" + ): + """ + Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1) + of https://huggingface.co/papers/2306.13649 for the definition. + + Args: + student_logits: + Tensor of shape (batch_size, sequence_length, vocab_size) + teacher_logits: + Tensor of shape (batch_size, sequence_length, vocab_size) + labels: + Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing + loss + beta: + Interpolation coefficient between 0 and 1 (default: 0.5) + temperature: + Softmax temperature (default: 1.0) + reduction: + Specifies the reduction to apply to the output (default: 'batchmean') + + Returns: + loss: Scalar tensor with the generalized JSD loss + """ + + # Apply temperature scaling + student_logits = student_logits / temperature + teacher_logits = teacher_logits / temperature + + # Compute log probabilities for student and probabilities for teacher + student_log_probs = F.log_softmax(student_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + + if beta == 0: + jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif beta == 1: + jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + # Compute the log of the mixture distribution + # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture + beta = torch.tensor(beta, dtype=student_log_probs.dtype) + mixture_log_probs = torch.logsumexp( + torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]), + dim=0, + ) + + # Compute KL divergences using F.kl_div + # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper. + kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) + kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) + + # Compute the Generalized Jensen-Shannon Divergence + jsd = beta * kl_teacher + (1 - beta) * kl_student + + # Masking + if labels is not None: + mask = labels != -100 + jsd = jsd[mask] + + # Apply reduction + if reduction == "batchmean": + return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / jsd.size(0) + elif reduction == "sum": + return jsd.sum() + elif reduction == "mean": + return jsd.mean() + else: + return jsd + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if self.use_liger_gkd_loss: + # Forward only through the base models (avoid lm_head to save memory) + unwrapped_student = self.accelerator.unwrap_model(model) + if hasattr(unwrapped_student, "get_decoder") and unwrapped_student.get_decoder() is not None: + base_student = unwrapped_student.get_decoder() + else: + base_student = getattr( + unwrapped_student, getattr(unwrapped_student, "base_model_prefix", "model"), unwrapped_student + ) + + student_outputs = base_student( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + use_cache=False, + ) + + self.teacher_model.eval() + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + if hasattr(unwrapped_teacher, "get_decoder") and unwrapped_teacher.get_decoder() is not None: + base_teacher = unwrapped_teacher.get_decoder() + else: + base_teacher = getattr( + unwrapped_teacher, getattr(unwrapped_teacher, "base_model_prefix", "model"), unwrapped_teacher + ) + with torch.no_grad(): + teacher_outputs = base_teacher( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + use_cache=False, + ) + + # hidden states (shifted) + student_hidden = student_outputs.last_hidden_state[:, :-1] + teacher_hidden = teacher_outputs.last_hidden_state[:, :-1] + + # Release full outputs to free memory + del student_outputs, teacher_outputs + + # labels mask and labels (shifted) + labels_mask = inputs["labels"] != -100 + masked_input_ids = torch.where( + labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100) + ) + true_labels = masked_input_ids[:, 1:].contiguous() + + # Release intermediate tensors + del labels_mask, masked_input_ids + + # heads + student_head = unwrapped_student.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + + # liger fused jsd loss + loss = self.liger_jsd_loss( + student_input=student_hidden, + student_weight=student_head.weight, + teacher_input=teacher_hidden, + teacher_weight=teacher_head.weight, + true_labels=true_labels, + student_bias=getattr(student_head, "bias", None), + teacher_bias=getattr(teacher_head, "bias", None), + ) + + # Release hidden states after loss computation + del student_hidden, teacher_hidden, true_labels + else: + # compute student output + student_outputs = model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + ) + + # compute teacher output in eval mode + self.teacher_model.eval() + with torch.no_grad(): + teacher_outputs = self.teacher_model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + ) + + # slice the logits for the generated tokens using the inputs["prompts"] lengths + prompt_lengths = inputs["prompts"].shape[1] + shifted_student_logits = student_outputs.logits[:, prompt_lengths - 1 : -1, :] + shifted_teacher_logits = teacher_outputs.logits[:, prompt_lengths - 1 : -1, :] + shifted_labels = inputs["labels"][:, prompt_lengths:] + + # compute loss + loss = self.generalized_jsd_loss( + student_logits=shifted_student_logits, + teacher_logits=shifted_teacher_logits, + labels=shifted_labels, + beta=self.beta, + ) + + # empty cache + empty_cache() + + # Return loss + return (loss, student_outputs) if return_outputs else loss + + @staticmethod + def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None): + # Generate output with respect to the prompt-only + generated_outputs = model.generate( + input_ids=inputs["prompts"], + attention_mask=inputs.get("prompt_attention_mask", None), + generation_config=generation_config, + return_dict_in_generate=True, + ) + + # Get the generated token IDs + generated_tokens = generated_outputs.sequences + # Calculate new attention mask + new_attention_mask = torch.ones_like(generated_tokens) + new_labels = generated_tokens.clone() + + # If there's pad_token_id, set attention mask to 0 for padding tokens + if pad_token_id is not None: + new_labels[new_labels == pad_token_id] = -100 + new_attention_mask[generated_tokens == pad_token_id] = 0 + + return generated_tokens, new_attention_mask, new_labels + + def training_step( + self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None + ) -> torch.Tensor: + """ + Perform a training step for the Generalized Knowledge Distillation (GKD) model. + + This method implements the on-policy learning approach described in the GKD paper. With probability + `self.lmbda`, it generates new responses using the student model, which are then used for training instead of + the original inputs. + """ + if self.seq_kd: + with ( + unwrap_model_for_generation( + self.teacher_model, + self.accelerator, + generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 + ) as unwrapped_model + ): + new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs( + unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id + ) + inputs["input_ids"] = new_input_ids + inputs["attention_mask"] = new_attention_mask + inputs["labels"] = new_labels + if random.random() <= self.lmbda: + with ( + unwrap_model_for_generation( + model, + self.accelerator, + generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 + ) as unwrapped_model + ): + new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs( + unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id + ) + inputs["input_ids"] = new_input_ids + inputs["attention_mask"] = new_attention_mask + inputs["labels"] = new_labels + + loss = super().training_step(model, inputs, num_items_in_batch) + return loss diff --git a/ICL/RL/trl_source/trl/experimental/gold/gold_trainer.py b/ICL/RL/trl_source/trl/experimental/gold/gold_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ebe74b0d7f22d6637689e0eb105068a28770ce5f --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/gold/gold_trainer.py @@ -0,0 +1,2108 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import textwrap +import warnings +from collections import defaultdict, deque +from collections.abc import Callable +from contextlib import nullcontext +from typing import Any, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from accelerate import PartialState +from accelerate.utils import DistributedType, broadcast_object_list, gather_object, is_peft_model +from datasets import Dataset, IterableDataset +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from transformers import AutoTokenizer, TrainerCallback, TrainerControl, TrainerState, is_bitsandbytes_available +from transformers.data.data_collator import DataCollator +from transformers.feature_extraction_utils import FeatureExtractionMixin +from transformers.generation.configuration_utils import GenerationConfig +from transformers.image_processing_utils import BaseImageProcessor +from transformers.integrations.integration_utils import is_wandb_available +from transformers.modeling_utils import PreTrainedModel +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.trainer_utils import EvalPrediction +from transformers.utils import ( + is_flash_attn_2_available, + is_liger_kernel_available, + is_peft_available, + is_rich_available, +) + +from ...data_utils import is_conversational, maybe_convert_to_chatml, pack_dataset, truncate_dataset +from ...extras.profiling import profiling_decorator +from ...generation.vllm_client import VLLMClient +from ...import_utils import is_vllm_available +from ...models import prepare_deepspeed +from ...models.utils import unwrap_model_for_generation +from ...trainer.sft_trainer import SFTTrainer +from ...trainer.utils import ( + create_model_from_path, + disable_dropout_in_model, + empty_cache, + ensure_master_addr_port, + pad, +) +from ..utils import DataCollatorForChatML +from .gold_config import GOLDConfig + + +if is_peft_available(): + from peft import PeftConfig + +if is_wandb_available(): + import wandb + +if is_vllm_available(): + from vllm import LLM, SamplingParams + from vllm.sampling_params import StructuredOutputsParams + +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss + +if is_rich_available(): + from rich.console import Console + from rich.panel import Panel + from rich.table import Table + from rich.text import Text + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + + +def print_prompt_completions_sample_uld( + prompts: list[str], + completions: list[str], + step: int, + num_samples: int = None, +) -> None: + """ + Print out a sample of model completions to the console with multiple reward metrics. + + This function creates a nicely formatted table showing prompt-completion pairs, useful for monitoring model outputs + during training. It requires the `rich` library to be installed. + + Args: + prompts (`list[str]`): + List of prompts. + completions (`list[str]`): + List of completions corresponding to the prompts. + rewards (`dict[str, list[float]]`): + Dictionary where keys are reward names and values are lists of rewards. + advantages (`list[float]`): + List of advantages corresponding to the prompts and completions. + step (`int`): + Current training step number, used in the output title. + num_samples (`int` or `None`, *optional*, defaults to `None`): + Number of random samples to display. If `None` (default), all items will be displayed. + + Example: + ```python + >>> from trl.trainer.utils import print_prompt_completions_sample + + >>> prompts = ["The sky is", "The sun is"] + >>> completions = [" blue.", " in the sky."] + >>> rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]} + >>> advantages = [0.987, 0.654] + >>> print_prompt_completions_sample(prompts, completions, rewards, advantages, 42) + ╭──────────────────────────── Step 42 ─────────────────────────────╮ + │ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │ + │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │ + │ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │ + │ │ The sky is │ blue. │ 0.12 │ 0.79 │ 0.99 │ │ + │ ├────────────┼──────────────┼─────────────┼────────┼───────────┤ │ + │ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ 0.65 │ │ + │ └────────────┴──────────────┴─────────────┴────────┴───────────┘ │ + ╰──────────────────────────────────────────────────────────────────╯ + ``` + """ + if not is_rich_available(): + raise ImportError( + "The function `print_prompt_completions_sample` requires the `rich` library. Please install it with " + "`pip install rich`." + ) + console = Console() + table = Table(show_header=True, header_style="bold white", expand=True) + + # Add columns + table.add_column("Prompt", style="bright_yellow") + table.add_column("Completion", style="bright_green") + + # Some basic input validation + if num_samples is not None: + if num_samples >= len(prompts): + num_samples = None + elif num_samples <= 0: + return + + # Subsample data if num_samples is specified + if num_samples is not None: + indices = random.sample(range(len(prompts)), num_samples) + prompts = [prompts[i] for i in indices] + completions = [completions[i] for i in indices] + + for i in range(len(prompts)): + table.add_row(Text(prompts[i]), Text(completions[i])) + table.add_section() # Adds a separator between rows + + panel = Panel(table, expand=False, title=f"Step {step}", border_style="bold white") + console.print(panel) + + +def build_teacher_inputs_from_texts( + tokenizer: PreTrainedTokenizerBase, + prompt_texts: list[str], + completion_texts: list[str], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Tokenize teacher prompts/completions and produce tensors ready for GOLD loss.""" + + pad_token_id = tokenizer.pad_token_id + eos_token_id = tokenizer.eos_token_id + + prompt_token_ids = tokenizer(prompt_texts, add_special_tokens=True)["input_ids"] + completion_token_ids = tokenizer(completion_texts, add_special_tokens=False)["input_ids"] + + sequences: list[torch.Tensor] = [] + attention_masks: list[torch.Tensor] = [] + labels_list: list[torch.Tensor] = [] + prompt_lengths: list[int] = [] + + for prompt_ids, completion_ids in zip(prompt_token_ids, completion_token_ids, strict=True): + # Remove trailing EOS from prompt so completions can extend cleanly + if eos_token_id is not None and prompt_ids and prompt_ids[-1] == eos_token_id: + prompt_ids = prompt_ids[:-1] + + prompt_lengths.append(len(prompt_ids)) + sequence = list(prompt_ids) + sequence.extend(completion_ids) + if eos_token_id is not None: + sequence.append(eos_token_id) + + seq_tensor = torch.tensor(sequence, dtype=torch.long) + sequences.append(seq_tensor) + attention_masks.append(torch.ones_like(seq_tensor)) + + labels = seq_tensor.clone() + labels[: len(prompt_ids)] = -100 + if pad_token_id is not None: + labels[labels == pad_token_id] = -100 + labels_list.append(labels) + + teacher_input_ids = pad( + sequences, + padding_side="right", + padding_value=pad_token_id if pad_token_id is not None else 0, + ) + teacher_attention_mask = pad(attention_masks, padding_side="right", padding_value=0).bool() + teacher_labels = pad(labels_list, padding_side="right", padding_value=-100) + + if eos_token_id is not None: + for row in range(teacher_attention_mask.size(0)): + valid = ( + teacher_input_ids[row] != pad_token_id + if pad_token_id is not None + else teacher_attention_mask[row].bool() + ) + if valid.any(): + last_idx = valid.nonzero(as_tuple=True)[0][-1] + teacher_attention_mask[row, last_idx + 1 :] = False + + teacher_prompt_length = max(prompt_lengths) if prompt_lengths else 0 + + return teacher_input_ids, teacher_labels, teacher_attention_mask, teacher_prompt_length + + +class ULDLoss(nn.Module): + """ + Universal Logit Distillation Loss. + """ + + def __init__(self, config: GOLDConfig, student_tokenizer=None, teacher_tokenizer=None, device=None): + super().__init__() + self.device = device + self.crossentropy_weight = config.uld_crossentropy_weight + self.distillation_weight = config.uld_distillation_weight + self.student_temperature = config.uld_student_temperature + self.teacher_temperature = config.uld_teacher_temperature + self.skip_student_eos = config.uld_skip_student_eos + self.skip_teacher_eos = config.uld_skip_teacher_eos + self.use_extended_uld = config.use_extended_uld + self.ignore_index = -100 + + # Add tokenizers for enhanced alignment + self.student_tokenizer = student_tokenizer + self.teacher_tokenizer = teacher_tokenizer + + # Hybrid ULD configuration + self.use_hybrid_loss = getattr(config, "uld_use_hybrid_loss", False) + self.hybrid_matched_weight = getattr(config, "uld_hybrid_matched_weight", None) + self.hybrid_unmatched_weight = getattr(config, "uld_hybrid_unmatched_weight", None) + self.beta = getattr(config, "beta", 1.0) # For JSD loss in hybrid matched tokens + + # Initialize vocabulary mapping for hybrid loss + self._vocab_mapping = None + self._teacher_matched_ids = None + self._student_matched_ids = None + if self.use_hybrid_loss and student_tokenizer is not None and teacher_tokenizer is not None: + self._initialize_vocabulary_mapping() + + def __call__( + self, student_logits, teacher_logits, student_labels, teacher_labels, student_input_ids, teacher_input_ids + ): + """ + Compute ULD loss with GKD trainer interface. + + Args: + student_logits: Student model logits [batch_size, seq_len, vocab_size] + teacher_logits: Teacher model logits [batch_size, seq_len, vocab_size] + student_labels: Student target labels [batch_size, seq_len] + teacher_labels: Teacher target labels [batch_size, seq_len] + student_input_ids: Student input token IDs [batch_size, seq_len] + teacher_input_ids: Teacher input token IDs [batch_size, seq_len] + + Returns: + Total loss (cross-entropy + distillation) + """ + # Compute cross-entropy loss for student + if self.crossentropy_weight > 0: + shift_logits = student_logits[..., :-1, :].contiguous() + shift_labels = student_labels[..., 1:].contiguous() + loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + crossentropy_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + crossentropy_loss = self.crossentropy_weight * crossentropy_loss + else: + crossentropy_loss = 0.0 + + # Compute distillation loss using ULD approximation + distillation_loss = self._compute_distillation_loss( + student_logits, teacher_logits, student_labels, teacher_labels, student_input_ids, teacher_input_ids + ) + + return crossentropy_loss + distillation_loss + + def _initialize_vocabulary_mapping(self): + """Initialize vocabulary mapping for hybrid ULD loss.""" + # Computing vocabulary mapping for hybrid ULD + + student_vocab = self.student_tokenizer.get_vocab() + teacher_vocab = self.teacher_tokenizer.get_vocab() + + # Create reverse mapping for student + student_token_to_id = dict(student_vocab.items()) + + vocab_mapping = {} + teacher_matched_ids = set() + student_matched_ids = set() + + for token_str, teacher_id in teacher_vocab.items(): + if token_str in student_token_to_id: + student_id = student_token_to_id[token_str] + vocab_mapping[teacher_id] = student_id + teacher_matched_ids.add(teacher_id) + student_matched_ids.add(student_id) + + self._vocab_mapping = vocab_mapping + self._teacher_matched_ids = teacher_matched_ids + self._student_matched_ids = student_matched_ids + + max_matched_teacher_id = max(self._vocab_mapping.keys()) + self.mapping_tensor = torch.full((max_matched_teacher_id + 1,), -1, dtype=torch.long) # -1 for unmapped ids + for k, v in self._vocab_mapping.items(): + self.mapping_tensor[k] = v + if self.device is not None: + self.mapping_tensor = self.mapping_tensor.to(self.device) + + def _compute_distillation_loss( + self, student_logits, teacher_logits, student_labels, teacher_labels, student_input_ids, teacher_input_ids + ): + """ + Compute the Universal Logit Distillation loss with token mapping. + + This version uses actual input_ids for accurate token mapping and multiplies probabilities for split tokens. + Both student_input_ids and teacher_input_ids are required for optimal alignment. + """ + # Get answer regions (same as original) + student_answer_index, student_answer_size = self._get_start_and_size_answers(student_labels) + teacher_answer_index, teacher_answer_size = self._get_start_and_size_answers(teacher_labels) + + if self.skip_student_eos: + student_answer_size = [size - 1 for size in student_answer_size] + if self.skip_teacher_eos: + teacher_answer_size = [size - 1 for size in teacher_answer_size] + + # Handle edge case where all answer sizes are 0 + if ( + not student_answer_size + or not teacher_answer_size + or max(max(student_answer_size), max(teacher_answer_size)) <= 0 + ): + return torch.zeros(1, device=student_logits.device, requires_grad=True) * student_logits.sum() * 1e-8 + + batch_size = student_logits.size(0) + distillation_losses = [] + + for i in range(batch_size): + # Get answer regions for this batch item + student_start = student_answer_index[i] + student_size = student_answer_size[i] + teacher_start = teacher_answer_index[i] + teacher_size = teacher_answer_size[i] + + if student_size <= 0 or teacher_size <= 0: + loss_i = student_logits[i].sum() * 0.0 + distillation_losses.append(loss_i) + continue + + # Extract answer logits + student_answer_logits = student_logits[i, student_start : student_start + student_size] + teacher_answer_logits = teacher_logits[i, teacher_start : teacher_start + teacher_size] + + # Convert to probabilities + student_probs = F.softmax(student_answer_logits / self.student_temperature, dim=-1) + teacher_probs = F.softmax(teacher_answer_logits / self.teacher_temperature, dim=-1) + + # Get token IDs for mapping (always use actual input_ids) + student_token_ids = student_input_ids[i, student_start : student_start + student_size].tolist() + teacher_token_ids = teacher_input_ids[i, teacher_start : teacher_start + teacher_size].tolist() + + if self.use_extended_uld: + # Build alignment groups directly from token ids using greedy text matching + student_alignment_groups, teacher_alignment_groups = self._build_alignment_groups_from_ids( + student_token_ids, teacher_token_ids + ) + + # Merge student probabilities using student alignment groups + # Pass student_token_ids to enable corrected conditional probability merging + student_aligned = self._merge_probabilities_with_alignment_groups( + student_probs, student_alignment_groups, student_token_ids + ) + + # Merge teacher probabilities using teacher alignment groups + # Pass teacher_token_ids to enable corrected conditional probability merging + teacher_aligned = self._merge_probabilities_with_alignment_groups( + teacher_probs, teacher_alignment_groups, teacher_token_ids + ) + else: + min_length = min(len(student_token_ids), len(teacher_token_ids)) + student_aligned = student_probs[:min_length, :] + teacher_aligned = teacher_probs[:min_length, :] + + # Apply ULD loss computation + if self.use_hybrid_loss and self._vocab_mapping is not None: + # Use hybrid approach: direct comparison for matched tokens, sorting for unmatched + aligned_loss = self._compute_hybrid_uld_loss(student_aligned, teacher_aligned) + else: + # Original approach: sort all probabilities + student_sorted = student_aligned.sort(dim=-1, descending=True).values + teacher_sorted = teacher_aligned.sort(dim=-1, descending=True).values + + # Pad vocabularies to same size + student_vocab_size = student_sorted.size(-1) + teacher_vocab_size = teacher_sorted.size(-1) + max_vocab_size = max(student_vocab_size, teacher_vocab_size) + + if student_vocab_size < max_vocab_size: + student_sorted = F.pad(student_sorted, (0, max_vocab_size - student_vocab_size)) + if teacher_vocab_size < max_vocab_size: + teacher_sorted = F.pad(teacher_sorted, (0, max_vocab_size - teacher_vocab_size)) + + # Compute L1 distance (ULD approach) + aligned_loss = F.l1_loss(student_sorted, teacher_sorted, reduction="sum") + aligned_loss /= student_aligned.size(0) # Normalize by sequence length + distillation_losses.append(aligned_loss) + + distillation_loss = torch.stack(distillation_losses).mean() + return self.distillation_weight * distillation_loss + + def _build_alignment_groups_from_ids(self, student_token_ids, teacher_token_ids): + """ + Build alignment groups using a greedy substring-equality algorithm on decoded token pieces. + + Args: + student_token_ids: List[int] + teacher_token_ids: List[int] + + Returns: + Tuple[List[List[int]], List[List[int]]]: student and teacher alignment groups + """ + + def to_canonical_pieces(tok, ids): + pieces = [] + prev = "" + for k in range(len(ids)): + # IMPORTANT: Do NOT skip special tokens - we need to align them too + cur = tok.decode(ids[: k + 1], skip_special_tokens=False, clean_up_tokenization_spaces=False) + # Extract the incremental addition (may include spaces/ZWJ/etc.) + pieces.append(cur[len(prev) :]) + prev = cur + return pieces + + s_pieces = to_canonical_pieces(self.student_tokenizer, student_token_ids) + t_pieces = to_canonical_pieces(self.teacher_tokenizer, teacher_token_ids) + + i = j = 0 + s_buf = t_buf = "" + s_group = [] + t_group = [] + s_groups = [] + t_groups = [] + + def flush(): + if s_group and t_group: + s_groups.append(s_group.copy()) + t_groups.append(t_group.copy()) + + # Greedily accumulate pieces until substrings match, then flush + while i < len(s_pieces) or j < len(t_pieces): + if s_buf == t_buf and s_buf != "": + flush() + s_buf = t_buf = "" + s_group = [] + t_group = [] + continue + + if s_buf == "" and i < len(s_pieces): + s_buf += s_pieces[i] + s_group.append(i) + i += 1 + continue + if t_buf == "" and j < len(t_pieces): + t_buf += t_pieces[j] + t_group.append(j) + j += 1 + continue + + if len(s_buf) <= len(t_buf): + if i < len(s_pieces): + s_buf += s_pieces[i] + s_group.append(i) + i += 1 + elif j < len(t_pieces): + t_buf += t_pieces[j] + t_group.append(j) + j += 1 + else: + if j < len(t_pieces): + t_buf += t_pieces[j] + t_group.append(j) + j += 1 + elif i < len(s_pieces): + s_buf += s_pieces[i] + s_group.append(i) + i += 1 + + # Flush any remainder if both sides accumulated something + if s_buf == t_buf and s_group and t_group: + flush() + elif s_group or t_group: + # Handle remaining unmatched tokens by forcing a flush + # This ensures both sides have the same number of alignment groups + if s_group or t_group: + # Ensure both groups have content (even if empty list) + if not s_group: + s_group = [] + if not t_group: + t_group = [] + # Force flush even if buffers don't match + if s_group or t_group: + s_groups.append(s_group.copy() if s_group else []) + t_groups.append(t_group.copy() if t_group else []) + + return s_groups, t_groups + + def _merge_probabilities_with_alignment_groups(self, probs, alignment_groups, token_ids=None): + """ + Merge probabilities based on alignment groups with corrected conditional probability handling. + + For a group merging tokens at positions [i, i+1, ..., i+k], we compute: + P_merged(y | x) = P(y | x) × P(token_{i+1} | token_i, x) × ... × P(token_{i+k} | ..., x) + + Where: + - P(y | x) is the marginal probability distribution over all vocabulary tokens at position i + - token_{i+1}, ..., token_{i+k} are the ACTUAL tokens that were generated + - The conditional probabilities P(token_j | ..., x) are extracted as SCALARS + - y ranges over all vocabulary tokens at position i + + This ensures the probability of the actual generated sequence is correct (by the chain rule), while introducing + a known bias for counterfactual tokens (since we don't have P(token_{i+k} | y, x) for y != token_i). The merged + distribution is unnormalized but preserves correct relative probabilities. + + Args: + probs: Probability tensor [seq_len, vocab_size] + alignment_groups: List of alignment groups (each group is a list of positions to merge) + token_ids: Actual token IDs that were generated [seq_len]. REQUIRED when any group has + len(group) > 1. If None when multi-token groups exist, raises ValueError. + + Returns: + Merged probability tensor [num_groups, vocab_size] + + Raises: + ValueError: If token_ids is None when merging multi-token groups + """ + if not alignment_groups: + return probs + + # Create aligned tensor + vocab_size = probs.size(-1) + target_len = len(alignment_groups) + aligned_probs = torch.zeros(target_len, vocab_size, device=probs.device, dtype=probs.dtype) + eps = 1e-8 + + # Process each alignment group + for group_idx, group in enumerate(alignment_groups): + # Handle probability merging + if len(group) > 1: + # Multiple tokens map to this group - merge using corrected conditional probability approach + if token_ids is None: + raise ValueError( + "token_ids must be provided when merging multi-token groups. " + "This is required for mathematically correct probability merging." + ) + + # Start with the marginal distribution at the first position + first_pos = group[0] + marginal_probs = probs[first_pos] # P(y | x₀) for all y + + # For each subsequent token in the group, extract the SCALAR conditional probability + # of the actual token that was generated, and multiply + conditional_prob_product = 1.0 + for idx in group[1:]: + # Get the actual token ID that was generated at this position + actual_token_id = token_ids[idx] + # Extract its probability (scalar) + token_prob = probs[idx, actual_token_id].clamp_min(eps) + conditional_prob_product *= token_prob + + # Merge: multiply the scalar conditional prob product with the entire marginal distribution + # This gives: P(y | x_0) × P(token_1 | token_0, x) × ... × P(token_k | ..., x) + # Note: This is unnormalized, but preserves the correct joint probability for the actual sequence + merged_probs = marginal_probs * conditional_prob_product + aligned_probs[group_idx] = merged_probs + + elif len(group) == 1: + aligned_probs[group_idx] = probs[group[0]] + else: + # No tokens map to this group + aligned_probs[group_idx] = torch.zeros_like(probs[0]) + + return aligned_probs + + def _compute_hybrid_uld_loss(self, student_aligned, teacher_aligned): + """ + Compute hybrid ULD loss on aligned probability distributions. This method: + 1. Directly compares probabilities for tokens with matching vocabulary entries + 2. Uses sorting approach only for tokens with different vocabulary entries + + Args: + student_aligned: Aligned student probabilities [seq_len, student_vocab_size] + teacher_aligned: Aligned teacher probabilities [seq_len, teacher_vocab_size] + Returns: + Combined hybrid loss + """ + device = student_aligned.device + # seq_len = student_aligned.size(0) # Unused variable + student_vocab_size = student_aligned.size(-1) + teacher_vocab_size = teacher_aligned.size(-1) + + # Convert sets to sorted tensors for indexing + if self._teacher_matched_ids: + teacher_matched_indices = torch.tensor(sorted(self._teacher_matched_ids), dtype=torch.long, device=device) + student_matched_indices = self.mapping_tensor[teacher_matched_indices] + else: + teacher_matched_indices = torch.tensor([], dtype=torch.long, device=device) + student_matched_indices = torch.tensor([], dtype=torch.long, device=device) + + # Create masks for unmatched tokens + teacher_matched_mask = torch.zeros(teacher_vocab_size, dtype=torch.bool, device=device) + student_matched_mask = torch.zeros(student_vocab_size, dtype=torch.bool, device=device) + + if len(teacher_matched_indices) > 0: + teacher_matched_mask[teacher_matched_indices] = True + student_matched_mask[student_matched_indices] = True + + # 1. JSD loss for matched vocabulary tokens (direct semantic correspondence) + matched_loss = torch.tensor(0.0, device=device) + matched_token_count = 0 + if len(teacher_matched_indices) > 0: + # Extract probabilities for matched tokens + teacher_matched_probs = teacher_aligned[:, teacher_matched_indices] # [seq_len, num_matched] + student_matched_probs = student_aligned[:, student_matched_indices] # [seq_len, num_matched] + matched_token_count = teacher_matched_probs.size(-1) + + # Use JSD loss for semantically aligned tokens + # Convert probabilities back to logits for JSD computation + + # Apply generalized JSD loss to matched tokens + matched_loss = self._compute_jsd_loss_for_matched_tokens(student_matched_probs, teacher_matched_probs) + + # 2. Sorted comparison loss for unmatched vocabulary tokens + teacher_unmatched_mask = ~teacher_matched_mask + student_unmatched_mask = ~student_matched_mask + + teacher_unmatched_probs = teacher_aligned[:, teacher_unmatched_mask] # [seq_len, num_teacher_unmatched] + student_unmatched_probs = student_aligned[:, student_unmatched_mask] # [seq_len, num_student_unmatched] + + unmatched_loss = torch.tensor(0.0, device=device) + if teacher_unmatched_probs.size(-1) > 0 and student_unmatched_probs.size(-1) > 0: + # Sort unmatched probabilities + teacher_unmatched_sorted = teacher_unmatched_probs.sort(dim=-1, descending=True).values + student_unmatched_sorted = student_unmatched_probs.sort(dim=-1, descending=True).values + + # Pad to same size if needed + teacher_unmatched_size = teacher_unmatched_sorted.size(-1) + student_unmatched_size = student_unmatched_sorted.size(-1) + max_unmatched_size = max(teacher_unmatched_size, student_unmatched_size) + + if teacher_unmatched_size < max_unmatched_size: + teacher_unmatched_sorted = F.pad( + teacher_unmatched_sorted, (0, max_unmatched_size - teacher_unmatched_size) + ) + if student_unmatched_size < max_unmatched_size: + student_unmatched_sorted = F.pad( + student_unmatched_sorted, (0, max_unmatched_size - student_unmatched_size) + ) + + # L1 loss on sorted unmatched tokens + unmatched_loss = F.l1_loss(student_unmatched_sorted, teacher_unmatched_sorted, reduction="sum") + unmatched_loss /= student_aligned.size(0) # Normalize by sequence length + + # 3. Combine losses with weights + if self.hybrid_matched_weight is None: + # Use adaptive weighting based on vocabulary overlap + hybrid_matched_weight = matched_token_count / max(1, teacher_vocab_size) + hybrid_unmatched_weight = 1.0 - hybrid_matched_weight + else: + # Use fixed weights provided in config + hybrid_matched_weight = self.hybrid_matched_weight + hybrid_unmatched_weight = self.hybrid_unmatched_weight + + total_loss = hybrid_matched_weight * matched_loss + hybrid_unmatched_weight * unmatched_loss + + # Store matched/unmatched components for logging + self.last_matched_loss = matched_loss + self.last_unmatched_loss = unmatched_loss + + return total_loss + + def _compute_jsd_loss_for_matched_tokens(self, student_logits, teacher_logits): + """ + Compute JSD loss for matched vocabulary tokens. + + Args: + student_logits: Student logits for matched tokens [seq_len, num_matched] + teacher_logits: Teacher logits for matched tokens [seq_len, num_matched] + Returns: + JSD loss for matched tokens + """ + # Reshape to [batch_size * seq_len, vocab_size] format expected by generalized_jsd_loss + batch_seq_len, num_matched = student_logits.shape + + student_logits_reshaped = student_logits.view(-1, num_matched) + teacher_logits_reshaped = teacher_logits.view(-1, num_matched) + + # Use the GOLD generalized JSD loss implementation that accepts probability inputs + jsd_loss = GOLDTrainer.generalized_jsd_loss( + student_logits_reshaped, + teacher_logits_reshaped, + labels=None, # No masking needed for matched tokens + beta=self.beta, # Standard JSD beta + temperature=1.0, # Already applied in main computation + reduction="batchmean", + logits_are_probs=True, + ) + + return jsd_loss + + def _get_start_and_size_answers(self, answer_tensors): + answers_index = [] + answers_size = [] + + for answer in answer_tensors: + answer_mask = answer.ne(self.ignore_index) + if not answer_mask.any(): + answers_index.append(0) + answers_size.append(0) + continue + + valid_indices = answer_mask.nonzero(as_tuple=True)[0] + answers_index.append(int(valid_indices[0].item())) + answers_size.append(int(answer_mask.sum().item())) + return answers_index, answers_size + + +class GOLDVLLMSyncCallback(TrainerCallback): + """Sync the model weights to vLLM after training steps when it's safe to do so.""" + + def __init__(self, trainer): + self.trainer = trainer + + def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): + """Sync weights after training step when DeepSpeed is stable.""" + if ( + self.trainer.use_vllm + and state.global_step != self.trainer._last_vllm_sync_step + and state.global_step % self.trainer.vllm_sync_frequency == 0 + ): + # Check if this is a step where gradients are synchronized + # This happens at the end of gradient accumulation cycles + if hasattr(self.trainer.accelerator, "sync_gradients") and self.trainer.accelerator.sync_gradients: + self.trainer._move_model_to_vllm() + self.trainer._last_vllm_sync_step = state.global_step + + +class GOLDTrainer(SFTTrainer): + _tag_names = ["trl", "gold"] + _name = "GOLD" + _paper = { + "title": "Unlocking On-Policy Distillation for Any Model Family", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @misc{patino2025unlocking, + title = {{Unlocking On-Policy Distillation for Any Model Family}}, + author = {Carlos Miguel Patiño and Kashif Rasul and Quentin Gallouédec and Ben Burtenshaw and Sergio Paniego and Vaibhav Srivastav and Thibaud Frere and Ed Beeching and Lewis Tunstall and Leandro von Werra and Thomas Wolf}, + year = 2025, + url = {https://huggingface.co/spaces/HuggingFaceH4/general-on-policy-logit-distillation}, + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module | str | None = None, + teacher_model: PreTrainedModel | nn.Module | str = None, + args: GOLDConfig | None = None, + data_collator: DataCollator | None = None, # type: ignore + train_dataset: Dataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + compute_metrics: Callable[[EvalPrediction], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: Optional["PeftConfig"] = None, + ): + self.model_name_or_path = model if isinstance(model, str) else model.config._name_or_path + self.model_revision = getattr(args, "student_model_revision", None) + if isinstance(model, str) and self.model_revision is not None: + args.model_init_kwargs = args.model_init_kwargs or {} + args.model_init_kwargs.setdefault("revision", self.model_revision) + + # Respect a user-provided data_collator; otherwise, provide a ChatML collator that + if data_collator is None: + data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length) + + # Liger fused GKD loss (JSD) + self.use_liger_gkd_loss = False + if args.use_liger_kernel: + self.liger_jsd_loss = LigerFusedLinearJSDLoss( + beta=args.beta, + ignore_index=-100, + temperature=args.temperature, + compiled=False, + ) + self.use_liger_gkd_loss = True + + if args.teacher_model_init_kwargs is None: + teacher_model_init_kwargs = {} + elif not isinstance(teacher_model, str): + raise ValueError( + "You passed teacher_model_init_kwargs to the GOLDConfig, but your teacher_model is already instantiated." + ) + else: + teacher_model_init_kwargs = args.teacher_model_init_kwargs + teacher_model_init_kwargs["torch_dtype"] = ( + teacher_model_init_kwargs["torch_dtype"] + if teacher_model_init_kwargs["torch_dtype"] in ["auto", None] + else getattr(torch, teacher_model_init_kwargs["torch_dtype"]) + ) + + if args.use_uld_loss and args.teacher_tokenizer_name_or_path is None: + if isinstance(teacher_model, str): + args.teacher_tokenizer_name_or_path = teacher_model + else: + raise ValueError( + "`teacher_tokenizer_name_or_path` must be set when using ULD loss with a pre-instantiated teacher model." + ) + + if isinstance(teacher_model, str): + init_kwargs = dict(teacher_model_init_kwargs) + if "torch_dtype" in init_kwargs and "dtype" not in init_kwargs: + init_kwargs["dtype"] = init_kwargs.pop("torch_dtype") + teacher_model = create_model_from_path(teacher_model, **init_kwargs) + self.use_uld_loss = args.use_uld_loss + self.teacher_tokenizer = None + if args.use_uld_loss and args.teacher_tokenizer_name_or_path is not None: + self.teacher_tokenizer = AutoTokenizer.from_pretrained(args.teacher_tokenizer_name_or_path) + if not hasattr(self.teacher_tokenizer, "pad_token") or self.teacher_tokenizer.pad_token is None: + self.teacher_tokenizer.pad_token = self.teacher_tokenizer.eos_token + + # Hybrid ULD loss configuration is handled in ULDLoss class + + super().__init__( + model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + peft_config=peft_config, + ) + + if args.disable_dropout: + disable_dropout_in_model(self.model) + if not args.use_uld_loss: + teacher_model.resize_token_embeddings(self.model.config.vocab_size) + + if self.is_deepspeed_enabled: + self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator) + else: + self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) + + self.lmbda = args.lmbda + self.beta = args.beta + self.temperature = args.temperature + self.top_p = args.top_p + self.seq_kd = args.seq_kd + + # Track per-step loss statistics for on/off-policy batches (used in logging) + self._on_policy_loss_total = 0.0 + self._off_policy_loss_total = 0.0 + self._on_policy_step_equiv = 0.0 + self._off_policy_step_equiv = 0.0 + + # Hybrid ULD matched/unmatched accumulators (logged every step when ULD hybrid is used) + self._matched_sum = 0.0 + self._unmatched_sum = 0.0 + self._matched_step_eq = 0.0 + self._unmatched_step_eq = 0.0 + + self.use_transformers_paged = args.use_transformers_paged or False + + self.uld_loss_fn = None + if self.use_uld_loss: + self.uld_loss_fn = ULDLoss( + config=args, + student_tokenizer=processing_class, + teacher_tokenizer=self.teacher_tokenizer, + device=self.accelerator.device, + ) + + generation_kwargs = { + "max_new_tokens": args.max_completion_length, + "temperature": args.temperature, + "top_p": args.top_p, + "do_sample": True, + "top_k": args.top_k, + "pad_token_id": self.processing_class.pad_token_id, + } + self.generation_config = GenerationConfig(**generation_kwargs) + # Keep training-specific generation kwargs to overwrite model's original generation config + self.generation_kwargs = generation_kwargs + if ( + hasattr(self.model.generation_config, "eos_token_id") + and self.model.generation_config.eos_token_id is not None + ): + self.generation_config.eos_token_id = self.model.generation_config.eos_token_id + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + self.log_completions = args.log_completions + self.log_completion_steps = args.log_completions_steps + self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.num_completions_to_print = args.num_completions_to_print + # maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the + # final optimization step. + maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.steps_per_generation + self._textual_logs = { + "prompt": deque(maxlen=maxlen), + "completion": deque(maxlen=maxlen), + "rewards": defaultdict(lambda: deque(maxlen=maxlen)), + "advantages": deque(maxlen=maxlen), + } + + self.use_vllm = args.use_vllm + if self.use_vllm: + if not is_vllm_available(): + raise ImportError( + "vLLM is not available and use_vllm is set to True. Please install vLLM with " + "`pip install vllm` to use it." + ) + self.vllm_mode = args.vllm_mode + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization + self.vllm_enable_sleep_mode = args.vllm_enable_sleep_mode + if self.vllm_mode == "server": + if self.accelerator.is_main_process: + self.vllm_client = VLLMClient( + host=args.vllm_server_host, + server_port=args.vllm_server_port, + connection_timeout=args.vllm_server_timeout, + ) + self.vllm_client.init_communicator() + elif self.vllm_mode == "colocate": + student_model_name_or_path = self.model_name_or_path + + # Make sure tensor_parallel_size divides world size evenly + if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: + raise ValueError( + f"vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size " + f"({self.accelerator.num_processes}) evenly." + ) + + if self.vllm_tensor_parallel_size > 1: + # Create subgroups of ranks for TP + self.vllm_tp_group, _ = torch.distributed.new_subgroups_by_enumeration( + [ + list( + range( + i * self.vllm_tensor_parallel_size, + (i + 1) * self.vllm_tensor_parallel_size, + ) + ) + for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) + ] + ) + + # vLLM requires the environment variables to be set for distributed training. + os.environ["RANK"] = str(self.accelerator.process_index) + os.environ["LOCAL_RANK"] = str(self.accelerator.local_process_index) + os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) + ensure_master_addr_port() + + vllm_quantization = None + if is_bitsandbytes_available(): + for _, module in model.named_modules(): + if isinstance(module, bnb.nn.Linear4bit): + vllm_quantization = "bitsandbytes" + break + elif isinstance(module, bnb.nn.Linear8bitLt): + raise ValueError("vLLM does not support in-flight 8-bit quantization.") + + self.vllm_engine = LLM( + model=student_model_name_or_path, + revision=self.model_revision, + tensor_parallel_size=self.vllm_tensor_parallel_size, + gpu_memory_utilization=self.vllm_gpu_memory_utilization, + max_num_seqs=self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps, + max_model_len=args.max_length, + distributed_executor_backend="external_launcher", + # Feed identical seed for tp groups to ensure sampling results are the same across workers + seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, + enable_sleep_mode=self.vllm_enable_sleep_mode, + quantization=vllm_quantization, + ) + + if self.vllm_enable_sleep_mode: + self.vllm_engine.sleep(level=2) + + # When using vLLM, the main process is responsible for loading the model weights. This can cause process + # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we + # synchronize all processes after vLLM has been fully initialized. + self.accelerator.wait_for_everyone() + else: + raise ValueError(f"Unknown vllm_mode: {self.vllm_mode}") + self.vllm_structured_outputs_regex = args.vllm_structured_outputs_regex + self.vllm_sync_frequency = args.vllm_sync_frequency + self._last_vllm_sync_step = -1 + + self.add_callback(GOLDVLLMSyncCallback(self)) + + def _set_signature_columns_if_needed(self): + super()._set_signature_columns_if_needed() + required_columns = [ + "prompts", + "prompt_attention_mask", + "messages", + "chat_template_kwargs", + "tools", + "original_prompt_text", + "original_completion_text", + ] + if self._signature_columns is None: + self._signature_columns = required_columns + else: + for column in required_columns: + if column not in self._signature_columns: + self._signature_columns.append(column) + + def _prepare_dataset( + self, + dataset: Dataset | IterableDataset, + processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin, + args, + packing: bool, + formatting_func: Callable[[dict], str] | None, + dataset_name: str, + ) -> Dataset | IterableDataset: + """ + Override dataset preparation to preserve original text for cross-tokenizer distillation and ensure + attention_mask is always added for DataCollatorForChatML compatibility. + """ + # Check if dataset is already processed + column_names = list(next(iter(dataset)).keys()) + is_processed = "input_ids" in column_names + + # Use our enhanced dataset preparation for: + # 1. ULD loss with cross-tokenizer (need original text preservation) + # 2. Any unprocessed dataset (need attention_mask for DataCollatorForChatML) + if not is_processed or (self.use_uld_loss and self.teacher_tokenizer is not None): + # For unprocessed datasets, use our enhanced tokenization + return self._prepare_dataset_with_original_text( + dataset, processing_class, args, packing, formatting_func, dataset_name + ) + + # Use parent implementation for all other cases + return super()._prepare_dataset(dataset, processing_class, args, packing, formatting_func, dataset_name) + + def _prepare_dataset_with_original_text( + self, + dataset: Dataset | IterableDataset, + processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin, + args, + packing: bool, + formatting_func: Callable[[dict], str] | None, + dataset_name: str, + ) -> Dataset | IterableDataset: + """ + Prepare dataset while preserving original text for cross-tokenizer distillation. + """ + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc + + with PartialState().main_process_first(): + # Apply the formatting function if any + if formatting_func is not None: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset" + + def _func(example): + return {"text": formatting_func(example)} + + dataset = dataset.map(_func, batched=False, **map_kwargs) + + # Convert the dataset to ChatML if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML" + column_names = next(iter(dataset)).keys() + dataset = dataset.map( + maybe_convert_to_chatml, + remove_columns="conversations" if "conversations" in column_names else None, + **map_kwargs, + ) + + # Apply the chat template if needed and preserve original text + first_example = next(iter(dataset)) + if not is_conversational(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" + + def add_eos(example, eos_token): + if "text" in example and not example["text"].endswith(eos_token): # language modeling case + example["text"] = example["text"] + eos_token + elif "completion" in example and not example["completion"].endswith(eos_token): + example["completion"] = example["completion"] + eos_token + return example + + dataset = dataset.map( + add_eos, + fn_kwargs={"eos_token": processing_class.eos_token}, + remove_columns="messages" if "messages" in column_names else None, # renamed to "text" + **map_kwargs, + ) + + # Tokenize the dataset while preserving original text + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset (preserving original text)" + + def tokenize_with_original_text(example, processing_class, dataset_text_field, assistant_only_loss): + """Modified tokenization function that preserves original text.""" + result = {} + + if "prompt" in example: # prompt-completion case + # Store original text + result["original_prompt_text"] = example["prompt"] + result["original_completion_text"] = example["completion"] + + if is_conversational(example): + prompt_ids = processing_class.apply_chat_template( + example["prompt"], return_dict=False, **example.get("chat_template_kwargs", {}) + ) + prompt_completion_ids = processing_class.apply_chat_template( + example["prompt"] + example["completion"], + return_dict=False, + **example.get("chat_template_kwargs", {}), + ) + else: + prompt_ids = processing_class(text=example["prompt"]).input_ids + prompt_completion_ids = processing_class( + text=example["prompt"] + example["completion"] + ).input_ids + + # Check if the tokenized prompt starts with the tokenized prompt+completion + if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids: + warnings.warn( + "Mismatch between tokenized prompt and the start of tokenized prompt+completion. " + "This may be due to unexpected tokenizer behavior, whitespace issues, or special " + "token handling. Verify that the tokenizer is processing text consistently.", + stacklevel=2, + ) + + # Create a completion mask + completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids)) + result.update( + { + "input_ids": prompt_completion_ids, + "completion_mask": completion_mask, + "attention_mask": [1] * len(prompt_completion_ids), # Add attention mask + } + ) + + else: # language modeling or conversational case + if is_conversational(example): + # For conversational data (ChatML), extract prompt and completion properly + messages = example["messages"] + + # Extract user and assistant messages separately + user_messages = [msg for msg in messages if msg["role"] != "assistant"] + assistant_messages = [msg for msg in messages if msg["role"] == "assistant"] + + if user_messages and assistant_messages: + # Apply chat template to get the prompt (everything up to assistant) + prompt_text = processing_class.apply_chat_template( + user_messages, + add_generation_prompt=True, # add assistant prompt + tokenize=False, + **example.get("chat_template_kwargs", {}), + ) + + # Get the full conversation with assistant response + full_text = processing_class.apply_chat_template( + messages, + add_generation_prompt=False, + tokenize=False, + **example.get("chat_template_kwargs", {}), + ) + + # Extract completion as everything after the prompt + # This ensures we capture any extra tokens (like tags) that the template adds + if full_text.startswith(prompt_text): + completion_text = full_text[len(prompt_text) :] + else: + # Fallback: use assistant content + EOS + assistant_content = assistant_messages[0]["content"] + completion_text = ( + assistant_content + processing_class.eos_token + if hasattr(processing_class, "eos_token") + else assistant_content + ) + + # Store original text for cross-tokenizer distillation + result["original_prompt_text"] = prompt_text + result["original_completion_text"] = completion_text + else: + # Fallback: use empty prompt and full text as completion + full_text = processing_class.apply_chat_template( + messages, tokenize=False, **example.get("chat_template_kwargs", {}) + ) + result["original_prompt_text"] = "" + result["original_completion_text"] = full_text + + # Process the conversation normally + processed = processing_class.apply_chat_template( + example["messages"], + return_dict=True, + return_assistant_tokens_mask=assistant_only_loss, + **example.get("chat_template_kwargs", {}), + ) + if "assistant_masks" in processed and 1 not in processed["assistant_masks"]: + raise RuntimeError( + "You're using `assistant_only_loss=True`, but at least one example has no " + "assistant tokens. This usually means the tokenizer's chat template doesn't " + "generate assistant masks — it may be missing the `{% generation %}` tag. Please " + "check the template and ensure it's correctly configured to support assistant " + "masking." + ) + result.update({k: processed[k] for k in ("input_ids", "assistant_masks") if k in processed}) + # Add attention_mask if not already present + if "attention_mask" not in result: + result["attention_mask"] = [1] * len(result["input_ids"]) + else: + # For regular language modeling, store the full text as completion and empty prompt + result["original_prompt_text"] = "" + result["original_completion_text"] = example.get(dataset_text_field, example.get("text", "")) + + tokenized = processing_class(text=example[dataset_text_field]) + result.update( + { + "input_ids": tokenized.input_ids, + "attention_mask": getattr(tokenized, "attention_mask", [1] * len(tokenized.input_ids)), + } + ) + + return result + + dataset = dataset.map( + tokenize_with_original_text, + fn_kwargs={ + "processing_class": processing_class, + "dataset_text_field": args.dataset_text_field, + "assistant_only_loss": args.assistant_only_loss, + }, + **map_kwargs, + ) + + # Pack or truncate + if packing: + if args.max_length is None: + raise ValueError("When packing is enabled, `max_length` can't be `None`.") + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Packing {dataset_name} dataset" + + columns_to_keep = ["input_ids", "original_prompt_text", "original_completion_text"] + existing_columns = set(dataset.column_names) + columns_to_select = [col for col in columns_to_keep if col in existing_columns] + + dataset = dataset.select_columns(columns_to_select) + dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs) + elif args.max_length is not None: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Truncating {dataset_name} dataset" + dataset = truncate_dataset(dataset, args.max_length, map_kwargs) + + if args.use_liger_kernel: + required_columns = { + "input_ids", + "attention_mask", + "position_ids", + "completion_mask", + "assistant_masks", + "original_prompt_text", + "original_completion_text", + } + dataset = dataset.select_columns(required_columns.intersection(dataset.column_names)) + + return dataset + + @staticmethod + def generalized_jsd_loss( + student_logits, + teacher_logits, + labels=None, + beta=0.5, + temperature=1.0, + reduction="batchmean", + logits_are_probs=False, + ): + """ + Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1) + of https://huggingface.co/papers/2306.13649 for the definition. + + Args: + student_logits: + Tensor of shape (batch_size, sequence_length, vocab_size) + teacher_logits: + Tensor of shape (batch_size, sequence_length, vocab_size) + labels: + Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing + loss + beta: + Interpolation coefficient between 0 and 1 (default: 0.5) + temperature: + Softmax temperature (default: 1.0) + reduction: + Specifies the reduction to apply to the output (default: 'batchmean') + + Returns: + loss: Scalar tensor with the generalized JSD loss + """ + + if logits_are_probs: + student_log_probs = torch.log(student_logits.clamp_min(1e-8)) + teacher_log_probs = torch.log(teacher_logits.clamp_min(1e-8)) + else: + # Apply temperature scaling to logits before computing probabilities + student_logits = student_logits / temperature + teacher_logits = teacher_logits / temperature + # Compute log probabilities for student and probabilities for teacher + student_log_probs = F.log_softmax(student_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + + if beta == 0: + jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif beta == 1: + jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + # Compute the log of the mixture distribution + # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture + beta = torch.tensor(beta, dtype=student_log_probs.dtype, device=student_log_probs.device) + mixture_log_probs = torch.logsumexp( + torch.stack([student_log_probs + torch.log1p(-beta), teacher_log_probs + torch.log(beta)]), + dim=0, + ) + + # Compute KL divergences using F.kl_div + # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper. + kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) + kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) + + # Compute the Generalized Jensen-Shannon Divergence + jsd = beta * kl_teacher + (1 - beta) * kl_student + + # Masking + if labels is not None: + mask = labels != -100 + jsd = jsd[mask] + + # Apply reduction + if reduction == "batchmean": + return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / jsd.size(0) + elif reduction == "sum": + return jsd.sum() + elif reduction == "mean": + return jsd.mean() + else: + return jsd + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if self.use_uld_loss and self.teacher_tokenizer is not None: + if "original_prompt_text" in inputs and "original_completion_text" in inputs: + prompt_texts = inputs["original_prompt_text"] + completion_texts = inputs["original_completion_text"] + full_texts = [p + c for p, c in zip(prompt_texts, completion_texts, strict=True)] + else: + # Fallback: decode student input_ids (current approach) + # WARNING: This may not work perfectly for cross-tokenizer distillation + full_sequences = inputs["input_ids"] + full_texts = self.processing_class.batch_decode(full_sequences, skip_special_tokens=False) + + # Try to split prompt/completion using original prompt length + prompt_lengths = inputs["prompts"].shape[1] + prompt_texts = self.processing_class.batch_decode(inputs["prompts"], skip_special_tokens=False) + completion_texts = [ + full.replace(prompt, "", 1) for full, prompt in zip(full_texts, prompt_texts, strict=True) + ] + + ( + teacher_input_ids, + teacher_labels, + teacher_attention_mask, + teacher_prompt_length, + ) = build_teacher_inputs_from_texts( + self.teacher_tokenizer, + prompt_texts, + completion_texts, + ) + + teacher_input_ids = teacher_input_ids.to(self.accelerator.device) + teacher_labels = teacher_labels.to(self.accelerator.device) + teacher_attention_mask = teacher_attention_mask.to(self.accelerator.device) + + outputs_student = model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + use_cache=False, + ) + + self.teacher_model.eval() + with torch.no_grad(): + outputs_teacher = self.teacher_model( + input_ids=teacher_input_ids, + attention_mask=teacher_attention_mask, + ) + + # These are not used for ULD loss but are needed if JSD loss were to be used in this branch + student_prompt_length = inputs["prompts"].shape[1] + shifted_student_logits = outputs_student.logits[:, student_prompt_length - 1 : -1, :] + shifted_teacher_logits = outputs_teacher.logits[:, teacher_prompt_length - 1 : -1, :] + shifted_labels = inputs["labels"][:, student_prompt_length:] + else: + if self.use_liger_gkd_loss: + # Forward only through the base models (avoid lm_head to save memory) + unwrapped_student = self.accelerator.unwrap_model(model) + if hasattr(unwrapped_student, "get_decoder") and unwrapped_student.get_decoder() is not None: + base_student = unwrapped_student.get_decoder() + else: + base_student = getattr( + unwrapped_student, getattr(unwrapped_student, "base_model_prefix", "model"), unwrapped_student + ) + + student_outputs = base_student( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + use_cache=False, + ) + + self.teacher_model.eval() + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + if hasattr(unwrapped_teacher, "get_decoder") and unwrapped_teacher.get_decoder() is not None: + base_teacher = unwrapped_teacher.get_decoder() + else: + base_teacher = getattr( + unwrapped_teacher, getattr(unwrapped_teacher, "base_model_prefix", "model"), unwrapped_teacher + ) + with torch.no_grad(): + teacher_outputs = base_teacher( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + use_cache=False, + ) + + # hidden states (shifted) + student_hidden = student_outputs.last_hidden_state[:, :-1] + teacher_hidden = teacher_outputs.last_hidden_state[:, :-1] + + # Release full outputs to free memory + del student_outputs, teacher_outputs + + # labels mask and labels (shifted) + labels_mask = inputs["labels"] != -100 + masked_input_ids = torch.where( + labels_mask, inputs["input_ids"], torch.full_like(inputs["input_ids"], -100) + ) + true_labels = masked_input_ids[:, 1:].contiguous() + + # heads + student_head = unwrapped_student.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + + # liger fused jsd loss + loss = self.liger_jsd_loss( + student_input=student_hidden, + student_weight=student_head.weight, + teacher_input=teacher_hidden, + teacher_weight=teacher_head.weight, + true_labels=true_labels, + student_bias=getattr(student_head, "bias", None), + teacher_bias=getattr(teacher_head, "bias", None), + ) + + # Release hidden states after loss computation + del student_hidden, teacher_hidden, true_labels + else: + # Original behavior for same tokenizer or when teacher_tokenizer is not provided + outputs_student = model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + ) + + self.teacher_model.eval() + with torch.no_grad(): + outputs_teacher = self.teacher_model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + ) + + prompt_lengths = inputs["prompts"].shape[1] + shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :] + shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :] + shifted_labels = inputs["labels"][:, prompt_lengths:] + loss = self.generalized_jsd_loss( + student_logits=shifted_student_logits, + teacher_logits=shifted_teacher_logits, + labels=shifted_labels, + beta=self.beta, + ) + + if self.use_uld_loss: + student_input_ids = inputs["input_ids"] + + # Use the *teacher* labels created above, not the student's. + teacher_labels_for_loss = teacher_labels if "teacher_labels" in locals() else inputs["labels"] + teacher_input_ids_for_loss = teacher_input_ids if "teacher_input_ids" in locals() else inputs["input_ids"] + + # Create properly masked student labels (fixing batch size > 1 issue) + student_labels = inputs["labels"].clone() + if hasattr(self.processing_class, "pad_token_id") and self.processing_class.pad_token_id is not None: + student_labels[student_labels == self.processing_class.pad_token_id] = -100 + + # Also mask pad tokens in teacher labels for consistency + if ( + hasattr(self, "teacher_tokenizer") + and hasattr(self.teacher_tokenizer, "pad_token_id") + and self.teacher_tokenizer.pad_token_id is not None + ): + teacher_labels[teacher_labels == self.teacher_tokenizer.pad_token_id] = -100 + + loss = self.uld_loss_fn( + student_logits=outputs_student.logits, + teacher_logits=outputs_teacher.logits, + student_labels=student_labels, + teacher_labels=teacher_labels_for_loss, + student_input_ids=student_input_ids, + teacher_input_ids=teacher_input_ids_for_loss, + ) + + # If ULD hybrid mode produced per-step matched/unmatched components, accumulate them for logging. + # Use gradient_accumulation_steps to mirror Trainer's windowing behavior. + if hasattr(self.uld_loss_fn, "last_matched_loss") and hasattr(self.uld_loss_fn, "last_unmatched_loss"): + try: + ga = max(1, int(self.args.gradient_accumulation_steps)) + except Exception: + ga = 1 + step_eq = 1.0 / ga + # read scalar values for logging + matched_val = ( + self.uld_loss_fn.last_matched_loss.item() + if self.uld_loss_fn.last_matched_loss is not None + else 0.0 + ) + unmatched_val = ( + self.uld_loss_fn.last_unmatched_loss.item() + if self.uld_loss_fn.last_unmatched_loss is not None + else 0.0 + ) + + self._matched_sum += matched_val + self._unmatched_sum += unmatched_val + self._matched_step_eq += step_eq + self._unmatched_step_eq += step_eq + + empty_cache() + + return (loss, outputs_student) if return_outputs else loss + + def generate_on_policy_outputs(self, model, inputs, generation_config, pad_token_id=None): + # Generate output with respect to the prompt only + if self.use_transformers_paged: + previous_attn = self.model.config._attn_implementation + if is_flash_attn_2_available(): + model.config._attn_implementation = "paged_attention" + else: + model.config._attn_implementation = "sdpa_paged" + prompt_mask = inputs.get("prompt_attention_mask") + prompts_tensor = inputs["prompts"] + if prompt_mask is not None: + prompt_sequences = [ + row[mask.bool()].detach().cpu().tolist() + for row, mask in zip(prompts_tensor, prompt_mask, strict=True) + ] + else: + prompt_sequences = [row.detach().cpu().tolist() for row in prompts_tensor] + generated_outputs = model.generate_batch(prompt_sequences, generation_config=generation_config) + model.config._attn_implementation = previous_attn + + completion_ids = [output.generated_tokens for output in generated_outputs.values()] + generated_tokens = torch.stack([torch.tensor(ids, device=model.device) for ids in completion_ids]) + else: + generated_outputs = model.generate( + input_ids=inputs["prompts"], + attention_mask=inputs.get("prompt_attention_mask", None), + generation_config=generation_config, + return_dict_in_generate=True, + ) + # Get the generated token IDs + generated_tokens = generated_outputs.sequences + + batch_size = generated_tokens.size(0) + device = generated_tokens.device + + prompt_mask = inputs.get("prompt_attention_mask") + pad_token_id = pad_token_id if pad_token_id is not None else self.processing_class.pad_token_id + + if prompt_mask is not None: + prompt_lengths = prompt_mask.sum(dim=1).to(torch.long) + else: + if pad_token_id is not None: + prompt_lengths = (inputs["prompts"] != pad_token_id).sum(dim=1).to(torch.long) + else: + prompt_lengths = torch.full( + (batch_size,), + inputs["prompts"].shape[1], + dtype=torch.long, + device=device, + ) + + new_input_ids = generated_tokens + new_attention_mask = torch.ones_like(new_input_ids) + if pad_token_id is not None: + new_attention_mask[new_input_ids == pad_token_id] = 0 + + new_labels = torch.full_like(new_input_ids, -100) + for idx in range(batch_size): + length = int(prompt_lengths[idx].item()) + new_labels[idx, length:] = new_input_ids[idx, length:] + + if pad_token_id is not None: + new_labels[new_input_ids == pad_token_id] = -100 + + prompt_texts = [] + completion_texts = [] + for idx in range(batch_size): + length = int(prompt_lengths[idx].item()) + prompt_tokens = inputs["prompts"][idx] + if prompt_mask is not None: + prompt_tokens = prompt_tokens[prompt_mask[idx].bool()] + elif pad_token_id is not None: + prompt_tokens = prompt_tokens[prompt_tokens != pad_token_id] + prompt_texts.append( + self.processing_class.decode( + prompt_tokens.tolist(), + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + ) + completion_tokens = new_input_ids[idx, length:] + completion_texts.append( + self.processing_class.decode( + completion_tokens.tolist(), + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + ) + + return new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts + + @profiling_decorator + def _generate_on_policy_outputs_vllm(self, inputs, generation_config, pad_token_id=None): + device = self.accelerator.device + + # Decode prompts for vLLM (without special tokens - vLLM expects clean text) + prompts_text_for_vllm = self.processing_class.batch_decode( + inputs["prompts"], + skip_special_tokens=True, + # clean_up_tokenization_spaces=False # Keep this commented unless specific issues arise + ) + # Remove padding token text if it appears, as vLLM expects clean prompts + if self.processing_class.pad_token: + prompts_text_for_vllm = [p.replace(self.processing_class.pad_token, "") for p in prompts_text_for_vllm] + + # Also decode prompts WITH special tokens for ULD loss computation + prompts_text_with_special = self.processing_class.batch_decode( + inputs["prompts"], + skip_special_tokens=False, + ) + + # system_prompt = "Please reason step by step, and put your final answer within \\boxed{}." + # target_system_prompt = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." + # prompts_text = [p.replace(target_system_prompt, system_prompt) for p in prompts_text] + # Add system prompt to prompts + + max_completion_length = generation_config.max_new_tokens + temperature = generation_config.temperature + # vLLM uses top_k=-1 for no top_k, transformers uses 0 or None. + top_k = generation_config.top_k if generation_config.top_k and generation_config.top_k > 0 else -1 + # top_p, repetition_penalty, min_p are not directly in generation_config, get from trainer args + top_p = self.args.top_p if hasattr(self.args, "top_p") else 1.0 + repetition_penalty = self.args.repetition_penalty if hasattr(self.args, "repetition_penalty") else 1.0 + min_p = self.args.min_p if hasattr(self.args, "min_p") else 0.0 + + if self.vllm_mode == "server": + all_prompts_text = gather_object(prompts_text_for_vllm) + if self.accelerator.is_main_process: + completion_ids = self.vllm_client.generate( + prompts=all_prompts_text, + n=1, # In GKD, we generate 1 completion per prompt from student + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + max_tokens=max_completion_length, + structured_outputs_regex=self.vllm_structured_outputs_regex, + )["completion_ids"] + else: + completion_ids = [None] * len(all_prompts_text) + completion_ids = broadcast_object_list(completion_ids, from_process=0) + process_slice = slice( + self.accelerator.process_index * len(prompts_text_for_vllm), + (self.accelerator.process_index + 1) * len(prompts_text_for_vllm), + ) + completion_ids = completion_ids[process_slice] + elif self.vllm_mode == "colocate": + if self.vllm_structured_outputs_regex: + structured_outputs = StructuredOutputsParams( + backend="outlines", regex=self.vllm_structured_outputs_regex + ) + else: + structured_outputs = None + sampling_params = SamplingParams( + n=1, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + max_tokens=max_completion_length, + structured_outputs=structured_outputs, + ) + + if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: + # Gather prompts from all ranks in the TP group and flatten. + # Each rank starts with its own prompts; after gathering, all ranks see the full group set. + orig_size = len(prompts_text_for_vllm) + gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_prompts, prompts_text_for_vllm, group=self.vllm_tp_group) + all_prompts_text = [p for sublist in gathered_prompts for p in sublist] + else: + all_prompts_text = prompts_text_for_vllm + + all_outputs = self.vllm_engine.generate(all_prompts_text, sampling_params=sampling_params, use_tqdm=False) + completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + + if hasattr(self, "vllm_tp_group") and self.vllm_tensor_parallel_size > 1: + # Slice completions for this rank within its TP group. + # Each rank generates all outputs — we keep only our share. + local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + completion_ids = completion_ids[tp_slice] + + if self.vllm_enable_sleep_mode: + self.vllm_engine.sleep(level=2) + else: + raise ValueError(f"Unknown vllm_mode: {self.vllm_mode}") + + # We need to combine prompt and completion for new_input_ids + # Tokenize prompts again to get prompt_ids on the correct device and format + # Use prompts_text_for_vllm (without special tokens) for tokenization since vLLM expects clean text + # Ensure add_special_tokens=False as vLLM typically handles prompts as raw text + # Calculate max_length for prompts, ensuring it's positive + prompt_max_length = max(1, self.args.max_length - max_completion_length) if self.args.max_length else None + prompt_tokenized = self.processing_class( + prompts_text_for_vllm, + return_tensors="pt", + padding="longest", + truncation=True if prompt_max_length else False, + max_length=prompt_max_length, + add_special_tokens=False, + ).to(device) + prompt_ids = prompt_tokenized.input_ids + + completion_ids_tensors = [torch.tensor(ids, device=device) for ids in completion_ids] + # Manually pad/truncate completions to max_completion_length length before using pad function + padded_completion_ids_list = [] + for completion_tensor in completion_ids_tensors: + if len(completion_tensor) > max_completion_length: + # Truncate if longer than max_completion_length + padded_completion_ids_list.append(completion_tensor[:max_completion_length]) + elif len(completion_tensor) < max_completion_length: + # Pad if shorter than max_completion_length + padding_needed = max_completion_length - len(completion_tensor) + padded_tensor = torch.cat( + [ + completion_tensor, + torch.full((padding_needed,), pad_token_id, device=device, dtype=completion_tensor.dtype), + ] + ) + padded_completion_ids_list.append(padded_tensor) + else: + # Already the right length + padded_completion_ids_list.append(completion_tensor) + + # Now all tensors are the same length, so we can stack them + padded_completion_ids = torch.stack(padded_completion_ids_list) + + # Ensure prompt_ids and padded_completion_ids are 2D + if prompt_ids.ndim == 1: + prompt_ids = prompt_ids.unsqueeze(0) + if padded_completion_ids.ndim == 1: + padded_completion_ids = padded_completion_ids.unsqueeze(0) + + new_input_ids = torch.cat([prompt_ids, padded_completion_ids], dim=1) + + new_attention_mask = torch.ones_like(new_input_ids, device=device) + new_labels = new_input_ids.clone() + + if pad_token_id is not None: + new_labels[new_labels == pad_token_id] = -100 + new_attention_mask[new_input_ids == pad_token_id] = 0 + + # Mask prompt tokens in labels + prompt_lengths = prompt_ids.shape[1] + new_labels[:, :prompt_lengths] = -100 + + # IMPORTANT: Preserve original text for cross-tokenizer ULD loss + # Use prompts_text_with_special (with special tokens) for ULD loss computation + # Extract completion texts from the generated completion IDs + completion_texts = [] + for comp_ids in completion_ids: + completion_text = self.processing_class.decode(comp_ids, skip_special_tokens=False) + completion_texts.append(completion_text) + + return new_input_ids, new_attention_mask, new_labels, prompts_text_with_special, completion_texts + + def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): + """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with student vLLM.""" + if visited is None: + visited = set() + + for child_name, child_module in module.named_children(): + child_prefix = f"{prefix}.{child_name}" if prefix else child_name + # recurse into the child + self._sync_fsdp_params_to_vllm(child_module, prefix=child_prefix, visited=visited) + + if isinstance(module, FSDP): + with FSDP.summon_full_params(module, recurse=False, writeback=False): + for param_name, param in module.named_parameters(): + full_name = f"{prefix}.{param_name}" if prefix else param_name + for extra in ("_fsdp_wrapped_module.", "_checkpoint_wrapped_module."): + full_name = full_name.replace(extra, "") + + if full_name in visited: + continue # skip FSDP subtrees already traversed + visited.add(full_name) + + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(full_name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.vllm_engine.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(full_name, param.data)]) + + def _move_model_to_vllm(self): + """Synchronize student model weights to vLLM engine.""" + # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 + if zero_stage_3: + import deepspeed + + gather_if_zero3 = deepspeed.zero.GatheredParameters + else: + gather_if_zero3 = nullcontext + + if self.vllm_mode == "colocate" and self.vllm_enable_sleep_mode: + empty_cache() + self.vllm_engine.wake_up(tags=["weights"]) + # Work around for https://github.com/vllm-project/vllm/issues/29341 + self.vllm_engine.collective_rpc("reload_weights") + + if is_peft_model(self.model): + # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as + # merging adapters in a sharded manner is not supported. + with gather_if_zero3(list(self.model.parameters())): + self.model.merge_adapter() + + # Update vLLM weights while parameters are gathered + if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext + # Update vLLM weights while parameters are gathered + # For PEFT with FSDP we need to use the memory efficient post-order traversal + self._sync_fsdp_params_to_vllm(self.model) + else: + # DeepSpeed ZeRO-3 with PEFT + for name, param in self.model.named_parameters(): + # When using PEFT, we need to recover the original parameter name and discard some parameters + name = name.removeprefix("base_model.model.").replace(".base_layer", "") + if self.model.prefix in name: + continue + # When module to save, remove its prefix and discard the original module + if "original_module" in name: + continue + name = name.replace("modules_to_save.default.", "") + + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.vllm_engine.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + # Unmerge adapters while parameters are still gathered + self.model.unmerge_adapter() + # Parameters will automatically be repartitioned when exiting the context + else: + # For non-PEFT models, simply gather (if needed) and update each parameter individually. + if self.is_fsdp_enabled: + # use memory-efficient post-order traversal for FSDP + self._sync_fsdp_params_to_vllm(self.model) + else: + # For DeepSpeed ZeRO-3, gather each parameter individually like GRPO trainer + for name, param in self.model.named_parameters(): + with gather_if_zero3([param]): + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.vllm_engine.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + + # Reset cache on vLLM + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.reset_prefix_cache() + elif self.vllm_mode == "colocate": + self.vllm_engine.reset_prefix_cache() + + def _wake_vllm_if_needed(self): + if self.vllm_mode == "colocate" and self.vllm_enable_sleep_mode: + empty_cache() + self.vllm_engine.wake_up(tags=["kv_cache"]) + + @profiling_decorator + def training_step( + self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None + ) -> torch.Tensor: + """ + Perform a training step for the General Online Logit Distillation (GOLD) model. + + This method implements the on-policy learning approach described in the GOLD blog post. With probability + `self.lmbda`, it generates new responses using the student model, which are then used for training instead of + the offline original inputs. + """ + on_policy = False + if random.random() <= self.lmbda: + on_policy = True + if self.use_vllm: + self._wake_vllm_if_needed() + result = self._generate_on_policy_outputs_vllm( + inputs, self.generation_config, self.processing_class.pad_token_id + ) + new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts = result + else: + with ( + unwrap_model_for_generation( + model, + self.accelerator, + generation_kwargs=self.generation_kwargs, # Override model.generation_config with generation_kwargs to fix transformers#42762 + ) as unwrapped_model + ): + result = self.generate_on_policy_outputs( + unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id + ) + new_input_ids, new_attention_mask, new_labels, prompt_texts, completion_texts = result + + inputs["input_ids"] = new_input_ids + inputs["attention_mask"] = new_attention_mask + inputs["labels"] = new_labels + + # CRITICAL: Preserve original text for cross-tokenizer ULD loss + # This ensures both off-policy (dataset) and on-policy (generated) samples + # can use proper text-based alignment for different tokenizers + inputs["original_prompt_text"] = prompt_texts + inputs["original_completion_text"] = completion_texts + + # Log prompt and completion texts + self._textual_logs["prompt"].extend(gather_object(prompt_texts)) + self._textual_logs["completion"].extend(gather_object(completion_texts)) + + loss = super().training_step(model, inputs, num_items_in_batch) + + loss_scalar = float(loss.detach()) + ga = max(1, int(self.args.gradient_accumulation_steps)) + step_equiv = 1.0 / ga + + if on_policy: + self._on_policy_loss_total += loss_scalar + self._on_policy_step_equiv += step_equiv + else: + self._off_policy_loss_total += loss_scalar + self._off_policy_step_equiv += step_equiv + return loss + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + if mode == "train": + device = self.accelerator.device if hasattr(self.accelerator, "device") else torch.device("cpu") + # include matched/unmatched accumulators for distributed reduction + vec = torch.tensor( + [ + self._on_policy_loss_total, + self._off_policy_loss_total, + self._on_policy_step_equiv, + self._off_policy_step_equiv, + self._matched_sum, + self._unmatched_sum, + self._matched_step_eq, + self._unmatched_step_eq, + ], + dtype=torch.float64, + device=device, + ) + + # Sum across processes so we mirror Trainer's distributed reduction + if ( + getattr(self.accelerator, "distributed_type", DistributedType.NO) != DistributedType.NO + and dist.is_available() + and dist.is_initialized() + ): + dist.all_reduce(vec, op=dist.ReduceOp.SUM) + + ( + on_sum, + off_sum, + on_eq, + off_eq, + matched_sum, + unmatched_sum, + matched_eq, + unmatched_eq, + ) = vec.tolist() + + # Compute category averages over the *same window* as Trainer's logs + # (avoid div-by-zero if, e.g., no on-policy steps in the window) + if on_eq > 0: + logs["on_policy_loss"] = round(on_sum / on_eq, 4) + if off_eq > 0: + logs["off_policy_loss"] = round(off_sum / off_eq, 4) + + # matched/unmatched averaged over same logging window (if present) + if matched_eq > 0: + logs["matched_loss"] = round(matched_sum / matched_eq, 4) + if unmatched_eq > 0: + logs["unmatched_loss"] = round(unmatched_sum / unmatched_eq, 4) + + # Reset window accumulators after logging (just like Trainer resets its window) + self._on_policy_loss_total = self._off_policy_loss_total = 0.0 + self._on_policy_step_equiv = self._off_policy_step_equiv = 0.0 + self._matched_sum = self._unmatched_sum = 0.0 + self._matched_step_eq = self._unmatched_step_eq = 0.0 + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + if ( + self.accelerator.is_main_process + and self.log_completions + and ((self.state.global_step % self.log_completion_steps) == 0) + ): + if is_rich_available(): + print_prompt_completions_sample_uld( + self._textual_logs["prompt"], + self._textual_logs["completion"], + self.state.global_step, + self.num_completions_to_print, + ) + + if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: + import pandas as pd + + table = { + "step": [str(self.state.global_step)] * len(self._textual_logs["prompt"]), + "prompt": self._textual_logs["prompt"], + "completion": self._textual_logs["completion"], + } + df = pd.DataFrame(table) + if self.wandb_log_unique_prompts: + df = df.drop_duplicates(subset=["prompt"]) + if self.num_completions_to_print and len(df) > 0: + df = df.sample(n=self.num_completions_to_print, random_state=42) + wandb.log({"completions": wandb.Table(dataframe=df)}) diff --git a/ICL/RL/trl_source/trl/experimental/prm/prm_trainer.py b/ICL/RL/trl_source/trl/experimental/prm/prm_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8d6c5279f37235d942b1e3b954cbed73191aab51 --- /dev/null +++ b/ICL/RL/trl_source/trl/experimental/prm/prm_trainer.py @@ -0,0 +1,354 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap +from collections.abc import Callable +from itertools import chain +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import transformers +from accelerate import PartialState, logging +from datasets import Dataset, features +from packaging.version import Version +from transformers import ( + BaseImageProcessor, + DataCollator, + DataCollatorForTokenClassification, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, +) +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_peft_available + +from ...trainer.base_trainer import BaseTrainer +from ...trainer.utils import disable_dropout_in_model +from ..utils import prepare_peft_model +from .prm_config import PRMConfig + + +if is_peft_available(): + from peft import PeftModel + +logger = logging.get_logger(__name__) + + +def compute_accuracy(eval_pred: EvalPrediction) -> dict[str, float]: + predictions, labels = eval_pred + if predictions.ndim == 3: + # Token classification task. Shapes are (batch_size, seq_len, num_labels) and (batch_size, seq_len) + # Used to compute the accuracy in the prm_trainer. + predictions = np.argmax(predictions, axis=2) + + # Flatten the predictions and labels to remove the ignored tokens. + predictions = np.array( + [ + p + for prediction, label in zip(predictions, labels, strict=True) + for (p, lbl) in zip(prediction, label, strict=True) + if lbl != -100 + ] + ) + labels = np.array([lbl for label in labels for lbl in label if lbl != -100]) + + else: + # Here, predictions is rewards_chosen and rewards_rejected. Shapes are (batch_size, 2) and (batch_size,) + # We want to see how much of the time rewards_chosen > rewards_rejected. + equal_mask = predictions[:, 0] == predictions[:, 1] + equal_predictions_count = int(equal_mask.sum()) + + if equal_predictions_count > 0: + # Before using the logger, the accelerate state must be initialized. It'susually the case when using this + # function inside a Trainer, but it may not be the case otherwise, in particular when unit testing. + PartialState() + + logger.warning( + f"There are {equal_predictions_count} out of {len(predictions[:, 0])} instances where the predictions " + "for both options are equal. These instances are ignored in the accuracy computation.", + ) + + # Filter out equal predictions + predictions = predictions[~equal_mask] + labels = labels[~equal_mask] + + # Use the remaining predictions for accuracy calculation + predictions = np.argmax(predictions, axis=1) + + accuracy = np.array(predictions == labels, dtype=float).mean().item() + return {"accuracy": accuracy} + + +class PRMTrainer(BaseTrainer): + """ + Initialize PRMTrainer. + + Args: + model ([`~transformers.PreTrainedModel`]): + The model to train, preferably an `AutoModelForTokenClassification`. + args ([`experimental.prm.PRMConfig`]): + The arguments to use for training. + data_collator ([`~transformers.DataCollator`]): + The data collator to use for training. If None is specified, the default data collator + ([`~transformers.DataCollatorForTokenClassification`]) will be used which will pad the sequences to the + maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset ([`~datasets.Dataset`]): + The dataset to use for training. + eval_dataset ([`~datasets.Dataset`]): + The dataset to use for evaluation. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be + used. + compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`): + The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) + will be used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in + a PEFT model. + """ + + _tag_names = ["trl", "prm"] + _name = "PRM" + _paper = { + "title": "Solving math word problems with process-and outcome-based feedback", + "id": "2211.14275", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @article{uesato2022solving, + title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}}, + author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina}, + year = 2022, + journal = {arXiv preprint arXiv:2211.14275} + }"""), + } + + def __init__( + self, + model: PreTrainedModel | nn.Module | None = None, + args: PRMConfig | None = None, + data_collator: DataCollator | None = None, + train_dataset: Dataset | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + processing_class: PreTrainedTokenizerBase + | BaseImageProcessor + | FeatureExtractionMixin + | ProcessorMixin + | None = None, + model_init: Callable[[], PreTrainedModel] | None = None, + compute_metrics: Callable[[EvalPrediction], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, + peft_config: dict | None = None, + ): + if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): + model = prepare_peft_model(model, peft_config, args) + + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(model) + + if compute_metrics is None: + compute_metrics = compute_accuracy + + if data_collator is None: + if processing_class is None: + raise ValueError( + "A processing_class must be specified when using the default DataCollatorForTokenClassification" + ) + data_collator = DataCollatorForTokenClassification(processing_class) + + if "input_ids" not in train_dataset.column_names: + with PartialState().main_process_first(): + fn_kwargs = { + "tokenizer": processing_class, + "step_separator": args.step_separator, + "max_length": args.max_length, + "max_completion_length": args.max_completion_length, + "train_on_last_step_only": args.train_on_last_step_only, + } + train_fn_kwargs = {**fn_kwargs, "is_eval": False} + train_dataset = train_dataset.map( + self.tokenize_row, + fn_kwargs=train_fn_kwargs, + num_proc=args.dataset_num_proc, + remove_columns=train_dataset.features, + desc="Tokenizing train dataset", + features=features.Features( # needed to avoid map to cast labels to bool + { + "labels": features.Sequence(features.Value("int64")), + "input_ids": features.Sequence(features.Value("int64")), + } + ), + ) + + eval_fn_kwargs = {**fn_kwargs, "is_eval": True} + if eval_dataset is not None: + eval_dataset = eval_dataset.map( + self.tokenize_row, + fn_kwargs=eval_fn_kwargs, + num_proc=args.dataset_num_proc, + remove_columns=eval_dataset.features, + desc="Tokenizing eval dataset", + features=features.Features( # needed to avoid map to cast labels to bool + { + "labels": features.Sequence(features.Value("int64")), + "input_ids": features.Sequence(features.Value("int64")), + } + ), + ) + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + @staticmethod + def tokenize_row( + features, + tokenizer, + step_separator, + max_length, + max_completion_length, + train_on_last_step_only, + is_eval, + ): + r""" + Tokenize a row of the dataset. + + Args: + features (`dict[str, str]`): + Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`. + tokenizer ([`~transformers.PreTrainedTokenizerBase`]): + Tokenizer used to process the data. + step_separator (`str`): + Separator between steps in the completion. + max_length (`int` or `None`): + Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated. + max_completion_length (`int` or `None`): + Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. + train_on_last_step_only (`bool`): + Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last + token of the completion. + is_eval (`bool`): + Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if + `train_on_last_step_only` is set to `True`. + + Returns: + `dict[str, list[int]]`: + Tokenized sequences with the keys `"input_ids"`, and `"labels". + + Example: + ```python + >>> from transformers import AutoTokenizer + + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") + >>> features = { + ... "prompt": "Which number is larger, 9.8 or 9.11?", + ... "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + ... "labels": [True, False], + ... } + >>> PRMTrainer.tokenize_row( + ... features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False + ... ) + {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198], + 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]} + ``` + """ + # Tokenize the prompt and completions + prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] + completions_ids = [ + tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"] + ] + if train_on_last_step_only and not is_eval: + labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])] + else: + labels = [int(label) for label in features["labels"]] + + # Get the ID of the separator token and add it to the completions + separator_ids = tokenizer.encode(step_separator, add_special_tokens=False) + completions_ids = [completion + separator_ids for completion in completions_ids] + + # Create the label + labels = [ + [-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels, strict=True) + ] + + # Join the completions and labels steps + completion_ids = list(chain(*completions_ids)) + labels = list(chain(*labels)) + + if tokenizer.bos_token_id is not None: + prompt_ids = [tokenizer.bos_token_id] + prompt_ids + + # Truncate completion sequences + if max_completion_length is not None: + completion_ids = completion_ids[:max_completion_length] + labels = labels[:max_completion_length] + + input_ids = prompt_ids + completion_ids + labels = [-100] * len(prompt_ids) + labels + + if max_length is not None: + input_ids = input_ids[:max_length] + labels = labels[:max_length] + + return {"input_ids": input_ids, "labels": labels} + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/ICL/RL/trl_source/trl/extras/__pycache__/profiling.cpython-313.pyc b/ICL/RL/trl_source/trl/extras/__pycache__/profiling.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e66a321d65cec4cc7c531e15bc3451d76dc89b92 Binary files /dev/null and b/ICL/RL/trl_source/trl/extras/__pycache__/profiling.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/generation/__pycache__/__init__.cpython-313.pyc b/ICL/RL/trl_source/trl/generation/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..884334d8e86996fda9787b48d94c46a74ae5f8bb Binary files /dev/null and b/ICL/RL/trl_source/trl/generation/__pycache__/__init__.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/generation/__pycache__/vllm_client.cpython-313.pyc b/ICL/RL/trl_source/trl/generation/__pycache__/vllm_client.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a275493d34a4e2bc8adc51228c2aad9bbe37f7d4 Binary files /dev/null and b/ICL/RL/trl_source/trl/generation/__pycache__/vllm_client.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/generation/__pycache__/vllm_generation.cpython-313.pyc b/ICL/RL/trl_source/trl/generation/__pycache__/vllm_generation.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b21c8a887ee3a466e8633e704b666ab2891a54a Binary files /dev/null and b/ICL/RL/trl_source/trl/generation/__pycache__/vllm_generation.cpython-313.pyc differ diff --git a/ICL/RL/trl_source/trl/models/__pycache__/utils.cpython-313.pyc b/ICL/RL/trl_source/trl/models/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..112e2c054d3547d6c386793858f46a989d3bca5e Binary files /dev/null and b/ICL/RL/trl_source/trl/models/__pycache__/utils.cpython-313.pyc differ diff --git a/ICL/RL_DAPO/config/__init__.py b/ICL/RL_DAPO/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/ICL/RL_DAPO/config/__init__.py @@ -0,0 +1 @@ + diff --git a/ICL/RL_DAPO/config/dapo_trainer.yaml b/ICL/RL_DAPO/config/dapo_trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9cd8bfa90c52bb28711ddc7f08ba199dd35e2a4 --- /dev/null +++ b/ICL/RL_DAPO/config/dapo_trainer.yaml @@ -0,0 +1,122 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +# === Data === +data: + train_files: /workspace/xiaobin/RL_DAPO/data/train.parquet + val_files: /workspace/xiaobin/RL_DAPO/data/val.parquet + prompt_key: prompt + reward_fn_key: reward_fn_key + truncation: left + max_prompt_length: 1024 + max_response_length: 2048 + train_batch_size: 64 + gen_batch_size: ${data.train_batch_size} + trust_remote_code: true + +# === Reward Model === +reward_model: + enable: false + # Use custom reward function instead of model-based RM + reward_manager: dapo + # Register our custom compute_score + custom_cls: xiaobin.RL_DAPO.compute_score.compute_score + overlong_buffer: + enable: true + len: 512 + penalty_factor: 1.0 + log: false + +# === Algorithm (DAPO-specific) === +algorithm: + adv_estimator: grpo + # DAPO: No KL penalty + use_kl_in_reward: false + kl_ctrl: + kl_coef: 0.0 + # DAPO: Dynamic sampling with group filtering + filter_groups: + _target_: verl.trainer.config.FilterGroupsConfig + enable: true + metric: seq_final_reward + max_num_gen_batches: 5 + +# === Actor / Rollout / Ref === +actor_rollout_ref: + model: + path: /workspace/xiaobin/SFT_model/hf_qwen3vl_siglip_vqa_iter_0000881 + enable_gradient_checkpointing: true + use_remove_padding: true + actor: + strategy: fsdp + # DAPO: Decoupled clip ratios (Clip-Higher) + clip_ratio_low: 0.2 + clip_ratio_high: 0.28 + clip_ratio_c: 10.0 + # DAPO: No KL loss + use_kl_loss: false + kl_loss_coef: 0.0 + # DAPO: Token-level loss aggregation + loss_agg_mode: token-mean + use_dynamic_bsz: true + ppo_max_token_len_per_gpu: 6144 + ppo_mini_batch_size: 8 + entropy_coeff: 0 + grad_clip: 1.0 + ulysses_sequence_parallel_size: 1 + fsdp_config: + param_offload: true + optimizer_offload: true + fsdp_size: -1 + optim: + lr: 1e-6 + lr_warmup_steps: 10 + weight_decay: 0.1 + rollout: + name: vllm + n: 8 + temperature: 1.0 + top_p: 1.0 + top_k: -1 + gpu_memory_utilization: 0.80 + tensor_model_parallel_size: 1 + enable_chunked_prefill: true + max_num_batched_tokens: 3072 + use_dynamic_bsz: true + log_prob_use_dynamic_bsz: true + log_prob_max_token_len_per_gpu: 6144 + val_kwargs: + temperature: 1.0 + top_p: 0.7 + top_k: -1 + do_sample: true + n: 1 + ref: + log_prob_use_dynamic_bsz: true + log_prob_max_token_len_per_gpu: 6144 + fsdp_config: + param_offload: true + ulysses_sequence_parallel_size: 1 + +# === Trainer === +trainer: + project_name: RL-DAPO-VQA + experiment_name: dapo-qwen3vl-retrieval-vqa + logger: '["console","wandb"]' + n_gpus_per_node: 8 + nnodes: 1 + val_before_train: true + test_freq: 10 + save_freq: 10 + total_epochs: 3 + total_training_steps: 200 + default_local_dir: /workspace/xiaobin/RL_DAPO/checkpoints + resume_mode: auto + balance_batch: true + critic_warmup: 0 + log_val_generations: 5