Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- ICL/DAPO/verl-recipe/char_count/README.md +59 -0
- ICL/DAPO/verl-recipe/char_count/create_dataset.py +198 -0
- ICL/DAPO/verl-recipe/char_count/reward_function.py +34 -0
- ICL/DAPO/verl-recipe/char_count/train_grpo.sh +45 -0
- ICL/DAPO/verl-recipe/char_count/train_sft.sh +97 -0
- ICL/DAPO/verl-recipe/collabllm/README.md +74 -0
- ICL/DAPO/verl-recipe/collabllm/utils.py +280 -0
- ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_8b_base_npu.sh +138 -0
- ICL/DAPO/verl-recipe/deepeyes/deepeyes.py +408 -0
- ICL/DAPO/verl-recipe/fault_recover/async_llm.py +84 -0
- ICL/DAPO/verl-recipe/flash_rl_ascend/README.md +121 -0
- ICL/DAPO/verl-recipe/flowrl/README.md +182 -0
- ICL/DAPO/verl-recipe/flowrl/__init__.py +17 -0
- ICL/DAPO/verl-recipe/flowrl/flowrl_fsdp_worker.py +495 -0
- ICL/DAPO/verl-recipe/flowrl/main_flowrl.py +185 -0
- ICL/DAPO/verl-recipe/flowrl/run_flowrl_qwen2.5_7b.sh +134 -0
- ICL/DAPO/verl-recipe/infigui-g1/README.md +56 -0
- ICL/DAPO/verl-recipe/langgraph_agent/__init__.py +13 -0
- ICL/DAPO/verl-recipe/langgraph_agent/chat_model.py +393 -0
- ICL/DAPO/verl-recipe/langgraph_agent/react_agent_loop.py +188 -0
- ICL/DAPO/verl-recipe/langgraph_agent/test_react_agent_loop.py +202 -0
- ICL/DAPO/verl-recipe/minicpmo/rl_dataset.py +571 -0
- ICL/DAPO/verl-recipe/prime/__init__.py +13 -0
- ICL/DAPO/verl-recipe/prime/prime_core_algos.py +147 -0
- ICL/DAPO/verl-recipe/prime/run_prime_qwen_code.sh +61 -0
- ICL/DAPO/verl-recipe/r1/run_r1_distill_qwen.sh +33 -0
- ICL/DAPO/verl-recipe/r1_ascend/Dockerfile.vllm_ascend.mindspeed.deepseekV3 +82 -0
- ICL/DAPO/verl-recipe/r1_ascend/README.md +119 -0
- ICL/DAPO/verl-recipe/r1_ascend/README_zh.md +119 -0
- ICL/DAPO/verl-recipe/r1_ascend/ray_start_grpo_npu.sh +82 -0
- ICL/DAPO/verl-recipe/r1_ascend/vllm_rollout_spmd.py +347 -0
- ICL/DAPO/verl-recipe/rep_exp/README.md +71 -0
- ICL/DAPO/verl-recipe/rep_exp/eval.sh +83 -0
- ICL/DAPO/verl-recipe/rep_exp/main_rep_exp.py +483 -0
- ICL/DAPO/verl-recipe/rep_exp/metric_utils.py +382 -0
- ICL/DAPO/verl-recipe/rep_exp/model_merge.sh +6 -0
- ICL/DAPO/verl-recipe/rep_exp/plot_pass_at_k.py +241 -0
- ICL/DAPO/verl-recipe/rep_exp/rep_exp_trainer.py +739 -0
- ICL/DAPO/verl-recipe/spin/core_algos.py +206 -0
- ICL/DAPO/verl-recipe/spin/main_spin.py +168 -0
- ICL/DAPO/verl-recipe/spin/spin_trainer.py +1312 -0
- ICL/LV/code/README.md +66 -0
- ICL/LV/code/SFT/__pycache__/dataset.cpython-310.pyc +0 -0
- ICL/LV/code/SFT/build_icl_eval_sharegpt.py +437 -0
- ICL/LV/code/SFT/check_kshot_ret_ans.py +319 -0
- ICL/LV/code/SFT/cuda-keyring_1.1-1_all.deb +0 -0
- ICL/LV/code/SFT/prepare_dataset.py +56 -0
- ICL/LV/code/adapters/gemma3_adapter.py +27 -0
- ICL/LV/code/adapters/qwen3vl_adapter.py +27 -0
- ICL/LV/code/attn map/attn map/attn map/__pycache__/token_attention_utils.cpython-313.pyc +0 -0
ICL/DAPO/verl-recipe/char_count/README.md
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Char Count
|
| 2 |
+
## Introduction
|
| 3 |
+
Char count is a simple NLP task. We create it for beginners to grasp the idea of RLVR. The task can be trained using a tiny model (e.g., https://huggingface.co/HuggingFaceTB/SmolLM2-135M) on a consumer GPU with only 8GB.
|
| 4 |
+
|
| 5 |
+
## Problem formulation
|
| 6 |
+
The prompt is: "How many {char} are there in {word}?". In order for LLM to better answer this question, we create SFT dataset with intermediate steps. For example,
|
| 7 |
+
|
| 8 |
+
```text
|
| 9 |
+
Question: How many n are there in n-i-n-e?
|
| 10 |
+
Answer:
|
| 11 |
+
n = n
|
| 12 |
+
i != n
|
| 13 |
+
n = n
|
| 14 |
+
e != n
|
| 15 |
+
\boxed{2}
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
Note that
|
| 19 |
+
- We add a dash between each individual char to make the task easier because each individual char will be tokenized to the same token by most tokenizer.
|
| 20 |
+
- In the SFT dataset, we create a CoT by listing all the individual chars and whether it equals to the target. In the end, it outputs the final answer inside the box.
|
| 21 |
+
- The task can be verified.
|
| 22 |
+
- The word is not always meaningful. Each char is sampled uniformly from a to z. We make the total length and the answer uniformly distributed within a range.
|
| 23 |
+
|
| 24 |
+
## Scripts
|
| 25 |
+
Installation
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
pip install verl==0.6.1
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
To create the dataset, run
|
| 33 |
+
```bash
|
| 34 |
+
python3 create_dataset.py
|
| 35 |
+
```
|
| 36 |
+
We create a train set and a val set. Both of them are used of SFT and RL. You can specify the total number of data, min/max length and data path.
|
| 37 |
+
|
| 38 |
+
To run the SFT
|
| 39 |
+
```bash
|
| 40 |
+
BACKEND=fsdp bash train_sft.sh # use fsdp
|
| 41 |
+
BACKEND=megatron bash train_sft.sh # use megatron
|
| 42 |
+
```
|
| 43 |
+
We train SFT for 1 epoch. After 1 epoch, the validation score is around 0.435.
|
| 44 |
+
|
| 45 |
+
Merge checkpoint trained from SFT
|
| 46 |
+
```bash
|
| 47 |
+
# sft
|
| 48 |
+
export CKPT_PATH=$HOME/experiments/char_count/models/sft/fsdp/global_step_140
|
| 49 |
+
python3 -m verl.model_merger merge --backend fsdp --local_dir $CKPT_PATH --target_dir $CKPT_PATH/huggingface/
|
| 50 |
+
# megatron
|
| 51 |
+
export CKPT_PATH=$HOME/experiments/char_count/models/sft/megatron/global_step_140
|
| 52 |
+
python3 -m verl.model_merger merge --backend megatron --local_dir $CKPT_PATH --target_dir $CKPT_PATH/huggingface/
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
To run GRPO
|
| 56 |
+
```bash
|
| 57 |
+
bash train_grpo.sh
|
| 58 |
+
```
|
| 59 |
+
We train GRPO for 2 epochs. After 2 epochs, the validation score is around 0.6.
|
ICL/DAPO/verl-recipe/char_count/create_dataset.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Task description:
|
| 17 |
+
Given a random word and a random char, count the number of occurrence of char in the word.
|
| 18 |
+
|
| 19 |
+
Create CoT dataset that split the word into separate char. Then list the char and count the occurrence.
|
| 20 |
+
|
| 21 |
+
The word set comes from shakespeare
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os.path
|
| 25 |
+
import random
|
| 26 |
+
|
| 27 |
+
prompt_template = "How many {} are there in word {}?"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def generate_random_char():
|
| 31 |
+
return chr(97 + random.randint(0, 25))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def create_prompt_response(min_length=3, max_length=5):
|
| 35 |
+
# randomly generate a length
|
| 36 |
+
word_length = random.randint(min_length, max_length)
|
| 37 |
+
# randomly generate a target count number. This makes the target number
|
| 38 |
+
target_count_number = random.randint(1, word_length)
|
| 39 |
+
|
| 40 |
+
char_lst = []
|
| 41 |
+
# generate the word
|
| 42 |
+
# step 1: generate the target word
|
| 43 |
+
target_char = generate_random_char()
|
| 44 |
+
|
| 45 |
+
for _ in range(target_count_number):
|
| 46 |
+
char_lst.append(target_char)
|
| 47 |
+
|
| 48 |
+
# step 2: generate other words
|
| 49 |
+
for _ in range(word_length - target_count_number):
|
| 50 |
+
while True:
|
| 51 |
+
char = generate_random_char()
|
| 52 |
+
if char != target_char:
|
| 53 |
+
char_lst.append(char)
|
| 54 |
+
break
|
| 55 |
+
|
| 56 |
+
# step 3: random permute char_lst
|
| 57 |
+
random.shuffle(char_lst)
|
| 58 |
+
|
| 59 |
+
word = "-".join(char_lst)
|
| 60 |
+
|
| 61 |
+
prompt = prompt_template.format(target_char, word)
|
| 62 |
+
final_answer = []
|
| 63 |
+
|
| 64 |
+
# cot
|
| 65 |
+
number = 0
|
| 66 |
+
for i, char in enumerate(char_lst):
|
| 67 |
+
cot = f"{char}"
|
| 68 |
+
if char != target_char:
|
| 69 |
+
cot += " != "
|
| 70 |
+
else:
|
| 71 |
+
cot += " = "
|
| 72 |
+
number += 1
|
| 73 |
+
cot += f"{target_char}."
|
| 74 |
+
|
| 75 |
+
final_answer.append(cot)
|
| 76 |
+
|
| 77 |
+
conclusion = f"\\boxed{{{number}}} {target_char} in {word}."
|
| 78 |
+
|
| 79 |
+
final_answer.append(conclusion)
|
| 80 |
+
|
| 81 |
+
final_answer = "\n".join(final_answer)
|
| 82 |
+
|
| 83 |
+
return prompt, final_answer
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
import argparse
|
| 88 |
+
|
| 89 |
+
parser = argparse.ArgumentParser()
|
| 90 |
+
parser.add_argument("--total_number", type=int, default=10000)
|
| 91 |
+
parser.add_argument("--min_length", type=int, default=5)
|
| 92 |
+
parser.add_argument("--max_length", type=int, default=20)
|
| 93 |
+
parser.add_argument("--data_path", type=str, default="~/data/char_count")
|
| 94 |
+
|
| 95 |
+
args = vars(parser.parse_args())
|
| 96 |
+
|
| 97 |
+
total_number = args["total_number"]
|
| 98 |
+
min_length = args["min_length"]
|
| 99 |
+
max_length = args["max_length"]
|
| 100 |
+
data_path = args["data_path"]
|
| 101 |
+
data_path = os.path.expanduser(data_path)
|
| 102 |
+
|
| 103 |
+
full_output = []
|
| 104 |
+
for _ in range(total_number):
|
| 105 |
+
output = create_prompt_response(min_length=min_length, max_length=max_length)
|
| 106 |
+
full_output.append(output)
|
| 107 |
+
|
| 108 |
+
# random reorder
|
| 109 |
+
random.shuffle(full_output)
|
| 110 |
+
|
| 111 |
+
# split for train and test
|
| 112 |
+
train_split_len = int(0.9 * len(full_output))
|
| 113 |
+
train_outputs = full_output[:train_split_len]
|
| 114 |
+
test_output = full_output[train_split_len:]
|
| 115 |
+
|
| 116 |
+
sft_train_dataset = {"messages": []}
|
| 117 |
+
|
| 118 |
+
for o in train_outputs:
|
| 119 |
+
messages = [
|
| 120 |
+
{"role": "user", "content": o[0]},
|
| 121 |
+
{"role": "assistant", "content": o[1]},
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
sft_train_dataset["messages"].append(messages)
|
| 125 |
+
|
| 126 |
+
sft_test_dataset = {"messages": []}
|
| 127 |
+
|
| 128 |
+
for o in test_output:
|
| 129 |
+
messages = [
|
| 130 |
+
{"role": "user", "content": o[0]},
|
| 131 |
+
{"role": "assistant", "content": o[1]},
|
| 132 |
+
]
|
| 133 |
+
sft_test_dataset["messages"].append(messages)
|
| 134 |
+
|
| 135 |
+
import pandas as pd
|
| 136 |
+
|
| 137 |
+
sft_train_dataset = pd.DataFrame(data=sft_train_dataset)
|
| 138 |
+
sft_test_dataset = pd.DataFrame(data=sft_test_dataset)
|
| 139 |
+
|
| 140 |
+
folder = os.path.join(data_path, "sft")
|
| 141 |
+
|
| 142 |
+
os.makedirs(folder, exist_ok=True)
|
| 143 |
+
|
| 144 |
+
sft_train_dataset.to_parquet(os.path.join(folder, "train.parquet"))
|
| 145 |
+
sft_test_dataset.to_parquet(os.path.join(folder, "test.parquet"))
|
| 146 |
+
|
| 147 |
+
# build RL dataset
|
| 148 |
+
rl_train_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []}
|
| 149 |
+
|
| 150 |
+
rl_test_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []}
|
| 151 |
+
|
| 152 |
+
from verl.utils.reward_score.math_reward import last_boxed_only_string, remove_boxed
|
| 153 |
+
|
| 154 |
+
for o in train_outputs:
|
| 155 |
+
prompt = o[0]
|
| 156 |
+
response = o[1]
|
| 157 |
+
prompt_with_template = [
|
| 158 |
+
{
|
| 159 |
+
"role": "user",
|
| 160 |
+
"content": prompt,
|
| 161 |
+
}
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
rl_train_dataset["prompt"].append(prompt_with_template)
|
| 165 |
+
rl_train_dataset["data_source"].append("char_count")
|
| 166 |
+
rl_train_dataset["ability"].append("other")
|
| 167 |
+
rl_train_dataset["reward_model"].append(
|
| 168 |
+
{"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))}
|
| 169 |
+
)
|
| 170 |
+
rl_train_dataset["extra_info"].append({"response": response})
|
| 171 |
+
|
| 172 |
+
for o in test_output:
|
| 173 |
+
prompt = o[0]
|
| 174 |
+
response = o[1]
|
| 175 |
+
prompt_with_template = [
|
| 176 |
+
{
|
| 177 |
+
"role": "user",
|
| 178 |
+
"content": prompt,
|
| 179 |
+
}
|
| 180 |
+
]
|
| 181 |
+
|
| 182 |
+
rl_test_dataset["prompt"].append(prompt_with_template)
|
| 183 |
+
rl_test_dataset["data_source"].append("char_count")
|
| 184 |
+
rl_test_dataset["ability"].append("other")
|
| 185 |
+
rl_test_dataset["reward_model"].append(
|
| 186 |
+
{"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))}
|
| 187 |
+
)
|
| 188 |
+
rl_test_dataset["extra_info"].append({"response": response})
|
| 189 |
+
|
| 190 |
+
rl_train_dataset = pd.DataFrame(data=rl_train_dataset)
|
| 191 |
+
rl_test_dataset = pd.DataFrame(data=rl_test_dataset)
|
| 192 |
+
|
| 193 |
+
folder = os.path.join(data_path, "rl")
|
| 194 |
+
|
| 195 |
+
os.makedirs(folder, exist_ok=True)
|
| 196 |
+
|
| 197 |
+
rl_train_dataset.to_parquet(os.path.join(folder, "train.parquet"))
|
| 198 |
+
rl_test_dataset.to_parquet(os.path.join(folder, "test.parquet"))
|
ICL/DAPO/verl-recipe/char_count/reward_function.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Reward function
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from verl.utils.reward_score import math_reward
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def char_count_reward_function(data_source, solution_str, ground_truth, extra_info=None):
|
| 23 |
+
try:
|
| 24 |
+
last_boxed_string = math_reward.last_boxed_only_string(solution_str)
|
| 25 |
+
if last_boxed_string is None:
|
| 26 |
+
return 0
|
| 27 |
+
solution = math_reward.remove_boxed(last_boxed_string)
|
| 28 |
+
if solution == ground_truth:
|
| 29 |
+
return 1
|
| 30 |
+
else:
|
| 31 |
+
return 0
|
| 32 |
+
except Exception:
|
| 33 |
+
print(ground_truth, solution_str)
|
| 34 |
+
return 0
|
ICL/DAPO/verl-recipe/char_count/train_grpo.sh
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
set -x
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
python3 -m verl.trainer.main_ppo \
|
| 5 |
+
algorithm.adv_estimator=grpo \
|
| 6 |
+
data.train_files=$HOME/data/char_count/rl/train.parquet \
|
| 7 |
+
data.val_files=$HOME/data/char_count/rl/test.parquet \
|
| 8 |
+
data.train_batch_size=128 \
|
| 9 |
+
data.max_prompt_length=128 \
|
| 10 |
+
data.max_response_length=128 \
|
| 11 |
+
data.filter_overlong_prompts=False \
|
| 12 |
+
data.truncation='error' \
|
| 13 |
+
actor_rollout_ref.model.path=$HOME/experiments/char_count/models/sft/megatron/global_step_140/huggingface \
|
| 14 |
+
actor_rollout_ref.actor.optim.lr=1e-6 \
|
| 15 |
+
actor_rollout_ref.model.use_remove_padding=True \
|
| 16 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
|
| 17 |
+
actor_rollout_ref.actor.use_dynamic_bsz=True \
|
| 18 |
+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5000 \
|
| 19 |
+
actor_rollout_ref.actor.use_kl_loss=False \
|
| 20 |
+
actor_rollout_ref.actor.kl_loss_coef=0.0 \
|
| 21 |
+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
| 22 |
+
actor_rollout_ref.actor.entropy_coeff=0 \
|
| 23 |
+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
| 24 |
+
actor_rollout_ref.actor.fsdp_config.param_offload=True \
|
| 25 |
+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
|
| 26 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
| 27 |
+
actor_rollout_ref.rollout.name=vllm \
|
| 28 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
|
| 29 |
+
actor_rollout_ref.rollout.n=8 \
|
| 30 |
+
actor_rollout_ref.rollout.enforce_eager=True \
|
| 31 |
+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
| 32 |
+
algorithm.use_kl_in_reward=False \
|
| 33 |
+
trainer.critic_warmup=0 \
|
| 34 |
+
trainer.logger='["console","tensorboard"]' \
|
| 35 |
+
trainer.project_name='verl_example' \
|
| 36 |
+
trainer.experiment_name='smol135m_grpo-1128a1' \
|
| 37 |
+
trainer.val_before_train=True \
|
| 38 |
+
trainer.n_gpus_per_node=1 \
|
| 39 |
+
trainer.nnodes=1 \
|
| 40 |
+
trainer.save_freq=-1 \
|
| 41 |
+
trainer.test_freq=5 \
|
| 42 |
+
trainer.total_epochs=5 \
|
| 43 |
+
trainer.use_legacy_worker_impl=disable \
|
| 44 |
+
custom_reward_function.path=./reward_function.py \
|
| 45 |
+
custom_reward_function.name=char_count_reward_function
|
ICL/DAPO/verl-recipe/char_count/train_sft.sh
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -xeuo pipefail
|
| 3 |
+
|
| 4 |
+
ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"}
|
| 5 |
+
|
| 6 |
+
TRAIN_FILES=${TRAIN_FILES:-$HOME/data/char_count/sft/train.parquet}
|
| 7 |
+
TEST_FILES=${TEST_FILES:-$HOME/data/char_count/sft/test.parquet}
|
| 8 |
+
|
| 9 |
+
backend=${BACKEND:-fsdp}
|
| 10 |
+
|
| 11 |
+
project_name=char_count-sft
|
| 12 |
+
|
| 13 |
+
RESUME_MODE=auto
|
| 14 |
+
MODEL_ID=${MODEL_ID:-HuggingFaceTB/SmolLM2-135M-Instruct}
|
| 15 |
+
|
| 16 |
+
SP_SIZE=${SP_SIZE:-1}
|
| 17 |
+
FSDP_SIZE=${FSDP_SIZE:-1}
|
| 18 |
+
FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp2"}
|
| 19 |
+
|
| 20 |
+
TP_SIZE=${TP_SIZE:-1}
|
| 21 |
+
PP_SIZE=${PP_SIZE:-1}
|
| 22 |
+
VPP_SIZE=${VPP_SIZE:-null}
|
| 23 |
+
CP_SIZE=${CP_SIZE:-1}
|
| 24 |
+
|
| 25 |
+
PAD_MODE=${PAD_MODE:-no_padding}
|
| 26 |
+
|
| 27 |
+
USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
|
| 28 |
+
|
| 29 |
+
FSDP_ENGINE_CONFIG="\
|
| 30 |
+
engine=${backend} \
|
| 31 |
+
optim=${backend} \
|
| 32 |
+
optim.lr=2e-5 \
|
| 33 |
+
optim.lr_warmup_steps_ratio=0.01 \
|
| 34 |
+
optim.weight_decay=0.1 \
|
| 35 |
+
optim.betas="[0.9,0.95]" \
|
| 36 |
+
optim.clip_grad=1.0 \
|
| 37 |
+
optim.min_lr_ratio=0.1 \
|
| 38 |
+
optim.warmup_style=cosine \
|
| 39 |
+
engine.ulysses_sequence_parallel_size=${SP_SIZE} \
|
| 40 |
+
engine.strategy=${FSDP_STRATEGY} \
|
| 41 |
+
engine.fsdp_size=${FSDP_SIZE}"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
MEGATRON_ENGINE_CONFIG="\
|
| 45 |
+
engine=${backend} \
|
| 46 |
+
optim=${backend} \
|
| 47 |
+
optim.lr=2e-5 \
|
| 48 |
+
optim.lr_warmup_steps_ratio=0.01 \
|
| 49 |
+
optim.weight_decay=0.1 \
|
| 50 |
+
optim.betas="[0.9,0.95]" \
|
| 51 |
+
optim.clip_grad=1.0 \
|
| 52 |
+
optim.lr_warmup_init=0 \
|
| 53 |
+
optim.lr_decay_style=cosine \
|
| 54 |
+
optim.min_lr=2e-6 \
|
| 55 |
+
engine.tensor_model_parallel_size=${TP_SIZE} \
|
| 56 |
+
engine.pipeline_model_parallel_size=${PP_SIZE} \
|
| 57 |
+
engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
|
| 58 |
+
engine.context_parallel_size=${CP_SIZE} \
|
| 59 |
+
engine.use_mbridge=False"
|
| 60 |
+
|
| 61 |
+
if [ "$backend" = "fsdp" ]; then
|
| 62 |
+
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
|
| 63 |
+
echo "Using fsdp engine"
|
| 64 |
+
exp_name=char_count-sft-SmolLM2-135M-Instruct-fsdp
|
| 65 |
+
else
|
| 66 |
+
ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
|
| 67 |
+
echo "Using megatron engine"
|
| 68 |
+
exp_name=char_count-sft-SmolLM2-135M-Instruct-megatron
|
| 69 |
+
fi
|
| 70 |
+
|
| 71 |
+
CKPT_HOME=${CKPT_HOME:-$HOME/experiments/char_count/models/sft/$backend}
|
| 72 |
+
mkdir -p "${CKPT_HOME}"
|
| 73 |
+
|
| 74 |
+
torchrun --standalone --nnodes=1 --nproc-per-node=${NUM_TRAINERS:-1} \
|
| 75 |
+
${ENTRYPOINT} \
|
| 76 |
+
data.train_files="${TRAIN_FILES}" \
|
| 77 |
+
data.train_batch_size=64 \
|
| 78 |
+
data.val_files="${TEST_FILES}" \
|
| 79 |
+
data.max_length=256 \
|
| 80 |
+
data.pad_mode=${PAD_MODE} \
|
| 81 |
+
data.truncation=error \
|
| 82 |
+
data.use_dynamic_bsz=True \
|
| 83 |
+
data.max_token_len_per_gpu=1792 \
|
| 84 |
+
data.messages_key=messages \
|
| 85 |
+
model.path=$MODEL_ID \
|
| 86 |
+
model.use_remove_padding=${USE_REMOVE_PADDING} \
|
| 87 |
+
${ENGINE_CONFIG} \
|
| 88 |
+
trainer.test_freq=-1 \
|
| 89 |
+
trainer.save_freq=70 \
|
| 90 |
+
trainer.logger=['console'] \
|
| 91 |
+
trainer.project_name="${project_name}" \
|
| 92 |
+
trainer.experiment_name="${exp_name}" \
|
| 93 |
+
trainer.total_epochs=1 \
|
| 94 |
+
trainer.default_local_dir="${CKPT_HOME}" \
|
| 95 |
+
trainer.resume_mode=${RESUME_MODE} \
|
| 96 |
+
trainer.max_ckpt_to_keep=5 \
|
| 97 |
+
checkpoint.save_contents=[model,optimizer,extra]
|
ICL/DAPO/verl-recipe/collabllm/README.md
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CollabLLM
|
| 2 |
+
|
| 3 |
+
This repository implements [CollabLLM](https://arxiv.org/pdf/2502.00640) (ICML 2025) using the verl framework. For the original implementation, see the [CollabLLM repository](https://github.com/Wuyxin/collabllm).
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
CollabLLM is a method for training language models to collaborate effectively in multi-turn conversations. This implementation adapts the original imlpementation to work with the Verl training framework.
|
| 7 |
+
|
| 8 |
+
## Quick start
|
| 9 |
+
|
| 10 |
+
### 0. Environment
|
| 11 |
+
Make sure the required packages for `verl` are installed. Additionally, install `litellm` and export the required API keys. The API model will be used for user simulators and, optionally, LLM Judges (see the Configuration section below).
|
| 12 |
+
|
| 13 |
+
### 1. Prepare Your Dataset
|
| 14 |
+
|
| 15 |
+
First, process your dataset using the provided script:
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
python process_dataset.py --dataset <> ... --dataset_type <sft or rl>
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
**Requirements:**
|
| 23 |
+
- Input: A Hugging Face multiturn dataset. Existing datasets: `collabllm/collabllm-multiturn-$DATASET`, with `DATASET` in one of [`math-hard(-large)`, `medium(-large)`, `bigcodebench(-large)`] (*-large are the datasets used in the CollabLLM paper)
|
| 24 |
+
- Example format: See [collabllm-multiturn-math-hard](https://huggingface.co/datasets/collabllm/collabllm-multiturn-math-hard)
|
| 25 |
+
- To generate your own dataset: Use [build_dataset.py](https://github.com/Wuyxin/collabllm/blob/main/scripts/engine/build_dataset.py) from the original CollabLLM repository
|
| 26 |
+
|
| 27 |
+
*Note: Check `process_dataset.py` for example commands and usage.*
|
| 28 |
+
|
| 29 |
+
### 2. Train Your Model
|
| 30 |
+
|
| 31 |
+
**(Optional) For Supervised Fine-Tuning (SFT):**
|
| 32 |
+
```bash
|
| 33 |
+
bash train_sft_collabllm.sh
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
**For Reinforcement Learning (RL):**
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
bash train_rl_collabllm.sh
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
The RL script shows an example to train CollabLLM on `math-hard-large`.
|
| 43 |
+
|
| 44 |
+
- The config to sample future conversations are in `recipe/collabllm/config/collabllm_interaction_config.yaml`.
|
| 45 |
+
- The Multiturn-aware Reward is aggregated from these three conversational-level rewards:
|
| 46 |
+
|
| 47 |
+
```
|
| 48 |
+
+reward_model.reward_kwargs.metric_weights.accuracy=1 \
|
| 49 |
+
+reward_model.reward_kwargs.metric_weights.interactivity=1 \
|
| 50 |
+
+reward_model.reward_kwargs.metric_weights.token_amount=-0.0001 \
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
You can remove, add, or modify the weights depending on your task. A list of implemented metrics you can already add are under `recipe/collabllm/metrics`. For example, on `medium-large`, you can replace `accuracy` with `bleu_score` via
|
| 54 |
+
```
|
| 55 |
+
+reward_model.reward_kwargs.metric_weights.bleu_score=1
|
| 56 |
+
```
|
| 57 |
+
which will instead apply bleu score on the sampled future conversations.
|
| 58 |
+
|
| 59 |
+
## Configuration
|
| 60 |
+
Read [doc](https://verl.readthedocs.io/en/latest/) for detailed configurations.
|
| 61 |
+
|
| 62 |
+
## Citation
|
| 63 |
+
If you find CollabLLM useful in your research, please cite the following:
|
| 64 |
+
|
| 65 |
+
```bibtex
|
| 66 |
+
@inproceedings{collabllm2025,
|
| 67 |
+
title={CollabLLM: From Passive Responders to Active Collaborators},
|
| 68 |
+
author={Shirley Wu and Michel Galley and Baolin Peng and Hao Cheng and
|
| 69 |
+
Gavin Li and Yao Dou and Weixin Cai and James Zou and
|
| 70 |
+
Jure Leskovec and Jianfeng Gao},
|
| 71 |
+
booktitle={International Conference on Machine Learning (ICML)},
|
| 72 |
+
year={2025}
|
| 73 |
+
}
|
| 74 |
+
```
|
ICL/DAPO/verl-recipe/collabllm/utils.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 CollabLLM team and/or its affiliates
|
| 2 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__file__)
|
| 20 |
+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def parse_messages(messages, strip_sys_prompt=True):
|
| 24 |
+
"""
|
| 25 |
+
Args:
|
| 26 |
+
messages: List[dict]
|
| 27 |
+
List of dictionaries with keys 'role' and 'content'
|
| 28 |
+
Example: messages = [{'role': 'user', 'content': 'Hello!'},
|
| 29 |
+
{'role': 'assistant', 'content': 'Hi!'}, ...]
|
| 30 |
+
"""
|
| 31 |
+
if messages is None:
|
| 32 |
+
return ""
|
| 33 |
+
|
| 34 |
+
if strip_sys_prompt:
|
| 35 |
+
messages = strip_system_prompt(messages)
|
| 36 |
+
|
| 37 |
+
chat = "\n".join(f"**{m.role.capitalize()}**: {m.content}" for m in messages)
|
| 38 |
+
|
| 39 |
+
return chat
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def strip_system_prompt(messages):
|
| 43 |
+
"""
|
| 44 |
+
Args:
|
| 45 |
+
messages: List[dict]
|
| 46 |
+
List of dictionaries with keys 'role' and 'content'
|
| 47 |
+
Example: messages = [{'role': 'user', 'content': 'Hello!'},
|
| 48 |
+
{'role': 'assistant', 'content': 'Hi!'}, ...]
|
| 49 |
+
"""
|
| 50 |
+
return [msg for msg in messages if msg.role != "system"]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def extract_json(s):
|
| 54 |
+
def convert_value(value):
|
| 55 |
+
true_values = {"true": True, "false": False, "null": None}
|
| 56 |
+
value_lower = value.lower()
|
| 57 |
+
if value_lower in true_values:
|
| 58 |
+
return true_values[value_lower]
|
| 59 |
+
try:
|
| 60 |
+
if "." in value or "e" in value.lower():
|
| 61 |
+
return float(value)
|
| 62 |
+
else:
|
| 63 |
+
return int(value)
|
| 64 |
+
except ValueError:
|
| 65 |
+
return value # Return as string if not a number
|
| 66 |
+
|
| 67 |
+
def parse_number(s, pos):
|
| 68 |
+
start = pos
|
| 69 |
+
while pos < len(s) and s[pos] in "-+0123456789.eE":
|
| 70 |
+
pos += 1
|
| 71 |
+
num_str = s[start:pos]
|
| 72 |
+
try:
|
| 73 |
+
if "." in num_str or "e" in num_str.lower():
|
| 74 |
+
return float(num_str), pos
|
| 75 |
+
else:
|
| 76 |
+
return int(num_str), pos
|
| 77 |
+
except ValueError:
|
| 78 |
+
logger.error(f"Invalid number at position {start}: {num_str}")
|
| 79 |
+
raise
|
| 80 |
+
|
| 81 |
+
def skip_whitespace(s, pos):
|
| 82 |
+
while pos < len(s) and s[pos] in " \t\n\r":
|
| 83 |
+
pos += 1
|
| 84 |
+
return pos
|
| 85 |
+
|
| 86 |
+
def parse_string(s, pos):
|
| 87 |
+
quote_char = s[pos]
|
| 88 |
+
assert quote_char in ('"', "'")
|
| 89 |
+
pos += 1
|
| 90 |
+
result = ""
|
| 91 |
+
while pos < len(s):
|
| 92 |
+
c = s[pos]
|
| 93 |
+
if c == "\\":
|
| 94 |
+
pos += 1
|
| 95 |
+
if pos >= len(s):
|
| 96 |
+
raise ValueError("Invalid escape sequence")
|
| 97 |
+
c = s[pos]
|
| 98 |
+
escape_sequences = {"n": "\n", "t": "\t", "r": "\r", "\\": "\\", quote_char: quote_char}
|
| 99 |
+
result += escape_sequences.get(c, c)
|
| 100 |
+
elif c == quote_char:
|
| 101 |
+
pos += 1
|
| 102 |
+
# Attempt to convert to a number if possible
|
| 103 |
+
converted_value = convert_value(result)
|
| 104 |
+
return converted_value, pos
|
| 105 |
+
else:
|
| 106 |
+
result += c
|
| 107 |
+
pos += 1
|
| 108 |
+
raise ValueError("Unterminated string")
|
| 109 |
+
|
| 110 |
+
def parse_key(s, pos):
|
| 111 |
+
pos = skip_whitespace(s, pos)
|
| 112 |
+
if s[pos] in ('"', "'"):
|
| 113 |
+
key, pos = parse_string(s, pos)
|
| 114 |
+
return key, pos
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError(f"Expected string for key at position {pos}")
|
| 117 |
+
|
| 118 |
+
def parse_object(s, pos):
|
| 119 |
+
obj = {}
|
| 120 |
+
assert s[pos] == "{"
|
| 121 |
+
pos += 1
|
| 122 |
+
pos = skip_whitespace(s, pos)
|
| 123 |
+
while pos < len(s) and s[pos] != "}":
|
| 124 |
+
pos = skip_whitespace(s, pos)
|
| 125 |
+
key, pos = parse_key(s, pos)
|
| 126 |
+
pos = skip_whitespace(s, pos)
|
| 127 |
+
if pos >= len(s) or s[pos] != ":":
|
| 128 |
+
raise ValueError(f'Expected ":" at position {pos}')
|
| 129 |
+
pos += 1
|
| 130 |
+
pos = skip_whitespace(s, pos)
|
| 131 |
+
value, pos = parse_value(s, pos)
|
| 132 |
+
obj[key] = value
|
| 133 |
+
pos = skip_whitespace(s, pos)
|
| 134 |
+
if pos < len(s) and s[pos] == ",":
|
| 135 |
+
pos += 1
|
| 136 |
+
pos = skip_whitespace(s, pos)
|
| 137 |
+
elif pos < len(s) and s[pos] == "}":
|
| 138 |
+
break
|
| 139 |
+
elif pos < len(s) and s[pos] != "}":
|
| 140 |
+
raise ValueError(f'Expected "," or "}}" at position {pos}')
|
| 141 |
+
if pos >= len(s) or s[pos] != "}":
|
| 142 |
+
raise ValueError(f'Expected "}}" at position {pos}')
|
| 143 |
+
pos += 1
|
| 144 |
+
return obj, pos
|
| 145 |
+
|
| 146 |
+
def parse_array(s, pos):
|
| 147 |
+
lst = []
|
| 148 |
+
assert s[pos] == "["
|
| 149 |
+
pos += 1
|
| 150 |
+
pos = skip_whitespace(s, pos)
|
| 151 |
+
while pos < len(s) and s[pos] != "]":
|
| 152 |
+
value, pos = parse_value(s, pos)
|
| 153 |
+
lst.append(value)
|
| 154 |
+
pos = skip_whitespace(s, pos)
|
| 155 |
+
if pos < len(s) and s[pos] == ",":
|
| 156 |
+
pos += 1
|
| 157 |
+
pos = skip_whitespace(s, pos)
|
| 158 |
+
elif pos < len(s) and s[pos] == "]":
|
| 159 |
+
break
|
| 160 |
+
elif pos < len(s) and s[pos] != "]":
|
| 161 |
+
raise ValueError(f'Expected "," or "]" at position {pos}')
|
| 162 |
+
if pos >= len(s) or s[pos] != "]":
|
| 163 |
+
raise ValueError(f'Expected "]" at position {pos}')
|
| 164 |
+
pos += 1
|
| 165 |
+
return lst, pos
|
| 166 |
+
|
| 167 |
+
def parse_triple_quoted_string(s, pos):
|
| 168 |
+
if s[pos : pos + 3] == "'''":
|
| 169 |
+
quote_str = "'''"
|
| 170 |
+
elif s[pos : pos + 3] == '"""':
|
| 171 |
+
quote_str = '"""'
|
| 172 |
+
else:
|
| 173 |
+
raise ValueError(f"Expected triple quotes at position {pos}")
|
| 174 |
+
pos += 3
|
| 175 |
+
result = ""
|
| 176 |
+
while pos < len(s):
|
| 177 |
+
if s[pos : pos + 3] == quote_str:
|
| 178 |
+
pos += 3
|
| 179 |
+
# Attempt to convert to a number if possible
|
| 180 |
+
converted_value = convert_value(result)
|
| 181 |
+
return converted_value, pos
|
| 182 |
+
else:
|
| 183 |
+
result += s[pos]
|
| 184 |
+
pos += 1
|
| 185 |
+
raise ValueError("Unterminated triple-quoted string")
|
| 186 |
+
|
| 187 |
+
def parse_value(s, pos):
|
| 188 |
+
pos = skip_whitespace(s, pos)
|
| 189 |
+
if pos >= len(s):
|
| 190 |
+
raise ValueError("Unexpected end of input")
|
| 191 |
+
if s[pos] == "{":
|
| 192 |
+
return parse_object(s, pos)
|
| 193 |
+
elif s[pos] == "[":
|
| 194 |
+
return parse_array(s, pos)
|
| 195 |
+
elif s[pos : pos + 3] in ("'''", '"""'):
|
| 196 |
+
return parse_triple_quoted_string(s, pos)
|
| 197 |
+
elif s[pos] in ('"', "'"):
|
| 198 |
+
return parse_string(s, pos)
|
| 199 |
+
elif s[pos : pos + 4].lower() == "true":
|
| 200 |
+
return True, pos + 4
|
| 201 |
+
elif s[pos : pos + 5].lower() == "false":
|
| 202 |
+
return False, pos + 5
|
| 203 |
+
elif s[pos : pos + 4].lower() == "null":
|
| 204 |
+
return None, pos + 4
|
| 205 |
+
elif s[pos] in "-+0123456789.":
|
| 206 |
+
return parse_number(s, pos)
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError(f"Unexpected character at position {pos}: {s[pos]}")
|
| 209 |
+
|
| 210 |
+
json_start = s.index("{")
|
| 211 |
+
json_end = s.rfind("}")
|
| 212 |
+
s = s[json_start : json_end + 1]
|
| 213 |
+
|
| 214 |
+
s = s.strip()
|
| 215 |
+
result, pos = parse_value(s, 0)
|
| 216 |
+
pos = skip_whitespace(s, pos)
|
| 217 |
+
if pos != len(s):
|
| 218 |
+
raise ValueError(f"Unexpected content at position {pos}")
|
| 219 |
+
return result
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def remove_think_block(msg: dict):
|
| 223 |
+
"""
|
| 224 |
+
remove <think>.*?</think> from content
|
| 225 |
+
"""
|
| 226 |
+
if "content" in msg and isinstance(msg["content"], str):
|
| 227 |
+
msg["content"] = re.sub(r"<think>.*?</think>", "", msg["content"], flags=re.DOTALL).strip()
|
| 228 |
+
return msg
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def is_valid_messages(msg: dict) -> bool:
|
| 232 |
+
"""
|
| 233 |
+
check if is valid messages, including:
|
| 234 |
+
1. <think> is paried with </think>
|
| 235 |
+
2. is not empty inside and outside <think>
|
| 236 |
+
3. is not nested, and at most one <think> block is allowed.
|
| 237 |
+
4. can not be empty if remove ending "<|im_end|>"
|
| 238 |
+
"""
|
| 239 |
+
content = msg.get("content")
|
| 240 |
+
if not isinstance(content, str):
|
| 241 |
+
return True
|
| 242 |
+
|
| 243 |
+
# Base case: empty or whitespace-only content is invalid.
|
| 244 |
+
if not content.strip():
|
| 245 |
+
return False
|
| 246 |
+
|
| 247 |
+
num_think_open = content.count("<think>")
|
| 248 |
+
num_think_close = content.count("</think>")
|
| 249 |
+
|
| 250 |
+
# Rule 1: Check for paired tags.
|
| 251 |
+
if num_think_open != num_think_close:
|
| 252 |
+
return False
|
| 253 |
+
|
| 254 |
+
# Rule 3: Allow at most one think block.
|
| 255 |
+
if num_think_open > 1:
|
| 256 |
+
return False
|
| 257 |
+
|
| 258 |
+
# Case 1: No <think> blocks.
|
| 259 |
+
if num_think_open == 0:
|
| 260 |
+
visible_content = content
|
| 261 |
+
# Case 2: Exactly one <think> block.
|
| 262 |
+
else:
|
| 263 |
+
# Rule 2: Check for empty content inside the think block.
|
| 264 |
+
match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
| 265 |
+
if not match or not match.group(1).strip():
|
| 266 |
+
return False
|
| 267 |
+
|
| 268 |
+
# The "visible" content is what's outside the think block.
|
| 269 |
+
visible_content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL)
|
| 270 |
+
|
| 271 |
+
visible_content = visible_content.strip()
|
| 272 |
+
|
| 273 |
+
# Rule 4 & 2 (outside): Check if visible content is empty after handling <|im_end|>.
|
| 274 |
+
if visible_content.endswith("<|im_end|>"):
|
| 275 |
+
visible_content = visible_content[: -len("<|im_end|>")]
|
| 276 |
+
|
| 277 |
+
if not visible_content.strip():
|
| 278 |
+
return False
|
| 279 |
+
|
| 280 |
+
return True
|
ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_8b_base_npu.sh
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
project_name='DAPO'
|
| 3 |
+
exp_name='DAPO-Qwen3-8B-Base'
|
| 4 |
+
|
| 5 |
+
adv_estimator=grpo
|
| 6 |
+
|
| 7 |
+
use_kl_in_reward=False
|
| 8 |
+
kl_coef=0.0
|
| 9 |
+
use_kl_loss=False
|
| 10 |
+
kl_loss_coef=0.0
|
| 11 |
+
|
| 12 |
+
clip_ratio_low=0.2
|
| 13 |
+
clip_ratio_high=0.28
|
| 14 |
+
|
| 15 |
+
max_prompt_length=$((1024 * 2))
|
| 16 |
+
max_response_length=$((1024 * 20))
|
| 17 |
+
enable_overlong_buffer=True
|
| 18 |
+
overlong_buffer_len=$((1024 * 4))
|
| 19 |
+
overlong_penalty_factor=1.0
|
| 20 |
+
|
| 21 |
+
loss_agg_mode="token-mean"
|
| 22 |
+
|
| 23 |
+
enable_filter_groups=False
|
| 24 |
+
filter_groups_metric=acc
|
| 25 |
+
max_num_gen_batches=10
|
| 26 |
+
train_prompt_bsz=16
|
| 27 |
+
gen_prompt_bsz=$((train_prompt_bsz * 3))
|
| 28 |
+
n_resp_per_prompt=16
|
| 29 |
+
train_prompt_mini_bsz=1
|
| 30 |
+
|
| 31 |
+
# Ray
|
| 32 |
+
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
| 33 |
+
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
| 34 |
+
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
| 35 |
+
NNODES=${NNODES:-1}
|
| 36 |
+
# Paths
|
| 37 |
+
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
| 38 |
+
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-8B-Base"}
|
| 39 |
+
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
| 40 |
+
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
|
| 41 |
+
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
|
| 42 |
+
|
| 43 |
+
# Algorithm
|
| 44 |
+
temperature=1.0
|
| 45 |
+
top_p=1.0
|
| 46 |
+
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
| 47 |
+
|
| 48 |
+
# Performance Related Parameter
|
| 49 |
+
sp_size=2
|
| 50 |
+
use_dynamic_bsz=True
|
| 51 |
+
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))
|
| 52 |
+
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))
|
| 53 |
+
offload=True
|
| 54 |
+
gen_tp=2
|
| 55 |
+
|
| 56 |
+
ray job submit --runtime-env="${RUNTIME_ENV}" \
|
| 57 |
+
-- python3 -m recipe.dapo.main_dapo \
|
| 58 |
+
data.train_files="${TRAIN_FILE}" \
|
| 59 |
+
data.val_files="${TEST_FILE}" \
|
| 60 |
+
data.prompt_key=prompt \
|
| 61 |
+
data.truncation='left' \
|
| 62 |
+
data.max_prompt_length=${max_prompt_length} \
|
| 63 |
+
data.max_response_length=${max_response_length} \
|
| 64 |
+
data.gen_batch_size=${gen_prompt_bsz} \
|
| 65 |
+
data.train_batch_size=${train_prompt_bsz} \
|
| 66 |
+
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
| 67 |
+
algorithm.adv_estimator=${adv_estimator} \
|
| 68 |
+
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
| 69 |
+
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
| 70 |
+
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
| 71 |
+
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
| 72 |
+
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
| 73 |
+
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
| 74 |
+
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
| 75 |
+
algorithm.filter_groups.enable=${enable_filter_groups} \
|
| 76 |
+
algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
|
| 77 |
+
algorithm.filter_groups.metric=${filter_groups_metric} \
|
| 78 |
+
actor_rollout_ref.model.use_remove_padding=True \
|
| 79 |
+
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
| 80 |
+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
| 81 |
+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
| 82 |
+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
| 83 |
+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
| 84 |
+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
| 85 |
+
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
| 86 |
+
+actor_rollout_ref.model.override_config.attention_dropout=0. \
|
| 87 |
+
+actor_rollout_ref.model.override_config.embd_pdrop=0. \
|
| 88 |
+
+actor_rollout_ref.model.override_config.resid_pdrop=0. \
|
| 89 |
+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
| 90 |
+
actor_rollout_ref.actor.optim.lr=1e-6 \
|
| 91 |
+
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
| 92 |
+
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
| 93 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
| 94 |
+
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
| 95 |
+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
|
| 96 |
+
actor_rollout_ref.actor.entropy_coeff=0 \
|
| 97 |
+
actor_rollout_ref.actor.grad_clip=1.0 \
|
| 98 |
+
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
| 99 |
+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
| 100 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.90 \
|
| 101 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
| 102 |
+
actor_rollout_ref.rollout.enable_chunked_prefill=False \
|
| 103 |
+
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
| 104 |
+
actor_rollout_ref.rollout.temperature=${temperature} \
|
| 105 |
+
actor_rollout_ref.rollout.top_p=${top_p} \
|
| 106 |
+
actor_rollout_ref.rollout.top_k="${top_k}" \
|
| 107 |
+
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
| 108 |
+
actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
|
| 109 |
+
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
| 110 |
+
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
| 111 |
+
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
| 112 |
+
actor_rollout_ref.rollout.name=vllm \
|
| 113 |
+
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
|
| 114 |
+
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
| 115 |
+
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
|
| 116 |
+
reward_model.reward_manager=dapo \
|
| 117 |
+
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
|
| 118 |
+
reward_model.overlong_buffer.len=${overlong_buffer_len} \
|
| 119 |
+
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
|
| 120 |
+
trainer.logger=['console'] \
|
| 121 |
+
trainer.project_name="${project_name}" \
|
| 122 |
+
trainer.experiment_name="${exp_name}" \
|
| 123 |
+
trainer.n_gpus_per_node=8 \
|
| 124 |
+
trainer.nnodes="${NNODES}" \
|
| 125 |
+
trainer.val_before_train=False \
|
| 126 |
+
trainer.test_freq=10 \
|
| 127 |
+
trainer.save_freq=20 \
|
| 128 |
+
trainer.total_epochs=1 \
|
| 129 |
+
trainer.total_training_steps=100 \
|
| 130 |
+
trainer.default_local_dir="${CKPTS_DIR}" \
|
| 131 |
+
trainer.resume_mode=auto \
|
| 132 |
+
data.shuffle=False \
|
| 133 |
+
actor_rollout_ref.actor.use_torch_compile=False \
|
| 134 |
+
actor_rollout_ref.ref.use_torch_compile=False \
|
| 135 |
+
actor_rollout_ref.actor.entropy_checkpointing=True \
|
| 136 |
+
actor_rollout_ref.ref.entropy_checkpointing=True \
|
| 137 |
+
actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
|
| 138 |
+
actor_rollout_ref.ref.fsdp_config.forward_prefetch=True
|
ICL/DAPO/verl-recipe/deepeyes/deepeyes.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import io
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
import random
|
| 18 |
+
import re
|
| 19 |
+
|
| 20 |
+
import requests
|
| 21 |
+
from openai import OpenAI
|
| 22 |
+
from PIL import Image
|
| 23 |
+
|
| 24 |
+
import verl.utils.torch_functional as verl_F
|
| 25 |
+
from verl.utils.dataset.rl_dataset import RLHFDataset
|
| 26 |
+
from verl.utils.model import compute_position_id_with_mask
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
openai_api_key = "EMPTY"
|
| 31 |
+
openai_api_base = os.environ.get("LLM_AS_A_JUDGE_BASE", "http://10.1.100.71:18901/v1")
|
| 32 |
+
|
| 33 |
+
client = OpenAI(
|
| 34 |
+
api_key=openai_api_key,
|
| 35 |
+
base_url=openai_api_base,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
model_name = ""
|
| 39 |
+
if openai_api_base:
|
| 40 |
+
try:
|
| 41 |
+
response = requests.get(f"{openai_api_base}/models")
|
| 42 |
+
response.raise_for_status()
|
| 43 |
+
models = response.json()
|
| 44 |
+
if models.get("data"):
|
| 45 |
+
model_name = models["data"][0]["id"]
|
| 46 |
+
else:
|
| 47 |
+
logger.warning("No models found at the specified API base for reward scoring.")
|
| 48 |
+
except (requests.exceptions.RequestException, KeyError, IndexError) as e:
|
| 49 |
+
logger.warning(f"Failed to get model from {openai_api_base}: {e}. Reward scoring will be disabled.")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class CustomRLHFDataset(RLHFDataset):
|
| 53 |
+
def __getitem__(self, item):
|
| 54 |
+
"""
|
| 55 |
+
Note that we also return the raw_input_ids so that it can be combined with other chat template
|
| 56 |
+
"""
|
| 57 |
+
row_dict: dict = self.dataframe[item]
|
| 58 |
+
row_dict[self.prompt_key] = [
|
| 59 |
+
{
|
| 60 |
+
"role": "system",
|
| 61 |
+
# We don't need tool description, because custom_chat_template will add it.
|
| 62 |
+
"content": (
|
| 63 |
+
"You are a helpful assistant. You can call functions to assist with the user query. "
|
| 64 |
+
"Important: You must call only one function at a time. After each function call, "
|
| 65 |
+
"wait for the execution result before making the next function call if needed."
|
| 66 |
+
),
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"role": "user",
|
| 70 |
+
"content": row_dict[self.prompt_key][1]["content"],
|
| 71 |
+
},
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
images = []
|
| 75 |
+
row_dict_images = row_dict.get(self.image_key, None)
|
| 76 |
+
if row_dict_images:
|
| 77 |
+
images = [Image.open(io.BytesIO(image["bytes"])) for image in row_dict_images]
|
| 78 |
+
messages = self._build_messages(row_dict)
|
| 79 |
+
|
| 80 |
+
if self.processor is not None:
|
| 81 |
+
raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
| 82 |
+
model_inputs = self.processor(text=[raw_prompt], images=images, return_tensors="pt")
|
| 83 |
+
|
| 84 |
+
input_ids = model_inputs.pop("input_ids")
|
| 85 |
+
attention_mask = model_inputs.pop("attention_mask")
|
| 86 |
+
|
| 87 |
+
if "second_per_grid_ts" in model_inputs:
|
| 88 |
+
model_inputs.pop("second_per_grid_ts")
|
| 89 |
+
|
| 90 |
+
else:
|
| 91 |
+
raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
| 92 |
+
model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
|
| 93 |
+
input_ids = model_inputs.pop("input_ids")
|
| 94 |
+
attention_mask = model_inputs.pop("attention_mask")
|
| 95 |
+
|
| 96 |
+
input_ids, attention_mask = verl_F.postprocess_data(
|
| 97 |
+
input_ids=input_ids,
|
| 98 |
+
attention_mask=attention_mask,
|
| 99 |
+
max_length=self.max_prompt_length,
|
| 100 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 101 |
+
left_pad=True,
|
| 102 |
+
truncation=self.truncation,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
|
| 106 |
+
from verl.models.transformers.qwen2_vl import get_rope_index
|
| 107 |
+
|
| 108 |
+
position_ids = [
|
| 109 |
+
get_rope_index(
|
| 110 |
+
self.processor,
|
| 111 |
+
input_ids=input_ids[0],
|
| 112 |
+
image_grid_thw=model_inputs.get("image_grid_thw"),
|
| 113 |
+
video_grid_thw=model_inputs.get("video_grid_thw"),
|
| 114 |
+
second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
|
| 115 |
+
attention_mask=attention_mask[0],
|
| 116 |
+
)
|
| 117 |
+
] # (1, 3, seq_len)
|
| 118 |
+
|
| 119 |
+
else:
|
| 120 |
+
position_ids = compute_position_id_with_mask(attention_mask)
|
| 121 |
+
|
| 122 |
+
row_dict["input_ids"] = input_ids[0]
|
| 123 |
+
row_dict["attention_mask"] = attention_mask[0]
|
| 124 |
+
row_dict["position_ids"] = position_ids[0]
|
| 125 |
+
|
| 126 |
+
raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
|
| 127 |
+
if len(raw_prompt_ids) > self.max_prompt_length:
|
| 128 |
+
if self.truncation == "left":
|
| 129 |
+
raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
|
| 130 |
+
elif self.truncation == "right":
|
| 131 |
+
raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
|
| 132 |
+
elif self.truncation == "middle":
|
| 133 |
+
left_half = self.max_prompt_length // 2
|
| 134 |
+
right_half = self.max_prompt_length - left_half
|
| 135 |
+
raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]
|
| 136 |
+
elif self.truncation == "error":
|
| 137 |
+
raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")
|
| 138 |
+
|
| 139 |
+
row_dict["raw_prompt_ids"] = raw_prompt_ids
|
| 140 |
+
# encode prompts without chat template
|
| 141 |
+
if self.return_raw_chat:
|
| 142 |
+
row_dict["raw_prompt"] = messages
|
| 143 |
+
|
| 144 |
+
# get prompts with chat template
|
| 145 |
+
if self.return_full_prompt:
|
| 146 |
+
row_dict["full_prompts"] = raw_prompt # array of strings
|
| 147 |
+
|
| 148 |
+
# add index for each prompt
|
| 149 |
+
index = row_dict.get("extra_info", {}).get("index", 0)
|
| 150 |
+
tools_kwargs = {
|
| 151 |
+
"image_zoom_in_tool": {
|
| 152 |
+
"create_kwargs": {"image": images[0]},
|
| 153 |
+
# "execute_kwargs": {},
|
| 154 |
+
# "calc_reward_kwargs": {},
|
| 155 |
+
# "release_kwargs": {},
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
row_dict["index"] = index
|
| 159 |
+
row_dict["tools_kwargs"] = tools_kwargs
|
| 160 |
+
row_dict["agent_name"] = "tool_agent"
|
| 161 |
+
return row_dict
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def compute_score(data_source: str, solution_str: str, ground_truth: str, extra_info=None) -> float:
|
| 165 |
+
"""
|
| 166 |
+
Compute reward score for model solutions with robust handling of various formats.
|
| 167 |
+
|
| 168 |
+
Returns a weighted combination of:
|
| 169 |
+
- Accuracy reward (0.8 weight): Whether the answer is semantically correct
|
| 170 |
+
- Format reward (0.2 weight): Whether the output follows expected format
|
| 171 |
+
- Tool reward (1.2 weight): Whether tools were used when answer is correct
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
# Initialize tracking variables
|
| 175 |
+
is_format_error = False
|
| 176 |
+
|
| 177 |
+
# 1. Check <think> tag format
|
| 178 |
+
count_think_1 = solution_str.count("<think>")
|
| 179 |
+
count_think_2 = solution_str.count("</think>")
|
| 180 |
+
if count_think_1 != count_think_2:
|
| 181 |
+
is_format_error = True
|
| 182 |
+
|
| 183 |
+
# 2. Check vision tokens (skip this since tokenizer removes special tokens)
|
| 184 |
+
# We'll use <tool_call> and <tool_response> instead to detect tool usage
|
| 185 |
+
|
| 186 |
+
# 3. Extract answer text with multiple fallback strategies
|
| 187 |
+
answer_text = ""
|
| 188 |
+
|
| 189 |
+
# Strategy 1: Try to extract from <answer> tags first
|
| 190 |
+
predict_no_think = (
|
| 191 |
+
solution_str.split("</think>")[-1].strip() if "</think>" in solution_str else solution_str.strip()
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Check <answer> tag format
|
| 195 |
+
count_answer_1 = predict_no_think.count("<answer>")
|
| 196 |
+
count_answer_2 = predict_no_think.count("</answer>")
|
| 197 |
+
if count_answer_1 != count_answer_2:
|
| 198 |
+
is_format_error = True
|
| 199 |
+
|
| 200 |
+
# Try to extract from <answer> tags
|
| 201 |
+
answer_match = re.search(r"<answer>(.*?)</answer>", predict_no_think, re.DOTALL)
|
| 202 |
+
if answer_match:
|
| 203 |
+
answer_text = answer_match.group(1).strip()
|
| 204 |
+
else:
|
| 205 |
+
# No proper <answer> tags found - this is a format error
|
| 206 |
+
is_format_error = True
|
| 207 |
+
|
| 208 |
+
# Strategy 2: If no <answer> tags, extract content after tool responses
|
| 209 |
+
# Look for pattern: <tool_response>...</tool_response>assistant\n[actual_answer]
|
| 210 |
+
tool_response_match = re.search(
|
| 211 |
+
r"</tool_response>\s*assistant\s*\n(.*?)$", predict_no_think, re.DOTALL | re.MULTILINE
|
| 212 |
+
)
|
| 213 |
+
if tool_response_match:
|
| 214 |
+
answer_text = tool_response_match.group(1).strip()
|
| 215 |
+
else:
|
| 216 |
+
# Strategy 3: If no tool responses, look for content after </think>
|
| 217 |
+
if "</think>" in solution_str:
|
| 218 |
+
# Remove any remaining tool-related tags and extract meaningful content
|
| 219 |
+
remaining_content = predict_no_think
|
| 220 |
+
# Remove tool calls and responses
|
| 221 |
+
remaining_content = re.sub(r"<tool_call>.*?</tool_call>", "", remaining_content, flags=re.DOTALL)
|
| 222 |
+
remaining_content = re.sub(
|
| 223 |
+
r"<tool_response>.*?</tool_response>", "", remaining_content, flags=re.DOTALL
|
| 224 |
+
)
|
| 225 |
+
# Remove user/assistant markers
|
| 226 |
+
remaining_content = re.sub(r"\b(user|assistant)\b", "", remaining_content)
|
| 227 |
+
answer_text = remaining_content.strip()
|
| 228 |
+
else:
|
| 229 |
+
# Strategy 4: Use the entire solution_str as fallback
|
| 230 |
+
answer_text = solution_str.strip()
|
| 231 |
+
|
| 232 |
+
# Clean up answer text
|
| 233 |
+
answer_text = answer_text.strip()
|
| 234 |
+
|
| 235 |
+
# If answer is still empty after all strategies, mark as format error
|
| 236 |
+
if not answer_text:
|
| 237 |
+
is_format_error = True
|
| 238 |
+
answer_text = solution_str.strip() # Use full text as last resort
|
| 239 |
+
|
| 240 |
+
# 4. Evaluate correctness using LLM judge
|
| 241 |
+
question_text = extra_info.get("question", "") if extra_info else ""
|
| 242 |
+
|
| 243 |
+
if not client or not model_name:
|
| 244 |
+
logger.warning("Reward function client not initialized or model name not found.")
|
| 245 |
+
return 0.0
|
| 246 |
+
|
| 247 |
+
system_prompt = (
|
| 248 |
+
"You are an expert evaluator. Your task is to determine if a model's answer is semantically equivalent to a "
|
| 249 |
+
"provided standard answer, given a specific question.\n"
|
| 250 |
+
"Your evaluation must be strict. The model's answer is only correct if it fully matches the meaning of the "
|
| 251 |
+
"standard answer.\n"
|
| 252 |
+
'You must provide your final judgement as a single word: either "CORRECT" or "INCORRECT". Do not provide '
|
| 253 |
+
"any explanation or other text."
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
user_prompt = (
|
| 257 |
+
f"I will provide a question, a standard answer, and a model's answer. You must evaluate if the model's "
|
| 258 |
+
f"answer is correct.\n\n"
|
| 259 |
+
f"---\n"
|
| 260 |
+
f"**Example 1:**\n"
|
| 261 |
+
f"[Question]: Is the countertop tan or blue?\n"
|
| 262 |
+
f"[Standard Answer]: The countertop is tan.\n"
|
| 263 |
+
f"[Model's Answer]: tan\n"
|
| 264 |
+
f"[Your Judgement]: CORRECT\n"
|
| 265 |
+
f"---\n"
|
| 266 |
+
f"**Example 2:**\n"
|
| 267 |
+
f"[Question]: Is the man phone both blue and closed?\n"
|
| 268 |
+
f"[Standard Answer]: Yes, the man phone is both blue and closed.\n"
|
| 269 |
+
f"[Model's Answer]: No.\n"
|
| 270 |
+
f"[Your Judgement]: INCORRECT\n"
|
| 271 |
+
f"---\n"
|
| 272 |
+
f"**Task:**\n"
|
| 273 |
+
f"[Question]: {question_text}\n"
|
| 274 |
+
f"[Standard Answer]: {ground_truth}\n"
|
| 275 |
+
f"[Model's Answer]: {answer_text}\n"
|
| 276 |
+
f"[Your Judgement]:"
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
try:
|
| 280 |
+
chat_response = client.chat.completions.create(
|
| 281 |
+
model=model_name,
|
| 282 |
+
messages=[
|
| 283 |
+
{"role": "system", "content": system_prompt},
|
| 284 |
+
{"role": "user", "content": user_prompt},
|
| 285 |
+
],
|
| 286 |
+
seed=random.randint(0, 1000000),
|
| 287 |
+
temperature=0.1, # Lower temperature for more deterministic judgement
|
| 288 |
+
extra_body={
|
| 289 |
+
"chat_template_kwargs": {"enable_thinking": False},
|
| 290 |
+
},
|
| 291 |
+
)
|
| 292 |
+
response = chat_response.choices[0].message.content.strip()
|
| 293 |
+
except Exception as e:
|
| 294 |
+
logger.warning(f" [WARNING] Chat completion request failed: {e}")
|
| 295 |
+
return 0.0
|
| 296 |
+
|
| 297 |
+
# Parse LLM judge response
|
| 298 |
+
if re.search(r"\bCORRECT\b", response, re.IGNORECASE):
|
| 299 |
+
acc_reward = 1.0
|
| 300 |
+
elif re.search(r"\bINCORRECT\b", response, re.IGNORECASE):
|
| 301 |
+
acc_reward = 0.0
|
| 302 |
+
else:
|
| 303 |
+
logger.warning(
|
| 304 |
+
f" [WARNING] Judgement format error. Expected 'CORRECT' or 'INCORRECT'.\n"
|
| 305 |
+
f"Response: '{response}'\n"
|
| 306 |
+
f"Model Answer: '{answer_text}'\n"
|
| 307 |
+
f"Ground Truth: '{ground_truth}'"
|
| 308 |
+
)
|
| 309 |
+
acc_reward = 0.0
|
| 310 |
+
|
| 311 |
+
# Penalize excessively long answers (potential judge hacking)
|
| 312 |
+
if len(answer_text) >= 1000:
|
| 313 |
+
acc_reward = 0.0
|
| 314 |
+
is_format_error = True
|
| 315 |
+
|
| 316 |
+
# 5. Check tool usage - look for tool_call/tool_response patterns instead of vision tokens
|
| 317 |
+
has_tool_usage = bool(
|
| 318 |
+
re.search(r"<tool_call>.*?</tool_call>", solution_str, re.DOTALL)
|
| 319 |
+
or re.search(r"<tool_response>.*?</tool_response>", solution_str, re.DOTALL)
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Tool reward: only give if tools were used AND answer is correct
|
| 323 |
+
tool_reward = 1.0 if has_tool_usage and acc_reward > 0.5 else 0.0
|
| 324 |
+
|
| 325 |
+
# Format reward: penalty for format errors
|
| 326 |
+
format_reward = -1.0 if is_format_error else 0.0
|
| 327 |
+
|
| 328 |
+
# Log debug information for problematic cases
|
| 329 |
+
if is_format_error or not answer_text:
|
| 330 |
+
logger.debug(
|
| 331 |
+
f"Format issue detected:\n"
|
| 332 |
+
f"Solution: {solution_str[:200]}...\n"
|
| 333 |
+
f"Extracted answer: '{answer_text}'\n"
|
| 334 |
+
f"Format error: {is_format_error}\n"
|
| 335 |
+
f"Tool usage: {has_tool_usage}"
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Final weighted score
|
| 339 |
+
final_score = 0.8 * acc_reward + 0.2 * format_reward + 1.2 * tool_reward
|
| 340 |
+
|
| 341 |
+
return final_score
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
if __name__ == "__main__":
|
| 345 |
+
# Test case 1: Original test case
|
| 346 |
+
predict_str = "The answer is 2 + 2 = 4 </think> <answer> right </answer> <answer> left </answer>"
|
| 347 |
+
ground_truth = "left"
|
| 348 |
+
extra_info = {
|
| 349 |
+
"answer": "The woman is to the left of the man who is holding the camera.",
|
| 350 |
+
"id": 0,
|
| 351 |
+
"image": "/cpfs/user/honglingyi/DATA/LLM/Vstar/gqa/images/713270.jpg",
|
| 352 |
+
"pred_ans": "The woman is to the right of the man who is holding the camera.",
|
| 353 |
+
"question": "Is the woman to the left or to the right of the man who is holding the camera?",
|
| 354 |
+
}
|
| 355 |
+
print("=== Test Case 1: Original test ===")
|
| 356 |
+
import time
|
| 357 |
+
|
| 358 |
+
time_start = time.time()
|
| 359 |
+
score = compute_score("common_reasoning", predict_str, ground_truth, extra_info)
|
| 360 |
+
print(f"Score: {score}")
|
| 361 |
+
time_end = time.time()
|
| 362 |
+
print(f"Time: {time_end - time_start}")
|
| 363 |
+
|
| 364 |
+
# Test case 2: Problematic case mentioned by user
|
| 365 |
+
problematic_solution = """<tool_call>
|
| 366 |
+
{"name": "image_zoom_in_tool", "arguments": {"bbox_2d": [226, 399, 265, 464], "label": "white van"}}
|
| 367 |
+
</tool_call>user
|
| 368 |
+
<tool_response>
|
| 369 |
+
Zoomed in on the image to the region [226, 399, 265, 464] with label white van.
|
| 370 |
+
</tool_response>
|
| 371 |
+
assistant
|
| 372 |
+
The white van is visible in the lower section of the image, near the diagonal road."""
|
| 373 |
+
|
| 374 |
+
problematic_ground_truth = "Yes, the white van is indeed situated in the bottom part of the picture."
|
| 375 |
+
problematic_extra_info = {
|
| 376 |
+
"question": "Is the white van in the bottom part of the picture?",
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
print("\n=== Test Case 2: Problematic case (no answer tags) ===")
|
| 380 |
+
print(f"Solution: {problematic_solution}")
|
| 381 |
+
print(f"Ground truth: {problematic_ground_truth}")
|
| 382 |
+
|
| 383 |
+
time_start = time.time()
|
| 384 |
+
score2 = compute_score("common_reasoning", problematic_solution, problematic_ground_truth, problematic_extra_info)
|
| 385 |
+
print(f"Score: {score2}")
|
| 386 |
+
time_end = time.time()
|
| 387 |
+
print(f"Time: {time_end - time_start}")
|
| 388 |
+
|
| 389 |
+
# Test case 3: Well-formatted case with tools
|
| 390 |
+
well_formatted_solution = """<think>
|
| 391 |
+
I need to use the image zoom tool to get a better look at the specific area.
|
| 392 |
+
</think>
|
| 393 |
+
<tool_call>
|
| 394 |
+
{"name": "image_zoom_in_tool", "arguments": {"bbox_2d": [226, 399, 265, 464], "label": "white van"}}
|
| 395 |
+
</tool_call>
|
| 396 |
+
<tool_response>
|
| 397 |
+
Zoomed in on the image to the region [226, 399, 265, 464] with label white van.
|
| 398 |
+
</tool_response>
|
| 399 |
+
<answer>Yes, the white van is indeed situated in the bottom part of the picture.</answer>"""
|
| 400 |
+
|
| 401 |
+
print("\n=== Test Case 3: Well-formatted case ===")
|
| 402 |
+
time_start = time.time()
|
| 403 |
+
score3 = compute_score(
|
| 404 |
+
"common_reasoning", well_formatted_solution, problematic_ground_truth, problematic_extra_info
|
| 405 |
+
)
|
| 406 |
+
print(f"Score: {score3}")
|
| 407 |
+
time_end = time.time()
|
| 408 |
+
print(f"Time: {time_end - time_start}")
|
ICL/DAPO/verl-recipe/fault_recover/async_llm.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
import asyncio
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
|
| 7 |
+
from vllm.utils import cdiv
|
| 8 |
+
from vllm.v1.engine.async_llm import AsyncLLM, logger
|
| 9 |
+
from vllm.v1.metrics.stats import IterationStats
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AsyncFaultRecoverLLM(AsyncLLM):
|
| 13 |
+
def _run_output_handler(self):
|
| 14 |
+
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
| 15 |
+
|
| 16 |
+
if self.output_handler is not None:
|
| 17 |
+
return
|
| 18 |
+
|
| 19 |
+
# Ensure that the task doesn't have a circular ref back to the AsyncLLM
|
| 20 |
+
# object, or else it won't be garbage collected and cleaned up properly.
|
| 21 |
+
engine_core = self.engine_core
|
| 22 |
+
output_processor = self.output_processor
|
| 23 |
+
log_stats = self.log_stats
|
| 24 |
+
logger_manager = self.logger_manager
|
| 25 |
+
|
| 26 |
+
async def output_handler(q):
|
| 27 |
+
try:
|
| 28 |
+
while True:
|
| 29 |
+
# 1) Pull EngineCoreOutputs from the EngineCore.
|
| 30 |
+
outputs = await engine_core.get_output_async()
|
| 31 |
+
|
| 32 |
+
if q is not None:
|
| 33 |
+
req_info = {}
|
| 34 |
+
for output in outputs.outputs:
|
| 35 |
+
req_info[output.request_id] = {}
|
| 36 |
+
req_info[output.request_id]["new_token_ids"] = output.new_token_ids
|
| 37 |
+
req_info[output.request_id]["finished"] = output.finished
|
| 38 |
+
await q.put.remote(req_info)
|
| 39 |
+
|
| 40 |
+
num_outputs = len(outputs.outputs)
|
| 41 |
+
|
| 42 |
+
iteration_stats = IterationStats() if (log_stats and num_outputs) else None
|
| 43 |
+
|
| 44 |
+
# Split outputs into chunks of at most
|
| 45 |
+
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
|
| 46 |
+
# event loop for too long.
|
| 47 |
+
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
|
| 48 |
+
slices = (outputs.outputs,)
|
| 49 |
+
else:
|
| 50 |
+
slices = np.array_split(outputs.outputs, cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
|
| 51 |
+
|
| 52 |
+
for i, outputs_slice in enumerate(slices):
|
| 53 |
+
# 2) Process EngineCoreOutputs.
|
| 54 |
+
processed_outputs = output_processor.process_outputs(
|
| 55 |
+
outputs_slice, outputs.timestamp, iteration_stats
|
| 56 |
+
)
|
| 57 |
+
# NOTE: RequestOutputs are pushed to their queues.
|
| 58 |
+
assert not processed_outputs.request_outputs
|
| 59 |
+
|
| 60 |
+
# Allow other asyncio tasks to run between chunks
|
| 61 |
+
if i + 1 < len(slices):
|
| 62 |
+
await asyncio.sleep(0)
|
| 63 |
+
|
| 64 |
+
# 3) Abort any reqs that finished due to stop strings.
|
| 65 |
+
await engine_core.abort_requests_async(processed_outputs.reqs_to_abort)
|
| 66 |
+
|
| 67 |
+
# 4) Logging.
|
| 68 |
+
# TODO(rob): make into a coroutine and launch it in
|
| 69 |
+
# background thread once Prometheus overhead is non-trivial.
|
| 70 |
+
if logger_manager:
|
| 71 |
+
logger_manager.record(
|
| 72 |
+
engine_idx=outputs.engine_index,
|
| 73 |
+
scheduler_stats=outputs.scheduler_stats,
|
| 74 |
+
iteration_stats=iteration_stats,
|
| 75 |
+
)
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.exception("AsyncLLM output_handler failed.")
|
| 78 |
+
output_processor.propagate_error(e)
|
| 79 |
+
|
| 80 |
+
from recipe.fault_recover.fault_manager import get_tokens_queue
|
| 81 |
+
|
| 82 |
+
tokens_queue = get_tokens_queue()
|
| 83 |
+
|
| 84 |
+
self.output_handler = asyncio.create_task(output_handler(tokens_queue))
|
ICL/DAPO/verl-recipe/flash_rl_ascend/README.md
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## 在线量化权重:
|
| 2 |
+
|
| 3 |
+
介绍在昇腾设备上,使用 [Flash-RL](https://github.com/yaof20/Flash-RL) 工具,修改推理后端,通过比较 INT8 模型和 BF16 模型,对权重和激活值执行在线量化。下文以 Qwen3-30B int8 为例,在 NPU 上跑通端到端功能。
|
| 4 |
+
|
| 5 |
+
### 环境依赖
|
| 6 |
+
|
| 7 |
+
##
|
| 8 |
+
| PyTorch版本 | torch_npu版本 | CANN版本 | Python版本 |
|
| 9 |
+
| ------------ |-----------| ---------- | ---------- |
|
| 10 |
+
| 2.7.1 | 2.7.1 | 8.5.0 | Python3.10 |
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
#### 1、安装 vllm 和 vllm-ascend
|
| 14 |
+
```bash
|
| 15 |
+
# vllm==0.10.1
|
| 16 |
+
git clone https://github.com/vllm-project/vllm.git
|
| 17 |
+
cd vllm
|
| 18 |
+
git checkout e03940762b43812fccd3c214bda60201cff9d16a
|
| 19 |
+
pip install -r requirements/build.txt
|
| 20 |
+
VLLM_TARGET_DEVICE=empty pip install -v .
|
| 21 |
+
cd ..
|
| 22 |
+
|
| 23 |
+
# vllm-ascend==0.10.1
|
| 24 |
+
git clone https://github.com/vllm-project/vllm-ascend.git
|
| 25 |
+
cd vllm-ascend
|
| 26 |
+
git checkout 7e16b4a7cdb15723c63c1c0efe58672a056eace8
|
| 27 |
+
pip install -r requirements.txt
|
| 28 |
+
export COMPILE_CUSTOM_KERNELS=1
|
| 29 |
+
python setup.py install
|
| 30 |
+
cd ..
|
| 31 |
+
|
| 32 |
+
# 源码安装transformers
|
| 33 |
+
git clone -b v4.57.6 https://github.com/huggingface/transformers.git
|
| 34 |
+
cd transformers
|
| 35 |
+
pip install -e .
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
#### 2、安装 MindSpeed 与 Megatron
|
| 39 |
+
```bash
|
| 40 |
+
# MindSpeed
|
| 41 |
+
git clone https://gitcode.com/Ascend/MindSpeed.git
|
| 42 |
+
cd MindSpeed
|
| 43 |
+
git checkout 1cdd0abd75e40936ad31721c092f57c695dd72c4
|
| 44 |
+
pip install -e .
|
| 45 |
+
cd ..
|
| 46 |
+
|
| 47 |
+
# Megatron
|
| 48 |
+
pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
#### 3、安装 verl
|
| 52 |
+
```bash
|
| 53 |
+
# verl
|
| 54 |
+
git clone https://github.com/volcengine/verl.git
|
| 55 |
+
cd verl
|
| 56 |
+
pip install -e .
|
| 57 |
+
cd ..
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
### 使用步骤:
|
| 61 |
+
|
| 62 |
+
#### 1、安装包:
|
| 63 |
+
|
| 64 |
+
```
|
| 65 |
+
pip install flash-llm-rl # need to be installed in all nodes in multi-node training
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
#### 2、打patch
|
| 69 |
+
|
| 70 |
+
安装 FlashRL 后,默认采用自动 patch,推荐改用手动方式,减少过程中的错误:
|
| 71 |
+
|
| 72 |
+
1. 在 `verl/verl/__init__.py` 文件中添加 `import flash_rl`;
|
| 73 |
+
2. 在 shell 脚本中添加 `flashrl cleanup`,这将禁用自动 patch;
|
| 74 |
+
|
| 75 |
+
#### 3、生成性能分析文件
|
| 76 |
+
|
| 77 |
+
具体来说,profile 文件会比较 bf16 模型和 int8 模型,以确定如何对更新后的模型执行在线量化:
|
| 78 |
+
|
| 79 |
+
```
|
| 80 |
+
flashrl profile -m Qwen3-30B-A3B -q Qwen3-30B-A3B-w8a8 -o ${PROFILE_PATH:-"$HOME/profile.30b.pt"} --fn int8
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
`-m` 参数后是 bf16 模型路径,`-q` 参数后是 int8 模型路径,`-o` 参数后是生成文件路径;
|
| 84 |
+
[RedHatAI](https://huggingface.co/RedHatAI/collections) 提供了各种量化模型;
|
| 85 |
+
|
| 86 |
+
#### 4、生成配置文件
|
| 87 |
+
|
| 88 |
+
通过以下命令生成 yaml 配置文件,供 patch 程序使用:
|
| 89 |
+
|
| 90 |
+
```
|
| 91 |
+
flashrl setup -m Qwen3-30B-A3B-w8a8 -p $HOME/profile.30b.pt --fn int8 -o ${CONFIG_PATH:-"$HOME/.flashrl_config.30b.yaml"}
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
`-m` 参数后是 int8 模型路径,`-p` 参数后是 profile 文件路径,`-o` 参数后是生成文件路径;
|
| 95 |
+
|
| 96 |
+
(可选)为了缩小 rollout 生成和梯度计算之间的差距,FlashRL 提供了在 DP 工作线程间以混合方式进行 16 位和 8 位 rollout 生成的功能。具体来说,运行以下命令会将第二个配置附加到现有的 yaml 配置文件中。
|
| 97 |
+
|
| 98 |
+
```
|
| 99 |
+
flashrl setup -a --fn bf16 -o ${CONFIG_PATH:-"$HOME/.flashrl_config.30b.yaml"}
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
#### 5、开始训练
|
| 103 |
+
|
| 104 |
+
脚本中添加以下环境变量:
|
| 105 |
+
|
| 106 |
+
```
|
| 107 |
+
# 打印详细日志,查看是否 patch 成功:
|
| 108 |
+
export FLASHRL_LOGGING_LEVEL=DEBUG
|
| 109 |
+
# 指定配置文件:
|
| 110 |
+
export FLASHRL_CONFIG=$HOME/.flashrl_config.30b.yaml
|
| 111 |
+
# 强制 lm-head 使用 bf16,减小精度损失:
|
| 112 |
+
export FLASHRL_LMHEAD_FP32=1
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
以上步骤已在 `test_qwen3-30b_int8_npu.sh` 提供实例,修改脚本中的模型路径即可自动执行,有具体问题可根据上述步骤排查;
|
| 116 |
+
|
| 117 |
+
在 `run.sh` 文件中补充机器 IP、网络接口,运行以下命令启动训练:
|
| 118 |
+
|
| 119 |
+
```
|
| 120 |
+
bash ./run.sh
|
| 121 |
+
```
|
ICL/DAPO/verl-recipe/flowrl/README.md
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h1 align="center" style="color:#1976D2; font-size:42px; font-weight:bold; margin-bottom:0;">
|
| 2 |
+
FlowRL
|
| 3 |
+
</h1>
|
| 4 |
+
|
| 5 |
+
<p align="center" style="color:#42A5F5; font-size:16px; margin-top:0;">
|
| 6 |
+
Matching Reward Distributions via Flow Balance
|
| 7 |
+
</p>
|
| 8 |
+
<p align="center" style="color:#42A5F5; font-size:15px; margin-top:4px;">
|
| 9 |
+
<a href="https://arxiv.org/abs/2509.15207" target="_blank">📄 arXiv Paper</a> |
|
| 10 |
+
<a href="https://huggingface.co/papers/2509.15207" target="_blank">🤗 #1 Paper of the Day</a>
|
| 11 |
+
</p>
|
| 12 |
+
<p align="center" style="color:#42A5F5; font-size:14px; margin-top:4px;">
|
| 13 |
+
<a href="https://x.com/RoverHM/status/1969113890878259518" target="_blank">𝕏 Post 1</a> |
|
| 14 |
+
<a href="https://x.com/zdhnarsil/status/1969049940774023428" target="_blank">𝕏 Post 2</a> |
|
| 15 |
+
<a href="https://x.com/_akhaliq/status/1968901977376505929" target="_blank">𝕏 Post 3</a> |
|
| 16 |
+
<a href="https://x.com/zhu_xuekai/status/1968942580197941563" target="_blank">𝕏 Post 4</a>
|
| 17 |
+
</p>
|
| 18 |
+
|
| 19 |
+
<p align="center">
|
| 20 |
+
<img src="figures/flowrl.png" alt="FlowRL Overview" width="95%"/>
|
| 21 |
+
</p>
|
| 22 |
+
|
| 23 |
+
## Table of Contents
|
| 24 |
+
|
| 25 |
+
- [FlowRL Objective](#flowrl-objective)
|
| 26 |
+
- [Trained Models & Experiment Logs](#trained-models--experiment-logs)
|
| 27 |
+
- [Quick Start](#quick-start)
|
| 28 |
+
- [Option 1: Original Paper Reproduction (verl 0.4.0)](#option-1-original-paper-reproduction-verl-040--recommended)
|
| 29 |
+
- [Step 1: Installation](#step-1-installation)
|
| 30 |
+
- [Step 2: Data Preparation](#step-2-data-preparation)
|
| 31 |
+
- [Step 3: Model Preparation](#step-3-model-preparation)
|
| 32 |
+
- [Step 4: Training Scripts](#step-4-training-scripts)
|
| 33 |
+
- [Option 2: Latest verl Recipe FlowRL](#option-3-latest-verl-recipe-flowrl)
|
| 34 |
+
- [Step 1: Prepare Data and Model](#step-1-prepare-data-and-model)
|
| 35 |
+
- [Step 2: Run Training](#step-2-run-training)
|
| 36 |
+
- [Option 3: Implement FlowRL Yourself](#option-4-implement-flowrl-yourself)
|
| 37 |
+
- [Testing](#testing)
|
| 38 |
+
- [Citation](#citation)
|
| 39 |
+
|
| 40 |
+
## FlowRL Objective
|
| 41 |
+
|
| 42 |
+
$$
|
| 43 |
+
\mathcal{L}_{\text{FlowRL}} = w \cdot \left( \log Z_{\phi}(x) + \frac{1}{|y|} \log \pi_{\theta}(y \mid x) - \beta \hat{r}(x, y) - \frac{1}{|y|} \log \pi_{\text{ref}}(y \mid x) \right)^2
|
| 44 |
+
$$
|
| 45 |
+
|
| 46 |
+
FlowRL is a flow-balanced reinforcement learning method that matches full reward distributions instead of maximizing rewards, promoting diverse exploration and generalizable reasoning trajectories in LLMs.
|
| 47 |
+
|
| 48 |
+
## Trained Models & Experiment Logs
|
| 49 |
+
|
| 50 |
+
| Base Model | Domain | WandB Logs | Hugging Face Model |
|
| 51 |
+
|-------|--------|------------|-------------------|
|
| 52 |
+
| Qwen2.5-7B | Math | [🔗 View Run](https://wandb.ai/xuekaizhu0/FlowRL/runs/pa62rs4x?nw=nwuserxuekaizhu0) | [🤗 Model](https://huggingface.co/xuekai/FlowRL-Qwen2.5-7B-math) |
|
| 53 |
+
| DeepSeek-7B | Code | [🔗 View Run](https://wandb.ai/xuekaizhu0/FlowRL/runs/wbw72gdv?nw=nwuserxuekaizhu0) | [🤗 Model](https://huggingface.co/xuekai/FlowRL-DeepSeek-7B-code) |
|
| 54 |
+
| Qwen2.5-32B | Math | - | [🤗 Model](https://huggingface.co/xuekai/FlowRL-Qwen2.5-32B-math) |
|
| 55 |
+
|
| 56 |
+
## Quick Start
|
| 57 |
+
|
| 58 |
+
There are three ways to use FlowRL:
|
| 59 |
+
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
**⭐ We recommend using Option 1 as the default choice.** Since verl updates frequently, the newest versions may have unstable factors such as training and inference mismatches. Option 1 uses verl 0.4.0, which is stable and has been thoroughly tested with our paper results.
|
| 63 |
+
|
| 64 |
+
---
|
| 65 |
+
|
| 66 |
+
### Option 1: Original Paper Reproduction (verl 0.4.0) ⭐ Recommended
|
| 67 |
+
|
| 68 |
+
For exact reproduction of results from the paper, use the original repository with verl 0.4.0:
|
| 69 |
+
|
| 70 |
+
👉 **Original Code:** [https://github.com/Xuekai-Zhu/FlowRL](https://github.com/Xuekai-Zhu/FlowRL)
|
| 71 |
+
|
| 72 |
+
#### Step 1: Installation
|
| 73 |
+
|
| 74 |
+
Install [verl](https://github.com/volcengine/verl) first before using FlowRL.
|
| 75 |
+
|
| 76 |
+
#### Step 2: Data Preparation
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
# Option A: Download our pre-processed datasets directly
|
| 80 |
+
bash preprocess/down_load_dataset.sh
|
| 81 |
+
# Move data to default directory
|
| 82 |
+
mv data/xuekai/flowrl-data-collection/math_data data/math_data
|
| 83 |
+
mv data/xuekai/flowrl-data-collection/code_data data/code_data
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
# Option B: Process data from original sources
|
| 88 |
+
# For detailed processing instructions, see data/README.md
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
#### Step 3: Model Preparation
|
| 92 |
+
|
| 93 |
+
For Math Tasks: `Qwen/Qwen2.5-7B` (default in script) ; `Qwen/Qwen2.5-32B`
|
| 94 |
+
|
| 95 |
+
For Code Tasks: `deepseek-ai/DeepSeek-R1-Distill-Qwen-7B`
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
# Download default model (Qwen2.5-7B for math)
|
| 99 |
+
bash preprocess/down_load_model.sh
|
| 100 |
+
|
| 101 |
+
# For other models, modify MODEL_NAME in the script before running
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
#### Step 4: Training Scripts
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
cd verl_FlowRL
|
| 108 |
+
|
| 109 |
+
# For 7B math training
|
| 110 |
+
bash command/training/math/flowrl_7B_math.sh
|
| 111 |
+
|
| 112 |
+
# For 32B math training
|
| 113 |
+
bash command/training/math/flowrl_32B_math.sh
|
| 114 |
+
|
| 115 |
+
# For 7B code training
|
| 116 |
+
bash command/training/code/flowrl_7B_code.sh
|
| 117 |
+
```
|
| 118 |
+
----
|
| 119 |
+
### Option 2: Latest verl Recipe FlowRL
|
| 120 |
+
|
| 121 |
+
For running FlowRL using the latest verl framework:
|
| 122 |
+
|
| 123 |
+
**Latest verl:**
|
| 124 |
+
|
| 125 |
+
- verl recipe: [https://github.com/volcengine/verl/tree/main/recipe/flowrl](https://github.com/volcengine/verl/tree/main/recipe/flowrl)
|
| 126 |
+
|
| 127 |
+
#### Step 1: Prepare Data and Model
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
# Prepare dataset
|
| 131 |
+
bash recipe/flowrl/prepare/prepare_data.sh
|
| 132 |
+
|
| 133 |
+
# Prepare model
|
| 134 |
+
bash recipe/flowrl/prepare/prepare_model.sh
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
#### Step 2: Run Training
|
| 138 |
+
|
| 139 |
+
```bash
|
| 140 |
+
# Train FlowRL with Qwen2.5-7B
|
| 141 |
+
bash recipe/flowrl/run_flowrl_qwen2.5_7b.sh
|
| 142 |
+
```
|
| 143 |
+
----
|
| 144 |
+
### Option 3: Implement FlowRL Yourself
|
| 145 |
+
|
| 146 |
+
If you want to implement FlowRL in your own codebase, we provide a detailed implementation guide:
|
| 147 |
+
|
| 148 |
+
📖 **[FlowRL Implementation Guide](FLOWRL_SIMPLE_GUIDE.md)**
|
| 149 |
+
|
| 150 |
+
This guide walks you through the key components and steps needed to integrate FlowRL into your existing training pipeline.
|
| 151 |
+
|
| 152 |
+
## Testing
|
| 153 |
+
|
| 154 |
+
After training your FlowRL models, you can evaluate them using the following commands:
|
| 155 |
+
|
| 156 |
+
```bash
|
| 157 |
+
cd verl_Test
|
| 158 |
+
|
| 159 |
+
# First merge the model
|
| 160 |
+
bash command/eval/merge_model.sh
|
| 161 |
+
|
| 162 |
+
# For math testing
|
| 163 |
+
bash command/eval/math/flowrl_math_test.sh
|
| 164 |
+
|
| 165 |
+
# For code testing
|
| 166 |
+
bash command/eval/code/flowrl_code_test.sh
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
**Reference:** For verl v0.5.0.dev merge model script, see [merge_model.sh](https://github.com/Xuekai-Zhu/verl_FlowRL/blob/flowrl-v0.5.0.dev/recipe/flowrl/eval/merge_model.sh)
|
| 170 |
+
|
| 171 |
+
## Citation
|
| 172 |
+
|
| 173 |
+
If you think this repo helps you, please kindly consider citing our paper:
|
| 174 |
+
|
| 175 |
+
```bibtex
|
| 176 |
+
@article{zhu2025flowrl,
|
| 177 |
+
title={FlowRL: Matching Reward Distributions for LLM Reasoning},
|
| 178 |
+
author={Zhu, Xuekai and Cheng, Daixuan and Zhang, Dinghuai and Li, Hengli and Zhang, Kaiyan and Jiang, Che and Sun, Youbang and Hua, Ermo and Zuo, Yuxin and Lv, Xingtai and others},
|
| 179 |
+
journal={arXiv preprint arXiv:2509.15207},
|
| 180 |
+
year={2025}
|
| 181 |
+
}
|
| 182 |
+
```
|
ICL/DAPO/verl-recipe/flowrl/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""FlowRL recipe package."""
|
| 16 |
+
|
| 17 |
+
__all__ = []
|
ICL/DAPO/verl-recipe/flowrl/flowrl_fsdp_worker.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""FlowRL FSDP Worker that uses FlowRLActor instead of standard DPActor."""
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import os
|
| 19 |
+
import warnings
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed
|
| 23 |
+
from peft import LoraConfig, TaskType, get_peft_model
|
| 24 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
# for torch 2.5+
|
| 28 |
+
from torch.distributed.tensor import DTensor
|
| 29 |
+
except ImportError:
|
| 30 |
+
from torch.distributed._tensor import DTensor
|
| 31 |
+
|
| 32 |
+
from recipe.flowrl.flowrl_actor import FlowRLActor, ProjZModule
|
| 33 |
+
|
| 34 |
+
from verl.models.transformers.monkey_patch import apply_monkey_patch
|
| 35 |
+
from verl.single_controller.base.decorator import Dispatch, register
|
| 36 |
+
from verl.utils import hf_processor, hf_tokenizer
|
| 37 |
+
from verl.utils.activation_offload import enable_activation_offloading
|
| 38 |
+
from verl.utils.config import omega_conf_to_dataclass
|
| 39 |
+
from verl.utils.device import (
|
| 40 |
+
get_device_id,
|
| 41 |
+
get_torch_device,
|
| 42 |
+
set_expandable_segments,
|
| 43 |
+
)
|
| 44 |
+
from verl.utils.fsdp_utils import (
|
| 45 |
+
CPUOffloadPolicy,
|
| 46 |
+
MixedPrecisionPolicy,
|
| 47 |
+
apply_fsdp2,
|
| 48 |
+
collect_lora_params,
|
| 49 |
+
fsdp2_load_full_state_dict,
|
| 50 |
+
get_fsdp_wrap_policy,
|
| 51 |
+
get_init_weight_context_manager,
|
| 52 |
+
get_shard_placement_fn,
|
| 53 |
+
init_fn,
|
| 54 |
+
load_fsdp_model_to_gpu,
|
| 55 |
+
offload_fsdp_model_to_cpu,
|
| 56 |
+
replace_lora_wrapper,
|
| 57 |
+
)
|
| 58 |
+
from verl.utils.memory_utils import aggressive_empty_cache
|
| 59 |
+
from verl.utils.model import convert_weight_keys
|
| 60 |
+
from verl.utils.profiler import log_gpu_memory_usage
|
| 61 |
+
from verl.utils.py_functional import convert_to_regular_types
|
| 62 |
+
from verl.workers.config import FSDPEngineConfig
|
| 63 |
+
from verl.workers.fsdp_workers import ActorRolloutRefWorker, get_sharding_strategy, get_vl_model_vision_tower
|
| 64 |
+
|
| 65 |
+
logger = logging.getLogger(__file__)
|
| 66 |
+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class FlowRLActorRolloutRefWorker(ActorRolloutRefWorker):
|
| 70 |
+
"""
|
| 71 |
+
FlowRL version of ActorRolloutRefWorker that uses FlowRLActor.
|
| 72 |
+
|
| 73 |
+
This worker adds FlowRL-specific modifications:
|
| 74 |
+
- ProjZModule for log Z estimation (added in _build_model_optimizer)
|
| 75 |
+
- FlowRLActor with trajectory balance loss (replaces standard DPActor)
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def _build_model_optimizer(
|
| 79 |
+
self,
|
| 80 |
+
model_path,
|
| 81 |
+
fsdp_config: FSDPEngineConfig,
|
| 82 |
+
optim_config,
|
| 83 |
+
override_model_config,
|
| 84 |
+
use_remove_padding=False,
|
| 85 |
+
use_fused_kernels=False,
|
| 86 |
+
enable_gradient_checkpointing=False,
|
| 87 |
+
trust_remote_code=False,
|
| 88 |
+
use_liger=False,
|
| 89 |
+
role="actor",
|
| 90 |
+
enable_activation_offload=False,
|
| 91 |
+
):
|
| 92 |
+
from torch import optim
|
| 93 |
+
from torch.distributed.fsdp import CPUOffload, MixedPrecision
|
| 94 |
+
from transformers import (
|
| 95 |
+
AutoConfig,
|
| 96 |
+
AutoModel,
|
| 97 |
+
AutoModelForCausalLM,
|
| 98 |
+
AutoModelForImageTextToText,
|
| 99 |
+
AutoModelForVision2Seq,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
from verl.utils.model import get_generation_config, print_model_size, update_model_config
|
| 103 |
+
from verl.utils.torch_dtypes import PrecisionType
|
| 104 |
+
|
| 105 |
+
assert role in ["actor", "ref"]
|
| 106 |
+
|
| 107 |
+
log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger)
|
| 108 |
+
local_path = model_path
|
| 109 |
+
|
| 110 |
+
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
|
| 111 |
+
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
|
| 112 |
+
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
| 113 |
+
self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code)
|
| 114 |
+
|
| 115 |
+
if self.config.model.get("custom_chat_template", None) is not None:
|
| 116 |
+
if self.processor is not None:
|
| 117 |
+
self.processor.chat_template = self.config.model.custom_chat_template
|
| 118 |
+
else:
|
| 119 |
+
self.tokenizer.chat_template = self.config.model.custom_chat_template
|
| 120 |
+
|
| 121 |
+
vllm_dtype = PrecisionType.to_dtype(self.config.rollout.dtype)
|
| 122 |
+
torch_dtype = fsdp_config.get("model_dtype", None)
|
| 123 |
+
if torch_dtype is None:
|
| 124 |
+
torch_dtype = torch.float32 if self._is_actor else vllm_dtype
|
| 125 |
+
else:
|
| 126 |
+
torch_dtype = PrecisionType.to_dtype(torch_dtype)
|
| 127 |
+
|
| 128 |
+
# override model kwargs
|
| 129 |
+
actor_model_config = AutoConfig.from_pretrained(
|
| 130 |
+
local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2"
|
| 131 |
+
)
|
| 132 |
+
# TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53
|
| 133 |
+
# which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids
|
| 134 |
+
# Maybe support Ulysses in VisionAttention in the future and remove this patch
|
| 135 |
+
if self.ulysses_sequence_parallel_size > 1 and hasattr(actor_model_config, "vision_config"):
|
| 136 |
+
actor_model_config.vision_config._attn_implementation = "eager"
|
| 137 |
+
|
| 138 |
+
# patch for kimi-vl
|
| 139 |
+
if getattr(actor_model_config, "model_type", None) == "kimi_vl":
|
| 140 |
+
actor_model_config.text_config.topk_method = "greedy"
|
| 141 |
+
|
| 142 |
+
self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)
|
| 143 |
+
|
| 144 |
+
override_config_kwargs = {
|
| 145 |
+
"bos_token_id": self.tokenizer.bos_token_id,
|
| 146 |
+
"eos_token_id": self.tokenizer.eos_token_id,
|
| 147 |
+
"pad_token_id": self.tokenizer.pad_token_id,
|
| 148 |
+
}
|
| 149 |
+
override_config_kwargs.update(override_model_config)
|
| 150 |
+
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
|
| 151 |
+
if self.rank == 0:
|
| 152 |
+
print(f"Model config after override: {actor_model_config}")
|
| 153 |
+
|
| 154 |
+
# NOTE(fix me): tie_word_embedding causes meta_tensor init to hang
|
| 155 |
+
init_context = get_init_weight_context_manager(
|
| 156 |
+
use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
with init_context(), warnings.catch_warnings():
|
| 160 |
+
warnings.simplefilter("ignore")
|
| 161 |
+
has_remote_code = hasattr(actor_model_config, "auto_map") and any(
|
| 162 |
+
actor_model_config.architectures[0] in val for val in actor_model_config.auto_map.values()
|
| 163 |
+
)
|
| 164 |
+
if has_remote_code:
|
| 165 |
+
auto_class = next(
|
| 166 |
+
k for k, v in actor_model_config.auto_map.items() if actor_model_config.architectures[0] in v
|
| 167 |
+
)
|
| 168 |
+
match auto_class:
|
| 169 |
+
case "AutoModelForVision2Seq":
|
| 170 |
+
actor_module_class = AutoModelForVision2Seq
|
| 171 |
+
case "AutoModelForCausalLM":
|
| 172 |
+
actor_module_class = AutoModelForCausalLM
|
| 173 |
+
case "AutoModelForImageTextToText":
|
| 174 |
+
actor_module_class = AutoModelForImageTextToText
|
| 175 |
+
case _:
|
| 176 |
+
actor_module_class = AutoModel
|
| 177 |
+
else:
|
| 178 |
+
if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys():
|
| 179 |
+
actor_module_class = AutoModelForVision2Seq
|
| 180 |
+
elif type(actor_model_config) in AutoModelForCausalLM._model_mapping.keys():
|
| 181 |
+
actor_module_class = AutoModelForCausalLM
|
| 182 |
+
elif type(actor_model_config) in AutoModelForImageTextToText._model_mapping.keys():
|
| 183 |
+
actor_module_class = AutoModelForImageTextToText
|
| 184 |
+
else:
|
| 185 |
+
actor_module_class = AutoModel
|
| 186 |
+
|
| 187 |
+
actor_module = actor_module_class.from_pretrained(
|
| 188 |
+
pretrained_model_name_or_path=local_path,
|
| 189 |
+
torch_dtype=torch_dtype,
|
| 190 |
+
config=actor_model_config,
|
| 191 |
+
trust_remote_code=trust_remote_code,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# ==== FlowRL: inject ProjZ BEFORE FSDP wrap ====
|
| 195 |
+
if role == "actor" and self._is_actor:
|
| 196 |
+
n_dim = actor_module.config.hidden_size
|
| 197 |
+
proj_layers = getattr(self.config.actor, "proj_layer", 3)
|
| 198 |
+
actor_module.add_module("proj_z", ProjZModule(n_dim, num_layers=proj_layers))
|
| 199 |
+
|
| 200 |
+
if self.rank == 0:
|
| 201 |
+
print(f"[FlowRL] Added proj_z (layers={proj_layers}, hidden={n_dim}) BEFORE FSDP wrap")
|
| 202 |
+
# ===============================================
|
| 203 |
+
|
| 204 |
+
# Apply Liger kernel to the model if use_liger is set to True
|
| 205 |
+
if use_liger:
|
| 206 |
+
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance
|
| 207 |
+
|
| 208 |
+
_apply_liger_kernel_to_instance(model=actor_module)
|
| 209 |
+
|
| 210 |
+
fused_kernel_options = self.config.model.get("fused_kernel_options", None)
|
| 211 |
+
fused_kernels_backend = (
|
| 212 |
+
fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
apply_monkey_patch(
|
| 216 |
+
model=actor_module,
|
| 217 |
+
use_remove_padding=use_remove_padding,
|
| 218 |
+
ulysses_sp_size=self.ulysses_sequence_parallel_size,
|
| 219 |
+
use_fused_kernels=use_fused_kernels,
|
| 220 |
+
fused_kernels_backend=fused_kernels_backend,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
|
| 224 |
+
actor_module.to(torch_dtype)
|
| 225 |
+
|
| 226 |
+
if enable_gradient_checkpointing:
|
| 227 |
+
actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
| 228 |
+
if self._is_lora:
|
| 229 |
+
print("Applying LoRA to actor module")
|
| 230 |
+
actor_module.enable_input_require_grads()
|
| 231 |
+
# Convert config to regular Python types before creating PEFT model
|
| 232 |
+
lora_config = {
|
| 233 |
+
"task_type": TaskType.CAUSAL_LM,
|
| 234 |
+
"r": self.config.model.lora_rank,
|
| 235 |
+
"lora_alpha": self.config.model.lora_alpha,
|
| 236 |
+
"target_modules": convert_to_regular_types(self.config.model.target_modules),
|
| 237 |
+
"exclude_modules": convert_to_regular_types(self.config.model.exclude_modules),
|
| 238 |
+
"bias": "none",
|
| 239 |
+
}
|
| 240 |
+
actor_module = get_peft_model(actor_module, LoraConfig(**lora_config))
|
| 241 |
+
|
| 242 |
+
self.use_orig_params = fsdp_config.get("use_orig_params", False)
|
| 243 |
+
if self.config.actor.get("freeze_vision_tower", False):
|
| 244 |
+
vision_tower = get_vl_model_vision_tower(actor_module)
|
| 245 |
+
if vision_tower is not None:
|
| 246 |
+
vision_tower.requires_grad_(False)
|
| 247 |
+
self.use_orig_params = True
|
| 248 |
+
if self.rank == 0:
|
| 249 |
+
print("[actor model] Vision tower is set to not trainable.")
|
| 250 |
+
else:
|
| 251 |
+
if self.rank == 0:
|
| 252 |
+
print("[actor model] No vision tower found.")
|
| 253 |
+
|
| 254 |
+
torch.distributed.barrier()
|
| 255 |
+
|
| 256 |
+
if self.rank == 0:
|
| 257 |
+
print_model_size(actor_module)
|
| 258 |
+
|
| 259 |
+
log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger)
|
| 260 |
+
|
| 261 |
+
# We wrap FSDP for rollout as well
|
| 262 |
+
mixed_precision_config = fsdp_config.get("mixed_precision", None)
|
| 263 |
+
if mixed_precision_config is not None:
|
| 264 |
+
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
|
| 265 |
+
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
|
| 266 |
+
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
|
| 267 |
+
else:
|
| 268 |
+
param_dtype = PrecisionType.to_dtype(self.config.actor.get("dtype", "bfloat16"))
|
| 269 |
+
reduce_dtype = torch.float32
|
| 270 |
+
buffer_dtype = torch.float32
|
| 271 |
+
|
| 272 |
+
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
|
| 273 |
+
|
| 274 |
+
auto_wrap_policy = get_fsdp_wrap_policy(
|
| 275 |
+
module=actor_module,
|
| 276 |
+
config=fsdp_config.get("wrap_policy", None),
|
| 277 |
+
is_lora=self.config.model.get("lora_rank", 0) > 0,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
if self._is_rollout and self.config.rollout.name == "hf":
|
| 281 |
+
# TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma
|
| 282 |
+
auto_wrap_policy = None
|
| 283 |
+
|
| 284 |
+
if self.rank == 0:
|
| 285 |
+
print(f"wrap_policy: {auto_wrap_policy}")
|
| 286 |
+
|
| 287 |
+
fsdp_mesh = self.device_mesh
|
| 288 |
+
sharding_strategy = get_sharding_strategy(fsdp_mesh)
|
| 289 |
+
|
| 290 |
+
# TODO: add transformer policy
|
| 291 |
+
# We force reference policy to use CPUOffload to save memory.
|
| 292 |
+
# We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation
|
| 293 |
+
cpu_offload = None if role == "actor" else CPUOffload(offload_params=True)
|
| 294 |
+
fsdp_strategy = self.config.actor.strategy
|
| 295 |
+
if fsdp_strategy == "fsdp":
|
| 296 |
+
actor_module_fsdp = FSDP(
|
| 297 |
+
actor_module,
|
| 298 |
+
cpu_offload=cpu_offload,
|
| 299 |
+
param_init_fn=init_fn,
|
| 300 |
+
auto_wrap_policy=auto_wrap_policy,
|
| 301 |
+
device_id=get_device_id(),
|
| 302 |
+
sharding_strategy=sharding_strategy, # zero3
|
| 303 |
+
mixed_precision=mixed_precision,
|
| 304 |
+
sync_module_states=True,
|
| 305 |
+
device_mesh=self.device_mesh,
|
| 306 |
+
use_orig_params=self.use_orig_params,
|
| 307 |
+
forward_prefetch=fsdp_config.get("forward_prefetch", False),
|
| 308 |
+
)
|
| 309 |
+
elif fsdp_strategy == "fsdp2":
|
| 310 |
+
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
|
| 311 |
+
mp_policy = MixedPrecisionPolicy(
|
| 312 |
+
param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
|
| 313 |
+
)
|
| 314 |
+
if role == "actor" and fsdp_config.offload_policy:
|
| 315 |
+
cpu_offload = CPUOffloadPolicy(pin_memory=True)
|
| 316 |
+
self._is_offload_param = False
|
| 317 |
+
self._is_offload_optimizer = False
|
| 318 |
+
else:
|
| 319 |
+
cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True)
|
| 320 |
+
|
| 321 |
+
fsdp_kwargs = {
|
| 322 |
+
"mesh": fsdp_mesh,
|
| 323 |
+
"mp_policy": mp_policy,
|
| 324 |
+
"offload_policy": cpu_offload,
|
| 325 |
+
"reshard_after_forward": fsdp_config.reshard_after_forward,
|
| 326 |
+
"shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]),
|
| 327 |
+
}
|
| 328 |
+
full_state = actor_module.state_dict()
|
| 329 |
+
apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config)
|
| 330 |
+
fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload)
|
| 331 |
+
actor_module_fsdp = actor_module
|
| 332 |
+
else:
|
| 333 |
+
raise NotImplementedError(f"not implement {fsdp_strategy}")
|
| 334 |
+
|
| 335 |
+
if enable_activation_offload:
|
| 336 |
+
enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing)
|
| 337 |
+
|
| 338 |
+
log_gpu_memory_usage(f"After {role} FSDP init", logger=logger)
|
| 339 |
+
|
| 340 |
+
# TODO: add more optimizer args into config
|
| 341 |
+
if role == "actor" and optim_config is not None:
|
| 342 |
+
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
|
| 343 |
+
|
| 344 |
+
actor_optimizer = optim.AdamW(
|
| 345 |
+
actor_module_fsdp.parameters(),
|
| 346 |
+
lr=optim_config.lr,
|
| 347 |
+
betas=optim_config.get("betas", (0.9, 0.999)),
|
| 348 |
+
weight_decay=optim_config.get("weight_decay", 1e-2),
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
total_steps = optim_config.get("total_training_steps", 0)
|
| 352 |
+
num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
|
| 353 |
+
warmup_style = optim_config.get("warmup_style", "constant")
|
| 354 |
+
min_lr_ratio = optim_config.get("min_lr_ratio", 0.0)
|
| 355 |
+
num_cycles = optim_config.get("num_cycles", 0.5)
|
| 356 |
+
if num_warmup_steps < 0:
|
| 357 |
+
num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0)
|
| 358 |
+
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
|
| 359 |
+
|
| 360 |
+
if self.rank == 0:
|
| 361 |
+
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
|
| 362 |
+
|
| 363 |
+
if warmup_style == "constant":
|
| 364 |
+
actor_lr_scheduler = get_constant_schedule_with_warmup(
|
| 365 |
+
optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps
|
| 366 |
+
)
|
| 367 |
+
elif warmup_style == "cosine":
|
| 368 |
+
actor_lr_scheduler = get_cosine_schedule_with_warmup(
|
| 369 |
+
optimizer=actor_optimizer,
|
| 370 |
+
num_warmup_steps=num_warmup_steps,
|
| 371 |
+
num_training_steps=total_steps,
|
| 372 |
+
min_lr_ratio=min_lr_ratio,
|
| 373 |
+
num_cycles=num_cycles,
|
| 374 |
+
)
|
| 375 |
+
else:
|
| 376 |
+
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
|
| 377 |
+
|
| 378 |
+
log_gpu_memory_usage(f"After {role} optimizer init", logger=logger)
|
| 379 |
+
else:
|
| 380 |
+
actor_optimizer = None
|
| 381 |
+
actor_lr_scheduler = None
|
| 382 |
+
|
| 383 |
+
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
|
| 384 |
+
|
| 385 |
+
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
| 386 |
+
def init_model(self):
|
| 387 |
+
"""Override init_model to use FlowRLActor instead of DataParallelPPOActor."""
|
| 388 |
+
# Call parent's init_model to set up the FSDP model (with proj_z already added)
|
| 389 |
+
super().init_model()
|
| 390 |
+
|
| 391 |
+
# Replace the actor with FlowRLActor if this worker is an actor
|
| 392 |
+
if self._is_actor:
|
| 393 |
+
if self.rank == 0:
|
| 394 |
+
print("[FlowRL] Replacing DataParallelPPOActor with FlowRLActor")
|
| 395 |
+
|
| 396 |
+
# Convert actor config to dataclass
|
| 397 |
+
actor_cfg = omega_conf_to_dataclass(self.config.actor)
|
| 398 |
+
|
| 399 |
+
# Create FlowRLActor with trajectory balance loss
|
| 400 |
+
self.actor = FlowRLActor(
|
| 401 |
+
config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
async def rollout_mode(self):
|
| 405 |
+
"""
|
| 406 |
+
Override rollout_mode to filter out proj_z parameters before syncing to vLLM.
|
| 407 |
+
|
| 408 |
+
FlowRL's proj_z module is only needed during training for estimating log Z.
|
| 409 |
+
It should not be loaded into vLLM since vLLM is only used for rollout generation.
|
| 410 |
+
"""
|
| 411 |
+
aggressive_empty_cache(force_sync=True)
|
| 412 |
+
|
| 413 |
+
log_gpu_memory_usage("Before load_fsdp_model_to_gpu", logger=logger)
|
| 414 |
+
if self._is_offload_param:
|
| 415 |
+
load_fsdp_model_to_gpu(self.actor_module_fsdp)
|
| 416 |
+
log_gpu_memory_usage("After load_fsdp_model_to_gpu", logger=logger)
|
| 417 |
+
|
| 418 |
+
peft_config = None
|
| 419 |
+
peft_model = getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
|
| 420 |
+
if hasattr(peft_model, "peft_config"): # LoRA
|
| 421 |
+
peft_config = peft_model.peft_config.get("default", None)
|
| 422 |
+
params = collect_lora_params(
|
| 423 |
+
module=self.actor_module_fsdp,
|
| 424 |
+
layered_summon=self.config.rollout.get("layered_summon", False),
|
| 425 |
+
base_sync_done=self.base_sync_done,
|
| 426 |
+
)
|
| 427 |
+
if not self.base_sync_done:
|
| 428 |
+
params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()}
|
| 429 |
+
else:
|
| 430 |
+
params = self.actor_module_fsdp.state_dict()
|
| 431 |
+
|
| 432 |
+
# ==== FlowRL: Filter out proj_z parameters ====
|
| 433 |
+
params = {k: v for k, v in params.items() if not k.startswith("proj_z")}
|
| 434 |
+
num_proj_z_filtered = len([k for k in self.actor_module_fsdp.state_dict().keys() if k.startswith("proj_z")])
|
| 435 |
+
if num_proj_z_filtered > 0 and self.rank == 0:
|
| 436 |
+
print(f"[FlowRL] Filtered {num_proj_z_filtered} proj_z parameters before syncing to vLLM")
|
| 437 |
+
# ===============================================
|
| 438 |
+
|
| 439 |
+
params = convert_weight_keys(
|
| 440 |
+
params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# Special handling for LoRA with sleep_level=2:
|
| 444 |
+
if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2:
|
| 445 |
+
base_model_params = collect_lora_params(
|
| 446 |
+
module=self.actor_module_fsdp,
|
| 447 |
+
layered_summon=self.layered_summon,
|
| 448 |
+
base_sync_done=False,
|
| 449 |
+
)
|
| 450 |
+
base_model_params = {replace_lora_wrapper(k, peft_config): v for k, v in base_model_params.items()}
|
| 451 |
+
# Filter proj_z from base model params as well
|
| 452 |
+
base_model_params = {k: v for k, v in base_model_params.items() if not k.startswith("proj_z")}
|
| 453 |
+
base_model_params = convert_weight_keys(
|
| 454 |
+
base_model_params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger)
|
| 458 |
+
if self._is_offload_param:
|
| 459 |
+
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
|
| 460 |
+
log_gpu_memory_usage("After offload_fsdp_model_to_cpu", logger=logger)
|
| 461 |
+
|
| 462 |
+
set_expandable_segments(False)
|
| 463 |
+
|
| 464 |
+
if peft_config is not None and self.base_sync_done:
|
| 465 |
+
per_tensor_param = params
|
| 466 |
+
else:
|
| 467 |
+
device = get_device_id()
|
| 468 |
+
per_tensor_param = (
|
| 469 |
+
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
|
| 470 |
+
for name, param in params.items()
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
if self.config.rollout.free_cache_engine:
|
| 474 |
+
await self.rollout.resume(tags=["weights"])
|
| 475 |
+
log_gpu_memory_usage("After resume weights", logger=logger)
|
| 476 |
+
|
| 477 |
+
if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2:
|
| 478 |
+
per_tensor_base_params = (
|
| 479 |
+
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
|
| 480 |
+
for name, param in base_model_params.items()
|
| 481 |
+
)
|
| 482 |
+
await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False)
|
| 483 |
+
del base_model_params, per_tensor_base_params
|
| 484 |
+
|
| 485 |
+
await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
|
| 486 |
+
log_gpu_memory_usage("After update_weights", logger=logger)
|
| 487 |
+
del params, per_tensor_param
|
| 488 |
+
aggressive_empty_cache(force_sync=True)
|
| 489 |
+
if self.config.rollout.free_cache_engine:
|
| 490 |
+
await self.rollout.resume(tags=["kv_cache"])
|
| 491 |
+
log_gpu_memory_usage("After resume kv_cache", logger=logger)
|
| 492 |
+
|
| 493 |
+
self.base_sync_done = True
|
| 494 |
+
self.torch_random_states = get_torch_device().get_rng_state()
|
| 495 |
+
get_torch_device().set_rng_state(self.gen_random_states)
|
ICL/DAPO/verl-recipe/flowrl/main_flowrl.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Main training script for FlowRL algorithm."""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import socket
|
| 20 |
+
|
| 21 |
+
import hydra
|
| 22 |
+
import ray
|
| 23 |
+
from omegaconf import OmegaConf
|
| 24 |
+
|
| 25 |
+
from verl.trainer.ppo.reward import load_reward_manager
|
| 26 |
+
from verl.utils.device import is_cuda_available
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@hydra.main(config_path="config", config_name="flowrl_trainer", version_base=None)
|
| 30 |
+
def main(config):
|
| 31 |
+
run_flowrl(config)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def run_flowrl(config) -> None:
|
| 35 |
+
if not ray.is_initialized():
|
| 36 |
+
# this is for local ray cluster
|
| 37 |
+
default_runtime_env = {
|
| 38 |
+
"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
|
| 39 |
+
}
|
| 40 |
+
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
|
| 41 |
+
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
|
| 42 |
+
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
|
| 43 |
+
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
|
| 44 |
+
print(f"ray init kwargs: {ray_init_kwargs}")
|
| 45 |
+
ray.init(**OmegaConf.to_container(ray_init_kwargs))
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
if (
|
| 49 |
+
is_cuda_available
|
| 50 |
+
and config.global_profiler.tool == "nsys"
|
| 51 |
+
and OmegaConf.select(config.global_profiler, "steps") is not None
|
| 52 |
+
and len(OmegaConf.select(config.global_profiler, "steps")) > 0
|
| 53 |
+
):
|
| 54 |
+
nsight_options = OmegaConf.to_container(
|
| 55 |
+
config.global_profiler.global_tool_config.nsys.controller_nsight_options
|
| 56 |
+
)
|
| 57 |
+
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
|
| 58 |
+
else:
|
| 59 |
+
runner = TaskRunner.remote()
|
| 60 |
+
ray.get(runner.run.remote(config))
|
| 61 |
+
finally:
|
| 62 |
+
if ray.is_initialized():
|
| 63 |
+
ray.shutdown()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
|
| 67 |
+
class TaskRunner:
|
| 68 |
+
def run(self, config):
|
| 69 |
+
# print initial config
|
| 70 |
+
from pprint import pprint
|
| 71 |
+
|
| 72 |
+
from omegaconf import OmegaConf
|
| 73 |
+
|
| 74 |
+
from verl.utils.fs import copy_to_local
|
| 75 |
+
|
| 76 |
+
print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
|
| 77 |
+
|
| 78 |
+
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
| 79 |
+
OmegaConf.resolve(config)
|
| 80 |
+
|
| 81 |
+
# download the checkpoint from hdfs
|
| 82 |
+
local_path = copy_to_local(config.actor_rollout_ref.model.path)
|
| 83 |
+
|
| 84 |
+
# instantiate tokenizer
|
| 85 |
+
from verl.utils import hf_processor, hf_tokenizer
|
| 86 |
+
|
| 87 |
+
tokenizer = hf_tokenizer(local_path)
|
| 88 |
+
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
|
| 89 |
+
|
| 90 |
+
from verl.single_controller.ray import RayWorkerGroup
|
| 91 |
+
|
| 92 |
+
# define worker classes
|
| 93 |
+
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
|
| 94 |
+
assert config.critic.strategy in {"fsdp", "fsdp2"}
|
| 95 |
+
|
| 96 |
+
# Use FlowRL custom worker instead of standard worker
|
| 97 |
+
from recipe.flowrl.flowrl_fsdp_worker import FlowRLActorRolloutRefWorker
|
| 98 |
+
|
| 99 |
+
from verl.workers.fsdp_workers import CriticWorker # , ActorRolloutRefWorker
|
| 100 |
+
|
| 101 |
+
ActorRolloutRefWorker = FlowRLActorRolloutRefWorker
|
| 102 |
+
ray_worker_group_cls = RayWorkerGroup
|
| 103 |
+
|
| 104 |
+
elif config.actor_rollout_ref.actor.strategy == "megatron":
|
| 105 |
+
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
| 106 |
+
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
|
| 107 |
+
|
| 108 |
+
ray_worker_group_cls = RayWorkerGroup
|
| 109 |
+
|
| 110 |
+
else:
|
| 111 |
+
raise NotImplementedError
|
| 112 |
+
|
| 113 |
+
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
|
| 114 |
+
|
| 115 |
+
role_worker_mapping = {
|
| 116 |
+
Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
|
| 117 |
+
Role.Critic: ray.remote(CriticWorker),
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
global_pool_id = "global_pool"
|
| 121 |
+
resource_pool_spec = {
|
| 122 |
+
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
| 123 |
+
}
|
| 124 |
+
mapping = {
|
| 125 |
+
Role.ActorRollout: global_pool_id,
|
| 126 |
+
Role.Critic: global_pool_id,
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# we should adopt a multi-source reward function here
|
| 130 |
+
# - for rule-based rm, we directly call a reward score
|
| 131 |
+
# - for model-based rm, we call a model
|
| 132 |
+
# - for code related prompt, we send to a sandbox if there are test cases
|
| 133 |
+
# - finally, we combine all the rewards together
|
| 134 |
+
# - The reward type depends on the tag of the data
|
| 135 |
+
if config.reward_model.enable:
|
| 136 |
+
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
|
| 137 |
+
from verl.workers.fsdp_workers import RewardModelWorker
|
| 138 |
+
elif config.reward_model.strategy == "megatron":
|
| 139 |
+
from verl.workers.megatron_workers import RewardModelWorker
|
| 140 |
+
else:
|
| 141 |
+
raise NotImplementedError
|
| 142 |
+
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
|
| 143 |
+
mapping[Role.RewardModel] = global_pool_id
|
| 144 |
+
|
| 145 |
+
# reference model
|
| 146 |
+
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
|
| 147 |
+
role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
|
| 148 |
+
mapping[Role.RefPolicy] = global_pool_id
|
| 149 |
+
|
| 150 |
+
reward_fn = load_reward_manager(
|
| 151 |
+
config,
|
| 152 |
+
tokenizer,
|
| 153 |
+
0,
|
| 154 |
+
max_resp_len=config.data.max_response_length,
|
| 155 |
+
overlong_buffer_cfg=config.reward_model.overlong_buffer,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Note that we always use function-based RM for validation
|
| 159 |
+
val_reward_fn = load_reward_manager(
|
| 160 |
+
config,
|
| 161 |
+
tokenizer,
|
| 162 |
+
1,
|
| 163 |
+
max_resp_len=config.data.max_response_length,
|
| 164 |
+
overlong_buffer_cfg=config.reward_model.overlong_buffer,
|
| 165 |
+
)
|
| 166 |
+
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
| 167 |
+
|
| 168 |
+
from recipe.flowrl.flowrl_ray_trainer import RayFlowRLTrainer
|
| 169 |
+
|
| 170 |
+
trainer = RayFlowRLTrainer(
|
| 171 |
+
config=config,
|
| 172 |
+
tokenizer=tokenizer,
|
| 173 |
+
processor=processor,
|
| 174 |
+
role_worker_mapping=role_worker_mapping,
|
| 175 |
+
resource_pool_manager=resource_pool_manager,
|
| 176 |
+
ray_worker_group_cls=ray_worker_group_cls,
|
| 177 |
+
reward_fn=reward_fn,
|
| 178 |
+
val_reward_fn=val_reward_fn,
|
| 179 |
+
)
|
| 180 |
+
trainer.init_workers()
|
| 181 |
+
trainer.fit()
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
if __name__ == "__main__":
|
| 185 |
+
main()
|
ICL/DAPO/verl-recipe/flowrl/run_flowrl_qwen2.5_7b.sh
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -xeuo pipefail
|
| 3 |
+
|
| 4 |
+
project_name='FlowRL'
|
| 5 |
+
exp_name='FlowRL-Qwen2.5-7B'
|
| 6 |
+
|
| 7 |
+
# Algorithm settings
|
| 8 |
+
adv_estimator=grpo
|
| 9 |
+
|
| 10 |
+
# KL settings (ref policy needed for FlowRL, but KL penalty disabled)
|
| 11 |
+
use_kl_in_reward=False # Enable ref policy for ref_log_prob (needed for FlowRL loss)
|
| 12 |
+
kl_coef=0.0
|
| 13 |
+
use_kl_loss=True
|
| 14 |
+
kl_loss_coef=0.0
|
| 15 |
+
|
| 16 |
+
# Clip parameters
|
| 17 |
+
clip_ratio_low=0.2
|
| 18 |
+
clip_ratio_high=0.28
|
| 19 |
+
|
| 20 |
+
# Sequence lengths
|
| 21 |
+
max_prompt_length=$((1024 * 2))
|
| 22 |
+
max_response_length=$((1024 * 8))
|
| 23 |
+
|
| 24 |
+
# Overlong buffer for very long responses
|
| 25 |
+
enable_overlong_buffer=True
|
| 26 |
+
overlong_buffer_len=$((1024 * 4))
|
| 27 |
+
overlong_penalty_factor=1.0
|
| 28 |
+
|
| 29 |
+
# Batch sizes
|
| 30 |
+
train_prompt_bsz=512
|
| 31 |
+
gen_prompt_bsz=$((train_prompt_bsz * 3))
|
| 32 |
+
n_resp_per_prompt=8
|
| 33 |
+
train_prompt_mini_bsz=32
|
| 34 |
+
|
| 35 |
+
# Checkpoint saving frequency (-1 to disable periodic saves)
|
| 36 |
+
save_freq=-1
|
| 37 |
+
|
| 38 |
+
# Ray
|
| 39 |
+
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
| 40 |
+
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
| 41 |
+
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
| 42 |
+
NNODES=${NNODES:-1}
|
| 43 |
+
|
| 44 |
+
# Paths
|
| 45 |
+
MODEL_PATH=${MODEL_PATH:-"${WORKING_DIR}/downloads/models/Qwen/Qwen2.5-7B"}
|
| 46 |
+
CKPTS_DIR=${CKPTS_DIR:-"${WORKING_DIR}/outputs/ckpts/${project_name}/${exp_name}"}
|
| 47 |
+
TRAIN_FILE=${TRAIN_FILE:-"${WORKING_DIR}/downloads/data/dapo-math-17k.parquet"}
|
| 48 |
+
TEST_FILE=${TEST_FILE:-"${WORKING_DIR}/downloads/data/aime-2024.parquet"}
|
| 49 |
+
|
| 50 |
+
# Sampling
|
| 51 |
+
temperature=1.0
|
| 52 |
+
top_p=1.0
|
| 53 |
+
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
| 54 |
+
val_top_p=0.7
|
| 55 |
+
|
| 56 |
+
# Performance Related Parameter
|
| 57 |
+
n_gpus=8
|
| 58 |
+
sp_size=1
|
| 59 |
+
use_dynamic_bsz=True
|
| 60 |
+
actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
|
| 61 |
+
infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
|
| 62 |
+
offload=False
|
| 63 |
+
gen_tp=1
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
python3 -m recipe.flowrl.main_flowrl \
|
| 67 |
+
data.train_files="${TRAIN_FILE}" \
|
| 68 |
+
data.val_files="${TEST_FILE}" \
|
| 69 |
+
data.prompt_key=prompt \
|
| 70 |
+
data.truncation='left' \
|
| 71 |
+
data.max_prompt_length=${max_prompt_length} \
|
| 72 |
+
data.max_response_length=${max_response_length} \
|
| 73 |
+
data.gen_batch_size=${gen_prompt_bsz} \
|
| 74 |
+
data.train_batch_size=${train_prompt_bsz} \
|
| 75 |
+
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
| 76 |
+
algorithm.adv_estimator=${adv_estimator} \
|
| 77 |
+
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
| 78 |
+
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
| 79 |
+
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
| 80 |
+
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
| 81 |
+
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
| 82 |
+
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
| 83 |
+
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
| 84 |
+
actor_rollout_ref.model.use_remove_padding=True \
|
| 85 |
+
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
| 86 |
+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
| 87 |
+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
| 88 |
+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
| 89 |
+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
| 90 |
+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
| 91 |
+
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
| 92 |
+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
| 93 |
+
actor_rollout_ref.actor.optim.lr=1e-6 \
|
| 94 |
+
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
| 95 |
+
actor_rollout_ref.actor.optim.warmup_style='constant' \
|
| 96 |
+
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
| 97 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
| 98 |
+
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
|
| 99 |
+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
|
| 100 |
+
actor_rollout_ref.actor.entropy_coeff=0 \
|
| 101 |
+
actor_rollout_ref.actor.grad_clip=1.0 \
|
| 102 |
+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
| 103 |
+
actor_rollout_ref.rollout.calculate_log_probs=True \
|
| 104 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
| 105 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
| 106 |
+
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
| 107 |
+
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
| 108 |
+
actor_rollout_ref.rollout.temperature=${temperature} \
|
| 109 |
+
actor_rollout_ref.rollout.top_p=${top_p} \
|
| 110 |
+
actor_rollout_ref.rollout.top_k="${top_k}" \
|
| 111 |
+
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
| 112 |
+
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
| 113 |
+
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
| 114 |
+
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
| 115 |
+
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
| 116 |
+
actor_rollout_ref.rollout.name=vllm \
|
| 117 |
+
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
|
| 118 |
+
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
| 119 |
+
actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
|
| 120 |
+
reward_model.reward_manager=dapo \
|
| 121 |
+
reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
|
| 122 |
+
reward_model.overlong_buffer.len=${overlong_buffer_len} \
|
| 123 |
+
reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
|
| 124 |
+
trainer.logger='["console","wandb"]' \
|
| 125 |
+
trainer.project_name="${project_name}" \
|
| 126 |
+
trainer.experiment_name="${exp_name}" \
|
| 127 |
+
trainer.n_gpus_per_node=${n_gpus} \
|
| 128 |
+
trainer.nnodes="${NNODES}" \
|
| 129 |
+
trainer.val_before_train=True \
|
| 130 |
+
trainer.test_freq=10 \
|
| 131 |
+
trainer.save_freq=${save_freq} \
|
| 132 |
+
trainer.total_epochs=1 \
|
| 133 |
+
trainer.default_local_dir="${CKPTS_DIR}" \
|
| 134 |
+
trainer.resume_mode=auto
|
ICL/DAPO/verl-recipe/infigui-g1/README.md
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Recipe for InfiGUI-G1
|
| 2 |
+
|
| 3 |
+
This directory contains the official implementation for the paper [InfiGUI-G1: Advancing GUI Grounding with Adaptive Exploration Policy Optimization](https://arxiv.org/abs/2508.05731).
|
| 4 |
+
|
| 5 |
+
This work introduces Adaptive Exploration Policy Optimization (AEPO), a policy optimization framework designed to enhance GUI grounding in Multimodal Large Language Models (MLLMs). AEPO improves exploration efficiency by employing a multi-answer generation strategy and a theoretically grounded Adaptive Exploration Reward (AER) function. This approach effectively addresses the challenge of semantic alignment in complex GUI grounding tasks.
|
| 6 |
+
|
| 7 |
+
We provide training scripts for both 3B and 7B models, configured for a single machine with 8 GPUs by default.
|
| 8 |
+
|
| 9 |
+
## Environment Setup
|
| 10 |
+
|
| 11 |
+
Please follow the main environment setup guide for `verl`.
|
| 12 |
+
|
| 13 |
+
The provided scripts use the following Docker image: `verlai/verl:app-verl0.5-transformers4.55.4-sglang0.4.10.post2-mcore0.13.0-te2.2`
|
| 14 |
+
|
| 15 |
+
## Data Preparation
|
| 16 |
+
|
| 17 |
+
Before starting the training, you need to download the example dataset. This dataset is a filtered version of [omniact](https://huggingface.co/datasets/Writer/omniact), containing only grounding tasks and excluding easy samples.
|
| 18 |
+
|
| 19 |
+
The data is hosted on the Hugging Face. You can download it using the `huggingface-cli`:
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
huggingface-cli download --repo-type dataset --resume-download InfiX-ai/omniact_grounding_filtered --local-dir data/omniact_grounding_filtered
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
This command will download the training and validation parquet files into the `data/omniact_grounding_filtered` directory, which is the default path used by the scripts.
|
| 26 |
+
|
| 27 |
+
## Training
|
| 28 |
+
|
| 29 |
+
We provide scripts to train the 3B and 7B models. Please run them from the root directory of `verl`.
|
| 30 |
+
|
| 31 |
+
- **Train the 3B model:**
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
bash recipe/infigui-g1/run_3b.sh
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
- **Train the 7B model:**
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
bash recipe/infigui-g1/run_7b.sh
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
## Using Custom Data
|
| 44 |
+
|
| 45 |
+
If you wish to train on your own dataset, please format your data to match the structure of the example files located in `data/omniact_grounding_filtered`.
|
| 46 |
+
|
| 47 |
+
Once your data is ready, you need to update the data path arguments in the training script.
|
| 48 |
+
|
| 49 |
+
In `run_3b.sh` or `run_7b.sh`, modify the following lines:
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
data.train_files=./path/to/your/train_data.parquet \
|
| 53 |
+
data.val_files=./path/to/your/val_data.parquet \
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Replace the paths with the location of your custom data files.
|
ICL/DAPO/verl-recipe/langgraph_agent/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
ICL/DAPO/verl-recipe/langgraph_agent/chat_model.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Ref: https://python.langchain.com/docs/how_to/custom_chat_model/
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import asyncio
|
| 19 |
+
import json
|
| 20 |
+
import logging
|
| 21 |
+
import os
|
| 22 |
+
import uuid
|
| 23 |
+
from typing import Any, Optional
|
| 24 |
+
|
| 25 |
+
from langchain_core.language_models import BaseChatModel
|
| 26 |
+
from langchain_core.language_models.base import LanguageModelInput
|
| 27 |
+
from langchain_core.messages import (
|
| 28 |
+
AIMessage,
|
| 29 |
+
BaseMessage,
|
| 30 |
+
convert_to_openai_messages,
|
| 31 |
+
)
|
| 32 |
+
from langchain_core.messages.tool import InvalidToolCall, ToolCall
|
| 33 |
+
from langchain_core.outputs import ChatGeneration, ChatResult
|
| 34 |
+
from langchain_core.runnables import Runnable, RunnableConfig
|
| 35 |
+
from langchain_core.tools import StructuredTool
|
| 36 |
+
from langchain_core.utils.function_calling import convert_to_openai_tool
|
| 37 |
+
from pydantic import Field
|
| 38 |
+
|
| 39 |
+
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, AsyncLLMServerManager
|
| 40 |
+
from verl.experimental.agent_loop.tool_parser import ToolParser
|
| 41 |
+
from verl.experimental.agent_loop.utils import add_generation_prompt_for_gpt_oss, format_gpt_oss_tool_response_manually
|
| 42 |
+
|
| 43 |
+
logger = logging.getLogger(__file__)
|
| 44 |
+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class MaxTokenExceededError(Exception):
|
| 48 |
+
"""Indicate that history chat messages + tool message exceeds LLM max_tokens."""
|
| 49 |
+
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ChatModel(BaseChatModel):
|
| 54 |
+
model_name: str = Field(alias="model")
|
| 55 |
+
"""The name of the model"""
|
| 56 |
+
|
| 57 |
+
client: AsyncLLMServerManager
|
| 58 |
+
"""AsyncLLM server manager"""
|
| 59 |
+
|
| 60 |
+
tokenizer: Any
|
| 61 |
+
"""Tokenizer for the model"""
|
| 62 |
+
|
| 63 |
+
max_tokens: int
|
| 64 |
+
"""Max tokens to generate"""
|
| 65 |
+
|
| 66 |
+
tool_parser: str = "hermes"
|
| 67 |
+
"""Tool parser for the model"""
|
| 68 |
+
|
| 69 |
+
max_parallel_calls: int = 1
|
| 70 |
+
"""Max parallel tool calls"""
|
| 71 |
+
|
| 72 |
+
temperature: float = 1.0
|
| 73 |
+
"""Temperature for sampling"""
|
| 74 |
+
|
| 75 |
+
top_p: float = 1.0
|
| 76 |
+
"""Top p for sampling"""
|
| 77 |
+
|
| 78 |
+
repetition_penalty: float = 1.0
|
| 79 |
+
"""Repetition penalty for sampling"""
|
| 80 |
+
|
| 81 |
+
def bind_tools(self, tools, **kwargs) -> Runnable[LanguageModelInput, BaseMessage]:
|
| 82 |
+
"""Bind tools to the model.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
tools: Sequence of tools to bind to the model.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
A Runnable that returns a message.
|
| 89 |
+
"""
|
| 90 |
+
formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools]
|
| 91 |
+
|
| 92 |
+
# used to remove system prompt prefix when encoding tool response
|
| 93 |
+
system_prompt = self.tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)
|
| 94 |
+
kwargs["system_prompt"] = system_prompt
|
| 95 |
+
|
| 96 |
+
return self.bind(tools=formatted_tools, **kwargs)
|
| 97 |
+
|
| 98 |
+
def with_structured_output(
|
| 99 |
+
self,
|
| 100 |
+
schema: dict | type,
|
| 101 |
+
*,
|
| 102 |
+
include_raw: bool = False,
|
| 103 |
+
**kwargs: Any,
|
| 104 |
+
) -> Runnable[LanguageModelInput, dict | BaseChatModel]:
|
| 105 |
+
"""Ref: https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/"""
|
| 106 |
+
raise NotImplementedError
|
| 107 |
+
|
| 108 |
+
def _generate(
|
| 109 |
+
self,
|
| 110 |
+
messages: list[BaseMessage],
|
| 111 |
+
stop: Optional[list[str]] = None,
|
| 112 |
+
**kwargs: Any,
|
| 113 |
+
) -> ChatResult:
|
| 114 |
+
raise NotImplementedError
|
| 115 |
+
|
| 116 |
+
async def _agenerate(
|
| 117 |
+
self,
|
| 118 |
+
messages: list[BaseMessage],
|
| 119 |
+
stop: Optional[list[str]] = None,
|
| 120 |
+
**kwargs: Any,
|
| 121 |
+
) -> ChatResult:
|
| 122 |
+
"""Asynchronously generate chat completion message.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
messages (list[BaseMessage]): List of list of messages.
|
| 126 |
+
stop (Optional[list[str]], optional): Stop words to use when generating. Model output is cut off at the
|
| 127 |
+
first occurrence of any of these substrings. Defaults to None.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
ChatResult: Chat result.
|
| 131 |
+
"""
|
| 132 |
+
request_id, prompt_ids, response_mask = await self._preprocess(messages, **kwargs)
|
| 133 |
+
|
| 134 |
+
sampling_params = {
|
| 135 |
+
"temperature": self.temperature,
|
| 136 |
+
"top_p": self.top_p,
|
| 137 |
+
"repetition_penalty": self.repetition_penalty,
|
| 138 |
+
}
|
| 139 |
+
if "sampling_params" in kwargs:
|
| 140 |
+
sampling_params.update(kwargs["sampling_params"])
|
| 141 |
+
|
| 142 |
+
output = await self.client.generate(
|
| 143 |
+
request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
message = await self._postprocess(request_id, prompt_ids, response_mask, output.token_ids, **kwargs)
|
| 147 |
+
generation = ChatGeneration(message=message)
|
| 148 |
+
return ChatResult(generations=[generation])
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
def _llm_type(self) -> str:
|
| 152 |
+
"""Get the type of language model used by this chat model."""
|
| 153 |
+
return self.model_name
|
| 154 |
+
|
| 155 |
+
async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple[str, list[int], list[int]]:
|
| 156 |
+
"""Preprocess messages for chat completion.
|
| 157 |
+
|
| 158 |
+
To ensure strong consistency with policy model, AsyncLLM server generate response with token in token out
|
| 159 |
+
instead of messages list.
|
| 160 |
+
|
| 161 |
+
But all agent frameworks use messages list to represent chat history. To mitigate the gap, we store trajectory
|
| 162 |
+
(prompt_ids, response_mask) in lastest AIMessage.response_metadata.
|
| 163 |
+
|
| 164 |
+
1. Encode ToolMessage to token ids.
|
| 165 |
+
2. Retrieve trajectory (prompt_ids, response_mask) from lastest AIMessage.response_metadata.
|
| 166 |
+
3. Append ToolMessage token ids to prompt_ids, and append 0 to response_mask.
|
| 167 |
+
|
| 168 |
+
Ref: https://python.langchain.com/docs/concepts/chat_history/
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
messages (list[BaseMessage]): List of messages.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
tuple[str, list[int], list[int]]: Request id, prompt ids, response mask.
|
| 175 |
+
"""
|
| 176 |
+
# messages: [system], human, ai, human|tool, ai, human|tool, ...
|
| 177 |
+
assert messages[-1].type in ["human", "tool"], (
|
| 178 |
+
f"Last message must be human or tool, but got {messages[-1].type}"
|
| 179 |
+
)
|
| 180 |
+
loop = asyncio.get_running_loop()
|
| 181 |
+
|
| 182 |
+
# Case 1: initial chat completion: [system], human
|
| 183 |
+
if messages[-1].type == "human" and (len(messages) == 1 or messages[-2].type != "ai"):
|
| 184 |
+
prompt_ids = await loop.run_in_executor(
|
| 185 |
+
None,
|
| 186 |
+
lambda: self.tokenizer.apply_chat_template(
|
| 187 |
+
convert_to_openai_messages(messages),
|
| 188 |
+
tools=kwargs.get("tools"),
|
| 189 |
+
add_generation_prompt=True,
|
| 190 |
+
tokenize=True,
|
| 191 |
+
),
|
| 192 |
+
)
|
| 193 |
+
return str(uuid.uuid4()), prompt_ids, []
|
| 194 |
+
|
| 195 |
+
# Case 2: follow up chat completion with tool/human response: [system], human, ai, human|tool, ...
|
| 196 |
+
for i in range(len(messages) - 1, -1, -1):
|
| 197 |
+
if messages[i].type == "ai":
|
| 198 |
+
break
|
| 199 |
+
assert "prompt_ids" in messages[i].response_metadata, "Last message must have prompt_ids in response_metadata"
|
| 200 |
+
assert "response_mask" in messages[i].response_metadata, (
|
| 201 |
+
"Last message must have response_mask in response_metadata"
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# encode tool response
|
| 205 |
+
tool_responses = convert_to_openai_messages(messages[i + 1 :])
|
| 206 |
+
if self.tool_parser == "hermes":
|
| 207 |
+
tool_response_ids = await loop.run_in_executor(
|
| 208 |
+
None,
|
| 209 |
+
lambda messages=tool_responses: self.tokenizer.apply_chat_template(
|
| 210 |
+
messages, add_generation_prompt=True, tokenize=True
|
| 211 |
+
),
|
| 212 |
+
)
|
| 213 |
+
tool_response_ids = tool_response_ids[len(kwargs["system_prompt"]) :]
|
| 214 |
+
elif self.tool_parser == "gpt-oss":
|
| 215 |
+
# Format tool responses manually
|
| 216 |
+
# since gpt-oss chat template requires tool call messages to parse tool response messages
|
| 217 |
+
# we need to format the tool response messages manually
|
| 218 |
+
tool_response_texts = []
|
| 219 |
+
for tool_msg in tool_responses:
|
| 220 |
+
if tool_msg["role"] == "tool":
|
| 221 |
+
# Use tool message's name if available (for multiple tool calls)
|
| 222 |
+
actual_tool_name = tool_msg.get("name", "unknown")
|
| 223 |
+
if actual_tool_name == "unknown":
|
| 224 |
+
logger.error(f"actual_tool_name: {actual_tool_name}")
|
| 225 |
+
formatted = format_gpt_oss_tool_response_manually(tool_msg["content"], actual_tool_name)
|
| 226 |
+
tool_response_texts.append(formatted)
|
| 227 |
+
|
| 228 |
+
# Tokenize the manually formatted tool responses
|
| 229 |
+
tool_response_text = "".join(tool_response_texts)
|
| 230 |
+
# need to add generation tokens for gpt-oss manually since add_generation_prompt is True
|
| 231 |
+
tool_response_text = add_generation_prompt_for_gpt_oss(tool_response_text)
|
| 232 |
+
logger.debug(f"tool_response_text: {tool_response_text}")
|
| 233 |
+
|
| 234 |
+
tool_response_ids = await loop.run_in_executor(
|
| 235 |
+
None, lambda: self.tokenizer.encode(tool_response_text, add_special_tokens=False)
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError(f"Unsupported tool parser: {self.tool_parser}")
|
| 239 |
+
|
| 240 |
+
# stop generation if response length exceeds max response length
|
| 241 |
+
if len(messages[i].response_metadata["response_mask"]) + len(tool_response_ids) >= self.max_tokens:
|
| 242 |
+
raise MaxTokenExceededError(f"Max response length {self.max_tokens} exceeded")
|
| 243 |
+
|
| 244 |
+
# append tool response to prompt
|
| 245 |
+
request_id = messages[i].response_metadata.pop("request_id")
|
| 246 |
+
prompt_ids = messages[i].response_metadata.pop("prompt_ids")
|
| 247 |
+
response_mask = messages[i].response_metadata.pop("response_mask")
|
| 248 |
+
prompt_ids += tool_response_ids
|
| 249 |
+
response_mask += [0] * len(tool_response_ids)
|
| 250 |
+
|
| 251 |
+
return request_id, prompt_ids, response_mask
|
| 252 |
+
|
| 253 |
+
async def _postprocess(
|
| 254 |
+
self, request_id: str, prompt_ids: list[int], response_mask: list[int], response_ids: list[int], **kwargs: Any
|
| 255 |
+
) -> AIMessage:
|
| 256 |
+
"""Postprocess response_ids when chat completion is done.
|
| 257 |
+
|
| 258 |
+
1. Decode response_ids, parse tool calls to AIMessage.
|
| 259 |
+
2. Append response_ids to prompt_ids, and append 1 to response_mask.
|
| 260 |
+
3. Store trajectory (prompt_ids, response_mask) in AIMessage.response_metadata.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
request_id (str): Unique request id.
|
| 264 |
+
prompt_ids (list[int]): Input prompt token ids in this chat completion.
|
| 265 |
+
response_mask (list[int]): Response mask before this chat completion.
|
| 266 |
+
response_ids (list[int]): LLM generated token ids in this chat completion.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
AIMessage: Postprocessed message.
|
| 270 |
+
"""
|
| 271 |
+
prompt_ids += response_ids
|
| 272 |
+
response_mask += [1] * len(response_ids)
|
| 273 |
+
|
| 274 |
+
tool_parser = ToolParser.get_tool_parser(self.tool_parser, self.tokenizer)
|
| 275 |
+
content, function_calls = await tool_parser.extract_tool_calls(response_ids)
|
| 276 |
+
|
| 277 |
+
tool_calls, invalid_tool_calls = [], []
|
| 278 |
+
|
| 279 |
+
for function_call in function_calls:
|
| 280 |
+
error = None
|
| 281 |
+
try:
|
| 282 |
+
args = json.loads(function_call.arguments)
|
| 283 |
+
if not isinstance(args, dict):
|
| 284 |
+
error = f"Tool arguments must be a JSON object, got {type(args).__name__}"
|
| 285 |
+
except json.JSONDecodeError as e:
|
| 286 |
+
error = f"Invalid JSON tool arguments: {e}"
|
| 287 |
+
|
| 288 |
+
if error:
|
| 289 |
+
logger.warning(error)
|
| 290 |
+
invalid_tool_calls.append(
|
| 291 |
+
InvalidToolCall(
|
| 292 |
+
name=function_call.name,
|
| 293 |
+
args=function_call.arguments,
|
| 294 |
+
id=str(uuid.uuid4()),
|
| 295 |
+
error=error,
|
| 296 |
+
)
|
| 297 |
+
)
|
| 298 |
+
else:
|
| 299 |
+
tool_calls.append(
|
| 300 |
+
ToolCall(
|
| 301 |
+
name=function_call.name,
|
| 302 |
+
args=args,
|
| 303 |
+
id=str(uuid.uuid4()),
|
| 304 |
+
)
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
message = AIMessage(
|
| 308 |
+
content=content,
|
| 309 |
+
tool_calls=tool_calls[: self.max_parallel_calls],
|
| 310 |
+
invalid_tool_calls=invalid_tool_calls[: self.max_parallel_calls],
|
| 311 |
+
response_metadata={
|
| 312 |
+
"request_id": request_id,
|
| 313 |
+
"prompt_ids": prompt_ids,
|
| 314 |
+
"response_mask": response_mask,
|
| 315 |
+
},
|
| 316 |
+
)
|
| 317 |
+
return message
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class TruncateStructuredTool(StructuredTool):
|
| 321 |
+
"""Structured tool with response truncation."""
|
| 322 |
+
|
| 323 |
+
tool_response_truncate_side: str
|
| 324 |
+
"""truncate side of tool response: left, middle, right"""
|
| 325 |
+
|
| 326 |
+
max_tool_response_length: int
|
| 327 |
+
"""max length of tool response"""
|
| 328 |
+
|
| 329 |
+
async def _arun(
|
| 330 |
+
self,
|
| 331 |
+
*args: Any,
|
| 332 |
+
config: RunnableConfig,
|
| 333 |
+
**kwargs: Any,
|
| 334 |
+
) -> Any:
|
| 335 |
+
tool_response = await super()._arun(*args, config=config, **kwargs)
|
| 336 |
+
tool_response = str(tool_response)
|
| 337 |
+
|
| 338 |
+
if len(tool_response) > self.max_tool_response_length:
|
| 339 |
+
if self.tool_response_truncate_side == "left":
|
| 340 |
+
tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)"
|
| 341 |
+
elif self.tool_response_truncate_side == "right":
|
| 342 |
+
tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :]
|
| 343 |
+
else:
|
| 344 |
+
length = self.max_tool_response_length // 2
|
| 345 |
+
tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:]
|
| 346 |
+
|
| 347 |
+
return tool_response
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def convert_to_agent_output(messages: list[BaseMessage], response_length: int) -> AgentLoopOutput:
|
| 351 |
+
"""Convert messages to AgentLoopOutput.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
messages (List[BaseMessage]): List of messages, last message must be assistant
|
| 355 |
+
with response_metadata containing `prompt_ids` and `response_mask`.
|
| 356 |
+
response_length (int): Max length of response.
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
AgentLoopOutput: agent loop output trajectory used for training.
|
| 360 |
+
"""
|
| 361 |
+
# skip last tool calls
|
| 362 |
+
for i in range(len(messages) - 1, -1, -1):
|
| 363 |
+
if messages[i].type != "tool":
|
| 364 |
+
break
|
| 365 |
+
last_message = messages[i]
|
| 366 |
+
assert last_message.type == "ai", f"Last message must be assistant, but got {last_message.type}"
|
| 367 |
+
assert "prompt_ids" in last_message.response_metadata, "Last message must have prompt_ids in response_metadata"
|
| 368 |
+
assert "response_mask" in last_message.response_metadata, (
|
| 369 |
+
"Last message must have response_mask in response_metadata"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
num_turns = 0
|
| 373 |
+
for i in range(len(messages)):
|
| 374 |
+
if messages[i].type == "system":
|
| 375 |
+
continue
|
| 376 |
+
# parallel tool calls are in single turn
|
| 377 |
+
if i == 0 or messages[i].type != messages[i - 1].type:
|
| 378 |
+
num_turns += 1
|
| 379 |
+
|
| 380 |
+
prompt_ids = last_message.response_metadata["prompt_ids"]
|
| 381 |
+
response_mask = last_message.response_metadata["response_mask"]
|
| 382 |
+
|
| 383 |
+
response_ids = prompt_ids[-len(response_mask) :]
|
| 384 |
+
prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]
|
| 385 |
+
|
| 386 |
+
output = AgentLoopOutput(
|
| 387 |
+
prompt_ids=prompt_ids,
|
| 388 |
+
response_ids=response_ids[:response_length],
|
| 389 |
+
response_mask=response_mask[:response_length],
|
| 390 |
+
num_turns=num_turns,
|
| 391 |
+
metrics={},
|
| 392 |
+
)
|
| 393 |
+
return output
|
ICL/DAPO/verl-recipe/langgraph_agent/react_agent_loop.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
LangGraph React Agent Loop.
|
| 16 |
+
|
| 17 |
+
This implementation is exact same as `ToolAgentLoop`.
|
| 18 |
+
|
| 19 |
+
Ref: https://langchain-ai.github.io/langgraph/tutorials/workflows/
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import logging
|
| 23 |
+
from typing import Any, Literal
|
| 24 |
+
|
| 25 |
+
from langchain_core.messages import AIMessage
|
| 26 |
+
from langchain_core.runnables import RunnableConfig
|
| 27 |
+
from langgraph.graph import END, MessagesState, StateGraph
|
| 28 |
+
from langgraph.prebuilt import ToolNode
|
| 29 |
+
from recipe.langgraph_agent.chat_model import (
|
| 30 |
+
ChatModel,
|
| 31 |
+
MaxTokenExceededError,
|
| 32 |
+
convert_to_agent_output,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
async def call_model(state: MessagesState, config: RunnableConfig):
|
| 41 |
+
model = config["configurable"]["model"]
|
| 42 |
+
sampling_params = config["configurable"]["sampling_params"]
|
| 43 |
+
try:
|
| 44 |
+
message = await model.ainvoke(state["messages"], sampling_params=sampling_params)
|
| 45 |
+
return {"messages": [message]}
|
| 46 |
+
except MaxTokenExceededError:
|
| 47 |
+
# last message is ToolMessage
|
| 48 |
+
return {"messages": []}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def should_continue(state: MessagesState, config: RunnableConfig) -> Literal["tools", END]:
|
| 52 |
+
# Safely extract max_assistant_turns from config
|
| 53 |
+
max_assistant_turns = None
|
| 54 |
+
try:
|
| 55 |
+
if config and "configurable" in config:
|
| 56 |
+
max_assistant_turns = config["configurable"].get("max_assistant_turns")
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.warning(f"Failed to extract max_assistant_turns from config: {e}")
|
| 59 |
+
|
| 60 |
+
num_assistant_turns = 0
|
| 61 |
+
for message in state["messages"]:
|
| 62 |
+
if message.type == "ai":
|
| 63 |
+
num_assistant_turns += 1
|
| 64 |
+
|
| 65 |
+
last_message = state["messages"][-1]
|
| 66 |
+
|
| 67 |
+
# LLM call failed, e.g: max response length exceeded
|
| 68 |
+
if last_message.type == "tool":
|
| 69 |
+
return END
|
| 70 |
+
|
| 71 |
+
# max assistant turns exceeded
|
| 72 |
+
# Use a reasonable default limit (25) if max_assistant_turns is not set
|
| 73 |
+
# This prevents infinite loops
|
| 74 |
+
effective_max_turns = max_assistant_turns if max_assistant_turns is not None else 25
|
| 75 |
+
if num_assistant_turns >= effective_max_turns:
|
| 76 |
+
return END
|
| 77 |
+
|
| 78 |
+
# no tool calls
|
| 79 |
+
if not getattr(last_message, "tool_calls", None):
|
| 80 |
+
return END
|
| 81 |
+
|
| 82 |
+
return "tools"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class ReactAgentLoop(AgentLoopBase):
|
| 86 |
+
# Recursion limit calculation constants
|
| 87 |
+
DEFAULT_MAX_ASSISTANT_TURNS = 25
|
| 88 |
+
MIN_RECURSION_LIMIT = 50
|
| 89 |
+
NODES_PER_TURN = 2 # Each AI turn involves agent + tools nodes
|
| 90 |
+
RECURSION_LIMIT_SAFETY_FACTOR = 1.5 # 50% buffer for edge cases
|
| 91 |
+
|
| 92 |
+
@classmethod
|
| 93 |
+
def init_class(cls, config, tokenizer, **kwargs):
|
| 94 |
+
if cls._class_initialized:
|
| 95 |
+
return
|
| 96 |
+
cls._class_initialized = True
|
| 97 |
+
print("Performing class-level ReactAgentLoop initialization")
|
| 98 |
+
|
| 99 |
+
# build graph
|
| 100 |
+
cls.graph = cls.build_graph()
|
| 101 |
+
|
| 102 |
+
@classmethod
|
| 103 |
+
def build_graph(cls) -> StateGraph:
|
| 104 |
+
workflow = StateGraph(MessagesState)
|
| 105 |
+
|
| 106 |
+
workflow.add_node("agent", call_model)
|
| 107 |
+
workflow.add_node("tools", ToolNode(cls.tools))
|
| 108 |
+
workflow.set_entry_point("agent")
|
| 109 |
+
workflow.add_conditional_edges(
|
| 110 |
+
"agent",
|
| 111 |
+
should_continue,
|
| 112 |
+
{
|
| 113 |
+
"tools": "tools",
|
| 114 |
+
END: END,
|
| 115 |
+
},
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
workflow.add_edge("tools", "agent")
|
| 119 |
+
graph = workflow.compile()
|
| 120 |
+
return graph
|
| 121 |
+
|
| 122 |
+
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
|
| 123 |
+
messages = list(kwargs["raw_prompt"])
|
| 124 |
+
|
| 125 |
+
model_path = self.config.actor_rollout_ref.model.path
|
| 126 |
+
model_name = "/".join(model_path.split("/")[-2:])
|
| 127 |
+
|
| 128 |
+
rollout = self.config.actor_rollout_ref.rollout
|
| 129 |
+
model = ChatModel(
|
| 130 |
+
model=model_name,
|
| 131 |
+
client=self.server_manager,
|
| 132 |
+
tokenizer=self.tokenizer,
|
| 133 |
+
max_tokens=rollout.response_length,
|
| 134 |
+
max_parallel_calls=rollout.multi_turn.max_parallel_calls,
|
| 135 |
+
tool_parser=rollout.multi_turn.format,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
model = model.bind_tools(self.tools, tool_choice="any")
|
| 139 |
+
|
| 140 |
+
# Calculate recursion_limit dynamically based on max_assistant_turns
|
| 141 |
+
max_assistant_turns = (
|
| 142 |
+
rollout.multi_turn.max_assistant_turns
|
| 143 |
+
if rollout.multi_turn.max_assistant_turns
|
| 144 |
+
else self.DEFAULT_MAX_ASSISTANT_TURNS
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Formula: nodes_per_turn * max_turns * safety_buffer, with minimum threshold
|
| 148 |
+
recursion_limit = max(
|
| 149 |
+
self.MIN_RECURSION_LIMIT,
|
| 150 |
+
int(max_assistant_turns * self.NODES_PER_TURN * self.RECURSION_LIMIT_SAFETY_FACTOR),
|
| 151 |
+
)
|
| 152 |
+
logger.info(f"Configured recursion_limit={recursion_limit} (max_assistant_turns={max_assistant_turns})")
|
| 153 |
+
|
| 154 |
+
config = {
|
| 155 |
+
"configurable": {
|
| 156 |
+
"model": model,
|
| 157 |
+
"sampling_params": sampling_params,
|
| 158 |
+
"max_user_turns": rollout.multi_turn.max_user_turns,
|
| 159 |
+
"max_assistant_turns": rollout.multi_turn.max_assistant_turns,
|
| 160 |
+
},
|
| 161 |
+
"recursion_limit": recursion_limit,
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
# TODO: how to handle multiple trajectories in an graph invocation?
|
| 165 |
+
# Each graph node may has its own LLM calls and state, e.g:
|
| 166 |
+
# https://github.com/google-gemini/gemini-fullstack-langgraph-quickstart
|
| 167 |
+
try:
|
| 168 |
+
state = await self.graph.ainvoke(input={"messages": messages}, config=config)
|
| 169 |
+
except Exception as e:
|
| 170 |
+
logger.error(f"Agent loop execution failed: {type(e).__name__}: {e}")
|
| 171 |
+
logger.error("Falling back to a minimal dummy trajectory.")
|
| 172 |
+
|
| 173 |
+
# Fallback to a minimal assistant message so that
|
| 174 |
+
# convert_to_agent_output and downstream padding logic
|
| 175 |
+
# can still run without crashing.
|
| 176 |
+
dummy_id = 0
|
| 177 |
+
fallback_message = AIMessage(
|
| 178 |
+
content="[Agent execution failed - no valid trajectory]",
|
| 179 |
+
response_metadata={
|
| 180 |
+
"request_id": "fallback",
|
| 181 |
+
"prompt_ids": [dummy_id, dummy_id],
|
| 182 |
+
"response_mask": [1],
|
| 183 |
+
},
|
| 184 |
+
)
|
| 185 |
+
state = {"messages": [fallback_message]}
|
| 186 |
+
|
| 187 |
+
output = convert_to_agent_output(state["messages"], rollout.response_length)
|
| 188 |
+
return output
|
ICL/DAPO/verl-recipe/langgraph_agent/test_react_agent_loop.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import pytest
|
| 19 |
+
import ray
|
| 20 |
+
from langchain_core.tools import tool
|
| 21 |
+
from omegaconf import DictConfig
|
| 22 |
+
from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop
|
| 23 |
+
from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager
|
| 24 |
+
|
| 25 |
+
from verl.protocol import DataProto
|
| 26 |
+
from verl.utils import hf_tokenizer
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@pytest.fixture
|
| 30 |
+
def init_config() -> DictConfig:
|
| 31 |
+
from hydra import compose, initialize_config_dir
|
| 32 |
+
|
| 33 |
+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
| 34 |
+
config = compose(config_name="ppo_trainer")
|
| 35 |
+
model_path = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 36 |
+
config.actor_rollout_ref.model.path = model_path
|
| 37 |
+
config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
|
| 38 |
+
config.actor_rollout_ref.rollout.mode = "async"
|
| 39 |
+
config.actor_rollout_ref.rollout.prompt_length = 4096
|
| 40 |
+
config.actor_rollout_ref.rollout.response_length = 4096
|
| 41 |
+
config.actor_rollout_ref.rollout.n = 4
|
| 42 |
+
config.actor_rollout_ref.rollout.agent.num_workers = 2
|
| 43 |
+
|
| 44 |
+
config.actor_rollout_ref.actor.use_dynamic_bsz = True
|
| 45 |
+
# test sleep/wake_up with fsdp offload
|
| 46 |
+
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
|
| 47 |
+
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
|
| 48 |
+
|
| 49 |
+
return config
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@tool(parse_docstring=True)
|
| 53 |
+
def get_current_temperature(location: str, unit: str = "celsius"):
|
| 54 |
+
"""Get current temperature at a location.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
location: The location to get the temperature for, in the format "City, State, Country".
|
| 58 |
+
unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
the temperature, the location, and the unit in a dict
|
| 62 |
+
"""
|
| 63 |
+
print(f"[DEBUG] get_current_temperature: {location}, {unit}")
|
| 64 |
+
return {
|
| 65 |
+
"temperature": 26.1,
|
| 66 |
+
"location": location,
|
| 67 |
+
"unit": unit,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@tool(parse_docstring=True)
|
| 72 |
+
def get_temperature_date(location: str, date: str, unit: str = "celsius"):
|
| 73 |
+
"""Get temperature at a location and date.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
location: The location to get the temperature for, in the format "City, State, Country".
|
| 77 |
+
date: The date to get the temperature for, in the format "Year-Month-Day".
|
| 78 |
+
unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
the temperature, the location, the date and the unit in a dict
|
| 82 |
+
"""
|
| 83 |
+
print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}")
|
| 84 |
+
return {
|
| 85 |
+
"temperature": 25.9,
|
| 86 |
+
"location": location,
|
| 87 |
+
"date": date,
|
| 88 |
+
"unit": unit,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class TestReactAgentLoop(ReactAgentLoop):
|
| 93 |
+
@classmethod
|
| 94 |
+
def init_class(cls, config, tokenizer, **kwargs):
|
| 95 |
+
# TODO: find better way to configure tools
|
| 96 |
+
cls.tools = [get_current_temperature, get_temperature_date]
|
| 97 |
+
super().init_class(config, tokenizer, **kwargs)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def test_react_agent(init_config):
|
| 101 |
+
ray.init(
|
| 102 |
+
runtime_env={
|
| 103 |
+
"env_vars": {
|
| 104 |
+
"TOKENIZERS_PARALLELISM": "true",
|
| 105 |
+
"NCCL_DEBUG": "WARN",
|
| 106 |
+
"VLLM_LOGGING_LEVEL": "INFO",
|
| 107 |
+
"VLLM_USE_V1": "1",
|
| 108 |
+
}
|
| 109 |
+
}
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# =========================== 1. Init rollout manager ===========================
|
| 113 |
+
agent_loop_config = [
|
| 114 |
+
{
|
| 115 |
+
"_target_": "recipe.langgraph_agent.test_react_agent_loop.TestReactAgentLoop",
|
| 116 |
+
"name": "react_agent",
|
| 117 |
+
},
|
| 118 |
+
]
|
| 119 |
+
agent_loop_config_path = "/tmp/agent_loop_config.json"
|
| 120 |
+
with open(agent_loop_config_path, "w") as f:
|
| 121 |
+
json.dump(agent_loop_config, f)
|
| 122 |
+
|
| 123 |
+
n = 2
|
| 124 |
+
init_config.actor_rollout_ref.rollout.n = n
|
| 125 |
+
# init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path
|
| 126 |
+
init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2
|
| 127 |
+
init_config.actor_rollout_ref.rollout.agent.agent_loop_config_path = agent_loop_config_path
|
| 128 |
+
agent_loop_manager = init_agent_loop_manager(init_config)
|
| 129 |
+
|
| 130 |
+
# =========================== 2. Generate sequences ===========================
|
| 131 |
+
raw_prompts = [
|
| 132 |
+
[
|
| 133 |
+
{"role": "user", "content": "How are you?"},
|
| 134 |
+
],
|
| 135 |
+
[
|
| 136 |
+
{"role": "user", "content": "What's the temperature in Los Angeles now?"},
|
| 137 |
+
],
|
| 138 |
+
[
|
| 139 |
+
{"role": "user", "content": "What's the temperature in New York now?"},
|
| 140 |
+
],
|
| 141 |
+
[
|
| 142 |
+
{
|
| 143 |
+
"role": "system",
|
| 144 |
+
"content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n"
|
| 145 |
+
"Current Date: 2024-09-30",
|
| 146 |
+
},
|
| 147 |
+
{"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"},
|
| 148 |
+
],
|
| 149 |
+
]
|
| 150 |
+
batch = DataProto(
|
| 151 |
+
non_tensor_batch={
|
| 152 |
+
"raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
|
| 153 |
+
"agent_name": np.array(["react_agent"] * len(raw_prompts)),
|
| 154 |
+
"data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
|
| 155 |
+
"reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
|
| 156 |
+
},
|
| 157 |
+
)
|
| 158 |
+
batch = batch.repeat(n)
|
| 159 |
+
result = agent_loop_manager.generate_sequences(prompts=batch)
|
| 160 |
+
assert len(result) == len(raw_prompts) * n
|
| 161 |
+
|
| 162 |
+
# Check turns
|
| 163 |
+
num_turns = result.non_tensor_batch["__num_turns__"]
|
| 164 |
+
print(f"num_turns: {num_turns}")
|
| 165 |
+
for i in range(len(num_turns)):
|
| 166 |
+
if i // n == 0:
|
| 167 |
+
# [user, assistant]
|
| 168 |
+
assert num_turns[i] == 2
|
| 169 |
+
else:
|
| 170 |
+
# [user, assistant, tool, assistant]
|
| 171 |
+
assert num_turns[i] == 4
|
| 172 |
+
|
| 173 |
+
# Check response_mask
|
| 174 |
+
tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
|
| 175 |
+
responses = result.batch["responses"]
|
| 176 |
+
response_mask = result.batch["response_mask"]
|
| 177 |
+
attention_mask = result.batch["attention_mask"]
|
| 178 |
+
assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
|
| 179 |
+
response_length = response_mask.size(1)
|
| 180 |
+
|
| 181 |
+
for i in range(len(responses)):
|
| 182 |
+
# response with tool response
|
| 183 |
+
valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]
|
| 184 |
+
response_with_obs = tokenizer.decode(valid_tokens)
|
| 185 |
+
|
| 186 |
+
# response without tool response
|
| 187 |
+
valid_tokens = responses[i][response_mask[i].bool()]
|
| 188 |
+
response_without_obs = tokenizer.decode(valid_tokens)
|
| 189 |
+
|
| 190 |
+
assert "<tool_response>" not in response_without_obs, (
|
| 191 |
+
f"found <tool_response> in response: {response_without_obs}"
|
| 192 |
+
)
|
| 193 |
+
assert "</tool_response>" not in response_without_obs, (
|
| 194 |
+
f"found </tool_response> in response: {response_without_obs}"
|
| 195 |
+
)
|
| 196 |
+
print("=========================")
|
| 197 |
+
print(response_with_obs)
|
| 198 |
+
print("---")
|
| 199 |
+
print(response_without_obs)
|
| 200 |
+
|
| 201 |
+
print("Test passed!")
|
| 202 |
+
ray.shutdown()
|
ICL/DAPO/verl-recipe/minicpmo/rl_dataset.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright 2023-2024 SGLang Team
|
| 3 |
+
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import logging
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
import re
|
| 22 |
+
from typing import Optional
|
| 23 |
+
|
| 24 |
+
import datasets
|
| 25 |
+
import torch
|
| 26 |
+
from omegaconf import DictConfig, ListConfig
|
| 27 |
+
from PIL import Image
|
| 28 |
+
from torch.utils.data import Dataset
|
| 29 |
+
from torchvision import transforms
|
| 30 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin
|
| 31 |
+
|
| 32 |
+
import verl.utils.torch_functional as verl_F
|
| 33 |
+
from verl.utils.dataset.vision_utils import process_image
|
| 34 |
+
from verl.utils.model import compute_position_id_with_mask
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def build_transform():
|
| 40 |
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
|
| 41 |
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
|
| 42 |
+
return transforms.Compose(
|
| 43 |
+
[
|
| 44 |
+
transforms.ToTensor(),
|
| 45 |
+
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
| 46 |
+
]
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def build_image_bound(input_ids, tokenizer, new_schema=True, logger=None):
|
| 51 |
+
if new_schema:
|
| 52 |
+
start_cond = (input_ids == tokenizer.im_start_id) | (input_ids == tokenizer.slice_start_id)
|
| 53 |
+
end_cond = (input_ids == tokenizer.im_end_id) | (input_ids == tokenizer.slice_end_id)
|
| 54 |
+
else:
|
| 55 |
+
start_cond = input_ids == tokenizer.im_start_id
|
| 56 |
+
end_cond = input_ids == tokenizer.im_end_id
|
| 57 |
+
image_start_tokens = torch.where(start_cond)[0]
|
| 58 |
+
image_start_tokens += 1
|
| 59 |
+
image_end_tokens = torch.where(end_cond)[0]
|
| 60 |
+
if len(image_start_tokens) != len(image_end_tokens):
|
| 61 |
+
logger.error("image start token != image end tokens")
|
| 62 |
+
raise Exception("image start token != image end tokens")
|
| 63 |
+
if len(image_start_tokens) > 0:
|
| 64 |
+
image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)])
|
| 65 |
+
else:
|
| 66 |
+
image_bound = []
|
| 67 |
+
return image_bound
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def preprocess(
|
| 71 |
+
images_dict,
|
| 72 |
+
conversations,
|
| 73 |
+
tokenizer,
|
| 74 |
+
transform,
|
| 75 |
+
query_nums=64,
|
| 76 |
+
slice_config=None,
|
| 77 |
+
llm_type=None,
|
| 78 |
+
patch_size=14,
|
| 79 |
+
batch_vision=False,
|
| 80 |
+
max_length=2048,
|
| 81 |
+
truncation="error",
|
| 82 |
+
apply_chat_template_kwargs=None,
|
| 83 |
+
logger=None,
|
| 84 |
+
):
|
| 85 |
+
"""
|
| 86 |
+
single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
|
| 87 |
+
"""
|
| 88 |
+
conversations = copy.deepcopy(conversations)
|
| 89 |
+
assert conversations[0]["role"] == "user", "the first role must be user"
|
| 90 |
+
|
| 91 |
+
if slice_config is not None:
|
| 92 |
+
assert isinstance(slice_config, dict)
|
| 93 |
+
assert "patch_size" in slice_config
|
| 94 |
+
assert "max_slice_nums" in slice_config
|
| 95 |
+
assert "scale_resolution" in slice_config
|
| 96 |
+
default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
|
| 97 |
+
new_schema = False
|
| 98 |
+
use_image_id = False
|
| 99 |
+
if llm_type == "qwen":
|
| 100 |
+
new_schema = True
|
| 101 |
+
use_image_id = True
|
| 102 |
+
image_placeholder_dict = {}
|
| 103 |
+
images = []
|
| 104 |
+
image_id_cnt = 0
|
| 105 |
+
for img_name, image in images_dict.items():
|
| 106 |
+
if slice_config:
|
| 107 |
+
source_image, patches, best_grid = slice_image(
|
| 108 |
+
image,
|
| 109 |
+
slice_config["max_slice_nums"],
|
| 110 |
+
slice_config["scale_resolution"],
|
| 111 |
+
slice_config["patch_size"],
|
| 112 |
+
)
|
| 113 |
+
images.append(source_image)
|
| 114 |
+
image_placeholder = default_image_placeholder
|
| 115 |
+
if len(patches) > 0:
|
| 116 |
+
for i in range(len(patches)):
|
| 117 |
+
for j in range(len(patches[0])):
|
| 118 |
+
images.append(patches[i][j])
|
| 119 |
+
if use_image_id:
|
| 120 |
+
image_placeholder = (
|
| 121 |
+
f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder
|
| 122 |
+
)
|
| 123 |
+
image_id_cnt += 1
|
| 124 |
+
image_placeholder += get_grid_placeholder(tokenizer, best_grid, query_nums, new_schema=new_schema)
|
| 125 |
+
image_placeholder_dict[img_name] = image_placeholder
|
| 126 |
+
else:
|
| 127 |
+
images.append(image)
|
| 128 |
+
if use_image_id:
|
| 129 |
+
image_placeholder = f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder
|
| 130 |
+
image_id_cnt += 1
|
| 131 |
+
else:
|
| 132 |
+
image_placeholder = default_image_placeholder
|
| 133 |
+
image_placeholder_dict[img_name] = image_placeholder
|
| 134 |
+
|
| 135 |
+
images = [transform(i) for i in images]
|
| 136 |
+
|
| 137 |
+
if len(images_dict) == 1 and "<image>" in images_dict:
|
| 138 |
+
if "<image>" in conversations[0]["content"]:
|
| 139 |
+
conversations[0]["content"] = conversations[0]["content"].replace("<image>", image_placeholder)
|
| 140 |
+
else:
|
| 141 |
+
conversations[0]["content"] = image_placeholder + "\n" + conversations[0]["content"]
|
| 142 |
+
else:
|
| 143 |
+
pattern = r"<image_\d+>"
|
| 144 |
+
new_conversations = []
|
| 145 |
+
for conversation in conversations:
|
| 146 |
+
content = conversation["content"]
|
| 147 |
+
parts = re.split(f"({pattern})", content)
|
| 148 |
+
for i, part in enumerate(parts):
|
| 149 |
+
if not part.strip():
|
| 150 |
+
continue
|
| 151 |
+
if re.match(pattern, part):
|
| 152 |
+
if part in image_placeholder_dict:
|
| 153 |
+
parts[i] = image_placeholder_dict[part]
|
| 154 |
+
else:
|
| 155 |
+
raise Exception(f"not found {part} in image dict")
|
| 156 |
+
conversation["content"] = "\n".join(parts)
|
| 157 |
+
new_conversations.append(conversation)
|
| 158 |
+
conversations = new_conversations
|
| 159 |
+
|
| 160 |
+
# TODO change role in conversation for different llm
|
| 161 |
+
prompt_with_chat_template = tokenizer.apply_chat_template(
|
| 162 |
+
conversations, add_generation_prompt=True, tokenize=False, **(apply_chat_template_kwargs or {})
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
|
| 166 |
+
prompt=prompt_with_chat_template,
|
| 167 |
+
tokenizer=tokenizer,
|
| 168 |
+
max_length=max_length,
|
| 169 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 170 |
+
left_pad=True,
|
| 171 |
+
truncation=truncation,
|
| 172 |
+
)
|
| 173 |
+
position_ids = compute_position_id_with_mask(attention_mask)
|
| 174 |
+
image_bound = build_image_bound(input_ids[0], tokenizer, new_schema, logger)
|
| 175 |
+
|
| 176 |
+
input_dict = {
|
| 177 |
+
"input_ids": input_ids[0],
|
| 178 |
+
"attention_mask": attention_mask[0],
|
| 179 |
+
"position_ids": position_ids[0],
|
| 180 |
+
"image_bound": image_bound,
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
if batch_vision:
|
| 184 |
+
tgt_sizes = []
|
| 185 |
+
reshape_images = []
|
| 186 |
+
for image in images:
|
| 187 |
+
H, W = image.shape[1:]
|
| 188 |
+
reshape_image = reshape_by_patch(image, patch_size)
|
| 189 |
+
reshape_images.append(reshape_image)
|
| 190 |
+
tgt_sizes.append([H // patch_size, W // patch_size])
|
| 191 |
+
if tgt_sizes:
|
| 192 |
+
tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32)
|
| 193 |
+
|
| 194 |
+
input_dict["pixel_values"] = reshape_images
|
| 195 |
+
input_dict["tgt_sizes"] = tgt_sizes
|
| 196 |
+
|
| 197 |
+
else:
|
| 198 |
+
input_dict["pixel_values"] = images
|
| 199 |
+
input_dict["tgt_sizes"] = []
|
| 200 |
+
|
| 201 |
+
return input_dict
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def slice_image(image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):
|
| 205 |
+
original_size = image.size
|
| 206 |
+
original_width, original_height = original_size
|
| 207 |
+
log_ratio = math.log(original_width / original_height)
|
| 208 |
+
ratio = original_width * original_height / (scale_resolution * scale_resolution)
|
| 209 |
+
multiple = min(math.ceil(ratio), max_slice_nums)
|
| 210 |
+
|
| 211 |
+
source_image = None
|
| 212 |
+
best_grid = None
|
| 213 |
+
patches = []
|
| 214 |
+
|
| 215 |
+
if multiple <= 1 or never_split:
|
| 216 |
+
# dont need to slice, upsample
|
| 217 |
+
best_size = find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)
|
| 218 |
+
source_image = image.resize(best_size, Image.Resampling.BICUBIC)
|
| 219 |
+
else:
|
| 220 |
+
candidate_split_grids_nums = []
|
| 221 |
+
for i in [multiple - 1, multiple, multiple + 1]:
|
| 222 |
+
if i == 1 or i > max_slice_nums:
|
| 223 |
+
continue
|
| 224 |
+
candidate_split_grids_nums.append(i)
|
| 225 |
+
|
| 226 |
+
# source image, down-sampling and ensure divided by patch_size
|
| 227 |
+
best_resize = find_best_resize(original_size, scale_resolution, patch_size)
|
| 228 |
+
source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
|
| 229 |
+
candidate_grids = []
|
| 230 |
+
|
| 231 |
+
# find best grid
|
| 232 |
+
for split_grids_nums in candidate_split_grids_nums:
|
| 233 |
+
m = 1
|
| 234 |
+
while m <= split_grids_nums:
|
| 235 |
+
if split_grids_nums % m == 0:
|
| 236 |
+
candidate_grids.append([m, split_grids_nums // m])
|
| 237 |
+
m += 1
|
| 238 |
+
|
| 239 |
+
best_grid = [1, 1]
|
| 240 |
+
min_error = float("inf")
|
| 241 |
+
for grid in candidate_grids:
|
| 242 |
+
error = abs(log_ratio - math.log(grid[0] / grid[1]))
|
| 243 |
+
if error < min_error:
|
| 244 |
+
best_grid = grid
|
| 245 |
+
min_error = error
|
| 246 |
+
|
| 247 |
+
refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, allow_upscale=True)
|
| 248 |
+
|
| 249 |
+
refine_image = image.resize(refine_size, Image.Resampling.BICUBIC)
|
| 250 |
+
patches = split_to_patches(refine_image, best_grid)
|
| 251 |
+
|
| 252 |
+
return source_image, patches, best_grid
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def ensure_divide(length, patch_size):
|
| 256 |
+
return max(round(length / patch_size) * patch_size, patch_size)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False):
|
| 260 |
+
width, height = original_size
|
| 261 |
+
if (width * height > scale_resolution * scale_resolution) or allow_upscale:
|
| 262 |
+
r = width / height
|
| 263 |
+
height = int(scale_resolution / math.sqrt(r))
|
| 264 |
+
width = int(height * r)
|
| 265 |
+
best_width = ensure_divide(width, patch_size)
|
| 266 |
+
best_height = ensure_divide(height, patch_size)
|
| 267 |
+
return (best_width, best_height)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def get_refine_size(original_size, grid, scale_resolution, patch_size, allow_upscale=False):
|
| 271 |
+
width, height = original_size
|
| 272 |
+
grid_x, grid_y = grid
|
| 273 |
+
|
| 274 |
+
refine_width = ensure_divide(width, grid_x)
|
| 275 |
+
refine_height = ensure_divide(height, grid_y)
|
| 276 |
+
|
| 277 |
+
grid_width = refine_width / grid_x
|
| 278 |
+
grid_height = refine_height / grid_y
|
| 279 |
+
|
| 280 |
+
best_grid_size = find_best_resize(
|
| 281 |
+
(grid_width, grid_height),
|
| 282 |
+
scale_resolution,
|
| 283 |
+
patch_size,
|
| 284 |
+
allow_upscale=allow_upscale,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
|
| 288 |
+
|
| 289 |
+
return refine_size
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def split_to_patches(image, grid):
|
| 293 |
+
patches = []
|
| 294 |
+
width, height = image.size
|
| 295 |
+
grid_x = int(width / grid[0])
|
| 296 |
+
grid_y = int(height / grid[1])
|
| 297 |
+
|
| 298 |
+
for i in range(0, height, grid_y):
|
| 299 |
+
images = []
|
| 300 |
+
for j in range(0, width, grid_x):
|
| 301 |
+
box = (j, i, j + grid_x, i + grid_y)
|
| 302 |
+
patch = image.crop(box)
|
| 303 |
+
images.append(patch)
|
| 304 |
+
patches.append(images)
|
| 305 |
+
|
| 306 |
+
return patches
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):
|
| 310 |
+
if new_schema:
|
| 311 |
+
image_placeholder = tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end
|
| 312 |
+
else:
|
| 313 |
+
image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
|
| 314 |
+
|
| 315 |
+
cols = grid[0]
|
| 316 |
+
rows = grid[1]
|
| 317 |
+
slices = []
|
| 318 |
+
for i in range(rows):
|
| 319 |
+
lines = []
|
| 320 |
+
for j in range(cols):
|
| 321 |
+
lines.append(image_placeholder)
|
| 322 |
+
slices.append("".join(lines))
|
| 323 |
+
if new_schema:
|
| 324 |
+
slice_placeholder = "\n".join(slices)
|
| 325 |
+
else:
|
| 326 |
+
slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end
|
| 327 |
+
return slice_placeholder
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def reshape_by_patch(image_tensor, patch_size):
|
| 331 |
+
"""
|
| 332 |
+
:param image_tensor: shape [3, H, W]
|
| 333 |
+
:param patch_size:
|
| 334 |
+
:return: [3, patch_size, HW/patch_size]
|
| 335 |
+
"""
|
| 336 |
+
patches = torch.nn.functional.unfold(image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size))
|
| 337 |
+
|
| 338 |
+
patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
|
| 339 |
+
patches = patches.permute(0, 1, 3, 2).reshape(image_tensor.size(0), patch_size, -1)
|
| 340 |
+
return patches
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def init_minicpmo_config(processor, config):
|
| 344 |
+
"""Initialize MiniCPM-o specific configuration"""
|
| 345 |
+
minicpmo_config = {
|
| 346 |
+
"transform": build_transform(),
|
| 347 |
+
"patch_size": config.get("patch_size", 14),
|
| 348 |
+
"query_nums": config.get("query_nums", 64),
|
| 349 |
+
"slice_config": config.get(
|
| 350 |
+
"slice_config", {"max_slice_nums": 9, "patch_size": config.get("patch_size", 14), "scale_resolution": 448}
|
| 351 |
+
),
|
| 352 |
+
"llm_type": config.get("llm_type", "qwen"),
|
| 353 |
+
"batch_vision": config.get("batch_vision", True),
|
| 354 |
+
}
|
| 355 |
+
return minicpmo_config
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def process_minicpmo_data(
|
| 359 |
+
row_dict,
|
| 360 |
+
messages,
|
| 361 |
+
tokenizer,
|
| 362 |
+
minicpmo_config,
|
| 363 |
+
image_key,
|
| 364 |
+
max_prompt_length,
|
| 365 |
+
truncation,
|
| 366 |
+
apply_chat_template_kwargs,
|
| 367 |
+
logger,
|
| 368 |
+
):
|
| 369 |
+
"""Process data for MiniCPM-o model"""
|
| 370 |
+
if len(row_dict[image_key]) == 1:
|
| 371 |
+
multi_modal_data = {}
|
| 372 |
+
image = process_image(row_dict.pop(image_key)[0])
|
| 373 |
+
multi_modal_data["image"] = [image]
|
| 374 |
+
images_dict = {"<image>": image}
|
| 375 |
+
else:
|
| 376 |
+
raise NotImplementedError
|
| 377 |
+
|
| 378 |
+
model_inputs = preprocess(
|
| 379 |
+
images_dict,
|
| 380 |
+
messages,
|
| 381 |
+
tokenizer,
|
| 382 |
+
minicpmo_config["transform"],
|
| 383 |
+
query_nums=minicpmo_config["query_nums"],
|
| 384 |
+
slice_config=minicpmo_config["slice_config"],
|
| 385 |
+
llm_type=minicpmo_config["llm_type"],
|
| 386 |
+
patch_size=minicpmo_config["patch_size"],
|
| 387 |
+
batch_vision=minicpmo_config["batch_vision"],
|
| 388 |
+
max_length=max_prompt_length,
|
| 389 |
+
truncation=truncation,
|
| 390 |
+
apply_chat_template_kwargs=apply_chat_template_kwargs,
|
| 391 |
+
logger=logger,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
raw_prompt = tokenizer.apply_chat_template(
|
| 395 |
+
messages, add_generation_prompt=True, tokenize=False, **(apply_chat_template_kwargs or {})
|
| 396 |
+
)
|
| 397 |
+
raw_prompt = raw_prompt.replace("<image>", "(<image>./</image>)")
|
| 398 |
+
|
| 399 |
+
return model_inputs, multi_modal_data, raw_prompt
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class RLHFDataset(Dataset):
|
| 403 |
+
"""
|
| 404 |
+
Load and preprocess RLHF data from Parquet files.
|
| 405 |
+
|
| 406 |
+
- Caches files locally.
|
| 407 |
+
- Reads into a HuggingFace Dataset and tokenizes prompts.
|
| 408 |
+
- Optionally handles images/videos via a ProcessorMixin.
|
| 409 |
+
- Filters prompts over a max length.
|
| 410 |
+
- Supports resuming from checkpoints.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
data_files (str or list): Path(s) to Parquet file(s).
|
| 414 |
+
tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs.
|
| 415 |
+
config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc.
|
| 416 |
+
processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos.
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
def __init__(
|
| 420 |
+
self,
|
| 421 |
+
data_files: str | list[str],
|
| 422 |
+
tokenizer: PreTrainedTokenizer,
|
| 423 |
+
config: DictConfig,
|
| 424 |
+
processor: Optional[ProcessorMixin] = None,
|
| 425 |
+
):
|
| 426 |
+
if not isinstance(data_files, list | ListConfig):
|
| 427 |
+
data_files = [data_files]
|
| 428 |
+
|
| 429 |
+
self.data_files = copy.deepcopy(data_files)
|
| 430 |
+
self.original_data_files = copy.deepcopy(data_files) # use for resume
|
| 431 |
+
self.tokenizer = tokenizer
|
| 432 |
+
self.processor = processor
|
| 433 |
+
self.config = config
|
| 434 |
+
|
| 435 |
+
self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf"))
|
| 436 |
+
self.prompt_key = config.get("prompt_key", "prompt")
|
| 437 |
+
self.image_key = config.get("image_key", "images")
|
| 438 |
+
self.video_key = config.get("video_key", "videos")
|
| 439 |
+
self.max_prompt_length = config.get("max_prompt_length", 1024)
|
| 440 |
+
self.return_raw_chat = config.get("return_raw_chat", False)
|
| 441 |
+
self.return_full_prompt = config.get("return_full_prompt", False)
|
| 442 |
+
self.truncation = config.get("truncation", "error")
|
| 443 |
+
self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)
|
| 444 |
+
self.apply_chat_template_kwargs = config.get("apply_chat_template_kwargs", {})
|
| 445 |
+
|
| 446 |
+
self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4))
|
| 447 |
+
self.num_workers = min(self.num_workers, os.cpu_count())
|
| 448 |
+
self.use_shm = config.get("use_shm", False)
|
| 449 |
+
self.chat_template_func = config.get("chat_template_func", None)
|
| 450 |
+
self.need_tools_kwargs = config.get("need_tools_kwargs", False)
|
| 451 |
+
self.filter_prompts = config.get("filter_prompts", True)
|
| 452 |
+
self.serialize_dataset = False
|
| 453 |
+
self.minicpmo_config = init_minicpmo_config(self.processor, config)
|
| 454 |
+
self._download()
|
| 455 |
+
self._read_files_and_tokenize()
|
| 456 |
+
|
| 457 |
+
def _download(self, use_origin_parquet=False):
|
| 458 |
+
from verl.utils.fs import copy_to_local
|
| 459 |
+
|
| 460 |
+
data_files = self.data_files if not use_origin_parquet else self.original_data_files
|
| 461 |
+
for i, parquet_file in enumerate(data_files):
|
| 462 |
+
self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm)
|
| 463 |
+
|
| 464 |
+
def _read_files_and_tokenize(self):
|
| 465 |
+
dataframes = []
|
| 466 |
+
for parquet_file in self.data_files:
|
| 467 |
+
# read parquet files and cache
|
| 468 |
+
dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
|
| 469 |
+
dataframes.append(dataframe)
|
| 470 |
+
self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)
|
| 471 |
+
|
| 472 |
+
print(f"dataset len: {len(self.dataframe)}")
|
| 473 |
+
|
| 474 |
+
def resume_dataset_state(self):
|
| 475 |
+
self.serialize_dataset = not hasattr(self, "original_data_files")
|
| 476 |
+
# resume dataframe if not it's serialized in data.pt
|
| 477 |
+
if not self.serialize_dataset:
|
| 478 |
+
self._download(use_origin_parquet=True) # download and resume from original parquet files
|
| 479 |
+
self._read_files_and_tokenize()
|
| 480 |
+
else:
|
| 481 |
+
print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance")
|
| 482 |
+
|
| 483 |
+
def __len__(self):
|
| 484 |
+
return len(self.dataframe)
|
| 485 |
+
|
| 486 |
+
def _build_messages(self, example: dict):
|
| 487 |
+
return example.pop(self.prompt_key)
|
| 488 |
+
|
| 489 |
+
def __getitem__(self, item):
|
| 490 |
+
"""
|
| 491 |
+
Note that we also return the raw_input_ids so that it can be combined with other chat template
|
| 492 |
+
"""
|
| 493 |
+
row_dict: dict = self.dataframe[item]
|
| 494 |
+
messages = self._build_messages(row_dict)
|
| 495 |
+
model_inputs = {}
|
| 496 |
+
|
| 497 |
+
if self.processor is not None:
|
| 498 |
+
model_inputs, multi_modal_data, raw_prompt = process_minicpmo_data(
|
| 499 |
+
row_dict,
|
| 500 |
+
messages,
|
| 501 |
+
self.tokenizer,
|
| 502 |
+
self.minicpmo_config,
|
| 503 |
+
self.image_key,
|
| 504 |
+
self.max_prompt_length,
|
| 505 |
+
self.truncation,
|
| 506 |
+
self.apply_chat_template_kwargs,
|
| 507 |
+
logger,
|
| 508 |
+
)
|
| 509 |
+
input_ids = model_inputs.pop("input_ids")
|
| 510 |
+
attention_mask = model_inputs.pop("attention_mask")
|
| 511 |
+
position_ids = model_inputs.pop("position_ids")
|
| 512 |
+
|
| 513 |
+
# There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
|
| 514 |
+
row_dict["multi_modal_data"] = multi_modal_data
|
| 515 |
+
row_dict["multi_modal_inputs"] = dict(model_inputs)
|
| 516 |
+
else:
|
| 517 |
+
raw_prompt = self.tokenizer.apply_chat_template(
|
| 518 |
+
messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
|
| 519 |
+
)
|
| 520 |
+
model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
|
| 521 |
+
input_ids = model_inputs.pop("input_ids")
|
| 522 |
+
attention_mask = model_inputs.pop("attention_mask")
|
| 523 |
+
position_ids = compute_position_id_with_mask(attention_mask)
|
| 524 |
+
|
| 525 |
+
row_dict["input_ids"] = input_ids
|
| 526 |
+
row_dict["attention_mask"] = attention_mask
|
| 527 |
+
row_dict["position_ids"] = position_ids
|
| 528 |
+
|
| 529 |
+
raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
|
| 530 |
+
if len(raw_prompt_ids) > self.max_prompt_length:
|
| 531 |
+
if self.truncation == "left":
|
| 532 |
+
raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
|
| 533 |
+
elif self.truncation == "right":
|
| 534 |
+
raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
|
| 535 |
+
elif self.truncation == "middle":
|
| 536 |
+
left_half = self.max_prompt_length // 2
|
| 537 |
+
right_half = self.max_prompt_length - left_half
|
| 538 |
+
raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]
|
| 539 |
+
elif self.truncation == "error":
|
| 540 |
+
raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")
|
| 541 |
+
|
| 542 |
+
row_dict["raw_prompt_ids"] = raw_prompt_ids
|
| 543 |
+
# encode prompts without chat template
|
| 544 |
+
if self.return_raw_chat:
|
| 545 |
+
row_dict["raw_prompt"] = messages
|
| 546 |
+
|
| 547 |
+
# get prompts with chat template
|
| 548 |
+
if self.return_full_prompt:
|
| 549 |
+
row_dict["full_prompts"] = raw_prompt # array of strings
|
| 550 |
+
|
| 551 |
+
# add index for each prompt
|
| 552 |
+
index = row_dict.get("extra_info", {}).get("index", 0)
|
| 553 |
+
tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {})
|
| 554 |
+
interaction_kwargs = row_dict.get("extra_info", {}).get("interaction_kwargs", {})
|
| 555 |
+
need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs)
|
| 556 |
+
if need_tools_kwargs and not tools_kwargs:
|
| 557 |
+
logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"])
|
| 558 |
+
row_dict["index"] = index
|
| 559 |
+
row_dict["tools_kwargs"] = tools_kwargs
|
| 560 |
+
row_dict["interaction_kwargs"] = interaction_kwargs
|
| 561 |
+
return row_dict
|
| 562 |
+
|
| 563 |
+
def __getstate__(self):
|
| 564 |
+
if not self.serialize_dataset:
|
| 565 |
+
state = self.__dict__.copy()
|
| 566 |
+
|
| 567 |
+
if "dataframe" in state:
|
| 568 |
+
del state["dataframe"]
|
| 569 |
+
return state
|
| 570 |
+
|
| 571 |
+
return self.__dict__.copy()
|
ICL/DAPO/verl-recipe/prime/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 PRIME team and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
ICL/DAPO/verl-recipe/prime/prime_core_algos.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 PRIME team and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
import verl
|
| 18 |
+
import verl.utils.torch_functional as verl_F
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Tensor, n_samples, config):
|
| 22 |
+
# calculate rloo reward on different reward sources, and sum again
|
| 23 |
+
def masked_rloo(reward_tensor_original, mask_tensor):
|
| 24 |
+
reward_tensor = reward_tensor_original.clone()
|
| 25 |
+
reward_tensor[~mask_tensor] = 0
|
| 26 |
+
for start_pos in range(0, reward_tensor.shape[0], n_samples):
|
| 27 |
+
cur_rewards_mean = torch.cat(
|
| 28 |
+
[
|
| 29 |
+
reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean(dim=0, keepdim=True)
|
| 30 |
+
for pos in range(start_pos, start_pos + n_samples)
|
| 31 |
+
],
|
| 32 |
+
dim=0,
|
| 33 |
+
)
|
| 34 |
+
cur_rewards_sum = cur_rewards_mean.sum()
|
| 35 |
+
cur_reward_baseline = cur_rewards_sum / (n_samples - 1)
|
| 36 |
+
reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] = (
|
| 37 |
+
reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]]
|
| 38 |
+
* (n_samples / (n_samples - 1))
|
| 39 |
+
- cur_reward_baseline
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
return reward_tensor
|
| 43 |
+
|
| 44 |
+
reward_tensors = []
|
| 45 |
+
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
if "rm_scores" in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.0:
|
| 48 |
+
reward_tensor = data.batch["rm_scores"]
|
| 49 |
+
reward_mask = response_mask.bool()
|
| 50 |
+
|
| 51 |
+
reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef)
|
| 52 |
+
|
| 53 |
+
if "acc" in data.batch.keys() and config.algorithm.reward_gt_coef != 0.0:
|
| 54 |
+
reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32)
|
| 55 |
+
reward_mask = torch.zeros_like(response_mask, dtype=torch.bool)
|
| 56 |
+
|
| 57 |
+
prompt_ids = data.batch["prompts"]
|
| 58 |
+
prompt_length = prompt_ids.shape[-1]
|
| 59 |
+
valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(-1)
|
| 60 |
+
|
| 61 |
+
reward_mask[
|
| 62 |
+
torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
|
| 63 |
+
valid_response_length - 1,
|
| 64 |
+
] = True
|
| 65 |
+
reward_tensor[
|
| 66 |
+
torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
|
| 67 |
+
valid_response_length - 1,
|
| 68 |
+
] = data.batch["acc"]
|
| 69 |
+
|
| 70 |
+
reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef)
|
| 71 |
+
|
| 72 |
+
final_reward_tensor = sum(reward_tensors)
|
| 73 |
+
|
| 74 |
+
returns = (final_reward_tensor * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
|
| 75 |
+
|
| 76 |
+
advantages = returns.clone()
|
| 77 |
+
advantages = verl_F.masked_whiten(advantages, response_mask)
|
| 78 |
+
|
| 79 |
+
return advantages, returns
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta):
|
| 83 |
+
cur_scores = ((token_level_scores * response_mask).sum(dim=1) * beta).sigmoid()
|
| 84 |
+
cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc)
|
| 85 |
+
return cur_dpo_loss
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode="none"):
|
| 89 |
+
# we always assume that the BoN size equals n_samples
|
| 90 |
+
# mode1: use acc as rm
|
| 91 |
+
# mode2: use Q as rm
|
| 92 |
+
cur_Q = (token_level_scores * response_mask).sum(dim=1) * beta
|
| 93 |
+
other_Q = torch.zeros_like(cur_Q)
|
| 94 |
+
for i in range(token_level_scores.shape[0]):
|
| 95 |
+
Q_chosen = Q_bc[i][acc_bc[i] < acc[i]] if acc[i] > 0 else Q_bc[i][acc_bc[i] > acc[i]]
|
| 96 |
+
if len(Q_chosen) > 0:
|
| 97 |
+
other_Q[i] = Q_chosen.mean() * beta
|
| 98 |
+
else:
|
| 99 |
+
other_Q[i] = 0
|
| 100 |
+
dpo_loss = -torch.log(torch.sigmoid((cur_Q - other_Q) * ((acc > 0).float() * 2 - 1)))
|
| 101 |
+
if bon_mode == "none":
|
| 102 |
+
dpo_loss = dpo_loss.mean()
|
| 103 |
+
else:
|
| 104 |
+
weight = torch.zeros_like(dpo_loss)
|
| 105 |
+
n_samples = acc_bc.shape[1]
|
| 106 |
+
if bon_mode == "bon_rm":
|
| 107 |
+
for i in range(token_level_scores.shape[0]):
|
| 108 |
+
weight[i] = n_samples * torch.pow((Q_bc[i] * beta <= cur_Q[i]).float().mean(), n_samples - 1)
|
| 109 |
+
elif bon_mode == "bon_acc":
|
| 110 |
+
for i in range(token_level_scores.shape[0]):
|
| 111 |
+
weight[i] = n_samples * torch.pow((acc_bc[i] <= acc[i]).float().mean(), n_samples - 1)
|
| 112 |
+
else:
|
| 113 |
+
raise NotImplementedError
|
| 114 |
+
dpo_loss = (dpo_loss * weight).sum()
|
| 115 |
+
|
| 116 |
+
return dpo_loss
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples):
|
| 120 |
+
dpo_acc = []
|
| 121 |
+
for start_id in range(0, token_level_scores.shape[0], n_samples):
|
| 122 |
+
cur_scores = (
|
| 123 |
+
token_level_scores[start_id : start_id + n_samples] * response_mask[start_id : start_id + n_samples]
|
| 124 |
+
).sum(dim=1)
|
| 125 |
+
|
| 126 |
+
def get_upper_triangle(tensor_x):
|
| 127 |
+
diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0)
|
| 128 |
+
upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1)
|
| 129 |
+
return diff_matrix[upper_tri_indices]
|
| 130 |
+
|
| 131 |
+
cur_acc_diff = get_upper_triangle(acc[start_id : start_id + n_samples]) # in range [-1,1]
|
| 132 |
+
cur_score_diff = get_upper_triangle(cur_scores) # in R
|
| 133 |
+
cur_score_prediction = (cur_score_diff > 0).float() # in [0,1]
|
| 134 |
+
if cur_acc_diff.abs().sum() == 0:
|
| 135 |
+
cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5
|
| 136 |
+
else:
|
| 137 |
+
cur_acc = (
|
| 138 |
+
((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * cur_acc_diff.abs()
|
| 139 |
+
).sum() / cur_acc_diff.abs().sum()
|
| 140 |
+
|
| 141 |
+
dpo_acc.append(cur_acc.unsqueeze(0))
|
| 142 |
+
|
| 143 |
+
return torch.cat(dpo_acc, dim=0).mean()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def compute_dpo_abs_accuracy(token_level_scores, acc, response_mask, n_samples):
|
| 147 |
+
return (torch.sign((token_level_scores * response_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean()
|
ICL/DAPO/verl-recipe/prime/run_prime_qwen_code.sh
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
set -x
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# download from https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data
|
| 5 |
+
code_train_path=$HOME/data/code/train.parquet
|
| 6 |
+
code_test_path=$HOME/data/code/test.parquet
|
| 7 |
+
|
| 8 |
+
train_files="['$code_train_path']"
|
| 9 |
+
test_files="['$code_test_path']"
|
| 10 |
+
|
| 11 |
+
model_path=PRIME-RL/Eurus-2-7B-SFT
|
| 12 |
+
# model_path=Qwen/Qwen2.5-0.5B-Instruct
|
| 13 |
+
|
| 14 |
+
python3 -m recipe.prime.main_prime \
|
| 15 |
+
data.train_files="$train_files" \
|
| 16 |
+
data.val_files="$test_files" \
|
| 17 |
+
data.train_batch_size=64 \
|
| 18 |
+
data.val_batch_size=6312 \
|
| 19 |
+
data.max_prompt_length=1024 \
|
| 20 |
+
data.max_response_length=3072 \
|
| 21 |
+
data.filter_overlong_prompts=True \
|
| 22 |
+
data.filter_accuracy=True \
|
| 23 |
+
data.accuracy_lower_bound=0.2 \
|
| 24 |
+
data.accuracy_upper_bound=0.8 \
|
| 25 |
+
data.oversample_factor=4 \
|
| 26 |
+
actor_rollout_ref.model.path=$model_path \
|
| 27 |
+
actor_rollout_ref.actor.optim.lr=5e-7 \
|
| 28 |
+
actor_rollout_ref.model.use_remove_padding=True \
|
| 29 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
|
| 30 |
+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
|
| 31 |
+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
| 32 |
+
actor_rollout_ref.actor.fsdp_config.param_offload=True \
|
| 33 |
+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
|
| 34 |
+
actor_rollout_ref.actor.use_kl_loss=False \
|
| 35 |
+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
|
| 36 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
| 37 |
+
actor_rollout_ref.rollout.name=vllm \
|
| 38 |
+
actor_rollout_ref.rollout.n=4 \
|
| 39 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
| 40 |
+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
|
| 41 |
+
algorithm.adv_estimator=rloo \
|
| 42 |
+
algorithm.use_kl_in_reward=True \
|
| 43 |
+
algorithm.kl_penalty=kl \
|
| 44 |
+
algorithm.kl_ctrl.kl_coef=0.001 \
|
| 45 |
+
reward_model.model.path=$model_path \
|
| 46 |
+
reward_model.micro_batch_size_per_gpu=1 \
|
| 47 |
+
reward_model.model.update=before \
|
| 48 |
+
reward_model.model.beta_train=0.05 \
|
| 49 |
+
reward_model.model.optim.lr=1e-6 \
|
| 50 |
+
reward_model.model.optim.grad_clip=10.0 \
|
| 51 |
+
reward_model.model.input_tokenizer=null \
|
| 52 |
+
reward_model.mini_batch_size=64 \
|
| 53 |
+
trainer.val_before_train=False \
|
| 54 |
+
trainer.logger='["console","wandb"]' \
|
| 55 |
+
trainer.project_name='prime_example' \
|
| 56 |
+
trainer.experiment_name='Eurus-2-7B-SFT-code' \
|
| 57 |
+
trainer.n_gpus_per_node=8 \
|
| 58 |
+
trainer.nnodes=1 \
|
| 59 |
+
trainer.save_freq=64 \
|
| 60 |
+
trainer.test_freq=64 \
|
| 61 |
+
trainer.total_epochs=15 $@
|
ICL/DAPO/verl-recipe/r1/run_r1_distill_qwen.sh
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_PATH=Qwen/DeepSeek-R1-Distill-Qwen-1.5B
|
| 2 |
+
DATA_PATH=/workspace/datasets/r1_bench
|
| 3 |
+
|
| 4 |
+
# Eval Data Process
|
| 5 |
+
python3 -m recipe.r1.data_process \
|
| 6 |
+
--local_dir $DATA_PATH \
|
| 7 |
+
--tasks all
|
| 8 |
+
|
| 9 |
+
# Generation
|
| 10 |
+
python3 -m verl.trainer.main_generation \
|
| 11 |
+
trainer.nnodes=1 \
|
| 12 |
+
trainer.n_gpus_per_node=8 \
|
| 13 |
+
data.path=$DATA_PATH/test.parquet \
|
| 14 |
+
data.prompt_key=prompt \
|
| 15 |
+
data.batch_size=1024 \
|
| 16 |
+
data.n_samples=8 \
|
| 17 |
+
data.output_path=$DATA_PATH/test-output-8.parquet \
|
| 18 |
+
model.path=$MODEL_PATH \
|
| 19 |
+
rollout.temperature=0.6 \
|
| 20 |
+
rollout.top_p=0.95 \
|
| 21 |
+
rollout.prompt_length=1024 \
|
| 22 |
+
rollout.response_length=32768 \
|
| 23 |
+
rollout.tensor_model_parallel_size=1 \
|
| 24 |
+
rollout.gpu_memory_utilization=0.9 \
|
| 25 |
+
rollout.max_num_batched_tokens=65536
|
| 26 |
+
|
| 27 |
+
# Evaluation
|
| 28 |
+
python3 -m recipe.r1.main_eval \
|
| 29 |
+
data.path=$DATA_PATH/test-output-8.parquet \
|
| 30 |
+
data.prompt_key=prompt \
|
| 31 |
+
data.response_key=responses \
|
| 32 |
+
custom_reward_function.path=recipe/r1/reward_score.py \
|
| 33 |
+
custom_reward_function.name=reward_func
|
ICL/DAPO/verl-recipe/r1_ascend/Dockerfile.vllm_ascend.mindspeed.deepseekV3
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
FROM quay.io/ascend/cann:8.2.rc1-a3-openeuler22.03-py3.11
|
| 16 |
+
|
| 17 |
+
ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"
|
| 18 |
+
ARG COMPILE_CUSTOM_KERNELS=1
|
| 19 |
+
|
| 20 |
+
# Define environments
|
| 21 |
+
ENV DEBIAN_FRONTED=noninteractive
|
| 22 |
+
ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS}
|
| 23 |
+
|
| 24 |
+
RUN yum install -y patch
|
| 25 |
+
|
| 26 |
+
WORKDIR /workspace
|
| 27 |
+
|
| 28 |
+
RUN pip config set global.index-url ${PIP_INDEX_URL}
|
| 29 |
+
|
| 30 |
+
# Install torch and torch-npu
|
| 31 |
+
RUN python3 -m pip install torch==2.5.1 torch-npu==2.5.1.post1
|
| 32 |
+
|
| 33 |
+
# Compile/Install apex
|
| 34 |
+
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh && \
|
| 35 |
+
source /usr/local/Ascend/nnal/atb/set_env.sh && \
|
| 36 |
+
source /usr/local/Ascend/nnal/asdsip/set_env.sh && \
|
| 37 |
+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \
|
| 38 |
+
git clone -b master https://gitcode.com/ascend/apex.git && \
|
| 39 |
+
cd apex/ && bash scripts/build.sh --python=3.11 && \
|
| 40 |
+
cd apex/dist/ && \
|
| 41 |
+
python3 -m pip install --upgrade apex-0.1+ascend-*.whl
|
| 42 |
+
|
| 43 |
+
# verl
|
| 44 |
+
RUN git clone https://github.com/volcengine/verl.git
|
| 45 |
+
|
| 46 |
+
# MindSpeed
|
| 47 |
+
RUN git clone https://gitcode.com/Ascend/MindSpeed.git && \
|
| 48 |
+
cd MindSpeed && \
|
| 49 |
+
git checkout f6688 && \
|
| 50 |
+
pip install -r requirements.txt && \
|
| 51 |
+
cp -r mindspeed ../verl
|
| 52 |
+
|
| 53 |
+
# Install vLLM
|
| 54 |
+
RUN git clone https://github.com/vllm-project/vllm.git && \
|
| 55 |
+
cd vllm && \
|
| 56 |
+
git checkout v0.9.1 && \
|
| 57 |
+
cp -r vllm ../verl
|
| 58 |
+
# In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it.
|
| 59 |
+
RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -e /workspace/vllm/ --extra-index https://download.pytorch.org/whl/cpu/ && \
|
| 60 |
+
python3 -m pip uninstall -y triton && \
|
| 61 |
+
python3 -m pip cache purge
|
| 62 |
+
|
| 63 |
+
# Install vllm-ascend
|
| 64 |
+
RUN git clone https://github.com/vllm-project/vllm-ascend.git && \
|
| 65 |
+
cd vllm-ascend && \
|
| 66 |
+
git checkout 8c7bc45 && \
|
| 67 |
+
cp -r vllm_ascend ../verl
|
| 68 |
+
|
| 69 |
+
# Append `libascebd_hal.so` path (devlib) to LD_LIBRARY_PATH
|
| 70 |
+
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh && \
|
| 71 |
+
source /usr/local/Ascend/nnal/atb/set_env.sh && \
|
| 72 |
+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \
|
| 73 |
+
export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/usr/include/c++/12:/usr/include/c++/12/`uname -i`-openEuler-linux && \
|
| 74 |
+
python3 -m pip install -v -e /workspace/vllm-ascend/ --exists-action=i --extra-index https://download.pytorch.org/whl/cpu/ && \
|
| 75 |
+
python3 -m pip cache purge
|
| 76 |
+
|
| 77 |
+
# Install modelscope (for fast download) and ray (for multinode) and Megatron-LM and others
|
| 78 |
+
RUN python3 -m pip install modelscope ray cache purge "transformers<4.54.0" mathruler cbor2 && \
|
| 79 |
+
pip install pybase64 fastapi zmq uvicorn openai msgspec blake3 py-cpuinfo gguf openai-harmony && \
|
| 80 |
+
pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1
|
| 81 |
+
|
| 82 |
+
CMD ["/bin/bash"]
|
ICL/DAPO/verl-recipe/r1_ascend/README.md
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeepSeek-R1-Zero on Ascend NPU
|
| 2 |
+
This recipe provides a sample for fine-tuning the Deepseek-V3-Base model using Reinforcement Learning from Human Feedback (RLHF) on Ascend NPUs, specifically utilizing the GRPO algorithm with rule-based rewards on the deepscaler dataset.
|
| 3 |
+
|
| 4 |
+
## Implementation Details
|
| 5 |
+
To implement RL training for the DeepSeek model on Ascend NPUs, this example includes the following key code additions and modifications:
|
| 6 |
+
- We implemented a simple rule-based reward function in `deepscaler.py`, referencing `verl/utils/reward_score/gsm8k.py`.
|
| 7 |
+
- We provided a dataset file conversion script `json_to_parquet.py`, which adds a template to the prompts to stimulate model thinking during the data file format conversion.
|
| 8 |
+
- Due to potential incomplete memory offloading during sleep operations for vLLM on NPUs, we added patches to manually handle the offloading and onloading of the rollout model and KVcache on NPUs. The related code is in `vllm_rollout_spmd.py` and `megatron_workers.py`.
|
| 9 |
+
- To enable vLLM to utilize all ranks for expert parallelism, support for vLLM's data parallelism was necessary. For this purpose, we added patches to construct the correct data parallel communication group. The related code is in `vllm_parallel_state.py` and `vllm_rollout_spmd.py`. Additionally, the `VLLM_DP_SIZE` environment variable must be correctly set to `world_size / vllm_tp_size`.
|
| 10 |
+
- The MindSpeed training framework for NPUs invalidates torch.compile to avoid compilation failures during training, but this prevents its use for accelerating inference. To resolve this, we added patches that allow compilation during inference but not during training. The related code is in `megatron_workers.py`.
|
| 11 |
+
- During RL training, multiple KV cache scheduling operations in vLLM on NPUs could lead to inconsistent memory allocation causing memory trampling. The fix for this issue is patched in `engine_core.py`.
|
| 12 |
+
|
| 13 |
+
By searching globally for `# NPU-ADAPTATION`, you can see the actual changes made by the patch code.
|
| 14 |
+
|
| 15 |
+
For more technical details, please refer to [the Technical Report (in Chinese)](https://gitcode.com/cann/cann-recipes-train/blob/master/docs/deepseek/deepseek_rl_train_optimization.md).
|
| 16 |
+
|
| 17 |
+
## Training Details
|
| 18 |
+
### Hyperparameters
|
| 19 |
+
This example fine-tunes the DeepSeek-671B Base model on the deepscaler dataset using a combination of simple format rewards and answer accuracy rewards. The key hyperparameters are as follows:
|
| 20 |
+
|
| 21 |
+
| iteration | learning rate | global batchsize | n_samples | temperature | kl-coef | prompt_max_len | response_max_len | rule reward | reward model |
|
| 22 |
+
|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
|
| 23 |
+
| 70 | 1e-6 (constant) | 512 | 16 | 1.0 | 0.001 | 1024 | 2048 | format + acc | - |
|
| 24 |
+
|
| 25 |
+
### Resource Allocation and Performance
|
| 26 |
+
This recipe was trained on an Ascend Atlas 800T A3 hyper-node server, utilizing 128 A3 NPUs, which is equivalent to 256 accelerator ranks. The specific deployment strategy is as follows:
|
| 27 |
+
|
| 28 |
+
| Rollout Deployment | Actor Deployment | Reference Deployment | Offload Strategy |
|
| 29 |
+
|:----:|:----:|:----:|:----:|
|
| 30 |
+
| TP2 EP256 | EP32 PP8 | Same as Actor | Full offload, optimizer utilizes the [Mindspeed Swap Optimizer feature](https://gitee.com/ascend/MindSpeed/blob/master/docs/features/swap-optimizer.md) |
|
| 31 |
+
|
| 32 |
+
The performance metrics for one training step are shown below (throughput varies with the model's response length during training):
|
| 33 |
+
|
| 34 |
+
| step | prompt_len_mean | response_len_mean | timing_step (s) | throughput (tps/A3) | timing_gen (s) | timing_reward (s) | timing_old_prob (s) | timing_ref_prob (s) | timing_update (s) |
|
| 35 |
+
|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
|
| 36 |
+
| 2 | 175.1 | 1385.0 | 1044.8 | 95.5 | 482.2 | 20.4 | 105.5 | 92.7 | 342.9 |
|
| 37 |
+
|
| 38 |
+
### Training Metrics
|
| 39 |
+
<div align="center">
|
| 40 |
+
<img src="./figures/rewards.png" width="33%" />
|
| 41 |
+
<img src="./figures/response_len.png" width="33%" />
|
| 42 |
+
<img src="./figures/val_score.png" width="33%" />
|
| 43 |
+
</div>
|
| 44 |
+
|
| 45 |
+
## Quick Start
|
| 46 |
+
|
| 47 |
+
### Environment Setup
|
| 48 |
+
For setting up the Ascend NPU environment for verl, please refer to [ascend_quick_start.rst (in Chinese)](../../docs/ascend_tutorial/ascend_quick_start.rst).
|
| 49 |
+
|
| 50 |
+
Alternatively, you can use the provided Dockerfile to build the project's runtime environment locally: `docker build -f Dockerfile.vllm_ascend.mindspeed.deepseekV3 -t REPOSITORY:TAG ./`
|
| 51 |
+
|
| 52 |
+
Prepare the source code with the following steps:
|
| 53 |
+
```bash
|
| 54 |
+
# Clone verl
|
| 55 |
+
git clone https://github.com/volcengine/verl.git
|
| 56 |
+
|
| 57 |
+
# Clone and setup vLLM (v0.9.1)
|
| 58 |
+
git clone https://github.com/vllm-project/vllm.git
|
| 59 |
+
cd vllm
|
| 60 |
+
git checkout v0.9.1
|
| 61 |
+
cp -r vllm ../verl
|
| 62 |
+
cd ..
|
| 63 |
+
|
| 64 |
+
# Clone and setup vLLM-Ascend (commit 8c7bc45)
|
| 65 |
+
git clone https://github.com/vllm-project/vllm-ascend.git
|
| 66 |
+
cd vllm-ascend
|
| 67 |
+
git checkout 8c7bc45
|
| 68 |
+
cp -r vllm_ascend ../verl
|
| 69 |
+
cd ..
|
| 70 |
+
|
| 71 |
+
# Clone and setup MindSpeed (commit f6688)
|
| 72 |
+
git clone https://gitcode.com/Ascend/MindSpeed.git
|
| 73 |
+
cd MindSpeed
|
| 74 |
+
git checkout f6688
|
| 75 |
+
cp -r mindspeed ../verl
|
| 76 |
+
cd ..
|
| 77 |
+
|
| 78 |
+
# Install Megatron-LM.core and other dependencies
|
| 79 |
+
pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1
|
| 80 |
+
pip install mathruler
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### Prepare the Training Dataset
|
| 84 |
+
This example uses the deepscaler dataset. Prepare it as follows:
|
| 85 |
+
- Download the dataset [JSON file](https://huggingface.co/datasets/agentica-org/DeepScaleR-Preview-Dataset/blob/main/deepscaler.json).
|
| 86 |
+
- Generate the `train.parquet` and `test.parquet` files and place them in the `./data/deepscaler` directory:
|
| 87 |
+
```bash
|
| 88 |
+
# Execute from the verl project directory
|
| 89 |
+
python recipe/r1_ascend/json_to_parquet.py --output_dir ./data/deepscaler --json_path path/to/deepscaler.json --train_data_ratio 0.9
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
The processed prompts used during training will include a specific template, for example: `A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Put your final answer within \boxed{}. <|User|>{problem}<|Assistant|>`
|
| 93 |
+
|
| 94 |
+
### Prepare Model Weights
|
| 95 |
+
Prepare the DeepSeek-V3-Base model weights as follows:
|
| 96 |
+
- Place the model configuration files (excluding the weights) into the `./DeepSeek-V3-hf` directory. The `config.json` file needs to be replaced to remove quantization and MTP configurations. Refer to [this link (in Chinese)](https://gitcode.com/cann/cann-recipes-train/blob/master/rl_train/deepseek/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87) for details.
|
| 97 |
+
- Download the FP8 model weights from [HuggingFace](https://huggingface.co/deepseek-ai/DeepSeek-V3-Base) or [ModelScope](https://www.modelscope.cn/models/deepseek-ai/DeepSeek-V3-Base). Ensure the target disk has over 650GB of free space.
|
| 98 |
+
- Convert the FP8 weights to BF16 weights. Refer to [this link (in Chinese)](https://gitcode.com/cann/cann-recipes-train/blob/master/rl_train/deepseek/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87) for instructions. This step requires over 1300GB of free space on the target disk.
|
| 99 |
+
|
| 100 |
+
This example uses pre-sharded distributed weights. Therefore, the following weight sharding step is also required:
|
| 101 |
+
- The distributed weights will be stored in `ckpts/DeepseekV3-dist-ckpts`.
|
| 102 |
+
- Use the script `verl/scripts/converter_hf_to_mcore.py` to shard the original BF16 weights into distributed weights. In practice, we found that 2TB of CPU RAM was insufficient for sharding the 671B model. Therefore, we adapted this script for expert parallelism and performed the weight sharding using a distributed strategy of EP8 PP8 across 64 NPUs.
|
| 103 |
+
|
| 104 |
+
### Other Code Modifications
|
| 105 |
+
In practice, to achieve the above results for on-policy RL training, we need to replace the `old_log_prob = data["old_log_probs"]` code in `verl/workers/actor/megatron_actor.py` with:
|
| 106 |
+
|
| 107 |
+
```python
|
| 108 |
+
on_policy = self.config.ppo_epochs == 1
|
| 109 |
+
if on_policy:
|
| 110 |
+
old_log_prob = log_prob.detach() # guarantee exact numerical equality
|
| 111 |
+
else:
|
| 112 |
+
old_log_prob = data["old_log_probs"]
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
### Execute RL Fine-tuning
|
| 116 |
+
```bash
|
| 117 |
+
# Start the RL fine-tuning for DeepSeekV3 from the verl directory
|
| 118 |
+
bash ./recipe/r1_ascend/ray_start_grpo_npu.sh
|
| 119 |
+
```
|
ICL/DAPO/verl-recipe/r1_ascend/README_zh.md
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeepSeek-R1-Zero on Ascend NPU
|
| 2 |
+
本recipe是基于Deepseek-V3-Base模型在NPU上进行RLHF后训练的样例,基于GRPO与规则奖励,使用deepscaler数据集。
|
| 3 |
+
|
| 4 |
+
## 实现细节
|
| 5 |
+
为了在Ascend NPU上实现DeepSeek模型的RL训练,本样例中补充了一些代码,如下所示:
|
| 6 |
+
- 我们参考`verl/utils/reward_score/gsm8k.py`,在`deepscaler.py`中实现了一个简单的规则奖励函数。
|
| 7 |
+
- 我们提供了数据集文件转换脚本`json_to_parquet.py`,在数据文件格式转换的同时给prompt增加了激发模型思考的模板。
|
| 8 |
+
- NPU上vLLM的sleep可能存在内存卸载不干净的问题,因此添加了一些patch,手动实现NPU上Rollout模型与KVcache的卸载与加载。相关代码在`vllm_rollout_spmd.py`以及 `megatron_workers.py`中。
|
| 9 |
+
- 为了实现vLLM利用所有卡进行专家并行,需要支持vLLM的数据并行。为此添加了一些patch构建正确的DP通信域。相关代码在`vllm_parallel_state.py`以及`vllm_rollout_spmd.py`中。此外还需要正确配置`VLLM_DP_SIZE`环境变量为`world_size / vllm_tp_size`。
|
| 10 |
+
- NPU的MindSpeed训练框架会将torch.compile无效化来规避训练侧的compile失败,但这会使推理侧无法利用torch.compile加速。为了解决该问题,本样例添加了一些patch,使推理时可以compile,训练时不compile。相关代码`megatron_workers.py`中。
|
| 11 |
+
- RL训练过程中,NPU上vLLM多次KVcache调度可能引发申请内存不一致导致内存踩踏问题,修复patch在`engine_core.py`中。
|
| 12 |
+
|
| 13 |
+
通过全局搜索`# NPU-ADAPTATION`,可以看到patch代码所做的实际改动。
|
| 14 |
+
|
| 15 |
+
更多技术细节可参考[技术报告](https://gitcode.com/cann/cann-recipes-train/blob/master/docs/deepseek/deepseek_rl_train_optimization.md)。
|
| 16 |
+
|
| 17 |
+
## 训练细节
|
| 18 |
+
### 训练超参
|
| 19 |
+
|
| 20 |
+
本样例基于DeepSeek-671B Base模型在deepscaler数据集上训练,使用简单的格式奖励和结果准确率奖励,训练超参如下:
|
| 21 |
+
|
| 22 |
+
| 迭代 | 学习率 | gbs | 采样数 | 温度 | kl-coef | 输入长度 | 输出长度 | 规则奖励 | 奖励模型 |
|
| 23 |
+
|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
|
| 24 |
+
| 70 | 1e-6 (constant) | 512 | 16 | 1.0 | 0.001 | 1024 | 2048 | format + acc | - |
|
| 25 |
+
|
| 26 |
+
### 训练资源与性能
|
| 27 |
+
本样例在昇腾Atlas 800T A3超节点服务器上进行训练,使用了128张A3 NPU,等效于256张加速卡。具体的部署方式如下:
|
| 28 |
+
|
| 29 |
+
| Rollout部署 | Actor部署 | Reference部署 | Offload策略 |
|
| 30 |
+
|:----:|:----:|:----:|:----:|
|
| 31 |
+
| TP2 EP256 | EP32 PP8 | 同Actor | 全offload,优化器使用[Mindspeed卸载特性](https://gitee.com/ascend/MindSpeed/blob/master/docs/features/swap-optimizer.md) |
|
| 32 |
+
|
| 33 |
+
得到一步的训练性能如下(吞吐会随着训练中模型输出长度变化而改变):
|
| 34 |
+
| step | 平均问题长度 | 平均回复长度 | 单步总耗时(s) | 吞吐(tps/A3) | gen耗时(s) | reward耗时(s) | old_prob耗时(s) | ref_prob耗时(s) | update耗时(s) |
|
| 35 |
+
|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
|
| 36 |
+
| 2 | 175.1 | 1385.0 | 1044.8 | 95.5 | 482.2 | 20.4 | 105.5 | 92.7 | 342.9 |
|
| 37 |
+
|
| 38 |
+
### 训练过程记录
|
| 39 |
+
<div align="center">
|
| 40 |
+
<img src="./figures/rewards.png" width="33%" />
|
| 41 |
+
<img src="./figures/response_len.png" width="33%" />
|
| 42 |
+
<img src="./figures/val_score.png" width="33%" />
|
| 43 |
+
</div>
|
| 44 |
+
|
| 45 |
+
## 快速开始
|
| 46 |
+
|
| 47 |
+
### 环境准备
|
| 48 |
+
verl上的NPU环境准备,可参考[ascend_quick_start.rst](../../docs/ascend_tutorial/ascend_quick_start.rst)进行配置。
|
| 49 |
+
|
| 50 |
+
此外,也可使用我们提供的Dockerfile在本地构建项目运行环境:`docker build -f Dockerfile.vllm_ascend.mindspeed.deepseekV3 -t REPOSITORY:TAG ./`
|
| 51 |
+
|
| 52 |
+
本样准备源码的步骤如下:
|
| 53 |
+
```bash
|
| 54 |
+
# verl
|
| 55 |
+
git clone https://github.com/volcengine/verl.git
|
| 56 |
+
|
| 57 |
+
# vLLM (v0.9.1)
|
| 58 |
+
git clone https://github.com/vllm-project/vllm.git
|
| 59 |
+
cd vllm
|
| 60 |
+
git checkout v0.9.1
|
| 61 |
+
cp -r vllm ../verl
|
| 62 |
+
cd ..
|
| 63 |
+
|
| 64 |
+
# vLLM-Ascend (v0.9.1-dev)
|
| 65 |
+
git clone https://github.com/vllm-project/vllm-ascend.git
|
| 66 |
+
cd vllm-ascend
|
| 67 |
+
git checkout 8c7bc45
|
| 68 |
+
cp -r vllm_ascend ../verl
|
| 69 |
+
cd ..
|
| 70 |
+
|
| 71 |
+
# MindSpeed (commit-id: f6688)
|
| 72 |
+
git clone https://gitcode.com/Ascend/MindSpeed.git
|
| 73 |
+
cd MindSpeed
|
| 74 |
+
git checkout f6688
|
| 75 |
+
cp -r mindspeed ../verl
|
| 76 |
+
cd ..
|
| 77 |
+
|
| 78 |
+
# Megatron-LM.core and others
|
| 79 |
+
pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1
|
| 80 |
+
pip install mathruler
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### 准备训练数据集
|
| 84 |
+
本样例使用deepscaler数据集。准备方式如下:
|
| 85 |
+
- 下载数据集[json文件](https://huggingface.co/datasets/agentica-org/DeepScaleR-Preview-Dataset/blob/main/deepscaler.json)。
|
| 86 |
+
- 获取`train.parquet`与`test.parquet`文件并放入`./data/deepscaler`路径:
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
# 在verl项目目录执行
|
| 90 |
+
python recipe/r1_ascend/json_to_parquet.py --output_dir ./data/deepscaler --json_path path/to/deepscaler.json --train_data_ratio 0.9
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
训练中经过处理的prompt将包含模板,例如:`A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Put your final answer within \boxed{}. <|User|>{problem}<|Assistant|>`
|
| 94 |
+
|
| 95 |
+
### 准备模型权重
|
| 96 |
+
DeepSeek-V3-Base模型权重准备步骤如下:
|
| 97 |
+
- 需要将模型配置相关文件(不含权重)放入`./DeepSeek-V3-hf`目录,并且`config.json`需要进行替换以去除量化和MTP。该步骤可参考[此链接](https://gitcode.com/cann/cann-recipes-train/blob/master/rl_train/deepseek/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87)。
|
| 98 |
+
- 模型FP8权重下载:[HuggingFace地址](https://huggingface.co/deepseek-ai/DeepSeek-V3-Base),[ModelScope地址](https://www.modelscope.cn/models/deepseek-ai/DeepSeek-V3-Base)。此步骤需要目录所在磁盘有650GB以上空间。
|
| 99 |
+
- 将FP8权重转为BF16权重,可参考[此链接](https://gitcode.com/cann/cann-recipes-train/blob/master/rl_train/deepseek/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87)。此步骤需要目录所在磁盘有1300GB以上空间。
|
| 100 |
+
|
| 101 |
+
本样例使用了预先切分的分布式权重,因此还要执行以下的切分权重操作:
|
| 102 |
+
- 分布式权重需存储至`ckpts/DeepseekV3-dist-ckpts`。
|
| 103 |
+
- 使用`verl/scripts/converter_hf_to_mcore.py`对原始的BF16权重切分得到分布式权重。实践中我们发现2T的CPU内存不足以完成671B模型的权重切分处理,为此我们对该脚本进行了专家并行的适配,并在64块NPU上用EP8 PP8分布式策略对权重进行了切分。
|
| 104 |
+
|
| 105 |
+
### 其他代码修改
|
| 106 |
+
实践中为了得到以上on-policy训练的结果,我们将 `verl/workers/actor/megatron_actor.py` 中的代码段 `old_log_prob = data["old_log_probs"]` 替换为如下代码:
|
| 107 |
+
```python
|
| 108 |
+
on_policy = self.config.ppo_epochs == 1
|
| 109 |
+
if on_policy:
|
| 110 |
+
old_log_prob = log_prob.detach() # 确保二者数值完全相等
|
| 111 |
+
else:
|
| 112 |
+
old_log_prob = data["old_log_probs"]
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
### 执行RL后训练
|
| 116 |
+
```bash
|
| 117 |
+
# verl目录下启动DeepSeekV3的RL后训练
|
| 118 |
+
bash ./recipe/r1_ascend/ray_start_grpo_npu.sh
|
| 119 |
+
```
|
ICL/DAPO/verl-recipe/r1_ascend/ray_start_grpo_npu.sh
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ray stop --force
|
| 2 |
+
|
| 3 |
+
export RAY_DEDUP_LOGS=0 # 0: disable ray's log folding 1: enable ray's log folding
|
| 4 |
+
export HYDRA_FULL_ERROR=1 # display the accurate error stack
|
| 5 |
+
|
| 6 |
+
ulimit -n 32768
|
| 7 |
+
mkdir logs
|
| 8 |
+
|
| 9 |
+
NNODES=16 # number of nodes
|
| 10 |
+
NPUS_PER_NODE=16 # the number of npus for each node
|
| 11 |
+
export WORLD_SIZE=$(($NNODES*$NPUS_PER_NODE))
|
| 12 |
+
|
| 13 |
+
RAY_START_PORT=6766
|
| 14 |
+
RAY_DASHBOARD_PORT=8260
|
| 15 |
+
|
| 16 |
+
MASTER_ADDR="IP FOR MASTER NODE" # modify it to correspond to the IP of the master node
|
| 17 |
+
SOCKET_IFNAME="SOCKET IFNAME FOR CURRENT NODE" # modify it to the communication network card of the current node
|
| 18 |
+
# obtain the current node IP
|
| 19 |
+
CURRENT_IP=$(ifconfig $SOCKET_IFNAME | grep -Eo 'inet (addr:)?([0-9]{1,3}\.){3}[0-9]{1,3}' | awk '{print $NF}')
|
| 20 |
+
export MASTER_PORT=29444
|
| 21 |
+
export HCCL_IF_BASE_PORT=64247
|
| 22 |
+
export TP_SOCKET_IFNAME=$SOCKET_IFNAME
|
| 23 |
+
export HCCL_SOCKET_IFNAME=$SOCKET_IFNAME
|
| 24 |
+
export GLOO_SOCKET_IFNAME=$SOCKET_IFNAME
|
| 25 |
+
|
| 26 |
+
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
| 27 |
+
export PYTORCH_NPU_ALLOC_CONF="expandable_segments:True"
|
| 28 |
+
export TASK_QUEUE_ENABLE=2 # enable level2 optimization of the sent queue of the ascend operator
|
| 29 |
+
export HCCL_BUFFSIZE=300 # the buffer size of HCCL
|
| 30 |
+
|
| 31 |
+
export HCCL_CONNECT_TIMEOUT=600
|
| 32 |
+
export HCCL_EXEC_TIMEOUT=600
|
| 33 |
+
|
| 34 |
+
export ASCEND_LAUNCH_BLOCKING=0 # debug usage, which seriously affects performance after use, but the error stack is accurate
|
| 35 |
+
|
| 36 |
+
export VLLM_USE_V1=1 # use the V1 engine of vLLM
|
| 37 |
+
export VLLM_ENABLE_GRAPH_MODE=1 # enable vLLM graph mode
|
| 38 |
+
export HCCL_OP_EXPANSION_MODE=AIV # enable the communication mode of AIV
|
| 39 |
+
export VLLM_ENABLE_MC2=1 # enable MC2 communication
|
| 40 |
+
export VLLM_DP_SIZE=128 # configure the DP size of vLLM, this is related to the vLLM instance num
|
| 41 |
+
|
| 42 |
+
# under the configuration of the vLLM log level of INFO, enable this configuration, print the time of prefill and decode
|
| 43 |
+
export VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE=0
|
| 44 |
+
|
| 45 |
+
if [ "$MASTER_ADDR" = "$CURRENT_IP" ]; then
|
| 46 |
+
# the master node starts
|
| 47 |
+
ray start --head --port=$RAY_START_PORT --dashboard-host=0.0.0.0 --node-ip-address=$CURRENT_IP --dashboard-port=$RAY_DASHBOARD_PORT --resources='{"NPU": '$NPUS_PER_NODE'}'
|
| 48 |
+
|
| 49 |
+
while true; do
|
| 50 |
+
ray_status_output=$(ray status)
|
| 51 |
+
npu_count=$(echo "$ray_status_output" | grep -oP '(?<=/)\d+\.\d+(?=\s*NPU)' | head -n 1)
|
| 52 |
+
npu_count_int=$(echo "$npu_count" | awk '{print int($1)}')
|
| 53 |
+
device_count=$((npu_count_int / $NPUS_PER_NODE))
|
| 54 |
+
|
| 55 |
+
# determine whether device_count is equal to NNODES
|
| 56 |
+
if [ "$device_count" -eq "$NNODES" ]; then
|
| 57 |
+
echo "Ray cluster is ready with $device_count devices (from $npu_count NPU resources), starting Python script."
|
| 58 |
+
ray status
|
| 59 |
+
bash ./recipe/r1_ascend/run_deepseekv3_671b_grpo_megatron_npu.sh
|
| 60 |
+
break
|
| 61 |
+
else
|
| 62 |
+
echo "Waiting for Ray to allocate $NNODES devices. Current device count: $device_count"
|
| 63 |
+
sleep 5
|
| 64 |
+
fi
|
| 65 |
+
done
|
| 66 |
+
else
|
| 67 |
+
# the child node attempts to register ray with the master node until successful
|
| 68 |
+
while true; do
|
| 69 |
+
# try to connect to the Ray cluster
|
| 70 |
+
ray start --address="$MASTER_ADDR:$RAY_START_PORT" --resources='{"NPU": '$NPUS_PER_NODE'}' --node-ip-address=$CURRENT_IP
|
| 71 |
+
|
| 72 |
+
# check if the connection is successful
|
| 73 |
+
ray status
|
| 74 |
+
if [ $? -eq 0 ]; then
|
| 75 |
+
echo "Successfully connected to the Ray cluster!"
|
| 76 |
+
break
|
| 77 |
+
else
|
| 78 |
+
echo "Failed to connect to the Ray cluster. Retrying in 5 seconds..."
|
| 79 |
+
sleep 5
|
| 80 |
+
fi
|
| 81 |
+
done
|
| 82 |
+
fi
|
ICL/DAPO/verl-recipe/r1_ascend/vllm_rollout_spmd.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
#
|
| 16 |
+
# Adapted from
|
| 17 |
+
# https://github.com/volcengine/verl/blob/main/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import os
|
| 21 |
+
from typing import Generator
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.distributed
|
| 25 |
+
from omegaconf import ListConfig
|
| 26 |
+
from torch.distributed.device_mesh import DeviceMesh
|
| 27 |
+
from vllm import LLM, SamplingParams
|
| 28 |
+
from vllm.config import CompilationConfig, CompilationLevel
|
| 29 |
+
|
| 30 |
+
from verl.third_party.vllm import VLLM_SLEEP_LEVEL
|
| 31 |
+
from verl.utils.device import get_device_name
|
| 32 |
+
from verl.utils.memory_utils import aggressive_empty_cache
|
| 33 |
+
from verl.workers.config import HFModelConfig, RolloutConfig
|
| 34 |
+
from verl.workers.rollout.vllm_rollout import vLLMRollout as vLLMRolloutBase
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger(__file__)
|
| 37 |
+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class vLLMRollout(vLLMRolloutBase):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
config: RolloutConfig,
|
| 44 |
+
model_config: HFModelConfig,
|
| 45 |
+
device_mesh: DeviceMesh,
|
| 46 |
+
):
|
| 47 |
+
self.config = config
|
| 48 |
+
self.model_config = model_config
|
| 49 |
+
self.device_mesh = device_mesh
|
| 50 |
+
# NPU-ADAPTATION: import vLLM-Ascend patch
|
| 51 |
+
from recipe.r1_ascend import engine_core # noqa: F401
|
| 52 |
+
from vllm_ascend.patch import (
|
| 53 |
+
platform, # noqa: F401
|
| 54 |
+
worker, # noqa: F401
|
| 55 |
+
)
|
| 56 |
+
# NPU-ADAPTATION END
|
| 57 |
+
|
| 58 |
+
if config.layered_summon:
|
| 59 |
+
self.sleep_level = 1
|
| 60 |
+
else:
|
| 61 |
+
self.sleep_level = VLLM_SLEEP_LEVEL
|
| 62 |
+
|
| 63 |
+
model_path = model_config.local_path
|
| 64 |
+
tokenizer = model_config.tokenizer
|
| 65 |
+
model_hf_config = model_config.hf_config
|
| 66 |
+
trust_remote_code = model_config.trust_remote_code
|
| 67 |
+
self.lora_kwargs = (
|
| 68 |
+
{"enable_lora": True, "max_loras": 1, "max_lora_rank": model_config.lora_rank}
|
| 69 |
+
if model_config.lora_rank > 0
|
| 70 |
+
else {}
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1)
|
| 74 |
+
assert tensor_parallel_size <= torch.distributed.get_world_size(), (
|
| 75 |
+
"tensor parallel size should be less than or equal to the world size"
|
| 76 |
+
)
|
| 77 |
+
max_num_batched_tokens = self.config.get("max_num_batched_tokens", 8192)
|
| 78 |
+
|
| 79 |
+
# NPU-ADAPTATION: VLLM_DP_SIZE is configured, the DP communication domain needs to be explicitly initialized
|
| 80 |
+
if int(os.environ.get("VLLM_DP_SIZE", "1")) > 1:
|
| 81 |
+
from recipe.r1_ascend.vllm_parallel_state import init_parallel_state
|
| 82 |
+
|
| 83 |
+
init_parallel_state(tensor_parallel_size)
|
| 84 |
+
# NPU-ADAPTATION END
|
| 85 |
+
|
| 86 |
+
rope_scaling_config = getattr(model_hf_config, "rope_scaling", None)
|
| 87 |
+
if not rope_scaling_config:
|
| 88 |
+
max_position_embeddings = None
|
| 89 |
+
if hasattr(model_hf_config, "max_position_embeddings"):
|
| 90 |
+
max_position_embeddings = model_hf_config.max_position_embeddings
|
| 91 |
+
elif hasattr(model_hf_config, "llm_config") and hasattr(
|
| 92 |
+
model_hf_config.llm_config, "max_position_embeddings"
|
| 93 |
+
):
|
| 94 |
+
max_position_embeddings = model_hf_config.llm_config.max_position_embeddings
|
| 95 |
+
elif hasattr(model_hf_config, "text_config") and hasattr(
|
| 96 |
+
model_hf_config.text_config, "max_position_embeddings"
|
| 97 |
+
):
|
| 98 |
+
max_position_embeddings = model_hf_config.text_config.max_position_embeddings
|
| 99 |
+
if max_position_embeddings is None:
|
| 100 |
+
raise ValueError("max_position_embeddings not found in model_hf_config")
|
| 101 |
+
assert max_position_embeddings >= config.prompt_length + config.response_length, (
|
| 102 |
+
"model context length should be greater than total sequence length"
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
# handle type where there's a length extend factor
|
| 106 |
+
# see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support
|
| 107 |
+
# for using yarn as an example
|
| 108 |
+
rope_scaling_factor = rope_scaling_config.get("factor", 1.0)
|
| 109 |
+
|
| 110 |
+
assert (
|
| 111 |
+
model_hf_config.max_position_embeddings * rope_scaling_factor
|
| 112 |
+
>= config.prompt_length + config.response_length
|
| 113 |
+
), (
|
| 114 |
+
"model context length should be greater than total sequence length, "
|
| 115 |
+
+ f"got rope_scaling_factor={rope_scaling_factor} and "
|
| 116 |
+
+ f"max_position_embeddings={model_hf_config.max_position_embeddings}"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
max_model_len = int(config.max_model_len or config.prompt_length + config.response_length)
|
| 120 |
+
|
| 121 |
+
load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format
|
| 122 |
+
|
| 123 |
+
# copy it to avoid secretly modifying the engine config
|
| 124 |
+
engine_kwargs = config.get("engine_kwargs", {}).get("vllm", {}) or {}
|
| 125 |
+
|
| 126 |
+
# For each vLLM engine parameter,
|
| 127 |
+
# - `None` means not setting it, so we pop it, and leave it to vLLM default value
|
| 128 |
+
# (which can vary across different vLLM versions);
|
| 129 |
+
# - Otherwise it's the desired value we want to explicitly set.
|
| 130 |
+
engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None}
|
| 131 |
+
if config.get("limit_images", None): # support for multi-image data
|
| 132 |
+
engine_kwargs["limit_mm_per_prompt"] = {"image": config.get("limit_images")}
|
| 133 |
+
|
| 134 |
+
compilation_config = {}
|
| 135 |
+
|
| 136 |
+
cudagraph_capture_sizes = config.get("cudagraph_capture_sizes")
|
| 137 |
+
# enforce_eager must be False to use cudagraph
|
| 138 |
+
if not config.enforce_eager and cudagraph_capture_sizes:
|
| 139 |
+
if isinstance(cudagraph_capture_sizes, ListConfig):
|
| 140 |
+
compilation_config["compilation_config"] = CompilationConfig(
|
| 141 |
+
level=CompilationLevel.PIECEWISE, cudagraph_capture_sizes=cudagraph_capture_sizes
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
logger.warning(f"cudagraph_capture_sizes must be a list, but got {cudagraph_capture_sizes}")
|
| 145 |
+
|
| 146 |
+
VLLM_ENABLE_GRAPGH_MODE = int(os.environ.get("VLLM_ENABLE_GRAPH_MODE", "0"))
|
| 147 |
+
self.inference_engine = LLM(
|
| 148 |
+
model=model_path,
|
| 149 |
+
# NPU-ADAPTATION: Enable inference EP and disable sleep mode.
|
| 150 |
+
enable_sleep_mode=False,
|
| 151 |
+
enable_expert_parallel=True,
|
| 152 |
+
# NPU-ADAPTATION END
|
| 153 |
+
tensor_parallel_size=tensor_parallel_size,
|
| 154 |
+
distributed_executor_backend="external_launcher",
|
| 155 |
+
dtype=config.dtype,
|
| 156 |
+
enforce_eager=config.enforce_eager,
|
| 157 |
+
gpu_memory_utilization=config.gpu_memory_utilization,
|
| 158 |
+
disable_custom_all_reduce=True,
|
| 159 |
+
skip_tokenizer_init=False,
|
| 160 |
+
max_model_len=max_model_len,
|
| 161 |
+
max_num_seqs=config.max_num_seqs,
|
| 162 |
+
load_format=load_format,
|
| 163 |
+
disable_log_stats=config.disable_log_stats,
|
| 164 |
+
max_num_batched_tokens=max_num_batched_tokens,
|
| 165 |
+
enable_chunked_prefill=config.enable_chunked_prefill,
|
| 166 |
+
enable_prefix_caching=False,
|
| 167 |
+
trust_remote_code=trust_remote_code,
|
| 168 |
+
seed=config.get("seed", 0),
|
| 169 |
+
# NPU-ADAPTATION: Enable graph mode and configure the parameters.
|
| 170 |
+
additional_config={
|
| 171 |
+
"torchair_graph_config": {
|
| 172 |
+
"enabled": VLLM_ENABLE_GRAPGH_MODE,
|
| 173 |
+
"use_cached_graph": False,
|
| 174 |
+
"graph_batch_sizes_init": False,
|
| 175 |
+
"graph_batch_sizes": [config.max_num_seqs],
|
| 176 |
+
"enable_multistream_mla": False,
|
| 177 |
+
"enable_multistream_moe": False,
|
| 178 |
+
"enable_view_optimize": False,
|
| 179 |
+
"enable_kv_nz": False,
|
| 180 |
+
"enable_frozen_parameter": False,
|
| 181 |
+
},
|
| 182 |
+
"ascend_scheduler_config": {
|
| 183 |
+
"enabled": True,
|
| 184 |
+
},
|
| 185 |
+
"refresh": True,
|
| 186 |
+
},
|
| 187 |
+
# NPU-ADAPTATION END
|
| 188 |
+
**compilation_config,
|
| 189 |
+
**self.lora_kwargs,
|
| 190 |
+
**engine_kwargs,
|
| 191 |
+
)
|
| 192 |
+
# NPU-ADAPTATION: Weight onload and offload, and initialization configurations such as kv_cache.
|
| 193 |
+
self.model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.get_model()
|
| 194 |
+
self.kv_cache_configs = None
|
| 195 |
+
self.cpu_model = {}
|
| 196 |
+
self.gpu_buffers = None
|
| 197 |
+
for name, params in self.model.named_parameters():
|
| 198 |
+
self.cpu_model[name] = torch.empty_like(params, device="cpu")
|
| 199 |
+
# NPU-ADAPTATION END
|
| 200 |
+
|
| 201 |
+
kwargs = dict(
|
| 202 |
+
n=1,
|
| 203 |
+
logprobs=0, # can be set to 0 and let actor to recompute
|
| 204 |
+
max_tokens=config.response_length,
|
| 205 |
+
repetition_penalty=config.get("repetition_penalty", 1.0),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
kwargs["detokenize"] = False
|
| 209 |
+
|
| 210 |
+
# supporting adding any sampling params from the config file
|
| 211 |
+
for k in config.keys():
|
| 212 |
+
if hasattr(SamplingParams(), str(k)) and k != "seed":
|
| 213 |
+
kwargs[k] = config.get(k)
|
| 214 |
+
kwargs["n"] = 1 # already repeat in ray_trainer
|
| 215 |
+
logger.info(f"vllm sampling kwargs: {kwargs}")
|
| 216 |
+
self.sampling_params = SamplingParams(**kwargs)
|
| 217 |
+
|
| 218 |
+
self.pad_token_id = tokenizer.pad_token_id
|
| 219 |
+
|
| 220 |
+
# NPU-ADAPTATION: Weight onload and offload, kv_cache init and free function
|
| 221 |
+
# NOTE: Due to potential incomplete memory offloading during sleep operations for vLLM on NPUs, we add
|
| 222 |
+
# patches to manually handle the off/on loading of the rollout model and KVcache on NPUs.
|
| 223 |
+
def init_cache_engine(self):
|
| 224 |
+
if os.environ["VLLM_USE_V1"] == "1":
|
| 225 |
+
worker = self.inference_engine.llm_engine.model_executor.driver_worker.worker
|
| 226 |
+
if not worker.model_runner.kv_caches:
|
| 227 |
+
# v1 use explicit initialization method
|
| 228 |
+
self.inference_engine.llm_engine.engine_core.engine_core.model_executor.initialize_from_config(
|
| 229 |
+
self.inference_engine.llm_engine.engine_core.engine_core.kv_cache_configs
|
| 230 |
+
)
|
| 231 |
+
self.inference_engine.llm_engine.reset_prefix_cache()
|
| 232 |
+
else:
|
| 233 |
+
if self.inference_engine.llm_engine.model_executor.driver_worker.worker.cache_engine is None:
|
| 234 |
+
self.inference_engine.llm_engine.model_executor.driver_worker.worker._init_cache_engine()
|
| 235 |
+
|
| 236 |
+
def onload_model_weights(self):
|
| 237 |
+
self.gpu_buffers = {}
|
| 238 |
+
for name, param in self.model.named_parameters():
|
| 239 |
+
self.gpu_buffers[name] = torch.empty_like(param, device=get_device_name())
|
| 240 |
+
for name, param in self.model.named_parameters():
|
| 241 |
+
param.data = self.gpu_buffers[name]
|
| 242 |
+
|
| 243 |
+
def offload_model_weights(self):
|
| 244 |
+
for name, params in self.model.named_parameters():
|
| 245 |
+
params.data = self.cpu_model[name]
|
| 246 |
+
if hasattr(self.model.model.layers[0].self_attn, "mla_attn"):
|
| 247 |
+
for i in range(self.model.model.start_layer, self.model.model.end_layer):
|
| 248 |
+
mla = self.model.model.layers[i].self_attn.mla_attn.impl
|
| 249 |
+
if hasattr(mla, "w_kc"):
|
| 250 |
+
mla.w_kc = None
|
| 251 |
+
mla.w_vc = None
|
| 252 |
+
if hasattr(mla, "W_UV"):
|
| 253 |
+
mla.W_UV = None
|
| 254 |
+
mla.W_UK_T = None
|
| 255 |
+
|
| 256 |
+
self.gpu_buffers = None
|
| 257 |
+
aggressive_empty_cache()
|
| 258 |
+
|
| 259 |
+
def free_cache_engine(self):
|
| 260 |
+
if os.environ["VLLM_USE_V1"] == "1":
|
| 261 |
+
worker = self.inference_engine.llm_engine.model_executor.driver_worker.worker
|
| 262 |
+
ctx = worker.model_runner.vllm_config.compilation_config.static_forward_context
|
| 263 |
+
else:
|
| 264 |
+
compilation_config = self.inference_engine.llm_engine.model_executor.driver_worker.worker.compilation_config
|
| 265 |
+
ctx = compilation_config.static_forward_context
|
| 266 |
+
from vllm.attention import AttentionType
|
| 267 |
+
|
| 268 |
+
layer_need_kv_cache = []
|
| 269 |
+
for layer_name in ctx:
|
| 270 |
+
if hasattr(ctx[layer_name], "attn_type") and ctx[layer_name].attn_type in (
|
| 271 |
+
AttentionType.DECODER,
|
| 272 |
+
AttentionType.ENCODER_DECODER,
|
| 273 |
+
):
|
| 274 |
+
layer_need_kv_cache.append(layer_name)
|
| 275 |
+
|
| 276 |
+
pipeline_parallel_size = self.inference_engine.llm_engine.vllm_config.parallel_config.pipeline_parallel_size
|
| 277 |
+
for layer_name in layer_need_kv_cache:
|
| 278 |
+
kv_cache = []
|
| 279 |
+
for _ in range(pipeline_parallel_size):
|
| 280 |
+
kv_cache.append(torch.tensor([]))
|
| 281 |
+
ctx[layer_name].kv_cache = kv_cache
|
| 282 |
+
if os.environ["VLLM_USE_V1"] == "1":
|
| 283 |
+
worker = self.inference_engine.llm_engine.model_executor.driver_worker.worker
|
| 284 |
+
|
| 285 |
+
worker.model_runner.kv_caches = []
|
| 286 |
+
else:
|
| 287 |
+
self.inference_engine.llm_engine.model_executor.driver_worker.worker.cache_engine = None
|
| 288 |
+
self.inference_engine.llm_engine.model_executor.driver_worker.worker.gpu_cache = None
|
| 289 |
+
|
| 290 |
+
if hasattr(self.model.model.layers[0].self_attn, "attn"):
|
| 291 |
+
for i in range(self.model.model.start_layer, self.model.model.end_layer):
|
| 292 |
+
attn_impl = self.model.model.layers[i].self_attn.attn.impl
|
| 293 |
+
if hasattr(attn_impl, "key_cache"):
|
| 294 |
+
attn_impl.key_cache = None
|
| 295 |
+
attn_impl.value_cache = None
|
| 296 |
+
|
| 297 |
+
aggressive_empty_cache()
|
| 298 |
+
|
| 299 |
+
def _process_mla(self, load_weight=False):
|
| 300 |
+
for i in range(self.model.model.start_layer, self.model.model.end_layer):
|
| 301 |
+
mla = self.model.model.layers[i].self_attn.mla_attn.impl
|
| 302 |
+
if hasattr(mla, "w_kc"):
|
| 303 |
+
mla.w_kc = None
|
| 304 |
+
mla.w_vc = None
|
| 305 |
+
if hasattr(mla, "W_UV"):
|
| 306 |
+
mla.W_UV = None
|
| 307 |
+
mla.W_UK_T = None
|
| 308 |
+
if load_weight:
|
| 309 |
+
mla.process_weights_after_loading(None)
|
| 310 |
+
|
| 311 |
+
async def resume(self, tags: list[str]):
|
| 312 |
+
"""Resume rollout weights or kv cache in NPU memory.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
tags: weights or kv_cache.
|
| 316 |
+
"""
|
| 317 |
+
if not self.config.free_cache_engine:
|
| 318 |
+
return
|
| 319 |
+
|
| 320 |
+
if "weights" in tags:
|
| 321 |
+
self.onload_model_weights()
|
| 322 |
+
elif "kv_cache" in tags:
|
| 323 |
+
self.init_cache_engine()
|
| 324 |
+
|
| 325 |
+
async def release(self):
|
| 326 |
+
"""Release weights and kv cache in NPU memory."""
|
| 327 |
+
if not self.config.free_cache_engine:
|
| 328 |
+
return
|
| 329 |
+
|
| 330 |
+
self.free_cache_engine()
|
| 331 |
+
self.offload_model_weights()
|
| 332 |
+
|
| 333 |
+
if hasattr(self.model.model.layers[0].self_attn, "mla_attn"):
|
| 334 |
+
self._process_mla()
|
| 335 |
+
|
| 336 |
+
async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs):
|
| 337 |
+
"""Update the weights of the rollout model.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
weights: A generator that yields the name of the weight tensor and the tensor itself.
|
| 341 |
+
"""
|
| 342 |
+
await super().update_weights(weights, **kwargs)
|
| 343 |
+
|
| 344 |
+
if hasattr(self.model.model.layers[0].self_attn, "mla_attn"):
|
| 345 |
+
self._process_mla(load_weight=True)
|
| 346 |
+
|
| 347 |
+
# NPU-ADAPTATION END
|
ICL/DAPO/verl-recipe/rep_exp/README.md
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# Representation-Based Exploration for Language Models: <br> From Test-Time to Post-Training
|
| 4 |
+
|
| 5 |
+
[📄 arXiv](https://arxiv.org/abs/2510.11686) [🌐 Website](https://rep-exp.github.io) [🐦 Twitter / X ](https://x.com/JensTuyls/status/1978244454617128993)
|
| 6 |
+
|
| 7 |
+
</div>
|
| 8 |
+
|
| 9 |
+
## Installation 🔌
|
| 10 |
+
|
| 11 |
+
Install the following commit of verl:
|
| 12 |
+
```
|
| 13 |
+
pip install verl@git+https://github.com/volcengine/verl.git@b9bd00efba253ea90072555c45692054cf703de2
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
The only other package to install is scikit-learn, which we'll use for applying a sparse projection.
|
| 17 |
+
```bash
|
| 18 |
+
pip install scikit-learn
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
## Running the Experiments 🚀
|
| 22 |
+
|
| 23 |
+
You can reproduce or extend our experiments by running the following commands:
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
# General format
|
| 27 |
+
sh rep_exp/train_elliptical.sh $TASK $SPARSE_DIM $BETA $SEED
|
| 28 |
+
|
| 29 |
+
# MATH
|
| 30 |
+
sh rep_exp/train_elliptical.sh math 32 0.01 42
|
| 31 |
+
|
| 32 |
+
# GSM8K
|
| 33 |
+
sh rep_exp/train_elliptical.sh gsm8k 32 0.01 42
|
| 34 |
+
|
| 35 |
+
# DAPO-WITH-AIME
|
| 36 |
+
sh rep_exp/train_elliptical.sh dapo-with-aime24 128 0.01 42
|
| 37 |
+
```
|
| 38 |
+
where `$TASK` is the task name, `$SPARSE_DIM` is the sparse dimension, `$BETA` is the beta parameter, and `$SEED` is the seed.
|
| 39 |
+
|
| 40 |
+
## Evaluation 📊
|
| 41 |
+
Once done training, you can evaluate the model on the test set by following two steps.
|
| 42 |
+
1. Merge the model checkpoint.
|
| 43 |
+
|
| 44 |
+
This is necessary because the model checkpoint is saved in multiple shards (depending on the nubmer of GPUs), and we need to merge them into a single checkpoint.
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
sh rep_exp/model_merge.sh /path/to/global_step_X/actor # where X is the global step of the checkpoint with the best pass@1 on dev
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
2. Evaluate the merged model.
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
sh rep_exp/eval.sh $TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
The results should be in a folder named `eval` and saved as a JSON file.
|
| 57 |
+
|
| 58 |
+
## Citation 📝
|
| 59 |
+
|
| 60 |
+
```bibtex
|
| 61 |
+
@article{tuyls2025representation,
|
| 62 |
+
title={Representation-Based Exploration for Language Models: From Test-Time to Post-Training},
|
| 63 |
+
author={Tuyls, Jens and Foster, Dylan J and Krishnamurthy, Akshay and Ash, Jordan T},
|
| 64 |
+
journal={arXiv preprint arXiv:2510.11686},
|
| 65 |
+
year={2025}
|
| 66 |
+
}
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## Contact 📬
|
| 70 |
+
|
| 71 |
+
If you have any questions or suggestions, feel free to reach out at [jtuyls@princeton.edu](mailto:jtuyls@princeton.edu).
|
ICL/DAPO/verl-recipe/rep_exp/eval.sh
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK=${1} # math, gsm8k, dapo-with-aime24
|
| 2 |
+
|
| 3 |
+
# Custom model path for evaluation after training
|
| 4 |
+
MODEL_PATH=${2} # /path/to/global_step_X/actor/hf, where X is the global step of the checkpoint with the best pass@1 on dev
|
| 5 |
+
|
| 6 |
+
# If you want to evaluate the base model before training
|
| 7 |
+
# MODEL_PATH=Qwen/Qwen2.5-7B-Instruct
|
| 8 |
+
|
| 9 |
+
train_path=$HOME/data/${TASK}/train.parquet
|
| 10 |
+
train_files="['$train_path']"
|
| 11 |
+
CHECKPOINT_SAVE_CONTENTS='["model"]'
|
| 12 |
+
|
| 13 |
+
if [ ${TASK} == "dapo-with-aime24" ]; then
|
| 14 |
+
MAX_PROMPT_LENGTH=$((1024 * 2))
|
| 15 |
+
MAX_RESPONSE_LENGTH=$((1024 * 8))
|
| 16 |
+
MAX_NUM_BATCHED_TOKENS=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))
|
| 17 |
+
test_path=$HOME/data/${TASK}/dev.parquet
|
| 18 |
+
else
|
| 19 |
+
MAX_PROMPT_LENGTH=1024
|
| 20 |
+
MAX_RESPONSE_LENGTH=1024
|
| 21 |
+
MAX_NUM_BATCHED_TOKENS=8192
|
| 22 |
+
test_path=$HOME/data/${TASK}/test.parquet
|
| 23 |
+
fi
|
| 24 |
+
|
| 25 |
+
test_files="['$test_path']"
|
| 26 |
+
|
| 27 |
+
# If you're on a cluster with no internet access, set to OFFLINE=True
|
| 28 |
+
OFFLINE=False
|
| 29 |
+
|
| 30 |
+
PYTHONUNBUFFERED=1 WANDB_MODE=disabled TRANSFORMERS_OFFLINE=${OFFLINE} python3 -u -m rep_exp.main_rep_exp \
|
| 31 |
+
algorithm.adv_estimator=grpo \
|
| 32 |
+
data.train_files="$train_files" \
|
| 33 |
+
data.val_files="$test_files" \
|
| 34 |
+
data.train_batch_size=1024 \
|
| 35 |
+
data.max_prompt_length=$MAX_PROMPT_LENGTH \
|
| 36 |
+
data.max_response_length=$MAX_RESPONSE_LENGTH \
|
| 37 |
+
data.filter_overlong_prompts=True \
|
| 38 |
+
data.truncation='error' \
|
| 39 |
+
data.val_batch_size=128 \
|
| 40 |
+
actor_rollout_ref.model.path="$MODEL_PATH" \
|
| 41 |
+
actor_rollout_ref.actor.checkpoint.save_contents=$CHECKPOINT_SAVE_CONTENTS \
|
| 42 |
+
actor_rollout_ref.actor.optim.lr=1e-6 \
|
| 43 |
+
actor_rollout_ref.model.use_remove_padding=True \
|
| 44 |
+
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
| 45 |
+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
|
| 46 |
+
actor_rollout_ref.actor.kl_loss_coef=0.0 \
|
| 47 |
+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
| 48 |
+
actor_rollout_ref.actor.entropy_coeff=0 \
|
| 49 |
+
actor_rollout_ref.actor.ppo_epochs=1 \
|
| 50 |
+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
| 51 |
+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
| 52 |
+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
| 53 |
+
actor_rollout_ref.rollout.mode=sync \
|
| 54 |
+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
|
| 55 |
+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
| 56 |
+
actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_BATCHED_TOKENS \
|
| 57 |
+
actor_rollout_ref.rollout.name=vllm \
|
| 58 |
+
actor_rollout_ref.rollout.gpu_memory_utilization=0.45 \
|
| 59 |
+
actor_rollout_ref.rollout.val_kwargs.n=256 \
|
| 60 |
+
actor_rollout_ref.rollout.n=8 \
|
| 61 |
+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
|
| 62 |
+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
| 63 |
+
reward_model.model.path="$MODEL_PATH" \
|
| 64 |
+
reward_model.model.use_remove_padding=False \
|
| 65 |
+
reward_model.model.fsdp_config.param_offload=True \
|
| 66 |
+
reward_model.micro_batch_size_per_gpu=32 \
|
| 67 |
+
reward_model.model.input_tokenizer=null \
|
| 68 |
+
actor_rollout_ref.actor.use_kl_loss=False \
|
| 69 |
+
algorithm.use_kl_in_reward=False \
|
| 70 |
+
trainer.critic_warmup=0 \
|
| 71 |
+
trainer.logger='["console","json_eval"]' \
|
| 72 |
+
trainer.project_name='rep-exp' \
|
| 73 |
+
trainer.experiment_name="${TASK}_eval" \
|
| 74 |
+
trainer.n_gpus_per_node=1 \
|
| 75 |
+
trainer.nnodes=1 \
|
| 76 |
+
trainer.save_freq=-1 \
|
| 77 |
+
trainer.test_freq=1 \
|
| 78 |
+
trainer.total_epochs=100 \
|
| 79 |
+
trainer.val_only=True \
|
| 80 |
+
trainer.resume_mode=disable \
|
| 81 |
+
trainer.resume_from_path=''
|
| 82 |
+
|
| 83 |
+
exit 0
|
ICL/DAPO/verl-recipe/rep_exp/main_rep_exp.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import socket
|
| 20 |
+
import warnings
|
| 21 |
+
|
| 22 |
+
import hydra
|
| 23 |
+
import ray
|
| 24 |
+
from omegaconf import OmegaConf
|
| 25 |
+
|
| 26 |
+
from verl.experimental.dataset.sampler import AbstractSampler
|
| 27 |
+
from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
|
| 28 |
+
from verl.trainer.ppo.reward import load_reward_manager
|
| 29 |
+
from verl.trainer.ppo.utils import need_critic, need_reference_policy
|
| 30 |
+
from verl.utils.config import validate_config
|
| 31 |
+
from verl.utils.device import is_cuda_available
|
| 32 |
+
from verl.utils.import_utils import load_extern_type
|
| 33 |
+
|
| 34 |
+
from .rep_exp_trainer import RayRepExpTrainer
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@hydra.main(config_path="config", config_name="rep_exp_trainer", version_base=None)
|
| 38 |
+
def main(config):
|
| 39 |
+
"""Main entry point for PPO training with Hydra configuration management.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
config_dict: Hydra configuration dictionary containing training parameters.
|
| 43 |
+
"""
|
| 44 |
+
run_ppo(config)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Define a function to run the PPO-like training process
|
| 48 |
+
def run_ppo(config, task_runner_class=None) -> None:
|
| 49 |
+
"""Initialize Ray cluster and run distributed PPO training process.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
config: Training configuration object containing all necessary parameters
|
| 53 |
+
for distributed PPO training including Ray initialization settings,
|
| 54 |
+
model paths, and training hyperparameters.
|
| 55 |
+
task_runner_class: For recipe to change TaskRunner.
|
| 56 |
+
"""
|
| 57 |
+
# Check if Ray is not initialized
|
| 58 |
+
if not ray.is_initialized():
|
| 59 |
+
# Initialize Ray with a local cluster configuration
|
| 60 |
+
# Set environment variables in the runtime environment to control tokenizer parallelism,
|
| 61 |
+
# NCCL debug level, VLLM logging level, and allow runtime LoRA updating
|
| 62 |
+
# `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration
|
| 63 |
+
default_runtime_env = get_ppo_ray_runtime_env()
|
| 64 |
+
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
|
| 65 |
+
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
|
| 66 |
+
|
| 67 |
+
if config.transfer_queue.enable:
|
| 68 |
+
# Add runtime environment variables for transfer queue
|
| 69 |
+
runtime_env_vars = runtime_env_kwargs.get("env_vars", {})
|
| 70 |
+
runtime_env_vars["TRANSFER_QUEUE_ENABLE"] = "1"
|
| 71 |
+
runtime_env_kwargs["env_vars"] = runtime_env_vars
|
| 72 |
+
|
| 73 |
+
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
|
| 74 |
+
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
|
| 75 |
+
print(f"ray init kwargs: {ray_init_kwargs}")
|
| 76 |
+
ray.init(**OmegaConf.to_container(ray_init_kwargs))
|
| 77 |
+
|
| 78 |
+
if task_runner_class is None:
|
| 79 |
+
task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head
|
| 80 |
+
|
| 81 |
+
# Create a remote instance of the TaskRunner class, and
|
| 82 |
+
# Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
|
| 83 |
+
if (
|
| 84 |
+
is_cuda_available
|
| 85 |
+
and config.global_profiler.tool == "nsys"
|
| 86 |
+
and config.global_profiler.get("steps") is not None
|
| 87 |
+
and len(config.global_profiler.get("steps", [])) > 0
|
| 88 |
+
):
|
| 89 |
+
from verl.utils.import_utils import is_nvtx_available
|
| 90 |
+
|
| 91 |
+
assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'"
|
| 92 |
+
nsight_options = OmegaConf.to_container(
|
| 93 |
+
config.global_profiler.global_tool_config.nsys.controller_nsight_options
|
| 94 |
+
)
|
| 95 |
+
runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote()
|
| 96 |
+
else:
|
| 97 |
+
runner = task_runner_class.remote()
|
| 98 |
+
ray.get(runner.run.remote(config))
|
| 99 |
+
|
| 100 |
+
# [Optional] get the path of the timeline trace file from the configuration, default to None
|
| 101 |
+
# This file is used for performance analysis
|
| 102 |
+
timeline_json_file = config.ray_kwargs.get("timeline_json_file", None)
|
| 103 |
+
if timeline_json_file:
|
| 104 |
+
ray.timeline(filename=timeline_json_file)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class TaskRunner:
|
| 108 |
+
"""Ray remote class for executing distributed PPO training tasks.
|
| 109 |
+
|
| 110 |
+
This class encapsulates the main training logic and runs as a Ray remote actor
|
| 111 |
+
to enable distributed execution across multiple nodes and GPUs.
|
| 112 |
+
|
| 113 |
+
Attributes:
|
| 114 |
+
role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes
|
| 115 |
+
mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(self):
|
| 119 |
+
self.role_worker_mapping = {}
|
| 120 |
+
self.mapping = {}
|
| 121 |
+
|
| 122 |
+
def add_actor_rollout_worker(self, config):
|
| 123 |
+
"""Add actor rollout worker based on the actor strategy."""
|
| 124 |
+
from verl.single_controller.ray import RayWorkerGroup
|
| 125 |
+
from verl.trainer.ppo.ray_trainer import Role
|
| 126 |
+
|
| 127 |
+
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
|
| 128 |
+
|
| 129 |
+
# use new model engine implementation
|
| 130 |
+
if use_legacy_worker_impl == "disable":
|
| 131 |
+
from verl.workers.engine_workers import ActorRolloutRefWorker
|
| 132 |
+
|
| 133 |
+
actor_rollout_cls = ActorRolloutRefWorker
|
| 134 |
+
ray_worker_group_cls = RayWorkerGroup
|
| 135 |
+
# NOTE: In new model engine, ref policy and actor rollout are in same ActorRolloutRefWorker,
|
| 136 |
+
# while in legacy model engine, ref policy is in a separate ActorRolloutRefWorker.
|
| 137 |
+
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
|
| 138 |
+
role = Role.ActorRolloutRef
|
| 139 |
+
else:
|
| 140 |
+
role = Role.ActorRollout
|
| 141 |
+
self.role_worker_mapping[role] = ray.remote(actor_rollout_cls)
|
| 142 |
+
self.mapping[role] = "global_pool"
|
| 143 |
+
return actor_rollout_cls, ray_worker_group_cls
|
| 144 |
+
|
| 145 |
+
if config.actor_rollout_ref.rollout.mode == "sync":
|
| 146 |
+
warnings.warn("spmd rollout mode is deprecated and will be removed in v0.6.2", stacklevel=2)
|
| 147 |
+
|
| 148 |
+
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
|
| 149 |
+
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
|
| 150 |
+
|
| 151 |
+
actor_rollout_cls = (
|
| 152 |
+
AsyncActorRolloutRefWorker
|
| 153 |
+
if config.actor_rollout_ref.rollout.mode == "async"
|
| 154 |
+
else ActorRolloutRefWorker
|
| 155 |
+
)
|
| 156 |
+
ray_worker_group_cls = RayWorkerGroup
|
| 157 |
+
|
| 158 |
+
elif config.actor_rollout_ref.actor.strategy == "megatron":
|
| 159 |
+
from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
|
| 160 |
+
|
| 161 |
+
actor_rollout_cls = (
|
| 162 |
+
AsyncActorRolloutRefWorker
|
| 163 |
+
if config.actor_rollout_ref.rollout.mode == "async"
|
| 164 |
+
else ActorRolloutRefWorker
|
| 165 |
+
)
|
| 166 |
+
ray_worker_group_cls = RayWorkerGroup
|
| 167 |
+
|
| 168 |
+
else:
|
| 169 |
+
raise NotImplementedError
|
| 170 |
+
|
| 171 |
+
self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls)
|
| 172 |
+
self.mapping[Role.ActorRollout] = "global_pool"
|
| 173 |
+
return actor_rollout_cls, ray_worker_group_cls
|
| 174 |
+
|
| 175 |
+
def add_critic_worker(self, config):
|
| 176 |
+
"""Add critic worker to role mapping."""
|
| 177 |
+
if config.critic.strategy in {"fsdp", "fsdp2"}:
|
| 178 |
+
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
|
| 179 |
+
if use_legacy_worker_impl in ["auto", "enable"]:
|
| 180 |
+
from verl.workers.fsdp_workers import CriticWorker
|
| 181 |
+
elif use_legacy_worker_impl == "disable":
|
| 182 |
+
from verl.workers.roles import CriticWorker
|
| 183 |
+
|
| 184 |
+
print("Using new worker implementation")
|
| 185 |
+
else:
|
| 186 |
+
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
|
| 187 |
+
|
| 188 |
+
elif config.critic.strategy == "megatron":
|
| 189 |
+
from verl.workers.megatron_workers import CriticWorker
|
| 190 |
+
|
| 191 |
+
else:
|
| 192 |
+
raise NotImplementedError
|
| 193 |
+
|
| 194 |
+
from verl.trainer.ppo.ray_trainer import Role
|
| 195 |
+
|
| 196 |
+
self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker)
|
| 197 |
+
self.mapping[Role.Critic] = "global_pool"
|
| 198 |
+
|
| 199 |
+
def init_resource_pool_mgr(self, config):
|
| 200 |
+
"""Initialize resource pool manager."""
|
| 201 |
+
|
| 202 |
+
global_pool_id = "global_pool"
|
| 203 |
+
resource_pool_spec = {
|
| 204 |
+
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
| 205 |
+
}
|
| 206 |
+
# TODO Here you can use the new registration method to support dynamic registration of roles
|
| 207 |
+
if config.reward_model.enable_resource_pool:
|
| 208 |
+
if config.reward_model.n_gpus_per_node <= 0:
|
| 209 |
+
raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0")
|
| 210 |
+
if config.reward_model.nnodes <= 0:
|
| 211 |
+
raise ValueError("config.reward_model.nnodes must be greater than 0")
|
| 212 |
+
|
| 213 |
+
reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes
|
| 214 |
+
resource_pool_spec["reward_pool"] = reward_pool
|
| 215 |
+
|
| 216 |
+
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
|
| 217 |
+
|
| 218 |
+
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping)
|
| 219 |
+
return resource_pool_manager
|
| 220 |
+
|
| 221 |
+
def add_reward_model_worker(self, config):
|
| 222 |
+
"""Add reward model worker if enabled."""
|
| 223 |
+
from verl.trainer.ppo.ray_trainer import Role
|
| 224 |
+
|
| 225 |
+
if config.reward_model.enable:
|
| 226 |
+
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
|
| 227 |
+
if use_legacy_worker_impl in ["auto", "enable"]:
|
| 228 |
+
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
|
| 229 |
+
if config.reward_model.elliptical:
|
| 230 |
+
from .workers.elliptical_reward_model_worker import (
|
| 231 |
+
EllipticalRewardModelWorker as RewardModelWorker,
|
| 232 |
+
)
|
| 233 |
+
else:
|
| 234 |
+
from verl.workers.fsdp_workers import RewardModelWorker
|
| 235 |
+
elif config.reward_model.strategy == "megatron":
|
| 236 |
+
from verl.workers.megatron_workers import RewardModelWorker
|
| 237 |
+
else:
|
| 238 |
+
raise NotImplementedError
|
| 239 |
+
elif use_legacy_worker_impl == "disable":
|
| 240 |
+
from verl.workers.roles import RewardModelWorker
|
| 241 |
+
|
| 242 |
+
print("Using new worker implementation")
|
| 243 |
+
else:
|
| 244 |
+
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
|
| 245 |
+
|
| 246 |
+
self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
|
| 247 |
+
if config.reward_model.enable_resource_pool:
|
| 248 |
+
self.mapping[Role.RewardModel] = "reward_pool"
|
| 249 |
+
else:
|
| 250 |
+
self.mapping[Role.RewardModel] = "global_pool"
|
| 251 |
+
|
| 252 |
+
def add_ref_policy_worker(self, config, ref_policy_cls):
|
| 253 |
+
"""Add reference policy worker if KL loss or KL reward is used."""
|
| 254 |
+
from verl.trainer.ppo.ray_trainer import Role
|
| 255 |
+
|
| 256 |
+
# Ref policy has been fused into ActorRolloutRefWorker in new model engine,
|
| 257 |
+
# we don't need to add a separate ref policy worker goup.
|
| 258 |
+
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
|
| 259 |
+
if use_legacy_worker_impl == "disable":
|
| 260 |
+
return
|
| 261 |
+
|
| 262 |
+
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
|
| 263 |
+
self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls)
|
| 264 |
+
self.mapping[Role.RefPolicy] = "global_pool"
|
| 265 |
+
|
| 266 |
+
def run(self, config):
|
| 267 |
+
"""Execute the main PPO training workflow.
|
| 268 |
+
|
| 269 |
+
This method sets up the distributed training environment, initializes
|
| 270 |
+
workers, datasets, and reward functions, then starts the training process.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
config: Training configuration object containing all parameters needed
|
| 274 |
+
for setting up and running the PPO training process.
|
| 275 |
+
"""
|
| 276 |
+
# Print the initial configuration. `resolve=True` will evaluate symbolic values.
|
| 277 |
+
from pprint import pprint
|
| 278 |
+
|
| 279 |
+
from omegaconf import OmegaConf
|
| 280 |
+
|
| 281 |
+
from verl.utils.fs import copy_to_local
|
| 282 |
+
|
| 283 |
+
print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
|
| 284 |
+
pprint(OmegaConf.to_container(config, resolve=True))
|
| 285 |
+
OmegaConf.resolve(config)
|
| 286 |
+
|
| 287 |
+
actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config)
|
| 288 |
+
self.add_critic_worker(config)
|
| 289 |
+
|
| 290 |
+
# We should adopt a multi-source reward function here:
|
| 291 |
+
# - for rule-based rm, we directly call a reward score
|
| 292 |
+
# - for model-based rm, we call a model
|
| 293 |
+
# - for code related prompt, we send to a sandbox if there are test cases
|
| 294 |
+
# finally, we combine all the rewards together
|
| 295 |
+
# The reward type depends on the tag of the data
|
| 296 |
+
self.add_reward_model_worker(config)
|
| 297 |
+
|
| 298 |
+
# Add a reference policy worker if KL loss or KL reward is used.
|
| 299 |
+
self.add_ref_policy_worker(config, actor_rollout_cls)
|
| 300 |
+
|
| 301 |
+
# validate config
|
| 302 |
+
validate_config(
|
| 303 |
+
config=config,
|
| 304 |
+
use_reference_policy=need_reference_policy(self.role_worker_mapping),
|
| 305 |
+
use_critic=need_critic(config),
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Download the checkpoint from HDFS to the local machine.
|
| 309 |
+
# `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
|
| 310 |
+
local_path = copy_to_local(
|
| 311 |
+
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Instantiate the tokenizer and processor.
|
| 315 |
+
from verl.utils import hf_processor, hf_tokenizer
|
| 316 |
+
|
| 317 |
+
trust_remote_code = config.data.get("trust_remote_code", False)
|
| 318 |
+
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
| 319 |
+
# Used for multimodal LLM, could be None
|
| 320 |
+
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
|
| 321 |
+
|
| 322 |
+
# Make sure the elliptical reward manager is registered
|
| 323 |
+
from .reward_manager.elliptical_reward_manager import EllipticalRewardManager # noqa: F401
|
| 324 |
+
|
| 325 |
+
# Load the reward manager for training and validation.
|
| 326 |
+
reward_manager_name = config.reward_model.get("reward_manager", "naive")
|
| 327 |
+
reward_fn = load_reward_manager(
|
| 328 |
+
config,
|
| 329 |
+
tokenizer,
|
| 330 |
+
num_examine=0,
|
| 331 |
+
**config.reward_model.get("reward_kwargs", {}).get(reward_manager_name, {}),
|
| 332 |
+
)
|
| 333 |
+
val_reward_fn = load_reward_manager(
|
| 334 |
+
config,
|
| 335 |
+
tokenizer,
|
| 336 |
+
num_examine=1,
|
| 337 |
+
**config.reward_model.get("reward_kwargs", {}).get(reward_manager_name, {}),
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
resource_pool_manager = self.init_resource_pool_mgr(config)
|
| 341 |
+
|
| 342 |
+
from verl.utils.dataset.rl_dataset import collate_fn
|
| 343 |
+
|
| 344 |
+
# Create training and validation datasets.
|
| 345 |
+
train_dataset = create_rl_dataset(
|
| 346 |
+
config.data.train_files,
|
| 347 |
+
config.data,
|
| 348 |
+
tokenizer,
|
| 349 |
+
processor,
|
| 350 |
+
is_train=True,
|
| 351 |
+
max_samples=config.data.get("train_max_samples", -1),
|
| 352 |
+
)
|
| 353 |
+
val_dataset = create_rl_dataset(
|
| 354 |
+
config.data.val_files,
|
| 355 |
+
config.data,
|
| 356 |
+
tokenizer,
|
| 357 |
+
processor,
|
| 358 |
+
is_train=False,
|
| 359 |
+
max_samples=config.data.get("val_max_samples", -1),
|
| 360 |
+
)
|
| 361 |
+
train_sampler = create_rl_sampler(config.data, train_dataset)
|
| 362 |
+
|
| 363 |
+
# Initialize the PPO trainer.
|
| 364 |
+
trainer = RayRepExpTrainer(
|
| 365 |
+
config=config,
|
| 366 |
+
tokenizer=tokenizer,
|
| 367 |
+
processor=processor,
|
| 368 |
+
role_worker_mapping=self.role_worker_mapping,
|
| 369 |
+
resource_pool_manager=resource_pool_manager,
|
| 370 |
+
ray_worker_group_cls=ray_worker_group_cls,
|
| 371 |
+
reward_fn=reward_fn,
|
| 372 |
+
val_reward_fn=val_reward_fn,
|
| 373 |
+
train_dataset=train_dataset,
|
| 374 |
+
val_dataset=val_dataset,
|
| 375 |
+
collate_fn=collate_fn,
|
| 376 |
+
train_sampler=train_sampler,
|
| 377 |
+
)
|
| 378 |
+
# Initialize the workers of the trainer.
|
| 379 |
+
trainer.init_workers()
|
| 380 |
+
|
| 381 |
+
# Start the training process.
|
| 382 |
+
trainer.fit()
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True, max_samples: int = -1):
|
| 386 |
+
"""Create a dataset.
|
| 387 |
+
|
| 388 |
+
Arguments:
|
| 389 |
+
data_paths: List of paths to data files.
|
| 390 |
+
data_config: The data config.
|
| 391 |
+
tokenizer (Tokenizer): The tokenizer.
|
| 392 |
+
processor (Processor): The processor.
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
dataset (Dataset): The dataset.
|
| 396 |
+
"""
|
| 397 |
+
from torch.utils.data import Dataset
|
| 398 |
+
|
| 399 |
+
from verl.utils.dataset.rl_dataset import RLHFDataset
|
| 400 |
+
|
| 401 |
+
# Check if a custom dataset class is specified in the data configuration
|
| 402 |
+
# and if the path to the custom class is provided
|
| 403 |
+
if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None:
|
| 404 |
+
# Dynamically load the custom dataset class
|
| 405 |
+
dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
|
| 406 |
+
# Verify that the custom dataset class inherits from torch.utils.data.Dataset
|
| 407 |
+
if not issubclass(dataset_cls, Dataset):
|
| 408 |
+
raise TypeError(
|
| 409 |
+
f"The custom dataset class '{data_config.custom_cls.name}' from "
|
| 410 |
+
f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset"
|
| 411 |
+
)
|
| 412 |
+
elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train:
|
| 413 |
+
# If a data generation strategy is specified, use the DynamicGenDataset class
|
| 414 |
+
from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset
|
| 415 |
+
|
| 416 |
+
dataset_cls = DynamicGenDataset
|
| 417 |
+
print("Using DynamicGenDataset for data generation.")
|
| 418 |
+
else:
|
| 419 |
+
# Use the default RLHFDataset class if no custom class is specified
|
| 420 |
+
dataset_cls = RLHFDataset
|
| 421 |
+
print(f"Using dataset class: {dataset_cls.__name__}")
|
| 422 |
+
|
| 423 |
+
# Instantiate the dataset using the determined dataset class
|
| 424 |
+
dataset = dataset_cls(
|
| 425 |
+
data_files=data_paths,
|
| 426 |
+
tokenizer=tokenizer,
|
| 427 |
+
processor=processor,
|
| 428 |
+
config=data_config,
|
| 429 |
+
max_samples=max_samples,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
return dataset
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def create_rl_sampler(data_config, dataset):
|
| 436 |
+
"""Create a sampler for the dataset.
|
| 437 |
+
|
| 438 |
+
Arguments:
|
| 439 |
+
data_config: The data config.
|
| 440 |
+
dataset (Dataset): The dataset.
|
| 441 |
+
|
| 442 |
+
Returns:
|
| 443 |
+
sampler (Sampler): The sampler.
|
| 444 |
+
"""
|
| 445 |
+
import torch
|
| 446 |
+
from torch.utils.data import SequentialSampler
|
| 447 |
+
|
| 448 |
+
# torch.utils.data.RandomSampler could not recover properly
|
| 449 |
+
from torchdata.stateful_dataloader.sampler import RandomSampler
|
| 450 |
+
|
| 451 |
+
if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None:
|
| 452 |
+
curriculum_class = load_extern_type(
|
| 453 |
+
data_config.sampler.class_path,
|
| 454 |
+
data_config.sampler.class_name,
|
| 455 |
+
)
|
| 456 |
+
sampler = curriculum_class(
|
| 457 |
+
data_source=dataset,
|
| 458 |
+
data_config=data_config,
|
| 459 |
+
)
|
| 460 |
+
assert isinstance(sampler, AbstractSampler)
|
| 461 |
+
assert data_config.get("dataloader_num_workers", 8) == 0, (
|
| 462 |
+
"If using curriculum, num_workers must be 0 to prevent data caching. "
|
| 463 |
+
"If the dataloader caches data before the batch is done the "
|
| 464 |
+
"curriculum sampler won't have the opportunity to reorder it. "
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# Use a sampler to facilitate checkpoint resumption.
|
| 468 |
+
# If shuffling is enabled in the data configuration, create a random sampler.
|
| 469 |
+
elif data_config.shuffle:
|
| 470 |
+
train_dataloader_generator = torch.Generator()
|
| 471 |
+
seed = data_config.get("seed")
|
| 472 |
+
if seed is not None:
|
| 473 |
+
train_dataloader_generator.manual_seed(seed)
|
| 474 |
+
sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
|
| 475 |
+
else:
|
| 476 |
+
# If shuffling is disabled, use a sequential sampler to iterate through the dataset in order.
|
| 477 |
+
sampler = SequentialSampler(data_source=dataset)
|
| 478 |
+
|
| 479 |
+
return sampler
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
if __name__ == "__main__":
|
| 483 |
+
main()
|
ICL/DAPO/verl-recipe/rep_exp/metric_utils.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Metrics related to the RepExp trainer.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
from functools import partial
|
| 20 |
+
from typing import Any
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from verl import DataProto
|
| 26 |
+
from verl.trainer.ppo.metric_utils import _compute_response_info, bootstrap_metric, calc_maj_val
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _compute_three_case_stats(data: DataProto, extrinsic_reward_tensor: torch.Tensor) -> dict:
|
| 30 |
+
"""
|
| 31 |
+
Compute the fraction of samples that have no rollouts correct, some rollouts correct, and all rollouts correct.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
data (DataProto): The data proto containing the batch data.
|
| 35 |
+
extrinsic_reward_tensor (torch.Tensor): The extrinsic reward tensor.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
dict[str, float]: A dictionary containing the fraction of samples that have no rollouts correct,
|
| 39 |
+
some rollouts correct, and all rollouts correct.
|
| 40 |
+
"""
|
| 41 |
+
no_rollouts_correct = 0
|
| 42 |
+
some_rollouts_correct = 0
|
| 43 |
+
all_rollouts_correct = 0
|
| 44 |
+
|
| 45 |
+
visited_uids = set()
|
| 46 |
+
for uid in data.non_tensor_batch["uid"]:
|
| 47 |
+
if uid in visited_uids:
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
visited_uids.add(uid)
|
| 51 |
+
mask = torch.from_numpy(data.non_tensor_batch["uid"] == uid)
|
| 52 |
+
|
| 53 |
+
# Split into three cases
|
| 54 |
+
if extrinsic_reward_tensor[mask].sum() == 0:
|
| 55 |
+
no_rollouts_correct += 1
|
| 56 |
+
elif extrinsic_reward_tensor[mask].sum() == mask.sum():
|
| 57 |
+
all_rollouts_correct += 1
|
| 58 |
+
elif extrinsic_reward_tensor[mask].sum() > 0 and extrinsic_reward_tensor[mask].sum() < mask.sum():
|
| 59 |
+
some_rollouts_correct += 1
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(f"Invalid extrinsic reward tensor: {extrinsic_reward_tensor[mask].sum()}")
|
| 62 |
+
|
| 63 |
+
# Sanity checks
|
| 64 |
+
assert len(visited_uids) == no_rollouts_correct + some_rollouts_correct + all_rollouts_correct
|
| 65 |
+
|
| 66 |
+
return {
|
| 67 |
+
"no_rollouts_correct_frac": no_rollouts_correct / len(visited_uids),
|
| 68 |
+
"some_rollouts_correct_frac": some_rollouts_correct / len(visited_uids),
|
| 69 |
+
"all_rollouts_correct_frac": all_rollouts_correct / len(visited_uids),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def compute_data_metrics(batch: DataProto, use_critic: bool = True, elliptical: bool = False) -> dict[str, Any]:
|
| 74 |
+
"""
|
| 75 |
+
Computes various metrics from a batch of data for PPO training.
|
| 76 |
+
|
| 77 |
+
This function calculates metrics related to scores, rewards, advantages, returns, values,
|
| 78 |
+
and sequence lengths from a batch of data. It provides statistical information (mean, max, min)
|
| 79 |
+
for each metric category.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.
|
| 83 |
+
use_critic: Whether to include critic-specific metrics. Defaults to True.
|
| 84 |
+
elliptical: Whether to include elliptical-specific metrics. Defaults to False.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
A dictionary of metrics including:
|
| 88 |
+
- critic/score/mean, max, min: Statistics about sequence scores
|
| 89 |
+
- critic/rewards/mean, max, min: Statistics about sequence rewards
|
| 90 |
+
- critic/advantages/mean, max, min: Statistics about advantages
|
| 91 |
+
- critic/returns/mean, max, min: Statistics about returns
|
| 92 |
+
- critic/values/mean, max, min: Statistics about critic values (if use_critic=True)
|
| 93 |
+
- critic/vf_explained_var: Explained variance of the value function (if use_critic=True)
|
| 94 |
+
- response_length/mean, max, min, clip_ratio: Statistics about response lengths
|
| 95 |
+
- prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
|
| 96 |
+
- num_turns/mean, max, min: Statistics about the number of multi-turn conversations
|
| 97 |
+
"""
|
| 98 |
+
sequence_score = batch.batch["token_level_scores"].sum(-1)
|
| 99 |
+
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
|
| 100 |
+
|
| 101 |
+
if elliptical:
|
| 102 |
+
sequence_intrinsic_reward = batch.non_tensor_batch["intrinsic_reward"].sum(-1)
|
| 103 |
+
sequence_beta_scaled_intrinsic_reward = batch.non_tensor_batch["beta_scaled_intrinsic_reward"].sum(-1)
|
| 104 |
+
sequence_extrinsic_reward = batch.non_tensor_batch["extrinsic_reward"].sum(-1)
|
| 105 |
+
sequence_total_reward = batch.non_tensor_batch["total_reward"].sum(-1)
|
| 106 |
+
sequence_raw_bonuses = batch.non_tensor_batch["raw_bonuses"].sum(-1)
|
| 107 |
+
|
| 108 |
+
three_case_stats = _compute_three_case_stats(batch, batch.non_tensor_batch["extrinsic_reward"])
|
| 109 |
+
|
| 110 |
+
advantages = batch.batch["advantages"]
|
| 111 |
+
returns = batch.batch["returns"]
|
| 112 |
+
|
| 113 |
+
max_response_length = batch.batch["responses"].shape[-1]
|
| 114 |
+
|
| 115 |
+
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
|
| 116 |
+
response_mask = batch.batch["response_mask"].bool()
|
| 117 |
+
|
| 118 |
+
max_prompt_length = prompt_mask.size(-1)
|
| 119 |
+
|
| 120 |
+
response_info = _compute_response_info(batch)
|
| 121 |
+
prompt_length = response_info["prompt_length"]
|
| 122 |
+
response_length = response_info["response_length"]
|
| 123 |
+
|
| 124 |
+
aborted_mask = (response_length == 0).bool()
|
| 125 |
+
non_aborted_mask = ~aborted_mask
|
| 126 |
+
|
| 127 |
+
non_aborted_sequence_score = sequence_score[non_aborted_mask]
|
| 128 |
+
non_aborted_sequence_reward = sequence_reward[non_aborted_mask]
|
| 129 |
+
|
| 130 |
+
score_mean = torch.mean(non_aborted_sequence_score).detach().item()
|
| 131 |
+
score_max = torch.max(non_aborted_sequence_score).detach().item()
|
| 132 |
+
score_min = torch.min(non_aborted_sequence_score).detach().item()
|
| 133 |
+
|
| 134 |
+
reward_mean = torch.mean(non_aborted_sequence_reward).detach().item()
|
| 135 |
+
reward_max = torch.max(non_aborted_sequence_reward).detach().item()
|
| 136 |
+
reward_min = torch.min(non_aborted_sequence_reward).detach().item()
|
| 137 |
+
|
| 138 |
+
valid_adv = torch.masked_select(advantages, response_mask)
|
| 139 |
+
valid_returns = torch.masked_select(returns, response_mask)
|
| 140 |
+
|
| 141 |
+
if use_critic:
|
| 142 |
+
values = batch.batch["values"]
|
| 143 |
+
valid_values = torch.masked_select(values, response_mask)
|
| 144 |
+
return_diff_var = torch.var(valid_returns - valid_values)
|
| 145 |
+
return_var = torch.var(valid_returns)
|
| 146 |
+
|
| 147 |
+
# Aborted samples and non-aborted response length statistics
|
| 148 |
+
# response_length_non_aborted/*: statistics computed on non-aborted samples only
|
| 149 |
+
aborted_ratio = torch.mean(aborted_mask.float()).detach().item()
|
| 150 |
+
|
| 151 |
+
non_aborted_response_length = response_length[non_aborted_mask]
|
| 152 |
+
if non_aborted_response_length.numel() > 0:
|
| 153 |
+
non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item()
|
| 154 |
+
non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item()
|
| 155 |
+
non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item()
|
| 156 |
+
non_aborted_response_length_clip_ratio = (
|
| 157 |
+
torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item()
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
raise ValueError("All samples are aborted, this should not happen.")
|
| 161 |
+
|
| 162 |
+
metrics = {
|
| 163 |
+
# score
|
| 164 |
+
"critic/score/mean": score_mean,
|
| 165 |
+
"critic/score/max": score_max,
|
| 166 |
+
"critic/score/min": score_min,
|
| 167 |
+
# reward
|
| 168 |
+
"critic/rewards/mean": reward_mean,
|
| 169 |
+
"critic/rewards/max": reward_max,
|
| 170 |
+
"critic/rewards/min": reward_min,
|
| 171 |
+
# adv
|
| 172 |
+
"critic/advantages/mean": torch.mean(valid_adv).detach().item(),
|
| 173 |
+
"critic/advantages/max": torch.max(valid_adv).detach().item(),
|
| 174 |
+
"critic/advantages/min": torch.min(valid_adv).detach().item(),
|
| 175 |
+
# returns
|
| 176 |
+
"critic/returns/mean": torch.mean(valid_returns).detach().item(),
|
| 177 |
+
"critic/returns/max": torch.max(valid_returns).detach().item(),
|
| 178 |
+
"critic/returns/min": torch.min(valid_returns).detach().item(),
|
| 179 |
+
**(
|
| 180 |
+
{
|
| 181 |
+
# values
|
| 182 |
+
"critic/values/mean": torch.mean(valid_values).detach().item(),
|
| 183 |
+
"critic/values/max": torch.max(valid_values).detach().item(),
|
| 184 |
+
"critic/values/min": torch.min(valid_values).detach().item(),
|
| 185 |
+
# vf explained var
|
| 186 |
+
"critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
|
| 187 |
+
}
|
| 188 |
+
if use_critic
|
| 189 |
+
else {}
|
| 190 |
+
),
|
| 191 |
+
**(
|
| 192 |
+
{
|
| 193 |
+
# raw bonuses
|
| 194 |
+
"critic/raw_bonuses/mean": np.mean(sequence_raw_bonuses).item(),
|
| 195 |
+
"critic/raw_bonuses/max": np.max(sequence_raw_bonuses).item(),
|
| 196 |
+
"critic/raw_bonuses/min": np.min(sequence_raw_bonuses).item(),
|
| 197 |
+
"critic/raw_bonuses/std": np.std(sequence_raw_bonuses).item(),
|
| 198 |
+
# intrinsic_reward
|
| 199 |
+
"critic/intrinsic_reward/mean": np.mean(sequence_intrinsic_reward).item(),
|
| 200 |
+
"critic/intrinsic_reward/max": np.max(sequence_intrinsic_reward).item(),
|
| 201 |
+
"critic/intrinsic_reward/min": np.min(sequence_intrinsic_reward).item(),
|
| 202 |
+
"critic/intrinsic_reward/std": np.std(sequence_intrinsic_reward).item(),
|
| 203 |
+
# beta_scaled_intrinsic_reward
|
| 204 |
+
"critic/beta_scaled_intrinsic_reward/mean": np.mean(sequence_beta_scaled_intrinsic_reward).item(),
|
| 205 |
+
"critic/beta_scaled_intrinsic_reward/max": np.max(sequence_beta_scaled_intrinsic_reward).item(),
|
| 206 |
+
"critic/beta_scaled_intrinsic_reward/min": np.min(sequence_beta_scaled_intrinsic_reward).item(),
|
| 207 |
+
"critic/beta_scaled_intrinsic_reward/std": np.std(sequence_beta_scaled_intrinsic_reward).item(),
|
| 208 |
+
# extrinsic_reward
|
| 209 |
+
"critic/extrinsic_reward/mean": np.mean(sequence_extrinsic_reward).item(),
|
| 210 |
+
"critic/extrinsic_reward/max": np.max(sequence_extrinsic_reward).item(),
|
| 211 |
+
"critic/extrinsic_reward/min": np.min(sequence_extrinsic_reward).item(),
|
| 212 |
+
"critic/extrinsic_reward/std": np.std(sequence_extrinsic_reward).item(),
|
| 213 |
+
# three_case_stats
|
| 214 |
+
"critic/extrinsic_reward/no_rollouts_correct_frac": three_case_stats["no_rollouts_correct_frac"],
|
| 215 |
+
"critic/extrinsic_reward/some_rollouts_correct_frac": three_case_stats["some_rollouts_correct_frac"],
|
| 216 |
+
"critic/extrinsic_reward/all_rollouts_correct_frac": three_case_stats["all_rollouts_correct_frac"],
|
| 217 |
+
# total_reward
|
| 218 |
+
"critic/total_reward/mean": np.mean(sequence_total_reward).item(),
|
| 219 |
+
"critic/total_reward/max": np.max(sequence_total_reward).item(),
|
| 220 |
+
"critic/total_reward/min": np.min(sequence_total_reward).item(),
|
| 221 |
+
"critic/total_reward/std": np.std(sequence_total_reward).item(),
|
| 222 |
+
}
|
| 223 |
+
if elliptical
|
| 224 |
+
else {}
|
| 225 |
+
),
|
| 226 |
+
# response length
|
| 227 |
+
"response_length/mean": torch.mean(response_length).detach().item(),
|
| 228 |
+
"response_length/max": torch.max(response_length).detach().item(),
|
| 229 |
+
"response_length/min": torch.min(response_length).detach().item(),
|
| 230 |
+
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
|
| 231 |
+
.detach()
|
| 232 |
+
.item(),
|
| 233 |
+
# response length (non-aborted only)
|
| 234 |
+
# These statistics exclude aborted samples to avoid skew from zeros
|
| 235 |
+
"response_length_non_aborted/mean": non_aborted_response_length_mean,
|
| 236 |
+
"response_length_non_aborted/max": non_aborted_response_length_max,
|
| 237 |
+
"response_length_non_aborted/min": non_aborted_response_length_min,
|
| 238 |
+
"response_length_non_aborted/clip_ratio": non_aborted_response_length_clip_ratio,
|
| 239 |
+
# aborted ratio
|
| 240 |
+
# Fraction of samples whose response length is zero
|
| 241 |
+
"response/aborted_ratio": aborted_ratio,
|
| 242 |
+
# prompt length
|
| 243 |
+
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
|
| 244 |
+
"prompt_length/max": torch.max(prompt_length).detach().item(),
|
| 245 |
+
"prompt_length/min": torch.min(prompt_length).detach().item(),
|
| 246 |
+
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
# multi-turn conversation
|
| 250 |
+
if "__num_turns__" in batch.non_tensor_batch:
|
| 251 |
+
num_turns = batch.non_tensor_batch["__num_turns__"]
|
| 252 |
+
metrics["num_turns/min"] = num_turns.min()
|
| 253 |
+
metrics["num_turns/max"] = num_turns.max()
|
| 254 |
+
metrics["num_turns/mean"] = num_turns.mean()
|
| 255 |
+
|
| 256 |
+
if "tool_call_counts" in batch.non_tensor_batch:
|
| 257 |
+
tool_call_counts = batch.non_tensor_batch["tool_call_counts"]
|
| 258 |
+
metrics["tool_call_counts/min"] = tool_call_counts.min()
|
| 259 |
+
metrics["tool_call_counts/max"] = tool_call_counts.max()
|
| 260 |
+
metrics["tool_call_counts/mean"] = tool_call_counts.mean()
|
| 261 |
+
|
| 262 |
+
return metrics
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def comb_estimator(n: int, c: int, k: int) -> float:
|
| 266 |
+
"""Calculates 1 - comb(n - c, k) / comb(n, k)."""
|
| 267 |
+
if n - c < k:
|
| 268 |
+
return 1.0
|
| 269 |
+
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def process_validation_metrics(
|
| 273 |
+
data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42
|
| 274 |
+
) -> dict[str, dict[str, dict[str, float]]]:
|
| 275 |
+
"""
|
| 276 |
+
Process validation metrics into a structured format with statistical analysis.
|
| 277 |
+
|
| 278 |
+
This function organizes validation metrics by data source and prompt, then computes
|
| 279 |
+
various statistical measures including means, standard deviations, best/worst values,
|
| 280 |
+
and majority voting results. It also performs bootstrap sampling to estimate statistics
|
| 281 |
+
for different sample sizes.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
data_sources: List of data source identifiers for each sample.
|
| 285 |
+
sample_uids: List of sample uids corresponding to each sample.
|
| 286 |
+
infos_dict: Dictionary mapping variable names to lists of values for each sample.
|
| 287 |
+
seed: Random seed for bootstrap sampling. Defaults to 42.
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
A nested dictionary with the structure:
|
| 291 |
+
{
|
| 292 |
+
data_source: {
|
| 293 |
+
variable_name: {
|
| 294 |
+
metric_name: value
|
| 295 |
+
}
|
| 296 |
+
}
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
Where metric_name includes:
|
| 300 |
+
- "mean@N": Mean value across N samples
|
| 301 |
+
- "std@N": Standard deviation across N samples
|
| 302 |
+
- "best@N/mean": Mean of the best values in bootstrap samples of size N
|
| 303 |
+
- "best@N/std": Standard deviation of the best values in bootstrap samples
|
| 304 |
+
- "worst@N/mean": Mean of the worst values in bootstrap samples
|
| 305 |
+
- "worst@N/std": Standard deviation of the worst values in bootstrap samples
|
| 306 |
+
- "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists)
|
| 307 |
+
- "maj@N/std": Standard deviation of majority voting results (if "pred" exists)
|
| 308 |
+
|
| 309 |
+
Example:
|
| 310 |
+
>>> data_sources = ["source1", "source1", "source2"]
|
| 311 |
+
>>> sample_uids = ["uid1", "uid1", "uid2"]
|
| 312 |
+
>>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]}
|
| 313 |
+
>>> result = process_validation_metrics(data_sources, sample_uids, infos_dict)
|
| 314 |
+
>>> # result will contain statistics for each data source and variable
|
| 315 |
+
"""
|
| 316 |
+
# Group metrics by data source, prompt and variable
|
| 317 |
+
data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
| 318 |
+
for sample_idx, data_source in enumerate(data_sources):
|
| 319 |
+
uid = sample_uids[sample_idx]
|
| 320 |
+
var2vals = data_src2uid2var2vals[data_source][uid]
|
| 321 |
+
for var_name, var_vals in infos_dict.items():
|
| 322 |
+
var2vals[var_name].append(var_vals[sample_idx])
|
| 323 |
+
|
| 324 |
+
# Calculate metrics for each group
|
| 325 |
+
data_src2uid2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
|
| 326 |
+
for data_source, uid2var2vals in data_src2uid2var2vals.items():
|
| 327 |
+
for uid, var2vals in uid2var2vals.items():
|
| 328 |
+
for var_name, var_vals in var2vals.items():
|
| 329 |
+
if isinstance(var_vals[0], str):
|
| 330 |
+
continue
|
| 331 |
+
|
| 332 |
+
metric = {}
|
| 333 |
+
n_resps = len(var_vals)
|
| 334 |
+
metric[f"mean@{n_resps}"] = np.mean(var_vals)
|
| 335 |
+
metric["pass@1/mean"] = comb_estimator(n_resps, np.sum(var_vals), 1)
|
| 336 |
+
|
| 337 |
+
if n_resps > 1:
|
| 338 |
+
metric[f"std@{n_resps}"] = np.std(var_vals)
|
| 339 |
+
|
| 340 |
+
ns = []
|
| 341 |
+
n = 2
|
| 342 |
+
while n < n_resps:
|
| 343 |
+
ns.append(n)
|
| 344 |
+
n *= 2
|
| 345 |
+
ns.append(n_resps)
|
| 346 |
+
|
| 347 |
+
for n in ns:
|
| 348 |
+
# [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(
|
| 349 |
+
# data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed
|
| 350 |
+
# )
|
| 351 |
+
# metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std
|
| 352 |
+
# metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std
|
| 353 |
+
metric[f"pass@{n}/mean"] = comb_estimator(n_resps, np.sum(var_vals), n)
|
| 354 |
+
if var2vals.get("pred", None) is not None:
|
| 355 |
+
vote_data = [
|
| 356 |
+
{"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True)
|
| 357 |
+
]
|
| 358 |
+
[(maj_n_mean, maj_n_std)] = bootstrap_metric(
|
| 359 |
+
data=vote_data,
|
| 360 |
+
subset_size=n,
|
| 361 |
+
reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")],
|
| 362 |
+
seed=seed,
|
| 363 |
+
)
|
| 364 |
+
metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std
|
| 365 |
+
|
| 366 |
+
data_src2uid2var2metric[data_source][uid][var_name] = metric
|
| 367 |
+
|
| 368 |
+
# Aggregate metrics across uids
|
| 369 |
+
data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
| 370 |
+
for data_source, uid2var2metric in data_src2uid2var2metric.items():
|
| 371 |
+
for uid, var2metric in uid2var2metric.items():
|
| 372 |
+
for var_name, metric in var2metric.items():
|
| 373 |
+
for metric_name, metric_val in metric.items():
|
| 374 |
+
data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val)
|
| 375 |
+
|
| 376 |
+
data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
|
| 377 |
+
for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items():
|
| 378 |
+
for var_name, metric2uid_vals in var2metric2uid_vals.items():
|
| 379 |
+
for metric_name, uid_vals in metric2uid_vals.items():
|
| 380 |
+
data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals)
|
| 381 |
+
|
| 382 |
+
return data_src2var2metric2val
|
ICL/DAPO/verl-recipe/rep_exp/model_merge.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CHECKPOINT_PATH=${1} # /path/to/global_step_X/actor, where X is the global step of the checkpoint with the best pass@1 on dev
|
| 2 |
+
|
| 3 |
+
python3 -m verl.model_merger merge \
|
| 4 |
+
--backend fsdp \
|
| 5 |
+
--local_dir $CHECKPOINT_PATH \
|
| 6 |
+
--target_dir $CHECKPOINT_PATH/hf
|
ICL/DAPO/verl-recipe/rep_exp/plot_pass_at_k.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Code to plot the pass@k results for the RepExp RL training results.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
import numpy as np
|
| 24 |
+
import scipy.stats as stats
|
| 25 |
+
import seaborn as sns
|
| 26 |
+
from matplotlib.lines import Line2D
|
| 27 |
+
|
| 28 |
+
# Content configuration
|
| 29 |
+
EVAL_FOLDER = "./eval"
|
| 30 |
+
TASKS = ["math"] # ["math", "gsm8k", "dapo-with-aime24"]
|
| 31 |
+
SEEDS = [41, 42, 43]
|
| 32 |
+
ALGORITHMS = ["elliptical"] # ["grpo", "elliptical", "untrained", "unlikely"]
|
| 33 |
+
LOG_AXES = True
|
| 34 |
+
|
| 35 |
+
# Plot configuration
|
| 36 |
+
FACE_COLOR = "#F7F7FF"
|
| 37 |
+
MARKER = "o"
|
| 38 |
+
LINEWIDTH = 1.275
|
| 39 |
+
MARKERSIZE = 6
|
| 40 |
+
MARKEREDGEWIDTH = 0.9
|
| 41 |
+
LABEL_FONT_SIZE = 10
|
| 42 |
+
TITLE_FONT_SIZE = 11
|
| 43 |
+
TICK_LABEL_FONT_SIZE = 8
|
| 44 |
+
LEGEND_FONT_SIZE = 8
|
| 45 |
+
|
| 46 |
+
TASK_TO_NICE_NAME = {
|
| 47 |
+
"math": "MATH",
|
| 48 |
+
"gsm8k": "GSM8K",
|
| 49 |
+
"dapo-with-aime24": "AIME 2024",
|
| 50 |
+
"countdown-4": "Countdown",
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
ALGO_TO_COLOR = {
|
| 54 |
+
"grpo": sns.color_palette("deep")[-1],
|
| 55 |
+
"untrained": sns.color_palette("deep")[7],
|
| 56 |
+
"elliptical": sns.color_palette("colorblind")[2],
|
| 57 |
+
"unlikely": sns.color_palette("deep")[1],
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
ALGO_TO_NICE_NAME = {
|
| 61 |
+
"grpo": "GRPO",
|
| 62 |
+
"untrained": "Base Model",
|
| 63 |
+
"elliptical": r"RepExp (ours)",
|
| 64 |
+
"unlikely": "Unlikeliness",
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def process_data(data: list[dict[str, float]], algorithm: str) -> tuple[dict[int, float], dict[int, float]]:
|
| 69 |
+
"""
|
| 70 |
+
Process the pass@k data generated by a given algorithm.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
data (List[Dict]): The data to process.
|
| 74 |
+
algorithm (str): Algorithm that generated the data.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Tuple[Dict[int, float], Dict[int, float]]:
|
| 78 |
+
pass_at_k - The mean pass@k values.
|
| 79 |
+
pass_at_k_sem - The standard error of the pass@k values.
|
| 80 |
+
"""
|
| 81 |
+
pass_at_k = defaultdict(list)
|
| 82 |
+
for d in data:
|
| 83 |
+
for key, v in d.items():
|
| 84 |
+
for k in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
|
| 85 |
+
if key.endswith(f"reward/pass@{k}/mean"):
|
| 86 |
+
pass_at_k[k].append(v)
|
| 87 |
+
|
| 88 |
+
# NOTE: we only use a single seed for untrained since there is only one checkpoint for it
|
| 89 |
+
if algorithm != "untrained":
|
| 90 |
+
for k in pass_at_k.keys():
|
| 91 |
+
assert len(pass_at_k[k]) == len(SEEDS)
|
| 92 |
+
|
| 93 |
+
pass_at_k_sem = {k: stats.sem(v) for k, v in pass_at_k.items()} if algorithm != "untrained" else None
|
| 94 |
+
pass_at_k = {k: np.mean(v) for k, v in pass_at_k.items()}
|
| 95 |
+
|
| 96 |
+
return pass_at_k, pass_at_k_sem
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def main():
|
| 100 |
+
# Get all top-level folders in EVAL_FOLDER
|
| 101 |
+
eval_folders = os.listdir(EVAL_FOLDER)
|
| 102 |
+
|
| 103 |
+
# Figure setup
|
| 104 |
+
sns.set_style("whitegrid")
|
| 105 |
+
fig, axs = plt.subplots(1, len(TASKS), figsize=(3 * len(TASKS), 3))
|
| 106 |
+
|
| 107 |
+
for i, task in enumerate(TASKS):
|
| 108 |
+
ax = axs[i] if len(TASKS) > 1 else axs
|
| 109 |
+
algo_to_xs = {}
|
| 110 |
+
algo_to_ys = {}
|
| 111 |
+
|
| 112 |
+
for algorithm in ALGORITHMS:
|
| 113 |
+
# Get all eval folders for the current task and algorithm
|
| 114 |
+
folders = [f for f in eval_folders if f.startswith(f"{task}_{algorithm}")]
|
| 115 |
+
if len(folders) == 0:
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
data = []
|
| 119 |
+
for folder in folders:
|
| 120 |
+
if algorithm == "untrained":
|
| 121 |
+
with open(os.path.join(EVAL_FOLDER, folder, "eval.json")) as f:
|
| 122 |
+
data.append(json.load(f))
|
| 123 |
+
else:
|
| 124 |
+
# walk all files recursively in folder
|
| 125 |
+
for root, dirs, files in os.walk(os.path.join(EVAL_FOLDER, folder)):
|
| 126 |
+
for file in files:
|
| 127 |
+
if file.endswith("eval.json"):
|
| 128 |
+
with open(os.path.join(root, file)) as f:
|
| 129 |
+
data.append(json.load(f))
|
| 130 |
+
break
|
| 131 |
+
|
| 132 |
+
pass_at_k, pass_at_k_sem = process_data(data, algorithm)
|
| 133 |
+
|
| 134 |
+
xs = np.array(list(pass_at_k.keys()))
|
| 135 |
+
ys = np.array([pass_at_k[k] for k in xs])
|
| 136 |
+
algo_to_xs[algorithm] = xs
|
| 137 |
+
algo_to_ys[algorithm] = ys
|
| 138 |
+
|
| 139 |
+
# Plot the current task - algorithm data
|
| 140 |
+
ax.plot(
|
| 141 |
+
xs,
|
| 142 |
+
ys,
|
| 143 |
+
color=ALGO_TO_COLOR[algorithm],
|
| 144 |
+
label=algorithm,
|
| 145 |
+
markeredgecolor=FACE_COLOR,
|
| 146 |
+
marker=MARKER,
|
| 147 |
+
linewidth=LINEWIDTH,
|
| 148 |
+
markersize=MARKERSIZE,
|
| 149 |
+
markeredgewidth=MARKEREDGEWIDTH,
|
| 150 |
+
alpha=1.0 if algorithm != "untrained" else 0.8,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Plot the standard error in shaded bands
|
| 154 |
+
if algorithm != "untrained":
|
| 155 |
+
sems = np.array([pass_at_k_sem[k] for k in xs])
|
| 156 |
+
ax.fill_between(xs, ys - sems, ys + sems, alpha=0.2, color=ALGO_TO_COLOR[algorithm])
|
| 157 |
+
|
| 158 |
+
# Set y-axis limits
|
| 159 |
+
if task == "math":
|
| 160 |
+
y_min = 0.7
|
| 161 |
+
ax.set_ylim(top=0.95, bottom=y_min)
|
| 162 |
+
elif task == "gsm8k":
|
| 163 |
+
y_min = 0.925
|
| 164 |
+
ax.set_ylim(top=0.995, bottom=y_min)
|
| 165 |
+
elif task == "dapo-with-aime24":
|
| 166 |
+
y_min = 0.1
|
| 167 |
+
ax.set_ylim(bottom=y_min, top=0.63)
|
| 168 |
+
|
| 169 |
+
# Set x-axis limits
|
| 170 |
+
if LOG_AXES:
|
| 171 |
+
ax.set_xlim(left=2 ** (-0.2), right=2 ** (8.2))
|
| 172 |
+
else:
|
| 173 |
+
ax.set_xlim(left=-10, right=266)
|
| 174 |
+
|
| 175 |
+
# Set x-axis scale and ticks
|
| 176 |
+
if LOG_AXES:
|
| 177 |
+
ax.set_xscale("log", base=2)
|
| 178 |
+
x_ticks = [2**i for i in range(int(np.log2(max(xs))) + 1)]
|
| 179 |
+
x_tick_labels = [f"$2^{{{i}}}$" for i in range(int(np.log2(max(xs))) + 1)]
|
| 180 |
+
else:
|
| 181 |
+
# set every 64
|
| 182 |
+
x_ticks = [1, 32, 64, 96, 128, 160, 192, 224, 256]
|
| 183 |
+
x_tick_labels = ["1", "32", "64", "96", "128", "160", "192", "224", "256"]
|
| 184 |
+
ax.set_xticks(x_ticks, x_tick_labels)
|
| 185 |
+
|
| 186 |
+
# Set axes labels
|
| 187 |
+
ax.set_xlabel("k", fontsize=LABEL_FONT_SIZE)
|
| 188 |
+
if i == 0:
|
| 189 |
+
ax.set_ylabel("Pass@k", fontsize=LABEL_FONT_SIZE)
|
| 190 |
+
|
| 191 |
+
# Set title
|
| 192 |
+
ax.set_title(f"{TASK_TO_NICE_NAME[task]}", fontsize=TITLE_FONT_SIZE)
|
| 193 |
+
|
| 194 |
+
# Set font size for tick labels
|
| 195 |
+
for _label in ax.get_xticklabels():
|
| 196 |
+
_label.set_fontsize(TICK_LABEL_FONT_SIZE)
|
| 197 |
+
for _label in ax.get_yticklabels():
|
| 198 |
+
_label.set_fontsize(TICK_LABEL_FONT_SIZE)
|
| 199 |
+
|
| 200 |
+
# Create legend handles
|
| 201 |
+
legend_handles = [
|
| 202 |
+
Line2D(
|
| 203 |
+
[0],
|
| 204 |
+
[0],
|
| 205 |
+
color=ALGO_TO_COLOR[algo],
|
| 206 |
+
marker=MARKER,
|
| 207 |
+
linestyle="-",
|
| 208 |
+
linewidth=LINEWIDTH,
|
| 209 |
+
markersize=MARKERSIZE,
|
| 210 |
+
markeredgewidth=MARKEREDGEWIDTH,
|
| 211 |
+
markeredgecolor=FACE_COLOR,
|
| 212 |
+
label=ALGO_TO_NICE_NAME[algo],
|
| 213 |
+
)
|
| 214 |
+
for algo in ALGORITHMS
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
# Create legend
|
| 218 |
+
legend = fig.legend(
|
| 219 |
+
handles=legend_handles,
|
| 220 |
+
loc="lower center",
|
| 221 |
+
ncol=len(ALGORITHMS),
|
| 222 |
+
bbox_to_anchor=(0.5, -0.07),
|
| 223 |
+
fontsize=LEGEND_FONT_SIZE,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
plt.tight_layout()
|
| 227 |
+
|
| 228 |
+
os.makedirs("figures", exist_ok=True)
|
| 229 |
+
# Save figure
|
| 230 |
+
plt.savefig(
|
| 231 |
+
os.path.join("figures", f"rl_pass_at_k_{TASKS}_{'' if LOG_AXES else '_linear_axes'}.pdf"),
|
| 232 |
+
bbox_extra_artists=(legend,),
|
| 233 |
+
bbox_inches="tight",
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# Close figure
|
| 237 |
+
plt.close()
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
if __name__ == "__main__":
|
| 241 |
+
main()
|
ICL/DAPO/verl-recipe/rep_exp/rep_exp_trainer.py
ADDED
|
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright 2023-2024 SGLang Team
|
| 3 |
+
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""
|
| 17 |
+
PPO Trainer with Ray-based single controller.
|
| 18 |
+
This trainer supports model-agonistic model initialization with huggingface
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import os
|
| 23 |
+
import uuid
|
| 24 |
+
from collections import defaultdict
|
| 25 |
+
from copy import deepcopy
|
| 26 |
+
from pprint import pprint
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
import ray
|
| 30 |
+
import torch
|
| 31 |
+
from omegaconf import OmegaConf
|
| 32 |
+
from tqdm import tqdm
|
| 33 |
+
|
| 34 |
+
from verl import DataProto
|
| 35 |
+
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
|
| 36 |
+
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
|
| 37 |
+
from verl.single_controller.ray import RayClassWithInitArgs
|
| 38 |
+
from verl.single_controller.ray.base import create_colocated_worker_cls
|
| 39 |
+
from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
|
| 40 |
+
from verl.trainer.ppo.metric_utils import (
|
| 41 |
+
compute_throughout_metrics,
|
| 42 |
+
compute_timing_metrics,
|
| 43 |
+
)
|
| 44 |
+
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask
|
| 45 |
+
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
|
| 46 |
+
from verl.trainer.ppo.utils import Role
|
| 47 |
+
from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi
|
| 48 |
+
from verl.utils.config import omega_conf_to_dataclass
|
| 49 |
+
from verl.utils.debug import marked_timer
|
| 50 |
+
from verl.utils.metric import reduce_metrics
|
| 51 |
+
from verl.utils.rollout_skip import RolloutSkip
|
| 52 |
+
|
| 53 |
+
from .metric_utils import compute_data_metrics, process_validation_metrics
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class RayRepExpTrainer(RayPPOTrainer):
|
| 57 |
+
"""Distributed RepExp trainer using Ray for scalable reinforcement learning.
|
| 58 |
+
|
| 59 |
+
See RayPPOTrainer parent class for more details.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def _save_checkpoint(self):
|
| 63 |
+
super()._save_checkpoint()
|
| 64 |
+
|
| 65 |
+
# Write best metric to global steps
|
| 66 |
+
local_best_metric_to_global_step = os.path.join(
|
| 67 |
+
self.config.trainer.default_local_dir, "best_metric_to_global_step.json"
|
| 68 |
+
)
|
| 69 |
+
with open(local_best_metric_to_global_step, "w") as f:
|
| 70 |
+
json.dump(self.best_dev_pass_at_k_to_global_step, f)
|
| 71 |
+
|
| 72 |
+
def _update_best_pass_at(self, val_metrics: dict[str, float], pass_at_k: int) -> bool:
|
| 73 |
+
"""
|
| 74 |
+
Save checkpoint if the validation metrics are the best.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
val_metrics: The validation metrics.
|
| 78 |
+
pass_at_k: The pass@k to use for determining whether to save the checkpoint.
|
| 79 |
+
"""
|
| 80 |
+
for k in val_metrics.keys():
|
| 81 |
+
if k.endswith(f"reward/pass@{pass_at_k}/mean"):
|
| 82 |
+
if val_metrics[k] > self.best_dev_pass_at_k[pass_at_k]:
|
| 83 |
+
self.best_dev_pass_at_k[pass_at_k] = val_metrics[k]
|
| 84 |
+
self.best_dev_pass_at_k_to_global_step[pass_at_k] = self.global_steps
|
| 85 |
+
return True
|
| 86 |
+
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
def _validate(self):
|
| 90 |
+
data_source_lst = []
|
| 91 |
+
reward_extra_infos_dict: dict[str, list] = defaultdict(list)
|
| 92 |
+
|
| 93 |
+
# Lists to collect samples for the table
|
| 94 |
+
sample_inputs = []
|
| 95 |
+
sample_outputs = []
|
| 96 |
+
sample_gts = []
|
| 97 |
+
sample_scores = []
|
| 98 |
+
sample_turns = []
|
| 99 |
+
sample_uids = []
|
| 100 |
+
|
| 101 |
+
for test_data in tqdm(self.val_dataloader, desc="Validating ..."):
|
| 102 |
+
test_batch = DataProto.from_single_dict(test_data)
|
| 103 |
+
|
| 104 |
+
if "uid" not in test_batch.non_tensor_batch:
|
| 105 |
+
test_batch.non_tensor_batch["uid"] = np.array(
|
| 106 |
+
[str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# repeat test batch
|
| 110 |
+
test_batch = test_batch.repeat(
|
| 111 |
+
repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# we only do validation on rule-based rm
|
| 115 |
+
if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
|
| 116 |
+
return {}
|
| 117 |
+
|
| 118 |
+
# Store original inputs
|
| 119 |
+
input_ids = test_batch.batch["input_ids"]
|
| 120 |
+
# TODO: Can we keep special tokens except for padding tokens?
|
| 121 |
+
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
|
| 122 |
+
sample_inputs.extend(input_texts)
|
| 123 |
+
sample_uids.extend(test_batch.non_tensor_batch["uid"])
|
| 124 |
+
|
| 125 |
+
ground_truths = [
|
| 126 |
+
item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch
|
| 127 |
+
]
|
| 128 |
+
sample_gts.extend(ground_truths)
|
| 129 |
+
|
| 130 |
+
test_gen_batch = self._get_gen_batch(test_batch)
|
| 131 |
+
test_gen_batch.meta_info = {
|
| 132 |
+
"eos_token_id": self.tokenizer.eos_token_id,
|
| 133 |
+
"pad_token_id": self.tokenizer.pad_token_id,
|
| 134 |
+
"recompute_log_prob": False,
|
| 135 |
+
"do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
|
| 136 |
+
"validate": True,
|
| 137 |
+
"global_steps": self.global_steps,
|
| 138 |
+
}
|
| 139 |
+
print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")
|
| 140 |
+
|
| 141 |
+
# pad to be divisible by dp_size
|
| 142 |
+
size_divisor = (
|
| 143 |
+
self.actor_rollout_wg.world_size
|
| 144 |
+
if not self.async_rollout_mode
|
| 145 |
+
else self.config.actor_rollout_ref.rollout.agent.num_workers
|
| 146 |
+
)
|
| 147 |
+
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)
|
| 148 |
+
if not self.async_rollout_mode:
|
| 149 |
+
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
|
| 150 |
+
else:
|
| 151 |
+
test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
|
| 152 |
+
|
| 153 |
+
# unpad
|
| 154 |
+
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
|
| 155 |
+
|
| 156 |
+
print("validation generation end")
|
| 157 |
+
|
| 158 |
+
# Store generated outputs
|
| 159 |
+
output_ids = test_output_gen_batch.batch["responses"]
|
| 160 |
+
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
|
| 161 |
+
sample_outputs.extend(output_texts)
|
| 162 |
+
|
| 163 |
+
test_batch = test_batch.union(test_output_gen_batch)
|
| 164 |
+
test_batch.meta_info["validate"] = True
|
| 165 |
+
|
| 166 |
+
# evaluate using reward_function
|
| 167 |
+
if self.val_reward_fn is None:
|
| 168 |
+
raise ValueError("val_reward_fn must be provided for validation.")
|
| 169 |
+
result = self.val_reward_fn(test_batch, return_dict=True)
|
| 170 |
+
reward_tensor = result["reward_tensor"]
|
| 171 |
+
scores = reward_tensor.sum(-1).cpu().tolist()
|
| 172 |
+
sample_scores.extend(scores)
|
| 173 |
+
|
| 174 |
+
reward_extra_infos_dict["reward"].extend(scores)
|
| 175 |
+
if "reward_extra_info" in result:
|
| 176 |
+
for key, lst in result["reward_extra_info"].items():
|
| 177 |
+
reward_extra_infos_dict[key].extend(lst)
|
| 178 |
+
|
| 179 |
+
# collect num_turns of each prompt
|
| 180 |
+
if "__num_turns__" in test_batch.non_tensor_batch:
|
| 181 |
+
sample_turns.append(test_batch.non_tensor_batch["__num_turns__"])
|
| 182 |
+
|
| 183 |
+
data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))
|
| 184 |
+
|
| 185 |
+
self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
|
| 186 |
+
|
| 187 |
+
# dump generations
|
| 188 |
+
val_data_dir = self.config.trainer.get("validation_data_dir", None)
|
| 189 |
+
if val_data_dir:
|
| 190 |
+
self._dump_generations(
|
| 191 |
+
inputs=sample_inputs,
|
| 192 |
+
outputs=sample_outputs,
|
| 193 |
+
gts=sample_gts,
|
| 194 |
+
scores=sample_scores,
|
| 195 |
+
reward_extra_infos_dict=reward_extra_infos_dict,
|
| 196 |
+
dump_path=val_data_dir,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
for key_info, lst in reward_extra_infos_dict.items():
|
| 200 |
+
assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"
|
| 201 |
+
|
| 202 |
+
data_sources = np.concatenate(data_source_lst, axis=0)
|
| 203 |
+
|
| 204 |
+
data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict)
|
| 205 |
+
metric_dict = {}
|
| 206 |
+
for data_source, var2metric2val in data_src2var2metric2val.items():
|
| 207 |
+
core_var = "acc" if "acc" in var2metric2val else "reward"
|
| 208 |
+
for var_name, metric2val in var2metric2val.items():
|
| 209 |
+
n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
|
| 210 |
+
for metric_name, metric_val in metric2val.items():
|
| 211 |
+
if (
|
| 212 |
+
(var_name == core_var)
|
| 213 |
+
and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"])
|
| 214 |
+
and (f"@{n_max}" in metric_name)
|
| 215 |
+
):
|
| 216 |
+
metric_sec = "val-core"
|
| 217 |
+
else:
|
| 218 |
+
metric_sec = "val-aux"
|
| 219 |
+
pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
|
| 220 |
+
metric_dict[pfx] = metric_val
|
| 221 |
+
|
| 222 |
+
if len(sample_turns) > 0:
|
| 223 |
+
sample_turns = np.concatenate(sample_turns)
|
| 224 |
+
metric_dict["val-aux/num_turns/min"] = sample_turns.min()
|
| 225 |
+
metric_dict["val-aux/num_turns/max"] = sample_turns.max()
|
| 226 |
+
metric_dict["val-aux/num_turns/mean"] = sample_turns.mean()
|
| 227 |
+
|
| 228 |
+
return metric_dict
|
| 229 |
+
|
| 230 |
+
def init_workers(self):
|
| 231 |
+
"""Initialize distributed training workers using Ray backend.
|
| 232 |
+
|
| 233 |
+
Creates:
|
| 234 |
+
1. Ray resource pools from configuration
|
| 235 |
+
2. Worker groups for each role (actor, critic, etc.)
|
| 236 |
+
"""
|
| 237 |
+
self.resource_pool_manager.create_resource_pool()
|
| 238 |
+
|
| 239 |
+
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
|
| 240 |
+
val_only = self.config.trainer.get("val_only", False)
|
| 241 |
+
|
| 242 |
+
# create actor and rollout
|
| 243 |
+
actor_role = Role.ActorRolloutRef if Role.ActorRolloutRef in self.role_worker_mapping else Role.ActorRollout
|
| 244 |
+
if self.hybrid_engine:
|
| 245 |
+
resource_pool = self.resource_pool_manager.get_resource_pool(actor_role)
|
| 246 |
+
actor_rollout_cls = RayClassWithInitArgs(
|
| 247 |
+
cls=self.role_worker_mapping[actor_role],
|
| 248 |
+
config=self.config.actor_rollout_ref,
|
| 249 |
+
role=str(actor_role),
|
| 250 |
+
)
|
| 251 |
+
self.resource_pool_to_cls[resource_pool][str(actor_role)] = actor_rollout_cls
|
| 252 |
+
else:
|
| 253 |
+
raise NotImplementedError
|
| 254 |
+
|
| 255 |
+
# create critic
|
| 256 |
+
if self.use_critic and not val_only:
|
| 257 |
+
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
|
| 258 |
+
critic_cfg = omega_conf_to_dataclass(self.config.critic)
|
| 259 |
+
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
|
| 260 |
+
self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls
|
| 261 |
+
|
| 262 |
+
# create reference policy if needed
|
| 263 |
+
if self.use_reference_policy and not val_only:
|
| 264 |
+
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
|
| 265 |
+
ref_policy_cls = RayClassWithInitArgs(
|
| 266 |
+
self.role_worker_mapping[Role.RefPolicy],
|
| 267 |
+
config=self.config.actor_rollout_ref,
|
| 268 |
+
role=str(Role.RefPolicy),
|
| 269 |
+
)
|
| 270 |
+
self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls
|
| 271 |
+
|
| 272 |
+
# create a reward model if reward_fn is None
|
| 273 |
+
if self.use_rm and not val_only:
|
| 274 |
+
# we create a RM here
|
| 275 |
+
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
|
| 276 |
+
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
|
| 277 |
+
self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls
|
| 278 |
+
|
| 279 |
+
# initialize WorkerGroup
|
| 280 |
+
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
|
| 281 |
+
# you should not use `create_colocated_worker_cls`.
|
| 282 |
+
# Instead, directly pass different resource pool to different worker groups.
|
| 283 |
+
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
|
| 284 |
+
all_wg = {}
|
| 285 |
+
wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
|
| 286 |
+
if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
|
| 287 |
+
wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
|
| 288 |
+
if OmegaConf.select(self.config.global_profiler, "steps") is not None:
|
| 289 |
+
wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps")
|
| 290 |
+
# Only require nsight worker options when tool is nsys
|
| 291 |
+
if OmegaConf.select(self.config.global_profiler, "tool") == "nsys":
|
| 292 |
+
assert (
|
| 293 |
+
OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
|
| 294 |
+
is not None
|
| 295 |
+
), "worker_nsight_options must be set when using nsys with profile_steps"
|
| 296 |
+
wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
|
| 297 |
+
OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
|
| 298 |
+
)
|
| 299 |
+
wg_kwargs["device_name"] = self.device_name
|
| 300 |
+
|
| 301 |
+
for resource_pool, class_dict in self.resource_pool_to_cls.items():
|
| 302 |
+
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
|
| 303 |
+
wg_dict = self.ray_worker_group_cls(
|
| 304 |
+
resource_pool=resource_pool,
|
| 305 |
+
ray_cls_with_init=worker_dict_cls,
|
| 306 |
+
**wg_kwargs,
|
| 307 |
+
)
|
| 308 |
+
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
|
| 309 |
+
all_wg.update(spawn_wg)
|
| 310 |
+
|
| 311 |
+
if self.use_critic:
|
| 312 |
+
self.critic_wg = all_wg[str(Role.Critic)]
|
| 313 |
+
self.critic_wg.init_model()
|
| 314 |
+
|
| 315 |
+
if self.use_reference_policy and not self.ref_in_actor:
|
| 316 |
+
if str(Role.RefPolicy) in all_wg:
|
| 317 |
+
self.ref_policy_wg = all_wg[str(Role.RefPolicy)]
|
| 318 |
+
self.ref_policy_wg.init_model()
|
| 319 |
+
else:
|
| 320 |
+
# Model engine: ActorRolloutRefWorker
|
| 321 |
+
assert str(Role.ActorRolloutRef) in all_wg, f"{all_wg.keys()=}"
|
| 322 |
+
self.ref_policy_wg = all_wg[str(Role.ActorRolloutRef)]
|
| 323 |
+
|
| 324 |
+
self.rm_wg = None
|
| 325 |
+
# initalization of rm_wg will be deprecated in the future
|
| 326 |
+
if self.use_rm:
|
| 327 |
+
self.rm_wg = all_wg[str(Role.RewardModel)]
|
| 328 |
+
self.rm_wg.init_model()
|
| 329 |
+
|
| 330 |
+
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
|
| 331 |
+
self.actor_rollout_wg = all_wg[str(actor_role)]
|
| 332 |
+
self.actor_rollout_wg.init_model()
|
| 333 |
+
|
| 334 |
+
# create async rollout manager and request scheduler
|
| 335 |
+
self.async_rollout_mode = False
|
| 336 |
+
if self.config.actor_rollout_ref.rollout.mode == "async":
|
| 337 |
+
from verl.experimental.agent_loop import AgentLoopManager
|
| 338 |
+
|
| 339 |
+
self.async_rollout_mode = True
|
| 340 |
+
self.async_rollout_manager = AgentLoopManager(
|
| 341 |
+
config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
def fit(self):
|
| 345 |
+
"""
|
| 346 |
+
The training loop of PPO.
|
| 347 |
+
The driver process only need to call the compute functions of the worker group through RPC
|
| 348 |
+
to construct the PPO dataflow.
|
| 349 |
+
The light-weight advantage computation is done on the driver process.
|
| 350 |
+
"""
|
| 351 |
+
from omegaconf import OmegaConf
|
| 352 |
+
|
| 353 |
+
from .utils.tracking import Tracking
|
| 354 |
+
|
| 355 |
+
logger = Tracking(
|
| 356 |
+
project_name=self.config.trainer.project_name,
|
| 357 |
+
experiment_name=self.config.trainer.experiment_name,
|
| 358 |
+
default_backend=self.config.trainer.logger,
|
| 359 |
+
config=OmegaConf.to_container(self.config, resolve=True),
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# global vars to track during training
|
| 363 |
+
self.global_steps = 0
|
| 364 |
+
|
| 365 |
+
self.best_dev_pass_at_k = {
|
| 366 |
+
1: 0,
|
| 367 |
+
}
|
| 368 |
+
self.best_dev_pass_at_k_to_global_step = {
|
| 369 |
+
1: 0,
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
# load checkpoint before doing anything
|
| 373 |
+
self._load_checkpoint()
|
| 374 |
+
|
| 375 |
+
current_epoch = self.global_steps // len(self.train_dataloader)
|
| 376 |
+
|
| 377 |
+
# perform validation before training
|
| 378 |
+
# currently, we only support validation using the reward_function.
|
| 379 |
+
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
|
| 380 |
+
val_metrics = self._validate()
|
| 381 |
+
assert val_metrics, f"{val_metrics=}"
|
| 382 |
+
|
| 383 |
+
# Initialize the best validation metrics for pass@k before training
|
| 384 |
+
self._update_best_pass_at(val_metrics, 1)
|
| 385 |
+
val_metrics["best/pass@1"] = self.best_dev_pass_at_k[1]
|
| 386 |
+
|
| 387 |
+
pprint(f"Initial validation metrics: {val_metrics}")
|
| 388 |
+
logger.log(data=val_metrics, step=self.global_steps)
|
| 389 |
+
|
| 390 |
+
if self.config.trainer.get("val_only", False):
|
| 391 |
+
return
|
| 392 |
+
|
| 393 |
+
if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
|
| 394 |
+
rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
|
| 395 |
+
rollout_skip.wrap_generate_sequences()
|
| 396 |
+
|
| 397 |
+
# add tqdm
|
| 398 |
+
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
|
| 399 |
+
|
| 400 |
+
# we start from step 1
|
| 401 |
+
self.global_steps += 1
|
| 402 |
+
last_val_metrics = None
|
| 403 |
+
self.max_steps_duration = 0
|
| 404 |
+
|
| 405 |
+
prev_step_profile = False
|
| 406 |
+
curr_step_profile = (
|
| 407 |
+
self.global_steps in self.config.global_profiler.steps
|
| 408 |
+
if self.config.global_profiler.steps is not None
|
| 409 |
+
else False
|
| 410 |
+
)
|
| 411 |
+
next_step_profile = False
|
| 412 |
+
|
| 413 |
+
for epoch in range(current_epoch, self.config.trainer.total_epochs):
|
| 414 |
+
for batch_dict in self.train_dataloader:
|
| 415 |
+
metrics = {}
|
| 416 |
+
timing_raw = {}
|
| 417 |
+
|
| 418 |
+
with marked_timer("start_profile", timing_raw):
|
| 419 |
+
self._start_profiling(
|
| 420 |
+
not prev_step_profile and curr_step_profile
|
| 421 |
+
if self.config.global_profiler.profile_continuous_steps
|
| 422 |
+
else curr_step_profile
|
| 423 |
+
)
|
| 424 |
+
batch: DataProto = DataProto.from_single_dict(batch_dict)
|
| 425 |
+
batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
|
| 426 |
+
|
| 427 |
+
# add uid to batch
|
| 428 |
+
batch.non_tensor_batch["uid"] = np.array(
|
| 429 |
+
[str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
gen_batch = self._get_gen_batch(batch)
|
| 433 |
+
|
| 434 |
+
# pass global_steps to trace
|
| 435 |
+
gen_batch.meta_info["global_steps"] = self.global_steps
|
| 436 |
+
gen_batch_output = gen_batch.repeat(
|
| 437 |
+
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
is_last_step = self.global_steps >= self.total_training_steps
|
| 441 |
+
with marked_timer("step", timing_raw):
|
| 442 |
+
# generate a batch
|
| 443 |
+
with marked_timer("gen", timing_raw, color="red"):
|
| 444 |
+
if not self.async_rollout_mode:
|
| 445 |
+
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
|
| 446 |
+
else:
|
| 447 |
+
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
|
| 448 |
+
|
| 449 |
+
timing_raw.update(gen_batch_output.meta_info["timing"])
|
| 450 |
+
gen_batch_output.meta_info.pop("timing", None)
|
| 451 |
+
|
| 452 |
+
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
|
| 453 |
+
if self.reward_fn is None:
|
| 454 |
+
raise ValueError("A reward_fn is required for REMAX advantage estimation.")
|
| 455 |
+
|
| 456 |
+
with marked_timer("gen_max", timing_raw, color="purple"):
|
| 457 |
+
gen_baseline_batch = deepcopy(gen_batch)
|
| 458 |
+
gen_baseline_batch.meta_info["do_sample"] = False
|
| 459 |
+
if not self.async_rollout_mode:
|
| 460 |
+
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
|
| 461 |
+
else:
|
| 462 |
+
gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
|
| 463 |
+
batch = batch.union(gen_baseline_output)
|
| 464 |
+
# compute reward model score on batch
|
| 465 |
+
rm_scores = None
|
| 466 |
+
if self.use_rm and "rm_scores" not in batch.batch.keys():
|
| 467 |
+
rm_scores = self.rm_wg.compute_rm_score(batch)
|
| 468 |
+
batch = batch.union(rm_scores)
|
| 469 |
+
reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn)
|
| 470 |
+
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
|
| 471 |
+
|
| 472 |
+
keys_to_pop = set(gen_baseline_output.batch.keys())
|
| 473 |
+
if rm_scores is not None:
|
| 474 |
+
keys_to_pop.update(rm_scores.batch.keys())
|
| 475 |
+
batch.pop(batch_keys=list(keys_to_pop))
|
| 476 |
+
|
| 477 |
+
batch.batch["reward_baselines"] = reward_baseline_tensor
|
| 478 |
+
|
| 479 |
+
del rm_scores, gen_baseline_batch, gen_baseline_output
|
| 480 |
+
# repeat to align with repeated responses in rollout
|
| 481 |
+
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
| 482 |
+
batch = batch.union(gen_batch_output)
|
| 483 |
+
|
| 484 |
+
if "response_mask" not in batch.batch.keys():
|
| 485 |
+
batch.batch["response_mask"] = compute_response_mask(batch)
|
| 486 |
+
# Balance the number of valid tokens across DP ranks.
|
| 487 |
+
# NOTE: This usually changes the order of data in the `batch`,
|
| 488 |
+
# which won't affect the advantage calculation (since it's based on uid),
|
| 489 |
+
# but might affect the loss calculation (due to the change of mini-batching).
|
| 490 |
+
if self.config.trainer.balance_batch:
|
| 491 |
+
self._balance_batch(batch, metrics=metrics)
|
| 492 |
+
|
| 493 |
+
# compute global_valid tokens
|
| 494 |
+
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
|
| 495 |
+
|
| 496 |
+
with marked_timer("reward", timing_raw, color="yellow"):
|
| 497 |
+
# compute reward model score
|
| 498 |
+
if self.use_rm and "rm_scores" not in batch.batch.keys():
|
| 499 |
+
if self.config.reward_model.elliptical.enable:
|
| 500 |
+
hidden_states = self.rm_wg.compute_hidden_states(batch)
|
| 501 |
+
batch = batch.union(hidden_states)
|
| 502 |
+
reward_tensor = self.rm_wg.compute_rm_score(batch)
|
| 503 |
+
else:
|
| 504 |
+
reward_tensor = self.rm_wg.compute_rm_score(batch)
|
| 505 |
+
batch = batch.union(reward_tensor)
|
| 506 |
+
|
| 507 |
+
if self.config.reward_model.launch_reward_fn_async:
|
| 508 |
+
future_reward = compute_reward_async.remote(
|
| 509 |
+
data=batch, config=self.config, tokenizer=self.tokenizer
|
| 510 |
+
)
|
| 511 |
+
else:
|
| 512 |
+
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
|
| 513 |
+
|
| 514 |
+
# Operating Mode Selection:
|
| 515 |
+
# - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ)
|
| 516 |
+
# - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ)
|
| 517 |
+
# Note: π_old computed once per data batch, serves as stable reference during mini-batch updates
|
| 518 |
+
rollout_corr_config = self.config.algorithm.get("rollout_correction", None)
|
| 519 |
+
bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False)
|
| 520 |
+
if bypass_recomputing_logprobs: # Use `rollout_log_probs`
|
| 521 |
+
from verl.trainer.ppo.rollout_corr_helper import apply_rollout_correction
|
| 522 |
+
|
| 523 |
+
apply_rollout_correction(
|
| 524 |
+
batch=batch,
|
| 525 |
+
rollout_corr_config=rollout_corr_config,
|
| 526 |
+
policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss,
|
| 527 |
+
)
|
| 528 |
+
else: # Recompute old_log_probs
|
| 529 |
+
with marked_timer("old_log_prob", timing_raw, color="blue"):
|
| 530 |
+
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
|
| 531 |
+
entropys = old_log_prob.batch["entropys"]
|
| 532 |
+
response_masks = batch.batch["response_mask"]
|
| 533 |
+
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
|
| 534 |
+
entropy_agg = agg_loss(
|
| 535 |
+
loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode
|
| 536 |
+
)
|
| 537 |
+
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
|
| 538 |
+
metrics.update(old_log_prob_metrics)
|
| 539 |
+
old_log_prob.batch.pop("entropys")
|
| 540 |
+
batch = batch.union(old_log_prob)
|
| 541 |
+
if "rollout_log_probs" in batch.batch.keys():
|
| 542 |
+
# TODO: we may want to add diff of probs too.
|
| 543 |
+
from verl.utils.debug.metrics import calculate_debug_metrics
|
| 544 |
+
|
| 545 |
+
metrics.update(calculate_debug_metrics(batch))
|
| 546 |
+
|
| 547 |
+
assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}'
|
| 548 |
+
|
| 549 |
+
if self.use_reference_policy:
|
| 550 |
+
# compute reference log_prob
|
| 551 |
+
with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
|
| 552 |
+
if not self.ref_in_actor:
|
| 553 |
+
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
|
| 554 |
+
else:
|
| 555 |
+
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
|
| 556 |
+
batch = batch.union(ref_log_prob)
|
| 557 |
+
|
| 558 |
+
# compute values
|
| 559 |
+
if self.use_critic:
|
| 560 |
+
with marked_timer("values", timing_raw, color="cyan"):
|
| 561 |
+
values = self.critic_wg.compute_values(batch)
|
| 562 |
+
batch = batch.union(values)
|
| 563 |
+
|
| 564 |
+
with marked_timer("adv", timing_raw, color="brown"):
|
| 565 |
+
# we combine with rule-based rm
|
| 566 |
+
reward_extra_infos_dict: dict[str, list]
|
| 567 |
+
if self.config.reward_model.launch_reward_fn_async:
|
| 568 |
+
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
|
| 569 |
+
batch.batch["token_level_scores"] = reward_tensor
|
| 570 |
+
|
| 571 |
+
if reward_extra_infos_dict:
|
| 572 |
+
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
|
| 573 |
+
|
| 574 |
+
# compute rewards. apply_kl_penalty if available
|
| 575 |
+
if self.config.algorithm.use_kl_in_reward:
|
| 576 |
+
batch, kl_metrics = apply_kl_penalty(
|
| 577 |
+
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
|
| 578 |
+
)
|
| 579 |
+
metrics.update(kl_metrics)
|
| 580 |
+
else:
|
| 581 |
+
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
|
| 582 |
+
|
| 583 |
+
# Compute rollout correction: IS weights, rejection sampling, and metrics
|
| 584 |
+
# Only runs in decoupled mode (computes once per batch using stable π_old)
|
| 585 |
+
# In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout
|
| 586 |
+
if (
|
| 587 |
+
rollout_corr_config is not None
|
| 588 |
+
and "rollout_log_probs" in batch.batch
|
| 589 |
+
and not bypass_recomputing_logprobs # Only in decoupled mode
|
| 590 |
+
):
|
| 591 |
+
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch
|
| 592 |
+
|
| 593 |
+
# Compute IS weights, apply rejection sampling, compute metrics
|
| 594 |
+
batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
|
| 595 |
+
# IS and off-policy metrics already have rollout_corr/ prefix
|
| 596 |
+
metrics.update(is_metrics)
|
| 597 |
+
|
| 598 |
+
# compute advantages, executed on the driver process
|
| 599 |
+
norm_adv_by_std_in_grpo = self.config.algorithm.get(
|
| 600 |
+
"norm_adv_by_std_in_grpo", True
|
| 601 |
+
) # GRPO adv normalization factor
|
| 602 |
+
|
| 603 |
+
batch = compute_advantage(
|
| 604 |
+
batch,
|
| 605 |
+
adv_estimator=self.config.algorithm.adv_estimator,
|
| 606 |
+
gamma=self.config.algorithm.gamma,
|
| 607 |
+
lam=self.config.algorithm.lam,
|
| 608 |
+
num_repeat=self.config.actor_rollout_ref.rollout.n,
|
| 609 |
+
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
|
| 610 |
+
config=self.config.algorithm,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
# update critic
|
| 614 |
+
if self.use_critic:
|
| 615 |
+
with marked_timer("update_critic", timing_raw, color="pink"):
|
| 616 |
+
critic_output = self.critic_wg.update_critic(batch)
|
| 617 |
+
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
|
| 618 |
+
metrics.update(critic_output_metrics)
|
| 619 |
+
|
| 620 |
+
# implement critic warmup
|
| 621 |
+
if self.config.trainer.critic_warmup <= self.global_steps:
|
| 622 |
+
# update actor
|
| 623 |
+
with marked_timer("update_actor", timing_raw, color="red"):
|
| 624 |
+
rollout_config = self.config.actor_rollout_ref.rollout
|
| 625 |
+
batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable
|
| 626 |
+
# TODO: Make "temperature" single source of truth from generation.
|
| 627 |
+
batch.meta_info["temperature"] = rollout_config.temperature
|
| 628 |
+
actor_output = self.actor_rollout_wg.update_actor(batch)
|
| 629 |
+
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
|
| 630 |
+
metrics.update(actor_output_metrics)
|
| 631 |
+
|
| 632 |
+
# Log rollout generations if enabled
|
| 633 |
+
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
|
| 634 |
+
if rollout_data_dir:
|
| 635 |
+
self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)
|
| 636 |
+
|
| 637 |
+
# validate
|
| 638 |
+
if (
|
| 639 |
+
self.val_reward_fn is not None
|
| 640 |
+
and self.config.trainer.test_freq > 0
|
| 641 |
+
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
|
| 642 |
+
):
|
| 643 |
+
with marked_timer("testing", timing_raw, color="green"):
|
| 644 |
+
val_metrics: dict = self._validate()
|
| 645 |
+
|
| 646 |
+
# Initialize the best validation metrics for pass@k before training
|
| 647 |
+
self._update_best_pass_at(val_metrics, 1)
|
| 648 |
+
val_metrics["best/pass@1"] = self.best_dev_pass_at_k[1]
|
| 649 |
+
|
| 650 |
+
if is_last_step:
|
| 651 |
+
last_val_metrics = val_metrics
|
| 652 |
+
metrics.update(val_metrics)
|
| 653 |
+
|
| 654 |
+
# Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
|
| 655 |
+
esi_close_to_expiration = should_save_ckpt_esi(
|
| 656 |
+
max_steps_duration=self.max_steps_duration,
|
| 657 |
+
redundant_time=self.config.trainer.esi_redundant_time,
|
| 658 |
+
)
|
| 659 |
+
# Check if the conditions for saving a checkpoint are met.
|
| 660 |
+
# The conditions include a mandatory condition (1) and
|
| 661 |
+
# one of the following optional conditions (2/3/4):
|
| 662 |
+
# 1. The save frequency is set to a positive value.
|
| 663 |
+
# 2. It's the last training step.
|
| 664 |
+
# 3. The current step number is a multiple of the save frequency.
|
| 665 |
+
# 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
|
| 666 |
+
if self.config.trainer.save_freq > 0 and (
|
| 667 |
+
is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration
|
| 668 |
+
):
|
| 669 |
+
if esi_close_to_expiration:
|
| 670 |
+
print("Force saving checkpoint: ESI instance expiration approaching.")
|
| 671 |
+
with marked_timer("save_checkpoint", timing_raw, color="green"):
|
| 672 |
+
self._save_checkpoint()
|
| 673 |
+
|
| 674 |
+
with marked_timer("stop_profile", timing_raw):
|
| 675 |
+
next_step_profile = (
|
| 676 |
+
self.global_steps + 1 in self.config.global_profiler.steps
|
| 677 |
+
if self.config.global_profiler.steps is not None
|
| 678 |
+
else False
|
| 679 |
+
)
|
| 680 |
+
self._stop_profiling(
|
| 681 |
+
curr_step_profile and not next_step_profile
|
| 682 |
+
if self.config.global_profiler.profile_continuous_steps
|
| 683 |
+
else curr_step_profile
|
| 684 |
+
)
|
| 685 |
+
prev_step_profile = curr_step_profile
|
| 686 |
+
curr_step_profile = next_step_profile
|
| 687 |
+
|
| 688 |
+
steps_duration = timing_raw["step"]
|
| 689 |
+
self.max_steps_duration = max(self.max_steps_duration, steps_duration)
|
| 690 |
+
|
| 691 |
+
# training metrics
|
| 692 |
+
metrics.update(
|
| 693 |
+
{
|
| 694 |
+
"training/global_step": self.global_steps,
|
| 695 |
+
"training/epoch": epoch,
|
| 696 |
+
}
|
| 697 |
+
)
|
| 698 |
+
# collect metrics
|
| 699 |
+
metrics.update(
|
| 700 |
+
compute_data_metrics(
|
| 701 |
+
batch=batch,
|
| 702 |
+
use_critic=self.use_critic,
|
| 703 |
+
elliptical=self.config.reward_model.elliptical.enable,
|
| 704 |
+
)
|
| 705 |
+
)
|
| 706 |
+
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
|
| 707 |
+
# TODO: implement actual tflpo and theoretical tflpo
|
| 708 |
+
n_gpus = self.resource_pool_manager.get_n_gpus()
|
| 709 |
+
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
|
| 710 |
+
# Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation
|
| 711 |
+
|
| 712 |
+
# this is experimental and may be changed/removed in the future in favor of a general-purpose one
|
| 713 |
+
if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
|
| 714 |
+
self.train_dataloader.sampler.update(batch=batch)
|
| 715 |
+
|
| 716 |
+
# TODO: make a canonical logger that supports various backend
|
| 717 |
+
logger.log(data=metrics, step=self.global_steps)
|
| 718 |
+
|
| 719 |
+
progress_bar.update(1)
|
| 720 |
+
self.global_steps += 1
|
| 721 |
+
|
| 722 |
+
if (
|
| 723 |
+
hasattr(self.config.actor_rollout_ref.actor, "profiler")
|
| 724 |
+
and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory"
|
| 725 |
+
):
|
| 726 |
+
self.actor_rollout_wg.dump_memory_snapshot(
|
| 727 |
+
tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}"
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
if is_last_step:
|
| 731 |
+
pprint(f"Final validation metrics: {last_val_metrics}")
|
| 732 |
+
progress_bar.close()
|
| 733 |
+
return
|
| 734 |
+
|
| 735 |
+
# this is experimental and may be changed/removed in the future
|
| 736 |
+
# in favor of a general-purpose data buffer pool
|
| 737 |
+
if hasattr(self.train_dataset, "on_batch_end"):
|
| 738 |
+
# The dataset may be changed after each training batch
|
| 739 |
+
self.train_dataset.on_batch_end(batch=batch)
|
ICL/DAPO/verl-recipe/spin/core_algos.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright 2023-2024 SGLang Team
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AdaptiveKLController:
|
| 22 |
+
"""
|
| 23 |
+
Adaptive KL controller described in the paper:
|
| 24 |
+
https://arxiv.org/pdf/1909.08593.pdf
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, init_kl_coef, target_kl, horizon):
|
| 28 |
+
self.value = init_kl_coef
|
| 29 |
+
self.target = target_kl
|
| 30 |
+
self.horizon = horizon
|
| 31 |
+
|
| 32 |
+
def update(self, current_kl, n_steps):
|
| 33 |
+
target = self.target
|
| 34 |
+
proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
|
| 35 |
+
mult = 1 + proportional_error * n_steps / self.horizon
|
| 36 |
+
self.value *= mult
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class FixedKLController:
|
| 40 |
+
"""Fixed KL controller."""
|
| 41 |
+
|
| 42 |
+
def __init__(self, kl_coef):
|
| 43 |
+
self.value = kl_coef
|
| 44 |
+
|
| 45 |
+
def update(self, current_kl, n_steps):
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_kl_controller(kl_ctrl):
|
| 50 |
+
if kl_ctrl.type == "fixed":
|
| 51 |
+
return FixedKLController(kl_coef=kl_ctrl.kl_coef)
|
| 52 |
+
elif kl_ctrl.type == "adaptive":
|
| 53 |
+
assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}"
|
| 54 |
+
return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
|
| 55 |
+
else:
|
| 56 |
+
raise NotImplementedError
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def compute_onlinedpo_pref(
|
| 60 |
+
token_level_rewards: torch.Tensor,
|
| 61 |
+
response_mask: torch.Tensor,
|
| 62 |
+
) -> torch.Tensor:
|
| 63 |
+
"""
|
| 64 |
+
Computes preferences between pairs of sequences based on summed rewards
|
| 65 |
+
and returns a mask aligned with the interleaved batch.
|
| 66 |
+
|
| 67 |
+
Assumes inputs are interleaved: [Resp1_Prompt0, Resp2_Prompt0, Resp1_Prompt1, Resp2_Prompt1, ...]
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
token_level_rewards: Tensor of shape [batch_size * 2, seq_len]
|
| 71 |
+
response_mask: Tensor of shape [batch_size * 2, seq_len]
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
torch.Tensor: A boolean mask of shape [batch_size * 2], where True indicates
|
| 75 |
+
the corresponding entry is the chosen response for its pair.
|
| 76 |
+
Example: [True, False, False, True, ...] means for prompt 0,
|
| 77 |
+
response 1 was chosen; for prompt 1, response 2 was chosen.
|
| 78 |
+
"""
|
| 79 |
+
# print(f"---- [DEBUG] Inside compute_onlinedpo_pref ----")
|
| 80 |
+
if token_level_rewards.shape[0] % 2 != 0 or response_mask.shape[0] % 2 != 0:
|
| 81 |
+
raise ValueError(
|
| 82 |
+
f"Input tensor batch dimension must be even for pair comparison, got shapes: "
|
| 83 |
+
f"{token_level_rewards.shape}, {response_mask.shape}"
|
| 84 |
+
)
|
| 85 |
+
if token_level_rewards.shape != response_mask.shape:
|
| 86 |
+
raise ValueError(f"Shape mismatch between rewards {token_level_rewards.shape} and mask {response_mask.shape}")
|
| 87 |
+
|
| 88 |
+
# 1. Calculate Sequence Scores
|
| 89 |
+
scores = (token_level_rewards * response_mask).sum(dim=-1)
|
| 90 |
+
# print(f" Calculated sequence scores shape: {scores.shape}") # [batch_size * 2]
|
| 91 |
+
|
| 92 |
+
# 2. Reshape scores to group pairs: [batch_size, 2]
|
| 93 |
+
try:
|
| 94 |
+
score_pairs = scores.view(-1, 2)
|
| 95 |
+
except RuntimeError as e:
|
| 96 |
+
print(f"ERROR reshaping scores (shape {scores.shape}) into pairs: {e}")
|
| 97 |
+
raise e
|
| 98 |
+
print(f" Reshaped score pairs shape: {score_pairs.shape}") # [batch_size, 2]
|
| 99 |
+
|
| 100 |
+
# 3. Compare scores to find which index (0 or 1) is the winner within each pair
|
| 101 |
+
# winner_indices[i] = 0 if score_pairs[i, 0] >= score_pairs[i, 1] else 1
|
| 102 |
+
winner_indices = torch.argmax(score_pairs, dim=1) # 0 if first is max, 1 if second is max
|
| 103 |
+
# Handle ties explicitly if argmax behavior isn't guaranteed (usually picks first max)
|
| 104 |
+
# Alternatively: winner_mask_original = score_pairs[:, 0] >= score_pairs[:, 1]
|
| 105 |
+
# print(f" Winner indices shape: {winner_indices.shape}") # [batch_size]
|
| 106 |
+
# print(f" Number where Response 2 (index 1) is preferred: {winner_indices.sum().item()}") # Counts number of 1s
|
| 107 |
+
|
| 108 |
+
# 4. Create the final [batch_size * 2] mask
|
| 109 |
+
num_pairs = score_pairs.shape[0]
|
| 110 |
+
full_batch_size = num_pairs * 2
|
| 111 |
+
# Create indices for the full batch [0, 1, 2, 3, ..., N*2-1]
|
| 112 |
+
# full_indices = torch.arange(full_batch_size, device=scores.device)
|
| 113 |
+
# Create indices corresponding to the winner within each pair's original index
|
| 114 |
+
# E.g., if winner_indices is [0, 1, 0], pair_indices is [0, 1, 2]
|
| 115 |
+
# winner_global_indices = (pair_indices * 2) + winner_indices -> [ (0*2)+0, (1*2)+1, (2*2)+0 ] -> [0, 3, 4]
|
| 116 |
+
pair_indices = torch.arange(num_pairs, device=scores.device)
|
| 117 |
+
winner_global_indices = (pair_indices * 2) + winner_indices
|
| 118 |
+
|
| 119 |
+
# Create boolean mask - True at the winner's position
|
| 120 |
+
output_preference_mask = torch.zeros(full_batch_size, dtype=torch.bool, device=scores.device)
|
| 121 |
+
output_preference_mask[winner_global_indices] = True
|
| 122 |
+
|
| 123 |
+
# print(f" Output preference mask shape: {output_preference_mask.shape}") # Should be [batch_size * 2]
|
| 124 |
+
# print(f" Output mask True count (Chosen): {output_preference_mask.sum().item()}") # Should be batch_size
|
| 125 |
+
# print(f" Output mask False count (Rejected): {(~output_preference_mask).sum().item()}") # Should be batch_size
|
| 126 |
+
# print(f"---- [DEBUG] Exiting compute_onlinedpo_pref ----")
|
| 127 |
+
|
| 128 |
+
return output_preference_mask
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def compute_online_dpo_loss(
|
| 132 |
+
policy_chosen_logps: torch.Tensor,
|
| 133 |
+
policy_rejected_logps: torch.Tensor,
|
| 134 |
+
reference_chosen_logps: torch.Tensor,
|
| 135 |
+
reference_rejected_logps: torch.Tensor,
|
| 136 |
+
beta: float,
|
| 137 |
+
label_smoothing: float = 0.0,
|
| 138 |
+
loss_type: str = "sigmoid",
|
| 139 |
+
reference_free: bool = False,
|
| 140 |
+
) -> torch.Tensor:
|
| 141 |
+
import torch.nn.functional as F
|
| 142 |
+
|
| 143 |
+
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
| 144 |
+
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
| 145 |
+
|
| 146 |
+
if reference_free:
|
| 147 |
+
ref_logratios = torch.zeros_like(pi_logratios)
|
| 148 |
+
|
| 149 |
+
logits = pi_logratios - ref_logratios
|
| 150 |
+
|
| 151 |
+
if loss_type == "sigmoid":
|
| 152 |
+
losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing
|
| 153 |
+
elif loss_type == "ipo":
|
| 154 |
+
losses = (logits - 1 / (2 * beta)) ** 2
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(f"Unsupported loss_type: {loss_type}. Choose 'sigmoid', 'ipo', or 'hinge'.")
|
| 157 |
+
|
| 158 |
+
return losses.mean()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_batch_logps(
|
| 162 |
+
logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False
|
| 163 |
+
) -> torch.FloatTensor:
|
| 164 |
+
"""
|
| 165 |
+
Compute the log probabilities of the given labels under the given logits.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
logits: Logits of the model (e.g., huggingface CausalLMOutputs `logits`).
|
| 169 |
+
Shape: (batch_size, sequence_length, vocab_size)
|
| 170 |
+
labels: Labels for computing the sequence log probabilities. Shape: (batch_size, sequence_length)
|
| 171 |
+
average_log_prob: If True, return the average log probability per sequence. Otherwise, return the sum.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given sequences.
|
| 175 |
+
"""
|
| 176 |
+
if logits.shape[:-1] != labels.shape:
|
| 177 |
+
raise ValueError("Logits and labels must have the same shape[:-1]")
|
| 178 |
+
|
| 179 |
+
# Ensure labels are contiguous and on the same device as logits
|
| 180 |
+
labels = labels.contiguous().to(logits.device)
|
| 181 |
+
# Shift so that tokens < n predict n
|
| 182 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 183 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 184 |
+
|
| 185 |
+
# Calculate per token log probability
|
| 186 |
+
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
|
| 187 |
+
per_token_logps = -loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
| 188 |
+
per_token_logps = per_token_logps.view(
|
| 189 |
+
shift_logits.size(0), shift_logits.size(1)
|
| 190 |
+
) # Reshape back to (batch_size, seq_len-1)
|
| 191 |
+
|
| 192 |
+
# Create a mask for the labels that are not -100
|
| 193 |
+
loss_mask = shift_labels != -100
|
| 194 |
+
|
| 195 |
+
# Apply the mask to the per token log probabilities
|
| 196 |
+
masked_logps = per_token_logps * loss_mask
|
| 197 |
+
|
| 198 |
+
# Calculate the sum or average log probability per sequence
|
| 199 |
+
sequence_logps = masked_logps.sum(dim=-1)
|
| 200 |
+
|
| 201 |
+
if average_log_prob:
|
| 202 |
+
# Avoid division by zero for sequences with no valid tokens
|
| 203 |
+
num_valid_tokens = loss_mask.sum(dim=-1)
|
| 204 |
+
return sequence_logps / torch.clamp(num_valid_tokens, min=1)
|
| 205 |
+
else:
|
| 206 |
+
return sequence_logps
|
ICL/DAPO/verl-recipe/spin/main_spin.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright 2023-2024 SGLang Team
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
import hydra
|
| 19 |
+
import ray
|
| 20 |
+
from recipe.spin.spin_trainer import RaySPINTrainer
|
| 21 |
+
from recipe.spin.utils import validate_config
|
| 22 |
+
|
| 23 |
+
from verl.trainer.ppo.reward import get_custom_reward_fn
|
| 24 |
+
from verl.trainer.ppo.utils import need_reference_policy
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@hydra.main(config_path="config", config_name="spin_trainer", version_base=None)
|
| 28 |
+
def main(config):
|
| 29 |
+
run_ppo(config)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def run_ppo(config) -> None:
|
| 33 |
+
# TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices
|
| 34 |
+
# isolation, will solve in the future
|
| 35 |
+
os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "")
|
| 36 |
+
if not ray.is_initialized():
|
| 37 |
+
# this is for local ray cluster
|
| 38 |
+
ray.init(
|
| 39 |
+
runtime_env={
|
| 40 |
+
"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
|
| 41 |
+
}
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
runner = TaskRunner.remote()
|
| 45 |
+
ray.get(runner.run.remote(config))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
|
| 49 |
+
class TaskRunner:
|
| 50 |
+
def run(self, config):
|
| 51 |
+
# print initial config
|
| 52 |
+
from pprint import pprint
|
| 53 |
+
|
| 54 |
+
from omegaconf import OmegaConf
|
| 55 |
+
|
| 56 |
+
from verl.utils.fs import copy_to_local
|
| 57 |
+
|
| 58 |
+
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
| 59 |
+
OmegaConf.resolve(config)
|
| 60 |
+
|
| 61 |
+
# define worker classes
|
| 62 |
+
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
|
| 63 |
+
assert config.critic.strategy in {"fsdp", "fsdp2"}
|
| 64 |
+
# from recipe.spin.fsdp_workers import ActorRolloutRefWorker
|
| 65 |
+
from recipe.spin.fsdp_workers import SPINRolloutRefWorker
|
| 66 |
+
|
| 67 |
+
from verl.single_controller.ray import RayWorkerGroup
|
| 68 |
+
|
| 69 |
+
ray_worker_group_cls = RayWorkerGroup
|
| 70 |
+
|
| 71 |
+
elif config.actor_rollout_ref.actor.strategy == "megatron":
|
| 72 |
+
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
| 73 |
+
from verl.single_controller.ray import RayWorkerGroup
|
| 74 |
+
|
| 75 |
+
ray_worker_group_cls = RayWorkerGroup
|
| 76 |
+
|
| 77 |
+
else:
|
| 78 |
+
raise NotImplementedError
|
| 79 |
+
|
| 80 |
+
from recipe.spin.spin_trainer import ResourcePoolManager, Role
|
| 81 |
+
|
| 82 |
+
role_worker_mapping = {
|
| 83 |
+
# Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
|
| 84 |
+
Role.ActorRollout: ray.remote(SPINRolloutRefWorker),
|
| 85 |
+
# Role.Critic: ray.remote(CriticWorker),
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
global_pool_id = "global_pool"
|
| 89 |
+
resource_pool_spec = {
|
| 90 |
+
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
|
| 91 |
+
}
|
| 92 |
+
mapping = {
|
| 93 |
+
Role.ActorRollout: global_pool_id,
|
| 94 |
+
# Role.Critic: global_pool_id,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
if config.reward_model.enable:
|
| 98 |
+
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
|
| 99 |
+
from recipe.spin.fsdp_workers import RewardModelWorker
|
| 100 |
+
elif config.reward_model.strategy == "megatron":
|
| 101 |
+
from verl.workers.megatron_workers import RewardModelWorker
|
| 102 |
+
else:
|
| 103 |
+
raise NotImplementedError
|
| 104 |
+
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
|
| 105 |
+
mapping[Role.RewardModel] = global_pool_id
|
| 106 |
+
|
| 107 |
+
# use reference model
|
| 108 |
+
# if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
|
| 109 |
+
# role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
|
| 110 |
+
role_worker_mapping[Role.RefPolicy] = ray.remote(SPINRolloutRefWorker)
|
| 111 |
+
mapping[Role.RefPolicy] = global_pool_id
|
| 112 |
+
|
| 113 |
+
# validate config
|
| 114 |
+
validate_config(
|
| 115 |
+
config=config,
|
| 116 |
+
use_reference_policy=need_reference_policy(role_worker_mapping),
|
| 117 |
+
use_critic=False,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# download the checkpoint from hdfs
|
| 121 |
+
local_path = copy_to_local(config.actor_rollout_ref.model.path)
|
| 122 |
+
|
| 123 |
+
# instantiate tokenizer
|
| 124 |
+
from verl.utils import hf_processor, hf_tokenizer
|
| 125 |
+
|
| 126 |
+
trust_remote_code = config.data.get("trust_remote_code", False)
|
| 127 |
+
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
| 128 |
+
processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
|
| 129 |
+
|
| 130 |
+
from verl.workers.reward_manager import get_reward_manager_cls
|
| 131 |
+
|
| 132 |
+
# Note(haibin.lin): please make sure custom reward managers are imported and
|
| 133 |
+
# registered via `verl.workers.reward_manager.register`
|
| 134 |
+
reward_manager_name = config.reward_model.get("reward_manager", "naive")
|
| 135 |
+
reward_manager_cls = get_reward_manager_cls(reward_manager_name)
|
| 136 |
+
|
| 137 |
+
compute_score = get_custom_reward_fn(config)
|
| 138 |
+
reward_kwargs = dict(config.reward_model.get("reward_kwargs", {}))
|
| 139 |
+
reward_fn = reward_manager_cls(
|
| 140 |
+
tokenizer=tokenizer,
|
| 141 |
+
num_examine=0,
|
| 142 |
+
compute_score=compute_score,
|
| 143 |
+
reward_fn_key=config.data.reward_fn_key,
|
| 144 |
+
**reward_kwargs,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Note that we always use function-based RM for validation
|
| 148 |
+
val_reward_fn = reward_manager_cls(
|
| 149 |
+
tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key
|
| 150 |
+
)
|
| 151 |
+
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
| 152 |
+
|
| 153 |
+
trainer = RaySPINTrainer(
|
| 154 |
+
config=config,
|
| 155 |
+
tokenizer=tokenizer,
|
| 156 |
+
processor=processor,
|
| 157 |
+
role_worker_mapping=role_worker_mapping,
|
| 158 |
+
resource_pool_manager=resource_pool_manager,
|
| 159 |
+
ray_worker_group_cls=ray_worker_group_cls,
|
| 160 |
+
reward_fn=reward_fn,
|
| 161 |
+
val_reward_fn=val_reward_fn,
|
| 162 |
+
)
|
| 163 |
+
trainer.init_workers()
|
| 164 |
+
trainer.fit_dpo()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
main()
|
ICL/DAPO/verl-recipe/spin/spin_trainer.py
ADDED
|
@@ -0,0 +1,1312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright 2023-2024 SGLang Team
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import traceback
|
| 18 |
+
import uuid
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
from contextlib import contextmanager
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
from pprint import pprint
|
| 23 |
+
from typing import Any, Optional
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import ray
|
| 27 |
+
import torch
|
| 28 |
+
from codetiming import Timer
|
| 29 |
+
from omegaconf import OmegaConf, open_dict
|
| 30 |
+
from recipe.spin import core_algos
|
| 31 |
+
from torch.utils.data import Dataset, Sampler
|
| 32 |
+
from torchdata.stateful_dataloader import StatefulDataLoader
|
| 33 |
+
from tqdm import tqdm
|
| 34 |
+
|
| 35 |
+
from verl import DataProto
|
| 36 |
+
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
|
| 37 |
+
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
| 38 |
+
from verl.single_controller.ray.base import create_colocated_worker_cls
|
| 39 |
+
from verl.trainer.ppo.metric_utils import compute_throughout_metrics, compute_timing_metrics, process_validation_metrics
|
| 40 |
+
from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model
|
| 41 |
+
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
|
| 42 |
+
from verl.utils.metric import reduce_metrics
|
| 43 |
+
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
|
| 44 |
+
from verl.utils.torch_functional import masked_mean
|
| 45 |
+
from verl.utils.tracking import ValidationGenerationsLogger
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class ResourcePoolManager:
|
| 50 |
+
"""
|
| 51 |
+
Define a resource pool specification. Resource pool will be initialized first.
|
| 52 |
+
Mapping
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
resource_pool_spec: dict[str, list[int]]
|
| 56 |
+
mapping: dict[Role, str]
|
| 57 |
+
resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)
|
| 58 |
+
|
| 59 |
+
def create_resource_pool(self):
|
| 60 |
+
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
|
| 61 |
+
# max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
|
| 62 |
+
# For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
|
| 63 |
+
# For Megatron backend, we recommend using max_colocate_count>1 that can utilize different
|
| 64 |
+
# WorkerGroup for different models
|
| 65 |
+
resource_pool = RayResourcePool(
|
| 66 |
+
process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name
|
| 67 |
+
)
|
| 68 |
+
self.resource_pool_dict[resource_pool_name] = resource_pool
|
| 69 |
+
|
| 70 |
+
self._check_resource_available()
|
| 71 |
+
|
| 72 |
+
def get_resource_pool(self, role: Role) -> RayResourcePool:
|
| 73 |
+
"""Get the resource pool of the worker_cls"""
|
| 74 |
+
return self.resource_pool_dict[self.mapping[role]]
|
| 75 |
+
|
| 76 |
+
def get_n_gpus(self) -> int:
|
| 77 |
+
"""Get the number of gpus in this cluster."""
|
| 78 |
+
return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
|
| 79 |
+
|
| 80 |
+
def _check_resource_available(self):
|
| 81 |
+
"""Check if the resource pool can be satisfied in this ray cluster."""
|
| 82 |
+
node_available_resources = ray._private.state.available_resources_per_node()
|
| 83 |
+
node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()}
|
| 84 |
+
|
| 85 |
+
# check total required gpus can be satisfied
|
| 86 |
+
total_available_gpus = sum(node_available_gpus.values())
|
| 87 |
+
total_required_gpus = sum(
|
| 88 |
+
[n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]
|
| 89 |
+
)
|
| 90 |
+
if total_available_gpus < total_required_gpus:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# check each resource pool can be satisfied, O(#resource_pools * #nodes)
|
| 96 |
+
for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
|
| 97 |
+
num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes)
|
| 98 |
+
for node, available_gpus in node_available_gpus.items():
|
| 99 |
+
if available_gpus >= num_gpus:
|
| 100 |
+
node_available_gpus[node] -= num_gpus
|
| 101 |
+
num_nodes -= 1
|
| 102 |
+
if num_nodes == 0:
|
| 103 |
+
break
|
| 104 |
+
if num_nodes > 0:
|
| 105 |
+
raise ValueError(
|
| 106 |
+
f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this "
|
| 107 |
+
f"ray cluster"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _compute_response_info(batch: DataProto) -> dict[str, Any]:
|
| 112 |
+
"""Placeholder: Computes prompt and response lengths."""
|
| 113 |
+
try:
|
| 114 |
+
# Assuming 'prompts' and 'responses' keys exist after generation/union
|
| 115 |
+
prompt_len = batch.batch["prompts"].shape[1]
|
| 116 |
+
resp_len = batch.batch["responses"].shape[1]
|
| 117 |
+
# This is simplified - real implementation might use attention masks
|
| 118 |
+
# to get actual lengths per sample.
|
| 119 |
+
batch_size = batch.batch.batch_size[0]
|
| 120 |
+
prompt_lengths_tensor = torch.full((batch_size,), prompt_len, dtype=torch.float32, device=batch.batch.device)
|
| 121 |
+
response_lengths_tensor = torch.full((batch_size,), resp_len, dtype=torch.float32, device=batch.batch.device)
|
| 122 |
+
|
| 123 |
+
# Try getting actual lengths from attention mask if possible (more accurate)
|
| 124 |
+
if "response_mask" in batch.batch:
|
| 125 |
+
response_lengths_tensor = batch.batch["response_mask"].sum(dim=1).float()
|
| 126 |
+
# if "attention_mask" in batch.batch and "response_mask" in batch.batch:
|
| 127 |
+
# full_mask = batch.batch["attention_mask"]
|
| 128 |
+
# resp_mask = batch.batch["response_mask"]
|
| 129 |
+
# Infer prompt mask length based on where response mask starts or total length
|
| 130 |
+
# This logic depends heavily on how your masks are constructed.
|
| 131 |
+
# Example: prompt_lengths_tensor = full_mask.sum(dim=1).float() - response_lengths_tensor
|
| 132 |
+
# Fallback to using prompt shape if mask logic is complex:
|
| 133 |
+
prompt_lengths_tensor = torch.tensor(
|
| 134 |
+
[batch.batch["prompts"].shape[1]] * batch_size, dtype=torch.float32, device=batch.batch.device
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
return {
|
| 138 |
+
"prompt_length": prompt_lengths_tensor,
|
| 139 |
+
"response_length": response_lengths_tensor,
|
| 140 |
+
"max_response_length": resp_len,
|
| 141 |
+
"max_prompt_length": prompt_len, # Or from config if fixed padding
|
| 142 |
+
}
|
| 143 |
+
except KeyError as e:
|
| 144 |
+
print(f"Warning: Missing key in _compute_response_info: {e}. Returning defaults.")
|
| 145 |
+
# Return default/dummy values if keys are missing
|
| 146 |
+
b_size = batch.batch.batch_size[0] if batch.batch.batch_size else 1
|
| 147 |
+
max_resp = batch.batch.get("responses").shape[1] if batch.batch.get("responses") is not None else 0
|
| 148 |
+
max_prompt = batch.batch.get("prompts").shape[1] if batch.batch.get("prompts") is not None else 0
|
| 149 |
+
return {
|
| 150 |
+
"prompt_length": torch.zeros(b_size),
|
| 151 |
+
"response_length": torch.zeros(b_size),
|
| 152 |
+
"max_response_length": max_resp,
|
| 153 |
+
"max_prompt_length": max_prompt,
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# --- Modified Metric Function ---
|
| 158 |
+
def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]:
|
| 159 |
+
"""
|
| 160 |
+
Computes and returns metrics relevant for the DPO-like process.
|
| 161 |
+
Assumes 'batch' contains results after generation and preference marking,
|
| 162 |
+
potentially including 'dpo_logits', 'preferences', 'chosen_logps', etc.
|
| 163 |
+
Removes PPO-specific advantage/return/critic metrics.
|
| 164 |
+
"""
|
| 165 |
+
print("---- [DEBUG] Computing DPO Data Metrics ----")
|
| 166 |
+
metrics = {}
|
| 167 |
+
try:
|
| 168 |
+
# --- Scores and Rewards (from reward_fn) ---
|
| 169 |
+
if "token_level_scores" in batch.batch and batch.batch["token_level_scores"] is not None:
|
| 170 |
+
sequence_score = batch.batch["token_level_scores"].sum(-1)
|
| 171 |
+
metrics.update(
|
| 172 |
+
{
|
| 173 |
+
"reward/score/mean": torch.mean(sequence_score).item(),
|
| 174 |
+
"reward/score/max": torch.max(sequence_score).item(),
|
| 175 |
+
"reward/score/min": torch.min(sequence_score).item(),
|
| 176 |
+
}
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
print("DEBUG compute_dpo_data_metrics: 'token_level_scores' not found.")
|
| 180 |
+
|
| 181 |
+
if "token_level_rewards" in batch.batch and batch.batch["token_level_rewards"] is not None:
|
| 182 |
+
sequence_reward = batch.batch["token_level_rewards"].sum(-1)
|
| 183 |
+
metrics.update(
|
| 184 |
+
{
|
| 185 |
+
"reward/rewards/mean": torch.mean(sequence_reward).item(),
|
| 186 |
+
"reward/rewards/max": torch.max(sequence_reward).item(),
|
| 187 |
+
"reward/rewards/min": torch.min(sequence_reward).item(),
|
| 188 |
+
}
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
print("DEBUG compute_dpo_data_metrics: 'token_level_rewards' not found.")
|
| 192 |
+
|
| 193 |
+
# --- DPO Specific Metrics (if stored previously) ---
|
| 194 |
+
if "dpo_logits" in batch.batch and batch.batch["dpo_logits"] is not None:
|
| 195 |
+
metrics["actor/dpo_logits"] = batch.batch["dpo_logits"].mean().item()
|
| 196 |
+
else:
|
| 197 |
+
print("DEBUG compute_dpo_data_metrics: 'dpo_logits' not found.")
|
| 198 |
+
|
| 199 |
+
if "chosen_logps" in batch.batch and batch.batch["chosen_logps"] is not None:
|
| 200 |
+
metrics["actor/chosen_logps"] = batch.batch["chosen_logps"].mean().item()
|
| 201 |
+
else:
|
| 202 |
+
print("DEBUG compute_dpo_data_metrics: 'chosen_logps' not found.")
|
| 203 |
+
|
| 204 |
+
if "rejected_logps" in batch.batch and batch.batch["rejected_logps"] is not None:
|
| 205 |
+
metrics["actor/rejected_logps"] = batch.batch["rejected_logps"].mean().item()
|
| 206 |
+
else:
|
| 207 |
+
print("DEBUG compute_dpo_data_metrics: 'rejected_logps' not found.")
|
| 208 |
+
|
| 209 |
+
# Add metrics based on the 'preferences' mask if available
|
| 210 |
+
# if "preferences" in batch.batch and batch.batch["preferences"] is not None:
|
| 211 |
+
# prefs_mask = batch.batch["preferences"] # Shape [batch_size * n]
|
| 212 |
+
# Calculate accuracy based on RM scores (assuming higher score -> True in mask)
|
| 213 |
+
# Requires chosen/rejected scores to be available or recalculated
|
| 214 |
+
# This is complex here, better calculated in the main loop or update function
|
| 215 |
+
|
| 216 |
+
# --- Length Metrics ---
|
| 217 |
+
response_info = _compute_response_info(batch)
|
| 218 |
+
prompt_length = response_info["prompt_length"]
|
| 219 |
+
response_length = response_info["response_length"]
|
| 220 |
+
max_response_length = response_info["max_response_length"]
|
| 221 |
+
max_prompt_length = response_info["max_prompt_length"] # Use calculated or from config
|
| 222 |
+
|
| 223 |
+
metrics.update(
|
| 224 |
+
{
|
| 225 |
+
"response_length/mean": torch.mean(response_length).item(),
|
| 226 |
+
"response_length/max": torch.max(response_length).item(),
|
| 227 |
+
"response_length/min": torch.min(response_length).item(),
|
| 228 |
+
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()).item(),
|
| 229 |
+
"prompt_length/mean": torch.mean(prompt_length).item(),
|
| 230 |
+
"prompt_length/max": torch.max(prompt_length).item(),
|
| 231 |
+
"prompt_length/min": torch.min(prompt_length).item(),
|
| 232 |
+
# Prompt clip ratio might need adjustment based on how max_prompt_length is defined
|
| 233 |
+
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).item(),
|
| 234 |
+
}
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
except KeyError as e:
|
| 238 |
+
print(f"ERROR in compute_dpo_data_metrics: Missing key {e}")
|
| 239 |
+
except Exception as e:
|
| 240 |
+
print(f"ERROR in compute_dpo_data_metrics: {e}")
|
| 241 |
+
traceback.print_exc()
|
| 242 |
+
|
| 243 |
+
print(f"---- [DEBUG] Calculated DPO Data Metrics: {list(metrics.keys())} ----")
|
| 244 |
+
return metrics
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"):
|
| 248 |
+
responses = data.batch["responses"]
|
| 249 |
+
response_length = responses.size(1)
|
| 250 |
+
token_level_scores = data.batch["token_level_scores"]
|
| 251 |
+
batch_size = data.batch.batch_size[0]
|
| 252 |
+
attention_mask = data.batch["attention_mask"]
|
| 253 |
+
response_mask = attention_mask[:, -response_length:]
|
| 254 |
+
|
| 255 |
+
# compute kl between ref_policy and current policy
|
| 256 |
+
# When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.
|
| 257 |
+
kld = core_algos.kl_penalty(
|
| 258 |
+
data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty
|
| 259 |
+
) # (batch_size, response_length)
|
| 260 |
+
kld = kld * response_mask
|
| 261 |
+
beta = kl_ctrl.value
|
| 262 |
+
|
| 263 |
+
token_level_rewards = token_level_scores - beta * kld
|
| 264 |
+
|
| 265 |
+
current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
|
| 266 |
+
current_kl = torch.mean(current_kl, dim=0).item()
|
| 267 |
+
|
| 268 |
+
# according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
|
| 269 |
+
kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
|
| 270 |
+
data.batch["token_level_rewards"] = token_level_rewards
|
| 271 |
+
|
| 272 |
+
metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta}
|
| 273 |
+
|
| 274 |
+
return data, metrics
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def compute_response_mask(data: DataProto):
|
| 278 |
+
responses = data.batch["responses"]
|
| 279 |
+
response_length = responses.size(1)
|
| 280 |
+
attention_mask = data.batch["attention_mask"]
|
| 281 |
+
return attention_mask[:, -response_length:]
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def compute_onlineDPO_pref(data: DataProto):
|
| 285 |
+
"""
|
| 286 |
+
Wrapper to compute DPO preference and add it to the DataProto batch.
|
| 287 |
+
Includes debugging prints.
|
| 288 |
+
"""
|
| 289 |
+
# print(f"\n---- [DEBUG] Entering compute_onlineDPO_pref ----")
|
| 290 |
+
# print(f" Input batch keys: {list(data.batch.keys())}")
|
| 291 |
+
|
| 292 |
+
# Check inputs
|
| 293 |
+
rewards_tensor = data.batch.get("token_level_rewards")
|
| 294 |
+
mask_tensor = data.batch.get("response_mask")
|
| 295 |
+
|
| 296 |
+
if rewards_tensor is None or mask_tensor is None:
|
| 297 |
+
print(" ERROR: Missing 'token_level_rewards' or 'response_mask' in input data!")
|
| 298 |
+
# Handle error case - maybe return original data or raise?
|
| 299 |
+
# Returning original data for now to potentially allow skipping
|
| 300 |
+
return data
|
| 301 |
+
|
| 302 |
+
try:
|
| 303 |
+
preferences = core_algos.compute_onlinedpo_pref(token_level_rewards=rewards_tensor, response_mask=mask_tensor)
|
| 304 |
+
# Store the result
|
| 305 |
+
data.batch["preferences"] = preferences
|
| 306 |
+
|
| 307 |
+
except AttributeError:
|
| 308 |
+
print("ERROR: Function 'compute_online_dpo_preference' not found in core_algos.py!")
|
| 309 |
+
# Assign dummy value or raise error
|
| 310 |
+
data.batch["preferences"] = None # Indicate failure
|
| 311 |
+
except Exception as e_pref:
|
| 312 |
+
print(f"ERROR during core_algos.compute_online_dpo_preference: {e_pref}")
|
| 313 |
+
import traceback
|
| 314 |
+
|
| 315 |
+
traceback.print_exc()
|
| 316 |
+
data.batch["preferences"] = None # Indicate failure
|
| 317 |
+
|
| 318 |
+
# print(f"---- [DEBUG] Exiting compute_onlineDPO_pref ----")
|
| 319 |
+
return data
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@contextmanager
|
| 323 |
+
def _timer(name: str, timing_raw: dict[str, float]):
|
| 324 |
+
with Timer(name=name, logger=None) as timer:
|
| 325 |
+
yield
|
| 326 |
+
timing_raw[name] = timer.last
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class RaySPINTrainer:
|
| 330 |
+
"""
|
| 331 |
+
Note that this trainer runs on the driver process on a single CPU/GPU node.
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
# TODO: support each role have individual ray_worker_group_cls,
|
| 335 |
+
# i.e., support different backend of different role
|
| 336 |
+
def __init__(
|
| 337 |
+
self,
|
| 338 |
+
config,
|
| 339 |
+
tokenizer,
|
| 340 |
+
role_worker_mapping: dict[Role, WorkerType],
|
| 341 |
+
resource_pool_manager: ResourcePoolManager,
|
| 342 |
+
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
|
| 343 |
+
processor=None,
|
| 344 |
+
reward_fn=None,
|
| 345 |
+
val_reward_fn=None,
|
| 346 |
+
train_dataset: Optional[Dataset] = None,
|
| 347 |
+
val_dataset: Optional[Dataset] = None,
|
| 348 |
+
collate_fn=None,
|
| 349 |
+
train_sampler: Optional[Sampler] = None,
|
| 350 |
+
device_name=None,
|
| 351 |
+
):
|
| 352 |
+
# assert get_torch_device().is_available(), 'cuda must be available on driver'
|
| 353 |
+
|
| 354 |
+
self.tokenizer = tokenizer
|
| 355 |
+
self.processor = processor
|
| 356 |
+
self.config = config
|
| 357 |
+
self.reward_fn = reward_fn
|
| 358 |
+
self.val_reward_fn = val_reward_fn
|
| 359 |
+
|
| 360 |
+
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
|
| 361 |
+
assert self.hybrid_engine, "Currently, only support hybrid engine"
|
| 362 |
+
|
| 363 |
+
if self.hybrid_engine:
|
| 364 |
+
assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}"
|
| 365 |
+
|
| 366 |
+
self.role_worker_mapping = role_worker_mapping
|
| 367 |
+
self.resource_pool_manager = resource_pool_manager
|
| 368 |
+
self.use_reference_policy = need_reference_policy(role_worker_mapping)
|
| 369 |
+
self.use_rm = need_reward_model(role_worker_mapping)
|
| 370 |
+
self.use_critic = False
|
| 371 |
+
self.ray_worker_group_cls = ray_worker_group_cls
|
| 372 |
+
self.validation_generations_logger = ValidationGenerationsLogger()
|
| 373 |
+
self.async_rollout_mode = False
|
| 374 |
+
self.device_name = device_name if device_name else self.config.trainer.device
|
| 375 |
+
|
| 376 |
+
# define in-reward KL control
|
| 377 |
+
# kl loss control currently not suppoorted
|
| 378 |
+
if config.algorithm.use_kl_in_reward:
|
| 379 |
+
self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
|
| 380 |
+
|
| 381 |
+
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
|
| 382 |
+
|
| 383 |
+
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):
|
| 384 |
+
"""
|
| 385 |
+
Creates the train and validation dataloaders.
|
| 386 |
+
"""
|
| 387 |
+
# TODO: we have to make sure the batch size is divisible by the dp size
|
| 388 |
+
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
|
| 389 |
+
|
| 390 |
+
if train_dataset is None:
|
| 391 |
+
train_dataset = create_rl_dataset(
|
| 392 |
+
self.config.data.train_files,
|
| 393 |
+
self.config.data,
|
| 394 |
+
self.tokenizer,
|
| 395 |
+
self.processor,
|
| 396 |
+
max_samples=self.config.data.get("train_max_samples", -1),
|
| 397 |
+
)
|
| 398 |
+
if val_dataset is None:
|
| 399 |
+
val_dataset = create_rl_dataset(
|
| 400 |
+
self.config.data.val_files,
|
| 401 |
+
self.config.data,
|
| 402 |
+
self.tokenizer,
|
| 403 |
+
self.processor,
|
| 404 |
+
max_samples=self.config.data.get("val_max_samples", -1),
|
| 405 |
+
)
|
| 406 |
+
self.train_dataset, self.val_dataset = train_dataset, val_dataset
|
| 407 |
+
|
| 408 |
+
if train_sampler is None:
|
| 409 |
+
train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
|
| 410 |
+
if collate_fn is None:
|
| 411 |
+
from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
|
| 412 |
+
|
| 413 |
+
collate_fn = default_collate_fn
|
| 414 |
+
|
| 415 |
+
self.train_dataloader = StatefulDataLoader(
|
| 416 |
+
dataset=self.train_dataset,
|
| 417 |
+
batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
|
| 418 |
+
num_workers=self.config.data.get("dataloader_num_workers", 8),
|
| 419 |
+
drop_last=True,
|
| 420 |
+
collate_fn=collate_fn,
|
| 421 |
+
sampler=train_sampler,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
val_batch_size = self.config.data.val_batch_size # Prefer config value if set
|
| 425 |
+
if val_batch_size is None:
|
| 426 |
+
val_batch_size = len(self.val_dataset)
|
| 427 |
+
|
| 428 |
+
self.val_dataloader = StatefulDataLoader(
|
| 429 |
+
dataset=self.val_dataset,
|
| 430 |
+
batch_size=val_batch_size,
|
| 431 |
+
num_workers=self.config.data.get("dataloader_num_workers", 8),
|
| 432 |
+
shuffle=False,
|
| 433 |
+
drop_last=False,
|
| 434 |
+
collate_fn=collate_fn,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
assert len(self.train_dataloader) >= 1, "Train dataloader is empty!"
|
| 438 |
+
assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!"
|
| 439 |
+
|
| 440 |
+
print(
|
| 441 |
+
f"Size of train dataloader: {len(self.train_dataloader)}, "
|
| 442 |
+
f"Size of val dataloader: {len(self.val_dataloader)}"
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
|
| 446 |
+
|
| 447 |
+
if self.config.trainer.total_training_steps is not None:
|
| 448 |
+
total_training_steps = self.config.trainer.total_training_steps
|
| 449 |
+
|
| 450 |
+
self.total_training_steps = total_training_steps
|
| 451 |
+
print(f"Total training steps: {self.total_training_steps}")
|
| 452 |
+
|
| 453 |
+
try:
|
| 454 |
+
OmegaConf.set_struct(self.config, True)
|
| 455 |
+
with open_dict(self.config):
|
| 456 |
+
if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"):
|
| 457 |
+
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
|
| 458 |
+
if OmegaConf.select(self.config, "critic.optim"):
|
| 459 |
+
self.config.critic.optim.total_training_steps = total_training_steps
|
| 460 |
+
except Exception as e:
|
| 461 |
+
print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")
|
| 462 |
+
|
| 463 |
+
def _maybe_log_val_generations(self, inputs, outputs, scores):
|
| 464 |
+
"""Log a table of validation samples to the configured logger (wandb or swanlab)"""
|
| 465 |
+
|
| 466 |
+
generations_to_log = self.config.trainer.log_val_generations
|
| 467 |
+
|
| 468 |
+
if generations_to_log == 0:
|
| 469 |
+
return
|
| 470 |
+
|
| 471 |
+
import numpy as np
|
| 472 |
+
|
| 473 |
+
# Create tuples of (input, output, score) and sort by input text
|
| 474 |
+
samples = list(zip(inputs, outputs, scores, strict=True))
|
| 475 |
+
samples.sort(key=lambda x: x[0]) # Sort by input text
|
| 476 |
+
|
| 477 |
+
# Use fixed random seed for deterministic shuffling
|
| 478 |
+
rng = np.random.RandomState(42)
|
| 479 |
+
rng.shuffle(samples)
|
| 480 |
+
|
| 481 |
+
# Take first N samples after shuffling
|
| 482 |
+
samples = samples[:generations_to_log]
|
| 483 |
+
|
| 484 |
+
# Log to each configured logger
|
| 485 |
+
self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)
|
| 486 |
+
|
| 487 |
+
def _validate(self):
|
| 488 |
+
data_source_lst = []
|
| 489 |
+
reward_extra_infos_dict: dict[str, list] = defaultdict(list)
|
| 490 |
+
|
| 491 |
+
# Lists to collect samples for the table
|
| 492 |
+
sample_inputs = []
|
| 493 |
+
sample_outputs = []
|
| 494 |
+
sample_scores = []
|
| 495 |
+
|
| 496 |
+
for test_data in self.val_dataloader:
|
| 497 |
+
test_batch = DataProto.from_single_dict(test_data)
|
| 498 |
+
|
| 499 |
+
# repeat test batch
|
| 500 |
+
test_batch = test_batch.repeat(
|
| 501 |
+
repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
# we only do validation on rule-based rm
|
| 505 |
+
if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
|
| 506 |
+
return {}
|
| 507 |
+
|
| 508 |
+
# Store original inputs
|
| 509 |
+
input_ids = test_batch.batch["input_ids"]
|
| 510 |
+
# TODO: Can we keep special tokens except for padding tokens?
|
| 511 |
+
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
|
| 512 |
+
sample_inputs.extend(input_texts)
|
| 513 |
+
|
| 514 |
+
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
|
| 515 |
+
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
|
| 516 |
+
if "multi_modal_inputs" in test_batch.non_tensor_batch:
|
| 517 |
+
non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"])
|
| 518 |
+
if "raw_prompt" in test_batch.non_tensor_batch:
|
| 519 |
+
non_tensor_batch_keys_to_pop.append("raw_prompt")
|
| 520 |
+
if "tools_kwargs" in test_batch.non_tensor_batch:
|
| 521 |
+
non_tensor_batch_keys_to_pop.append("tools_kwargs")
|
| 522 |
+
test_gen_batch = test_batch.pop(
|
| 523 |
+
batch_keys=batch_keys_to_pop,
|
| 524 |
+
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
test_gen_batch.meta_info = {
|
| 528 |
+
"eos_token_id": self.tokenizer.eos_token_id,
|
| 529 |
+
"pad_token_id": self.tokenizer.pad_token_id,
|
| 530 |
+
"recompute_log_prob": False,
|
| 531 |
+
"do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
|
| 532 |
+
"validate": True,
|
| 533 |
+
}
|
| 534 |
+
print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")
|
| 535 |
+
|
| 536 |
+
# pad to be divisible by dp_size
|
| 537 |
+
test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
|
| 538 |
+
if not self.async_rollout_mode:
|
| 539 |
+
test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
|
| 540 |
+
else:
|
| 541 |
+
test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
|
| 542 |
+
|
| 543 |
+
# unpad
|
| 544 |
+
test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
|
| 545 |
+
print("validation generation end")
|
| 546 |
+
|
| 547 |
+
# Store generated outputs
|
| 548 |
+
output_ids = test_output_gen_batch.batch["responses"]
|
| 549 |
+
output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
|
| 550 |
+
sample_outputs.extend(output_texts)
|
| 551 |
+
|
| 552 |
+
test_batch = test_batch.union(test_output_gen_batch)
|
| 553 |
+
|
| 554 |
+
# evaluate using reward_function
|
| 555 |
+
result = self.val_reward_fn(test_batch, return_dict=True)
|
| 556 |
+
reward_tensor = result["reward_tensor"]
|
| 557 |
+
scores = reward_tensor.sum(-1).cpu().tolist()
|
| 558 |
+
sample_scores.extend(scores)
|
| 559 |
+
|
| 560 |
+
reward_extra_infos_dict["reward"].extend(scores)
|
| 561 |
+
if "reward_extra_info" in result:
|
| 562 |
+
for key, lst in result["reward_extra_info"].items():
|
| 563 |
+
reward_extra_infos_dict[key].extend(lst)
|
| 564 |
+
|
| 565 |
+
data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))
|
| 566 |
+
|
| 567 |
+
self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
|
| 568 |
+
|
| 569 |
+
# dump generations
|
| 570 |
+
val_data_dir = self.config.trainer.get("validation_data_dir", None)
|
| 571 |
+
if val_data_dir:
|
| 572 |
+
sample_gts = [
|
| 573 |
+
item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch
|
| 574 |
+
]
|
| 575 |
+
self._dump_generations(
|
| 576 |
+
inputs=sample_inputs,
|
| 577 |
+
outputs=sample_outputs,
|
| 578 |
+
gts=sample_gts,
|
| 579 |
+
scores=sample_scores,
|
| 580 |
+
reward_extra_infos_dict=reward_extra_infos_dict,
|
| 581 |
+
dump_path=val_data_dir,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
for key_info, lst in reward_extra_infos_dict.items():
|
| 585 |
+
assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"
|
| 586 |
+
|
| 587 |
+
data_sources = np.concatenate(data_source_lst, axis=0)
|
| 588 |
+
print(f"DEBUG: Data sources shape: {data_sources.shape}") # Added Print
|
| 589 |
+
print(f"DEBUG: reward_extra_infos_dict keys before processing: {reward_extra_infos_dict.keys()}") # Added Print
|
| 590 |
+
|
| 591 |
+
data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)
|
| 592 |
+
print(
|
| 593 |
+
f"DEBUG: Output of process_validation_metrics (data_src2var2metric2val): {data_src2var2metric2val}"
|
| 594 |
+
) # Added Print
|
| 595 |
+
metric_dict = {}
|
| 596 |
+
for data_source, var2metric2val in data_src2var2metric2val.items():
|
| 597 |
+
core_var = "acc" if "acc" in var2metric2val else "reward"
|
| 598 |
+
for var_name, metric2val in var2metric2val.items():
|
| 599 |
+
n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
|
| 600 |
+
for metric_name, metric_val in metric2val.items():
|
| 601 |
+
if (
|
| 602 |
+
(var_name == core_var)
|
| 603 |
+
and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"])
|
| 604 |
+
and (f"@{n_max}" in metric_name)
|
| 605 |
+
):
|
| 606 |
+
metric_sec = "val-core"
|
| 607 |
+
else:
|
| 608 |
+
metric_sec = "val-aux"
|
| 609 |
+
pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
|
| 610 |
+
metric_dict[pfx] = metric_val
|
| 611 |
+
|
| 612 |
+
return metric_dict
|
| 613 |
+
|
| 614 |
+
def init_workers(self):
|
| 615 |
+
"""Init resource pool and worker group"""
|
| 616 |
+
self.resource_pool_manager.create_resource_pool()
|
| 617 |
+
|
| 618 |
+
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
|
| 619 |
+
|
| 620 |
+
# create actor and rollout
|
| 621 |
+
if self.hybrid_engine:
|
| 622 |
+
resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
|
| 623 |
+
actor_rollout_cls = RayClassWithInitArgs(
|
| 624 |
+
cls=self.role_worker_mapping[Role.ActorRollout],
|
| 625 |
+
config=self.config.actor_rollout_ref,
|
| 626 |
+
role="actor_rollout",
|
| 627 |
+
)
|
| 628 |
+
self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
|
| 629 |
+
else:
|
| 630 |
+
raise NotImplementedError
|
| 631 |
+
|
| 632 |
+
# create critic
|
| 633 |
+
if self.use_critic:
|
| 634 |
+
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
|
| 635 |
+
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
|
| 636 |
+
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
|
| 637 |
+
|
| 638 |
+
# create reference policy if needed
|
| 639 |
+
if self.use_reference_policy:
|
| 640 |
+
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
|
| 641 |
+
ref_policy_cls = RayClassWithInitArgs(
|
| 642 |
+
self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref"
|
| 643 |
+
)
|
| 644 |
+
self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
|
| 645 |
+
|
| 646 |
+
# create a reward model if reward_fn is None
|
| 647 |
+
if self.use_rm:
|
| 648 |
+
# we create a RM here
|
| 649 |
+
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
|
| 650 |
+
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
|
| 651 |
+
self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
|
| 652 |
+
|
| 653 |
+
# initialize WorkerGroup
|
| 654 |
+
# NOTE: if you want to use a different resource pool for each role, which can support different
|
| 655 |
+
# parallel size,
|
| 656 |
+
# you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to
|
| 657 |
+
# different worker groups.
|
| 658 |
+
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
|
| 659 |
+
all_wg = {}
|
| 660 |
+
self.wg_dicts = []
|
| 661 |
+
wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
|
| 662 |
+
if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
|
| 663 |
+
wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
|
| 664 |
+
wg_kwargs["device_name"] = self.device_name
|
| 665 |
+
|
| 666 |
+
for resource_pool, class_dict in self.resource_pool_to_cls.items():
|
| 667 |
+
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
|
| 668 |
+
wg_dict = self.ray_worker_group_cls(
|
| 669 |
+
resource_pool=resource_pool,
|
| 670 |
+
ray_cls_with_init=worker_dict_cls,
|
| 671 |
+
**wg_kwargs,
|
| 672 |
+
)
|
| 673 |
+
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
|
| 674 |
+
all_wg.update(spawn_wg)
|
| 675 |
+
# keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
|
| 676 |
+
self.wg_dicts.append(wg_dict)
|
| 677 |
+
|
| 678 |
+
if self.use_critic:
|
| 679 |
+
self.critic_wg = all_wg["critic"]
|
| 680 |
+
self.critic_wg.init_model()
|
| 681 |
+
|
| 682 |
+
if self.use_reference_policy:
|
| 683 |
+
self.ref_policy_wg = all_wg["ref"]
|
| 684 |
+
self.ref_policy_wg.init_model()
|
| 685 |
+
|
| 686 |
+
if self.use_rm:
|
| 687 |
+
self.rm_wg = all_wg["rm"]
|
| 688 |
+
self.rm_wg.init_model()
|
| 689 |
+
|
| 690 |
+
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
|
| 691 |
+
self.actor_rollout_wg = all_wg["actor_rollout"]
|
| 692 |
+
self.actor_rollout_wg.init_model()
|
| 693 |
+
|
| 694 |
+
def _save_checkpoint(self):
|
| 695 |
+
# path: given_path + `/global_step_{global_steps}` + `/actor`
|
| 696 |
+
local_global_step_folder = os.path.join(
|
| 697 |
+
self.config.trainer.default_local_dir, f"global_step_{self.global_steps}"
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
print(f"local_global_step_folder: {local_global_step_folder}")
|
| 701 |
+
actor_local_path = os.path.join(local_global_step_folder, "actor")
|
| 702 |
+
|
| 703 |
+
actor_remote_path = (
|
| 704 |
+
None
|
| 705 |
+
if self.config.trainer.default_hdfs_dir is None
|
| 706 |
+
else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor")
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False)
|
| 710 |
+
if remove_previous_ckpt_in_save:
|
| 711 |
+
print(
|
| 712 |
+
"Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and "
|
| 713 |
+
"max_critic_ckpt_to_keep=1 instead"
|
| 714 |
+
)
|
| 715 |
+
max_actor_ckpt_to_keep = (
|
| 716 |
+
self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1
|
| 717 |
+
)
|
| 718 |
+
max_critic_ckpt_to_keep = (
|
| 719 |
+
self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
self.actor_rollout_wg.save_checkpoint(
|
| 723 |
+
actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
if self.use_critic:
|
| 727 |
+
critic_local_path = os.path.join(local_global_step_folder, "critic")
|
| 728 |
+
critic_remote_path = (
|
| 729 |
+
None
|
| 730 |
+
if self.config.trainer.default_hdfs_dir is None
|
| 731 |
+
else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic")
|
| 732 |
+
)
|
| 733 |
+
self.critic_wg.save_checkpoint(
|
| 734 |
+
critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
# save dataloader
|
| 738 |
+
dataloader_local_path = os.path.join(local_global_step_folder, "data.pt")
|
| 739 |
+
dataloader_state_dict = self.train_dataloader.state_dict()
|
| 740 |
+
torch.save(dataloader_state_dict, dataloader_local_path)
|
| 741 |
+
|
| 742 |
+
# latest checkpointed iteration tracker (for atomic usage)
|
| 743 |
+
local_latest_checkpointed_iteration = os.path.join(
|
| 744 |
+
self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt"
|
| 745 |
+
)
|
| 746 |
+
with open(local_latest_checkpointed_iteration, "w") as f:
|
| 747 |
+
f.write(str(self.global_steps))
|
| 748 |
+
|
| 749 |
+
def _load_checkpoint(self):
|
| 750 |
+
if self.config.trainer.resume_mode == "disable":
|
| 751 |
+
return 0
|
| 752 |
+
|
| 753 |
+
# load from hdfs
|
| 754 |
+
if self.config.trainer.default_hdfs_dir is not None:
|
| 755 |
+
raise NotImplementedError("load from hdfs is not implemented yet")
|
| 756 |
+
else:
|
| 757 |
+
checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path
|
| 758 |
+
if not os.path.isabs(checkpoint_folder):
|
| 759 |
+
working_dir = os.getcwd()
|
| 760 |
+
checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
|
| 761 |
+
global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest
|
| 762 |
+
|
| 763 |
+
# find global_step_folder
|
| 764 |
+
if self.config.trainer.resume_mode == "auto":
|
| 765 |
+
if global_step_folder is None:
|
| 766 |
+
print("Training from scratch")
|
| 767 |
+
return 0
|
| 768 |
+
else:
|
| 769 |
+
if self.config.trainer.resume_mode == "resume_path":
|
| 770 |
+
assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type"
|
| 771 |
+
assert "global_step_" in self.config.trainer.resume_from_path, (
|
| 772 |
+
"resume ckpt must specify the global_steps"
|
| 773 |
+
)
|
| 774 |
+
global_step_folder = self.config.trainer.resume_from_path
|
| 775 |
+
if not os.path.isabs(global_step_folder):
|
| 776 |
+
working_dir = os.getcwd()
|
| 777 |
+
global_step_folder = os.path.join(working_dir, global_step_folder)
|
| 778 |
+
print(f"Load from checkpoint folder: {global_step_folder}")
|
| 779 |
+
# set global step
|
| 780 |
+
self.global_steps = int(global_step_folder.split("global_step_")[-1])
|
| 781 |
+
|
| 782 |
+
print(f"Setting global step to {self.global_steps}")
|
| 783 |
+
print(f"Resuming from {global_step_folder}")
|
| 784 |
+
|
| 785 |
+
actor_path = os.path.join(global_step_folder, "actor")
|
| 786 |
+
critic_path = os.path.join(global_step_folder, "critic")
|
| 787 |
+
# load actor
|
| 788 |
+
self.actor_rollout_wg.load_checkpoint(
|
| 789 |
+
actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
|
| 790 |
+
)
|
| 791 |
+
# load critic
|
| 792 |
+
if self.use_critic:
|
| 793 |
+
self.critic_wg.load_checkpoint(
|
| 794 |
+
critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
# load dataloader,
|
| 798 |
+
# TODO: from remote not implemented yet
|
| 799 |
+
dataloader_local_path = os.path.join(global_step_folder, "data.pt")
|
| 800 |
+
if os.path.exists(dataloader_local_path):
|
| 801 |
+
dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)
|
| 802 |
+
self.train_dataloader.load_state_dict(dataloader_state_dict)
|
| 803 |
+
else:
|
| 804 |
+
print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch")
|
| 805 |
+
|
| 806 |
+
def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"):
|
| 807 |
+
"""Reorder the data on single controller such that each dp rank gets similar total tokens"""
|
| 808 |
+
attention_mask = batch.batch["attention_mask"]
|
| 809 |
+
batch_size = attention_mask.shape[0]
|
| 810 |
+
global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
|
| 811 |
+
world_size = self.actor_rollout_wg.world_size
|
| 812 |
+
global_partition_lst = get_seqlen_balanced_partitions(
|
| 813 |
+
global_seqlen_lst, k_partitions=world_size, equal_size=True
|
| 814 |
+
)
|
| 815 |
+
# reorder based on index. The data will be automatically equally partitioned by dispatch function
|
| 816 |
+
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
|
| 817 |
+
batch.reorder(global_idx)
|
| 818 |
+
global_balance_stats = log_seqlen_unbalance(
|
| 819 |
+
seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
|
| 820 |
+
)
|
| 821 |
+
metrics.update(global_balance_stats)
|
| 822 |
+
|
| 823 |
+
def fit_dpo(self): # Renamed for clarity as standard PPO loop
|
| 824 |
+
"""
|
| 825 |
+
The training loop of Online DPO using a periodically updated reference model.
|
| 826 |
+
The driver process calls worker groups for computation.
|
| 827 |
+
Advantage computation is replaced by DPO logic.
|
| 828 |
+
"""
|
| 829 |
+
import traceback # Ensure traceback is imported
|
| 830 |
+
|
| 831 |
+
from omegaconf import OmegaConf
|
| 832 |
+
|
| 833 |
+
from verl.utils.tracking import Tracking
|
| 834 |
+
|
| 835 |
+
# Initialize logger
|
| 836 |
+
logger = None
|
| 837 |
+
try:
|
| 838 |
+
logger = Tracking(
|
| 839 |
+
project_name=self.config.trainer.project_name,
|
| 840 |
+
experiment_name=self.config.trainer.experiment_name,
|
| 841 |
+
default_backend=self.config.trainer.logger,
|
| 842 |
+
config=OmegaConf.to_container(self.config, resolve=True, throw_on_missing=False),
|
| 843 |
+
)
|
| 844 |
+
except Exception as e:
|
| 845 |
+
print(f"Warning: Failed to initialize logger: {e}")
|
| 846 |
+
|
| 847 |
+
self.global_steps = 0
|
| 848 |
+
# Load checkpoint before doing anything
|
| 849 |
+
loaded_step = self._load_checkpoint()
|
| 850 |
+
self.global_steps = loaded_step + 1 if loaded_step is not None and loaded_step > 0 else 1
|
| 851 |
+
print(
|
| 852 |
+
f"Starting Online DPO training from global step {self.global_steps}. "
|
| 853 |
+
f"Total steps: {self.total_training_steps}"
|
| 854 |
+
)
|
| 855 |
+
print(f"Reference model update frequency: {self.config.trainer.get('ref_update_freq', 'Not Set')}")
|
| 856 |
+
|
| 857 |
+
# Check if reference policy is configured correctly for this mode
|
| 858 |
+
if not self.use_reference_policy:
|
| 859 |
+
print(
|
| 860 |
+
"WARNING: 'use_reference_policy' is False. Periodic reference model update requires a "
|
| 861 |
+
"reference policy worker. DPO updates might fail or use incorrect logic."
|
| 862 |
+
)
|
| 863 |
+
# Consider raising an error if strict adherence is required:
|
| 864 |
+
# raise ValueError("Periodic reference model update requires 'use_reference_policy' to be True "
|
| 865 |
+
# "and a configured reference worker.")
|
| 866 |
+
|
| 867 |
+
# Perform validation before training
|
| 868 |
+
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
|
| 869 |
+
print("Running validation before Online DPO training...")
|
| 870 |
+
val_metrics = self._validate()
|
| 871 |
+
pprint(f"Initial validation metrics: {val_metrics}")
|
| 872 |
+
if logger and val_metrics:
|
| 873 |
+
logger.log(data=val_metrics, step=max(0, self.global_steps - 1))
|
| 874 |
+
if self.config.trainer.get("val_only", False):
|
| 875 |
+
print("Validation only mode enabled. Exiting training.")
|
| 876 |
+
if logger and hasattr(logger, "finish"):
|
| 877 |
+
logger.finish()
|
| 878 |
+
return
|
| 879 |
+
|
| 880 |
+
# Add tqdm progress bar
|
| 881 |
+
progress_bar = tqdm(
|
| 882 |
+
total=self.total_training_steps,
|
| 883 |
+
initial=self.global_steps,
|
| 884 |
+
desc="Online DPO Training Progress",
|
| 885 |
+
position=0,
|
| 886 |
+
leave=True,
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
last_val_metrics = None
|
| 890 |
+
should_stop = False
|
| 891 |
+
|
| 892 |
+
for epoch in range(self.config.trainer.total_epochs):
|
| 893 |
+
if should_stop:
|
| 894 |
+
break
|
| 895 |
+
print(f"--- Starting Online DPO Epoch {epoch} ---")
|
| 896 |
+
try:
|
| 897 |
+
train_iterator = iter(self.train_dataloader)
|
| 898 |
+
except TypeError:
|
| 899 |
+
print("Warning: Dataloader is not iterable.")
|
| 900 |
+
train_iterator = self.train_dataloader # Fallback attempt
|
| 901 |
+
|
| 902 |
+
for batch_idx, batch_dict in enumerate(train_iterator):
|
| 903 |
+
if self.global_steps > self.total_training_steps:
|
| 904 |
+
should_stop = True
|
| 905 |
+
break
|
| 906 |
+
|
| 907 |
+
metrics = {}
|
| 908 |
+
timing_raw = {}
|
| 909 |
+
step_timer = Timer(logger=None)
|
| 910 |
+
ref_log_prob_computed = False # Flag to track if ref log probs were computed
|
| 911 |
+
|
| 912 |
+
try: # Outer try-except for the whole step
|
| 913 |
+
step_timer.start()
|
| 914 |
+
with _timer("step", timing_raw):
|
| 915 |
+
batch: DataProto = DataProto.from_single_dict(batch_dict)
|
| 916 |
+
current_batch_size = batch.batch.batch_size[0]
|
| 917 |
+
print(
|
| 918 |
+
f"\n[Step {self.global_steps}, Batch {batch_idx}] Processing batch size: "
|
| 919 |
+
f"{current_batch_size}"
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
# --- Reference Model Update ---
|
| 923 |
+
ref_update_freq = self.config.trainer.get("ref_update_freq", -1)
|
| 924 |
+
if (
|
| 925 |
+
self.use_reference_policy
|
| 926 |
+
and ref_update_freq > 0
|
| 927 |
+
and self.global_steps % ref_update_freq == 0
|
| 928 |
+
):
|
| 929 |
+
print(f"\n[Step {self.global_steps}] Updating Reference Model Weights from Actor...")
|
| 930 |
+
try:
|
| 931 |
+
# --- This requires careful implementation with FSDP ---
|
| 932 |
+
# 1. Save actor state dict (potentially to CPU memory or disk)
|
| 933 |
+
# This needs to be done collectively across actor worker ranks.
|
| 934 |
+
# The checkpoint_manager might be adaptable, or use FSDP APIs directly.
|
| 935 |
+
# Example placeholder using a conceptual save/load mechanism:
|
| 936 |
+
actor_state_path = "/tmp/actor_state_mid" # Temporary path
|
| 937 |
+
self.actor_rollout_wg.save_checkpoint(actor_state_path) # Adapt save logic
|
| 938 |
+
|
| 939 |
+
# 2. Load the state dict onto the reference model worker group
|
| 940 |
+
# This also needs collective loading on the ref worker ranks.
|
| 941 |
+
self.ref_policy_wg.load_checkpoint(actor_state_path, None, True) # Adapt load logic
|
| 942 |
+
|
| 943 |
+
print(f"[Step {self.global_steps}] Reference Model Weights Updated.")
|
| 944 |
+
# Optionally remove the temporary state file
|
| 945 |
+
# os.remove(actor_state_path) # Needs rank-aware removal or shared storage
|
| 946 |
+
|
| 947 |
+
except Exception as sync_e:
|
| 948 |
+
print(f"ERROR during reference model sync at step {self.global_steps}: {sync_e}")
|
| 949 |
+
traceback.print_exc()
|
| 950 |
+
|
| 951 |
+
# Pop keys for generation
|
| 952 |
+
pop_batch_keys = ["input_ids", "attention_mask"]
|
| 953 |
+
if "position_ids" in batch.batch:
|
| 954 |
+
pop_batch_keys.append("position_ids")
|
| 955 |
+
pop_non_tensor_keys = ["raw_prompt_ids"] if "raw_prompt_ids" in batch.non_tensor_batch else []
|
| 956 |
+
if "multi_modal_inputs" in batch.non_tensor_batch.keys():
|
| 957 |
+
pop_non_tensor_keys.extend(["multi_modal_data", "multi_modal_inputs"])
|
| 958 |
+
original_non_tensor_data = batch.non_tensor_batch
|
| 959 |
+
gen_batch = batch.pop(
|
| 960 |
+
batch_keys=pop_batch_keys,
|
| 961 |
+
non_tensor_batch_keys=pop_non_tensor_keys,
|
| 962 |
+
)
|
| 963 |
+
gen_batch = gen_batch.repeat(
|
| 964 |
+
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
|
| 965 |
+
)
|
| 966 |
+
# (Add Debug prints for gen_batch if needed)
|
| 967 |
+
|
| 968 |
+
# Generate sequences (chosen/rejected pairs)
|
| 969 |
+
with _timer("gen", timing_raw):
|
| 970 |
+
try:
|
| 971 |
+
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
|
| 972 |
+
# (Add Debug prints for gen_batch_output if needed)
|
| 973 |
+
except Exception as gen_e:
|
| 974 |
+
print(f"\n!!!!!!!! ERROR DURING GENERATION (Step {self.global_steps}) !!!!!!!!")
|
| 975 |
+
print(gen_e)
|
| 976 |
+
traceback.print_exc()
|
| 977 |
+
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
| 978 |
+
step_timer.stop()
|
| 979 |
+
continue
|
| 980 |
+
|
| 981 |
+
# Combine original prompts with generated sequences
|
| 982 |
+
batch.non_tensor_batch = original_non_tensor_data # Restore non-tensor data
|
| 983 |
+
batch.non_tensor_batch["uid"] = np.array(
|
| 984 |
+
[str(uuid.uuid4()) for _ in range(current_batch_size)], dtype=object
|
| 985 |
+
)
|
| 986 |
+
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
| 987 |
+
batch = batch.union(gen_batch_output)
|
| 988 |
+
# (Add Debug prints after union if needed)
|
| 989 |
+
|
| 990 |
+
# Compute response mask (needed for ref logprob calc and DPO prep)
|
| 991 |
+
batch.batch["response_mask"] = compute_response_mask(batch)
|
| 992 |
+
|
| 993 |
+
if self.config.trainer.balance_batch:
|
| 994 |
+
self._balance_batch(batch, metrics=metrics)
|
| 995 |
+
|
| 996 |
+
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
|
| 997 |
+
|
| 998 |
+
# --- Compute Log Probs for the CURRENT policy (used for KL if enabled, or ActorAsRef
|
| 999 |
+
# fallback) ---
|
| 1000 |
+
# Note: For pure DPO with external ref, this 'old_log_probs' might not be strictly needed
|
| 1001 |
+
# unless used for other metrics or a fallback. Keep it for now.
|
| 1002 |
+
with _timer("policy_log_prob", timing_raw):
|
| 1003 |
+
policy_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch)
|
| 1004 |
+
batch = batch.union(policy_log_prob_output) # Adds 'old_log_probs'
|
| 1005 |
+
# (Debug prints for old_log_probs)
|
| 1006 |
+
|
| 1007 |
+
# --- Compute Log Probs using the EXTERNAL Reference Model ---
|
| 1008 |
+
if self.use_reference_policy:
|
| 1009 |
+
with _timer("ref_log_prob_dpo", timing_raw):
|
| 1010 |
+
# print(f"---- [Step {self.global_steps}] DEBUG DPO: Calling compute_ref_log_prob ----")
|
| 1011 |
+
try:
|
| 1012 |
+
# 'batch' contains interleaved chosen/rejected sequences
|
| 1013 |
+
ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob(
|
| 1014 |
+
batch
|
| 1015 |
+
) # Returns DataProto with 'ref_log_prob'
|
| 1016 |
+
batch = batch.union(
|
| 1017 |
+
ref_log_prob_output
|
| 1018 |
+
) # Adds 'ref_log_prob' key [batch_size * n, seq_len]
|
| 1019 |
+
ref_log_prob_computed = True # Mark success
|
| 1020 |
+
# print(f"---- [Step {self.global_steps}] DEBUG DPO: ref_log_prob tensor shape: "
|
| 1021 |
+
# f"{batch.batch['ref_log_prob'].shape} ----")
|
| 1022 |
+
except Exception as ref_e:
|
| 1023 |
+
print(f"ERROR computing reference log probs at step {self.global_steps}: {ref_e}")
|
| 1024 |
+
traceback.print_exc()
|
| 1025 |
+
batch.batch["ref_log_prob"] = None # Mark as failed
|
| 1026 |
+
ref_log_prob_computed = False
|
| 1027 |
+
else:
|
| 1028 |
+
print(
|
| 1029 |
+
"Warning: Skipping external reference log prob calculation as use_reference_policy "
|
| 1030 |
+
"is False."
|
| 1031 |
+
)
|
| 1032 |
+
# DPO update will likely fail unless ActorAsRef logic is re-enabled in dp_actor
|
| 1033 |
+
|
| 1034 |
+
# --- Compute Rewards/Scores (used to determine preference) ---
|
| 1035 |
+
with _timer("reward_calc", timing_raw):
|
| 1036 |
+
# (Reward calculation logic using RM or reward_fn as before)
|
| 1037 |
+
# ... Ensure this calculates 'token_level_rewards' or similar ...
|
| 1038 |
+
if self.use_rm:
|
| 1039 |
+
reward_tensor_rm = self.rm_wg.compute_rm_score(batch)
|
| 1040 |
+
batch = batch.union(reward_tensor_rm) # Adds 'rm_scores'
|
| 1041 |
+
|
| 1042 |
+
reward_extra_infos_dict = {}
|
| 1043 |
+
try:
|
| 1044 |
+
if self.reward_fn is None:
|
| 1045 |
+
# print(f"---- [DEBUG Step {self.global_steps}] ERROR: self.reward_fn is None! "
|
| 1046 |
+
# f"Using dummy rewards. ----")
|
| 1047 |
+
# Use rm_scores if available, otherwise zeros
|
| 1048 |
+
reward_tensor = batch.batch.get(
|
| 1049 |
+
"rm_scores", torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32)
|
| 1050 |
+
)
|
| 1051 |
+
else:
|
| 1052 |
+
reward_result = self.reward_fn(batch, return_dict=True)
|
| 1053 |
+
reward_tensor = reward_result["reward_tensor"] # Final combined reward
|
| 1054 |
+
reward_extra_infos_dict = reward_result.get("reward_extra_info", {})
|
| 1055 |
+
|
| 1056 |
+
except Exception:
|
| 1057 |
+
# print(f'---- [DEBUG Step {self.global_steps}] Error in reward_fn call: {e}. '
|
| 1058 |
+
# f'Using dummy rewards. ----')
|
| 1059 |
+
traceback.print_exc()
|
| 1060 |
+
reward_tensor = torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32)
|
| 1061 |
+
reward_extra_infos_dict = {}
|
| 1062 |
+
|
| 1063 |
+
# Use 'token_level_rewards' as the key for preference calculation
|
| 1064 |
+
batch.batch["token_level_rewards"] = reward_tensor
|
| 1065 |
+
if reward_extra_infos_dict:
|
| 1066 |
+
batch.non_tensor_batch.update(
|
| 1067 |
+
{k: np.array(v) for k, v in reward_extra_infos_dict.items()}
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
# --- Determine Preferences ---
|
| 1071 |
+
# Uses 'token_level_rewards' to determine chosen/rejected based on score
|
| 1072 |
+
batch = compute_onlineDPO_pref(batch) # Adds 'preferences' key
|
| 1073 |
+
|
| 1074 |
+
# --- Prepare DPO Batch ---
|
| 1075 |
+
dpo_update_batch_proto = None # Initialize
|
| 1076 |
+
with _timer("prepare_dpo_batch", timing_raw):
|
| 1077 |
+
try:
|
| 1078 |
+
if "preferences" not in batch.batch or batch.batch["preferences"] is None:
|
| 1079 |
+
raise ValueError("'preferences' key missing or None after compute_onlineDPO_pref.")
|
| 1080 |
+
|
| 1081 |
+
# Check if reference log probs were computed successfully (if needed)
|
| 1082 |
+
if self.use_reference_policy and not ref_log_prob_computed:
|
| 1083 |
+
raise ValueError("Reference log probs required but failed to compute.")
|
| 1084 |
+
|
| 1085 |
+
# Check required base keys
|
| 1086 |
+
required_keys = ["input_ids", "attention_mask", "response_mask"]
|
| 1087 |
+
for rk in required_keys:
|
| 1088 |
+
if rk not in batch.batch or batch.batch[rk] is None:
|
| 1089 |
+
raise KeyError(f"Required key '{rk}' missing from batch for DPO prep.")
|
| 1090 |
+
|
| 1091 |
+
preferences_mask = batch.batch["preferences"] # Shape [batch_size * n]
|
| 1092 |
+
not_preferences_mask = ~preferences_mask
|
| 1093 |
+
|
| 1094 |
+
# Gather Chosen/Rejected Base Tensors
|
| 1095 |
+
chosen_input_ids = batch.batch["input_ids"][preferences_mask]
|
| 1096 |
+
chosen_attention_mask = batch.batch["attention_mask"][preferences_mask]
|
| 1097 |
+
rejected_input_ids = batch.batch["input_ids"][not_preferences_mask]
|
| 1098 |
+
rejected_attention_mask = batch.batch["attention_mask"][not_preferences_mask]
|
| 1099 |
+
chosen_position_ids = (
|
| 1100 |
+
batch.batch.get("position_ids")[preferences_mask]
|
| 1101 |
+
if "position_ids" in batch.batch
|
| 1102 |
+
else None
|
| 1103 |
+
)
|
| 1104 |
+
rejected_position_ids = (
|
| 1105 |
+
batch.batch.get("position_ids")[not_preferences_mask]
|
| 1106 |
+
if "position_ids" in batch.batch
|
| 1107 |
+
else None
|
| 1108 |
+
)
|
| 1109 |
+
|
| 1110 |
+
# Create Labels
|
| 1111 |
+
print("WARNING: Creating DPO labels using configured max_prompt_length...")
|
| 1112 |
+
prompt_len = self.config.data.max_prompt_length
|
| 1113 |
+
chosen_labels = chosen_input_ids.clone()
|
| 1114 |
+
chosen_labels[:, :prompt_len] = -100
|
| 1115 |
+
rejected_labels = rejected_input_ids.clone()
|
| 1116 |
+
rejected_labels[:, :prompt_len] = -100
|
| 1117 |
+
|
| 1118 |
+
# Calculate and Gather Reference Log Probs (Sequence Level)
|
| 1119 |
+
if self.use_reference_policy:
|
| 1120 |
+
ref_log_prob_tensor = batch.batch["ref_log_prob"] # Token level [bsz * n, seq_len]
|
| 1121 |
+
response_mask_full = batch.batch[
|
| 1122 |
+
"response_mask"
|
| 1123 |
+
] # Response mask [bsz * n, seq_len]
|
| 1124 |
+
ref_sequence_logps = (ref_log_prob_tensor * response_mask_full).sum(
|
| 1125 |
+
dim=-1
|
| 1126 |
+
) # Sequence level [bsz * n]
|
| 1127 |
+
reference_chosen_logps = ref_sequence_logps[preferences_mask]
|
| 1128 |
+
reference_rejected_logps = ref_sequence_logps[not_preferences_mask]
|
| 1129 |
+
else:
|
| 1130 |
+
# If not using external ref, DPO needs ActorAsRef logic in dp_actor
|
| 1131 |
+
# We won't add the keys here, dp_actor will handle it (or fail if not modified)
|
| 1132 |
+
print(
|
| 1133 |
+
"Info: Not adding explicit reference logps to DPO batch "
|
| 1134 |
+
"(use_reference_policy=False)."
|
| 1135 |
+
)
|
| 1136 |
+
reference_chosen_logps = None # Explicitly None
|
| 1137 |
+
reference_rejected_logps = None
|
| 1138 |
+
|
| 1139 |
+
# Package Tensors
|
| 1140 |
+
dpo_tensors = {
|
| 1141 |
+
"chosen_input_ids": chosen_input_ids,
|
| 1142 |
+
"chosen_attention_mask": chosen_attention_mask,
|
| 1143 |
+
"chosen_labels": chosen_labels,
|
| 1144 |
+
"rejected_input_ids": rejected_input_ids,
|
| 1145 |
+
"rejected_attention_mask": rejected_attention_mask,
|
| 1146 |
+
"rejected_labels": rejected_labels,
|
| 1147 |
+
}
|
| 1148 |
+
# Conditionally add reference logps if computed
|
| 1149 |
+
if reference_chosen_logps is not None:
|
| 1150 |
+
dpo_tensors["reference_chosen_logps"] = reference_chosen_logps
|
| 1151 |
+
if reference_rejected_logps is not None:
|
| 1152 |
+
dpo_tensors["reference_rejected_logps"] = reference_rejected_logps
|
| 1153 |
+
# Add position ids if they exist
|
| 1154 |
+
if chosen_position_ids is not None:
|
| 1155 |
+
dpo_tensors["chosen_position_ids"] = chosen_position_ids
|
| 1156 |
+
if rejected_position_ids is not None:
|
| 1157 |
+
dpo_tensors["rejected_position_ids"] = rejected_position_ids
|
| 1158 |
+
|
| 1159 |
+
# Prepare Meta Info
|
| 1160 |
+
dpo_meta = {
|
| 1161 |
+
"dpo_beta": OmegaConf.select(self.config.algorithm, "dpo_beta", default=0.1),
|
| 1162 |
+
"dpo_loss_type": OmegaConf.select(
|
| 1163 |
+
self.config.algorithm, "dpo_loss_type", default="sigmoid"
|
| 1164 |
+
),
|
| 1165 |
+
"dpo_label_smoothing": OmegaConf.select(
|
| 1166 |
+
self.config.algorithm, "dpo_label_smoothing", default=0.0
|
| 1167 |
+
),
|
| 1168 |
+
"use_reference_policy": self.use_reference_policy,
|
| 1169 |
+
"reference_free": not self.use_reference_policy, # False if using external ref
|
| 1170 |
+
"global_step": self.global_steps,
|
| 1171 |
+
}
|
| 1172 |
+
|
| 1173 |
+
dpo_update_batch_proto = DataProto.from_dict(tensors=dpo_tensors, meta_info=dpo_meta)
|
| 1174 |
+
# print(f"---- [Step {self.global_steps}] DEBUG DPO: Prepared DPO Update Batch ----")
|
| 1175 |
+
# print(f" Keys: {list(dpo_update_batch_proto.batch.keys())}")
|
| 1176 |
+
# print(f" Meta Info: {dpo_meta}")
|
| 1177 |
+
|
| 1178 |
+
except Exception as e_prep:
|
| 1179 |
+
print(f"ERROR preparing DPO batch at step {self.global_steps}: {e_prep}")
|
| 1180 |
+
traceback.print_exc()
|
| 1181 |
+
dpo_update_batch_proto = None # Skip update on error
|
| 1182 |
+
|
| 1183 |
+
# --- Actor Update Step ---
|
| 1184 |
+
actor_output = None
|
| 1185 |
+
if self.config.trainer.critic_warmup <= self.global_steps and dpo_update_batch_proto:
|
| 1186 |
+
with _timer("update_actor", timing_raw):
|
| 1187 |
+
# Pass the batch containing reference log probs (if computed)
|
| 1188 |
+
# The modified update_actor_dpo expects them if reference_free=False
|
| 1189 |
+
actor_output = self.actor_rollout_wg.update_actor_dpo(dpo_update_batch_proto)
|
| 1190 |
+
if actor_output and "metrics" in actor_output.meta_info:
|
| 1191 |
+
metrics.update(reduce_metrics(actor_output.meta_info["metrics"]))
|
| 1192 |
+
elif dpo_update_batch_proto is None:
|
| 1193 |
+
print(
|
| 1194 |
+
f"Skipping actor update at step {self.global_steps} due to DPO batch preparation error."
|
| 1195 |
+
)
|
| 1196 |
+
|
| 1197 |
+
# --- Validation and Saving ---
|
| 1198 |
+
test_freq = OmegaConf.select(self.config.trainer, "test_freq", default=-1)
|
| 1199 |
+
is_last_step = self.global_steps >= self.total_training_steps
|
| 1200 |
+
if (
|
| 1201 |
+
self.val_reward_fn is not None
|
| 1202 |
+
and test_freq > 0
|
| 1203 |
+
and (is_last_step or self.global_steps % test_freq == 0)
|
| 1204 |
+
):
|
| 1205 |
+
print(f"\nRunning DPO validation at step {self.global_steps}...")
|
| 1206 |
+
val_timing_raw = {}
|
| 1207 |
+
with _timer("testing", val_timing_raw):
|
| 1208 |
+
val_metrics: dict = self._validate()
|
| 1209 |
+
if is_last_step:
|
| 1210 |
+
last_val_metrics = val_metrics
|
| 1211 |
+
if val_metrics:
|
| 1212 |
+
metrics["time/validation_run"] = val_timing_raw.get("testing", 0)
|
| 1213 |
+
metrics.update(val_metrics)
|
| 1214 |
+
else:
|
| 1215 |
+
print("Validation skipped or returned no metrics.")
|
| 1216 |
+
|
| 1217 |
+
save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1)
|
| 1218 |
+
if save_freq > 0 and (is_last_step or self.global_steps % save_freq == 0):
|
| 1219 |
+
print(f"\nSaving DPO checkpoint at step {self.global_steps}...")
|
| 1220 |
+
with _timer("save_checkpoint", timing_raw):
|
| 1221 |
+
self._save_checkpoint() # Saves actor (and potentially critic if used elsewhere)
|
| 1222 |
+
metrics["time/save_checkpoint"] = timing_raw.get("save_checkpoint", 0)
|
| 1223 |
+
|
| 1224 |
+
# --- End main step timer context ---
|
| 1225 |
+
|
| 1226 |
+
# --- Metrics calculation AFTER the 'step' timer block ---
|
| 1227 |
+
metrics.update(compute_dpo_data_metrics(batch=batch)) # Use DPO-specific metrics
|
| 1228 |
+
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
|
| 1229 |
+
n_gpus = self.resource_pool_manager.get_n_gpus()
|
| 1230 |
+
if "step" in timing_raw:
|
| 1231 |
+
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
|
| 1232 |
+
else:
|
| 1233 |
+
print(
|
| 1234 |
+
f"Warning: 'step' key missing from timing_raw at step {self.global_steps}. "
|
| 1235 |
+
f"Skipping throughput."
|
| 1236 |
+
)
|
| 1237 |
+
|
| 1238 |
+
step_timer.stop()
|
| 1239 |
+
metrics["time/step"] = step_timer.last
|
| 1240 |
+
|
| 1241 |
+
# Log metrics
|
| 1242 |
+
log_freq = OmegaConf.select(self.config.trainer, "log_freq", default=1)
|
| 1243 |
+
if logger and self.global_steps % log_freq == 0:
|
| 1244 |
+
log_payload = metrics.copy()
|
| 1245 |
+
# Add learning rate to log payload
|
| 1246 |
+
if actor_output and "actor/lr" in metrics:
|
| 1247 |
+
log_payload["actor/lr"] = metrics["actor/lr"]
|
| 1248 |
+
|
| 1249 |
+
print(f"[Step {self.global_steps} DPO] Logging Step Payload Keys: {list(log_payload.keys())}")
|
| 1250 |
+
try:
|
| 1251 |
+
logger.log(data=log_payload, step=self.global_steps)
|
| 1252 |
+
except Exception as e:
|
| 1253 |
+
print(f"Logging failed at step {self.global_steps}: {e}")
|
| 1254 |
+
|
| 1255 |
+
# Update progress bar
|
| 1256 |
+
postfix_metrics = {
|
| 1257 |
+
k: f"{v:.3f}" if isinstance(v, float) else v
|
| 1258 |
+
for k, v in metrics.items()
|
| 1259 |
+
if isinstance(v, int | float)
|
| 1260 |
+
}
|
| 1261 |
+
progress_bar.set_postfix(postfix_metrics)
|
| 1262 |
+
|
| 1263 |
+
except Exception as step_e:
|
| 1264 |
+
print(f"\n!!!!!!!! ERROR DURING DPO Step {self.global_steps} !!!!!!!!")
|
| 1265 |
+
print(f"Caught Exception: {step_e}")
|
| 1266 |
+
traceback.print_exc()
|
| 1267 |
+
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
| 1268 |
+
step_timer.stop()
|
| 1269 |
+
should_stop = True
|
| 1270 |
+
break
|
| 1271 |
+
|
| 1272 |
+
if is_last_step or should_stop:
|
| 1273 |
+
print(f"Stopping DPO training at step {self.global_steps}.")
|
| 1274 |
+
break
|
| 1275 |
+
|
| 1276 |
+
self.global_steps += 1
|
| 1277 |
+
progress_bar.update(1)
|
| 1278 |
+
|
| 1279 |
+
# End of epoch handling
|
| 1280 |
+
if hasattr(self.train_dataloader, "reset"):
|
| 1281 |
+
try:
|
| 1282 |
+
self.train_dataloader.reset()
|
| 1283 |
+
except Exception as e:
|
| 1284 |
+
print(f"Warning: Failed to reset train dataloader state: {e}")
|
| 1285 |
+
if should_stop:
|
| 1286 |
+
break
|
| 1287 |
+
|
| 1288 |
+
# --- Final cleanup and logging ---
|
| 1289 |
+
progress_bar.close()
|
| 1290 |
+
final_step = max(0, self.global_steps - 1)
|
| 1291 |
+
print(f"Online DPO Training finished at step {final_step}.")
|
| 1292 |
+
# Save final checkpoint
|
| 1293 |
+
save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1)
|
| 1294 |
+
if not self.config.trainer.get("val_only", False) and (save_freq <= 0 or final_step % save_freq != 0):
|
| 1295 |
+
print(f"Saving final DPO checkpoint at step {final_step}...")
|
| 1296 |
+
self._save_checkpoint()
|
| 1297 |
+
|
| 1298 |
+
# Final validation run
|
| 1299 |
+
if self.val_reward_fn and last_val_metrics is None and not self.config.trainer.get("val_only", False):
|
| 1300 |
+
print("Running final validation...")
|
| 1301 |
+
last_val_metrics = self._validate()
|
| 1302 |
+
if last_val_metrics and logger:
|
| 1303 |
+
last_val_metrics["final_validation"] = True
|
| 1304 |
+
try:
|
| 1305 |
+
logger.log(data=last_val_metrics, step=final_step)
|
| 1306 |
+
except Exception as e:
|
| 1307 |
+
print(f"[Final Val Metrics Log Error]: {e}")
|
| 1308 |
+
|
| 1309 |
+
pprint(f"Final validation metrics: {last_val_metrics}")
|
| 1310 |
+
if logger and hasattr(logger, "finish"):
|
| 1311 |
+
logger.finish()
|
| 1312 |
+
print("Online DPO Training Run Complete.")
|
ICL/LV/code/README.md
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Unified Multi-Model VQA Codebase
|
| 2 |
+
|
| 3 |
+
目的
|
| 4 |
+
- 这一套代码是“模型无关”的通用评测/数据/提示构造层;所有模型仅通过“适配器”接入。
|
| 5 |
+
- 通用输入是 OpenAI 扁平内容序列(image→text;示例用 [REQUEST]/[RESPONSE];查询的 [RESPONSE] 为空)。
|
| 6 |
+
|
| 7 |
+
目录
|
| 8 |
+
- core/
|
| 9 |
+
- prompting/openai_segments.py 扁平序列构造与落盘适配
|
| 10 |
+
- datasets/m3it_reader.py M3IT 统一读取 & base64 图片缓存
|
| 11 |
+
- metrics/metrics.py Token‑F1、BERTScore‑F1 等
|
| 12 |
+
- eval/ 与模型无关的评测脚本(调用 adapters)
|
| 13 |
+
- zero_shot_vqa.py / random_k_shot_vqa.py
|
| 14 |
+
- eval_textual_retriever_vqa.py / eval_visual_retriever_vqa.py / eval_multimodal_retriever_vqa.py
|
| 15 |
+
- order 评测(统一缓存 + 独立指标脚本):
|
| 16 |
+
- order_eval_core.py(内部调用) / _modal_order.py(内部调用)
|
| 17 |
+
- eval_order_caption_bertscore.py / eval_order_caption_cider.py
|
| 18 |
+
- eval_order_classification_accuracy.py / eval_order_classification_f1.py
|
| 19 |
+
- eval_order_reasoning_accuracy.py / eval_order_reasoning_ras.py
|
| 20 |
+
- eval_order_vqa_bertscore.py / eval_order_vqa_tokenf1.py
|
| 21 |
+
- adapters/
|
| 22 |
+
- idefics2_adapter.py
|
| 23 |
+
- qwen_vl_adapter.py
|
| 24 |
+
- qwen3vl_adapter.py
|
| 25 |
+
- gemma3_adapter.py
|
| 26 |
+
|
| 27 |
+
使用
|
| 28 |
+
- 例:零样本(Idefics2)
|
| 29 |
+
python3 -m core.eval.zero_shot_vqa \
|
| 30 |
+
--adapter idefics2 \
|
| 31 |
+
--model-path /path/to/idefics2-8b \
|
| 32 |
+
--dataset-root /path/to/M3IT \
|
| 33 |
+
--split test --total-samples 500 \
|
| 34 |
+
--instruction-image "C:\\Users\\you\\instruction.png" --dump-first 2
|
| 35 |
+
|
| 36 |
+
- 例:随机 few‑shot(Qwen‑VL)
|
| 37 |
+
python3 -m core.eval.random_k_shot_vqa \
|
| 38 |
+
--adapter qwen-vl \
|
| 39 |
+
--model-path /path/to/Qwen-VL \
|
| 40 |
+
--dataset-root /path/to/M3IT \
|
| 41 |
+
--split test --k-shots 3 --total-samples 500 \
|
| 42 |
+
--use-paper-instruction --instruction-image "C:\\Users\\you\\instruction.png"
|
| 43 |
+
|
| 44 |
+
- 例:模态顺序评测(以 VQA Token-F1 为例)
|
| 45 |
+
python3 -m core.eval.eval_order_vqa_tokenf1 \
|
| 46 |
+
--adapter idefics2 \
|
| 47 |
+
--model-path /path/to/idefics2-8b \
|
| 48 |
+
--dataset-root /path/to/M3IT \
|
| 49 |
+
--retriever-model-path /path/to/BridgeTower-or-CLIP \
|
| 50 |
+
--orders image-text,text-image,text-image-text \
|
| 51 |
+
--k-shots 3 --total-samples 500 --split val
|
| 52 |
+
--adapter qwen-vl \
|
| 53 |
+
--model-path /path/to/Qwen-VL \
|
| 54 |
+
--dataset-root /path/to/M3IT \
|
| 55 |
+
--split test --k-shots 3 --total-samples 500 \
|
| 56 |
+
--use-paper-instruction --instruction-image "C:\\Users\\you\\instruction.png"
|
| 57 |
+
|
| 58 |
+
约定
|
| 59 |
+
- 适配器接口见 adapters/*.py:
|
| 60 |
+
- create(model_path: str) -> Adapter
|
| 61 |
+
- Adapter.generate_from_segments(segs: List[dict], temperature: float, top_p: float, max_new_tokens: int) -> str
|
| 62 |
+
- 可选:Adapter.generate_single(image_path: str, prompt: str, ...)
|
| 63 |
+
|
| 64 |
+
说明
|
| 65 |
+
- 适配器与通用源码彻底分离;你可以只替换 adapters/xxx_adapter.py 即可对接新模型。
|
| 66 |
+
- Windows 路径/BASE64/data:URL 的图片由 prompting/openai_segments.py 自动兼容。
|
ICL/LV/code/SFT/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (8.2 kB). View file
|
|
|
ICL/LV/code/SFT/build_icl_eval_sharegpt.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Build a prompt-only ShareGPT-style eval set for deciding <RET> vs <ANS>.
|
| 4 |
+
|
| 5 |
+
Prompt format is aligned with build_icl_dataset.py:
|
| 6 |
+
instruction + <image> + "Question: ...\\nAction:"
|
| 7 |
+
|
| 8 |
+
But for evaluation we keep ONLY the initial human turn in `conversations` to avoid leaking labels.
|
| 9 |
+
Gold labels are stored outside the prompt:
|
| 10 |
+
- expected_first_tag: "<RET>" or "<ANS>" (NOT included in conversations)
|
| 11 |
+
- answer: used for offline checking (NOT included in conversations)
|
| 12 |
+
- shots: for RET samples only, used for the follow-up step after model outputs <RET>
|
| 13 |
+
|
| 14 |
+
Important:
|
| 15 |
+
- Never use train split for eval; recommend val/test/dev.
|
| 16 |
+
- Optionally excludes any uid already present in an existing training jsonl to avoid overlap.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import random
|
| 22 |
+
import sys
|
| 23 |
+
from dataclasses import dataclass
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import Dict, List, Optional, Set
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Add code root to PYTHONPATH for core/ imports
|
| 29 |
+
CODE_ROOT = Path(__file__).resolve().parents[1]
|
| 30 |
+
if str(CODE_ROOT) not in sys.path:
|
| 31 |
+
sys.path.insert(0, str(CODE_ROOT))
|
| 32 |
+
|
| 33 |
+
from core.datasets.m3it_reader import iter_m3it_samples, load_instructions # noqa: E402
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass(frozen=True)
|
| 37 |
+
class PoolItem:
|
| 38 |
+
image_path: str
|
| 39 |
+
description: str
|
| 40 |
+
subdir: str
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass(frozen=True)
|
| 44 |
+
class QueryItem:
|
| 45 |
+
image_path: str
|
| 46 |
+
question: str
|
| 47 |
+
answer: str
|
| 48 |
+
subdir: str
|
| 49 |
+
uid: str
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _extract_uid(raw: Dict, fallback: str) -> str:
|
| 53 |
+
if isinstance(raw, dict):
|
| 54 |
+
for k in ("id", "image_id"):
|
| 55 |
+
v = raw.get(k)
|
| 56 |
+
if isinstance(v, (str, int)):
|
| 57 |
+
return str(v)
|
| 58 |
+
meta = raw.get("meta") if isinstance(raw.get("meta"), dict) else {}
|
| 59 |
+
for k in ("img_id", "id", "image_id"):
|
| 60 |
+
v = meta.get(k)
|
| 61 |
+
if isinstance(v, (str, int)):
|
| 62 |
+
return str(v)
|
| 63 |
+
return fallback
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def discover_subdirs(dataset_root: Path, category: str) -> List[str]:
|
| 67 |
+
base = dataset_root / "data" / category
|
| 68 |
+
if not base.exists():
|
| 69 |
+
return []
|
| 70 |
+
out: List[str] = []
|
| 71 |
+
for p in sorted(base.iterdir()):
|
| 72 |
+
if p.is_dir():
|
| 73 |
+
out.append(f"{category}/{p.name}")
|
| 74 |
+
return out
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def pick_instruction(insts: List[str], rng: random.Random) -> str:
|
| 78 |
+
if insts:
|
| 79 |
+
s = rng.choice(insts)
|
| 80 |
+
if isinstance(s, str) and s.strip():
|
| 81 |
+
return s.strip()
|
| 82 |
+
return "Please answer the question based on the image."
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def load_exclude_uids(path: Optional[str]) -> Set[str]:
|
| 86 |
+
if not path:
|
| 87 |
+
return set()
|
| 88 |
+
p = Path(path)
|
| 89 |
+
if not p.exists():
|
| 90 |
+
return set()
|
| 91 |
+
|
| 92 |
+
out: Set[str] = set()
|
| 93 |
+
with p.open("r", encoding="utf-8") as f:
|
| 94 |
+
for line in f:
|
| 95 |
+
line = line.strip()
|
| 96 |
+
if not line:
|
| 97 |
+
continue
|
| 98 |
+
try:
|
| 99 |
+
obj = json.loads(line)
|
| 100 |
+
except Exception:
|
| 101 |
+
continue
|
| 102 |
+
if not isinstance(obj, dict):
|
| 103 |
+
continue
|
| 104 |
+
uid = obj.get("uid")
|
| 105 |
+
if isinstance(uid, (str, int)):
|
| 106 |
+
out.add(str(uid))
|
| 107 |
+
continue
|
| 108 |
+
sid = obj.get("id")
|
| 109 |
+
if isinstance(sid, (str, int)):
|
| 110 |
+
out.add(str(sid))
|
| 111 |
+
return out
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def to_rel(path: str, root: Path) -> str:
|
| 115 |
+
try:
|
| 116 |
+
return str(Path(path).relative_to(root))
|
| 117 |
+
except Exception:
|
| 118 |
+
return path
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def build_pool_for_subdir(
|
| 122 |
+
*,
|
| 123 |
+
dataset_root: Path,
|
| 124 |
+
subdir: str,
|
| 125 |
+
split: str,
|
| 126 |
+
cache_dir: Path,
|
| 127 |
+
target_n: int,
|
| 128 |
+
max_samples_scan: int,
|
| 129 |
+
) -> List[PoolItem]:
|
| 130 |
+
items: List[PoolItem] = []
|
| 131 |
+
try:
|
| 132 |
+
iterable = iter_m3it_samples(
|
| 133 |
+
str(dataset_root),
|
| 134 |
+
subdir,
|
| 135 |
+
split=split,
|
| 136 |
+
cache_dir=str(cache_dir),
|
| 137 |
+
max_samples=None,
|
| 138 |
+
)
|
| 139 |
+
except FileNotFoundError:
|
| 140 |
+
return []
|
| 141 |
+
|
| 142 |
+
for idx, smp in enumerate(iterable):
|
| 143 |
+
if max_samples_scan > 0 and idx >= max_samples_scan:
|
| 144 |
+
break
|
| 145 |
+
if not smp.answers:
|
| 146 |
+
continue
|
| 147 |
+
desc = (smp.answers[0] or "").strip()
|
| 148 |
+
if not desc:
|
| 149 |
+
continue
|
| 150 |
+
items.append(PoolItem(smp.image_path, desc, subdir))
|
| 151 |
+
if target_n > 0 and len(items) >= target_n:
|
| 152 |
+
break
|
| 153 |
+
return items
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def collect_query_pool(
|
| 157 |
+
*,
|
| 158 |
+
dataset_root: Path,
|
| 159 |
+
subdirs: List[str],
|
| 160 |
+
split: str,
|
| 161 |
+
cache_dir: Path,
|
| 162 |
+
exclude_uids: Set[str],
|
| 163 |
+
target_n: int,
|
| 164 |
+
seed: int,
|
| 165 |
+
max_samples_per_subdir: int,
|
| 166 |
+
) -> List[QueryItem]:
|
| 167 |
+
rng = random.Random(seed)
|
| 168 |
+
subdirs = list(subdirs)
|
| 169 |
+
rng.shuffle(subdirs)
|
| 170 |
+
|
| 171 |
+
seen: Set[str] = set()
|
| 172 |
+
out: List[QueryItem] = []
|
| 173 |
+
|
| 174 |
+
for subdir in subdirs:
|
| 175 |
+
taken = 0
|
| 176 |
+
try:
|
| 177 |
+
iterable = iter_m3it_samples(
|
| 178 |
+
str(dataset_root),
|
| 179 |
+
subdir,
|
| 180 |
+
split=split,
|
| 181 |
+
cache_dir=str(cache_dir),
|
| 182 |
+
max_samples=None,
|
| 183 |
+
)
|
| 184 |
+
except FileNotFoundError:
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
for i, smp in enumerate(iterable):
|
| 188 |
+
if max_samples_per_subdir > 0 and taken >= max_samples_per_subdir:
|
| 189 |
+
break
|
| 190 |
+
q = (smp.text or "").strip()
|
| 191 |
+
if not q:
|
| 192 |
+
continue
|
| 193 |
+
if not smp.answers:
|
| 194 |
+
continue
|
| 195 |
+
ans = (smp.answers[0] or "").strip()
|
| 196 |
+
if not ans:
|
| 197 |
+
continue
|
| 198 |
+
uid = _extract_uid(smp.raw, f"{subdir}:{i:08d}")
|
| 199 |
+
if uid in exclude_uids:
|
| 200 |
+
continue
|
| 201 |
+
if uid in seen:
|
| 202 |
+
continue
|
| 203 |
+
seen.add(uid)
|
| 204 |
+
out.append(QueryItem(smp.image_path, q, ans, subdir, uid))
|
| 205 |
+
taken += 1
|
| 206 |
+
if target_n > 0 and len(out) >= target_n:
|
| 207 |
+
return out
|
| 208 |
+
return out
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def select_shots(
|
| 212 |
+
pool: List[PoolItem],
|
| 213 |
+
k: int,
|
| 214 |
+
rng: random.Random,
|
| 215 |
+
exclude_image: Optional[str] = None,
|
| 216 |
+
) -> List[PoolItem]:
|
| 217 |
+
if not pool or k <= 0:
|
| 218 |
+
return []
|
| 219 |
+
cand = [p for p in pool if p.image_path != exclude_image]
|
| 220 |
+
if not cand:
|
| 221 |
+
cand = pool
|
| 222 |
+
if len(cand) >= k:
|
| 223 |
+
return rng.sample(cand, k=k)
|
| 224 |
+
return [rng.choice(cand) for _ in range(k)]
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def write_jsonl(path: Path, records: List[Dict]) -> None:
|
| 228 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 229 |
+
with path.open("w", encoding="utf-8") as f:
|
| 230 |
+
for r in records:
|
| 231 |
+
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def build_prompt_only_record(
|
| 235 |
+
*,
|
| 236 |
+
uid: str,
|
| 237 |
+
instruction: str,
|
| 238 |
+
image_rel: str,
|
| 239 |
+
question: str,
|
| 240 |
+
expected_first_tag: str,
|
| 241 |
+
answer: str,
|
| 242 |
+
category: str,
|
| 243 |
+
subdir: str,
|
| 244 |
+
shots: List[Dict],
|
| 245 |
+
k_shot: int,
|
| 246 |
+
) -> Dict:
|
| 247 |
+
human = []
|
| 248 |
+
if instruction:
|
| 249 |
+
human.append(instruction.strip())
|
| 250 |
+
human.append("<image>")
|
| 251 |
+
human.append(f"Question: {question}\nAction:")
|
| 252 |
+
human_value = "\n".join([x for x in human if x]).strip()
|
| 253 |
+
|
| 254 |
+
return {
|
| 255 |
+
"id": uid,
|
| 256 |
+
"images": [image_rel],
|
| 257 |
+
"conversations": [
|
| 258 |
+
{"from": "human", "value": human_value},
|
| 259 |
+
],
|
| 260 |
+
"expected_first_tag": expected_first_tag,
|
| 261 |
+
"answer": answer,
|
| 262 |
+
"k_shot": k_shot,
|
| 263 |
+
"shots": shots,
|
| 264 |
+
"category": category,
|
| 265 |
+
"subdir": subdir,
|
| 266 |
+
"instruction": instruction,
|
| 267 |
+
"query": {"image": image_rel, "question": question},
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def main() -> int:
|
| 272 |
+
ap = argparse.ArgumentParser(description="Build prompt-only eval set (ShareGPT jsonl) for <RET>/<ANS> decision.")
|
| 273 |
+
ap.add_argument("--dataset-root", default="/workspace/M3IT")
|
| 274 |
+
ap.add_argument("--output-dir", default="/workspace/M3IT_new/ICL_eval")
|
| 275 |
+
ap.add_argument("--category", default="vqa")
|
| 276 |
+
ap.add_argument("--split", default="val", help="Never use train; recommend val/test/dev.")
|
| 277 |
+
ap.add_argument("--pool-split", default="val", help="Never use train; recommend val/test/dev.")
|
| 278 |
+
ap.add_argument("--seed", type=int, default=42)
|
| 279 |
+
|
| 280 |
+
ap.add_argument("--total", type=int, default=100)
|
| 281 |
+
ap.add_argument("--ret-ratio", type=float, default=0.5)
|
| 282 |
+
ap.add_argument("--query-pool-size", type=int, default=1000, help="How many queries to collect before sampling.")
|
| 283 |
+
ap.add_argument("--max-samples-per-subdir", type=int, default=2000, help="Scan cap per subdir when collecting queries.")
|
| 284 |
+
ap.add_argument("--pool-size-per-subdir", type=int, default=2000, help="Max pool size to build per subdir (for shots).")
|
| 285 |
+
ap.add_argument("--pool-scan-per-subdir", type=int, default=4000, help="Scan cap per subdir when building pools.")
|
| 286 |
+
ap.add_argument("--shot-k-min", type=int, default=1)
|
| 287 |
+
ap.add_argument("--shot-k-max", type=int, default=3)
|
| 288 |
+
|
| 289 |
+
ap.add_argument(
|
| 290 |
+
"--exclude-uids-from",
|
| 291 |
+
default="/workspace/M3IT_new/ICL/vqa/merged_shuffled_sharegpt.jsonl",
|
| 292 |
+
help="Optional jsonl to exclude uids/ids (to avoid overlap with training).",
|
| 293 |
+
)
|
| 294 |
+
ap.add_argument("--overwrite", action="store_true")
|
| 295 |
+
ap.add_argument("--output", default=None, help="Default: {output_dir}/{category}/eval_sharegpt_{total}.jsonl")
|
| 296 |
+
args = ap.parse_args()
|
| 297 |
+
|
| 298 |
+
if args.split.strip().lower() == "train" or args.pool_split.strip().lower() == "train":
|
| 299 |
+
raise ValueError("split/pool-split=train is not allowed for eval set")
|
| 300 |
+
if args.total <= 0:
|
| 301 |
+
raise ValueError("total must be > 0")
|
| 302 |
+
if not (0.0 <= args.ret_ratio <= 1.0):
|
| 303 |
+
raise ValueError("ret-ratio must be in [0, 1]")
|
| 304 |
+
if args.shot_k_min <= 0 or args.shot_k_max < args.shot_k_min:
|
| 305 |
+
raise ValueError("invalid shot-k range")
|
| 306 |
+
|
| 307 |
+
dataset_root = Path(args.dataset_root)
|
| 308 |
+
output_dir = Path(args.output_dir)
|
| 309 |
+
cache_dir = output_dir / "_image_cache"
|
| 310 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 311 |
+
|
| 312 |
+
out_path = Path(
|
| 313 |
+
args.output
|
| 314 |
+
if args.output
|
| 315 |
+
else str(output_dir / args.category / f"eval_sharegpt_{args.total}.jsonl")
|
| 316 |
+
)
|
| 317 |
+
if out_path.exists() and not args.overwrite:
|
| 318 |
+
raise FileExistsError(f"Output exists: {out_path} (use --overwrite to replace)")
|
| 319 |
+
|
| 320 |
+
subdirs = discover_subdirs(dataset_root, args.category)
|
| 321 |
+
if not subdirs:
|
| 322 |
+
raise FileNotFoundError(f"No subdirs found under {dataset_root}/data/{args.category}")
|
| 323 |
+
|
| 324 |
+
exclude_uids = load_exclude_uids(args.exclude_uids_from)
|
| 325 |
+
|
| 326 |
+
# Load instructions once per subdir.
|
| 327 |
+
inst_map: Dict[str, List[str]] = {sd: load_instructions(dataset_root, sd) for sd in subdirs}
|
| 328 |
+
|
| 329 |
+
query_pool_target = max(args.total, args.query_pool_size)
|
| 330 |
+
queries = collect_query_pool(
|
| 331 |
+
dataset_root=dataset_root,
|
| 332 |
+
subdirs=subdirs,
|
| 333 |
+
split=args.split,
|
| 334 |
+
cache_dir=cache_dir,
|
| 335 |
+
exclude_uids=exclude_uids,
|
| 336 |
+
target_n=query_pool_target,
|
| 337 |
+
seed=args.seed,
|
| 338 |
+
max_samples_per_subdir=args.max_samples_per_subdir,
|
| 339 |
+
)
|
| 340 |
+
if len(queries) < args.total:
|
| 341 |
+
raise RuntimeError(f"Not enough queries collected: got {len(queries)}/{args.total}. "
|
| 342 |
+
f"Try increasing --max-samples-per-subdir or changing --split.")
|
| 343 |
+
|
| 344 |
+
rng = random.Random(args.seed)
|
| 345 |
+
ret_n = int(round(args.total * args.ret_ratio))
|
| 346 |
+
ret_n = max(0, min(args.total, ret_n))
|
| 347 |
+
ans_n = args.total - ret_n
|
| 348 |
+
|
| 349 |
+
labels = ["RET"] * ret_n + ["ANS"] * ans_n
|
| 350 |
+
rng.shuffle(labels)
|
| 351 |
+
|
| 352 |
+
# Pools are built lazily per subdir for RET samples only.
|
| 353 |
+
pool_map: Dict[str, List[PoolItem]] = {}
|
| 354 |
+
|
| 355 |
+
records: List[Dict] = []
|
| 356 |
+
used_uids: Set[str] = set()
|
| 357 |
+
|
| 358 |
+
for label in labels:
|
| 359 |
+
for _try in range(2000):
|
| 360 |
+
q = rng.choice(queries)
|
| 361 |
+
if q.uid in used_uids:
|
| 362 |
+
continue
|
| 363 |
+
|
| 364 |
+
inst = pick_instruction(inst_map.get(q.subdir, []), rng)
|
| 365 |
+
image_rel = to_rel(q.image_path, output_dir)
|
| 366 |
+
|
| 367 |
+
if label == "ANS":
|
| 368 |
+
used_uids.add(q.uid)
|
| 369 |
+
records.append(
|
| 370 |
+
build_prompt_only_record(
|
| 371 |
+
uid=q.uid,
|
| 372 |
+
instruction=inst,
|
| 373 |
+
image_rel=image_rel,
|
| 374 |
+
question=q.question,
|
| 375 |
+
expected_first_tag="<ANS>",
|
| 376 |
+
answer=q.answer,
|
| 377 |
+
category=args.category,
|
| 378 |
+
subdir=q.subdir,
|
| 379 |
+
shots=[],
|
| 380 |
+
k_shot=0,
|
| 381 |
+
)
|
| 382 |
+
)
|
| 383 |
+
break
|
| 384 |
+
|
| 385 |
+
# RET case: attach hidden shots for the follow-up step.
|
| 386 |
+
if q.subdir not in pool_map:
|
| 387 |
+
pool_map[q.subdir] = build_pool_for_subdir(
|
| 388 |
+
dataset_root=dataset_root,
|
| 389 |
+
subdir=q.subdir,
|
| 390 |
+
split=args.pool_split,
|
| 391 |
+
cache_dir=cache_dir,
|
| 392 |
+
target_n=args.pool_size_per_subdir,
|
| 393 |
+
max_samples_scan=args.pool_scan_per_subdir,
|
| 394 |
+
)
|
| 395 |
+
pool = pool_map.get(q.subdir, [])
|
| 396 |
+
if not pool:
|
| 397 |
+
continue
|
| 398 |
+
|
| 399 |
+
k = rng.randint(args.shot_k_min, args.shot_k_max)
|
| 400 |
+
shots_items = select_shots(pool, k, rng, exclude_image=q.image_path)
|
| 401 |
+
shots = [
|
| 402 |
+
{"image": to_rel(s.image_path, output_dir), "description": s.description}
|
| 403 |
+
for s in shots_items
|
| 404 |
+
]
|
| 405 |
+
used_uids.add(q.uid)
|
| 406 |
+
records.append(
|
| 407 |
+
build_prompt_only_record(
|
| 408 |
+
uid=q.uid,
|
| 409 |
+
instruction=inst,
|
| 410 |
+
image_rel=image_rel,
|
| 411 |
+
question=q.question,
|
| 412 |
+
expected_first_tag="<RET>",
|
| 413 |
+
answer=q.answer,
|
| 414 |
+
category=args.category,
|
| 415 |
+
subdir=q.subdir,
|
| 416 |
+
shots=shots,
|
| 417 |
+
k_shot=k,
|
| 418 |
+
)
|
| 419 |
+
)
|
| 420 |
+
break
|
| 421 |
+
else:
|
| 422 |
+
raise RuntimeError(f"Failed to sample enough records for label={label}. "
|
| 423 |
+
f"Try increasing --query-pool-size or relaxing --exclude-uids-from.")
|
| 424 |
+
|
| 425 |
+
rng.shuffle(records)
|
| 426 |
+
write_jsonl(out_path, records)
|
| 427 |
+
|
| 428 |
+
# Lightweight summary to stdout.
|
| 429 |
+
ret_cnt = sum(1 for r in records if r.get("expected_first_tag") == "<RET>")
|
| 430 |
+
ans_cnt = len(records) - ret_cnt
|
| 431 |
+
print(f"[OK] wrote={len(records)} ret={ret_cnt} ans={ans_cnt} -> {out_path}")
|
| 432 |
+
print(f"[INFO] image_root (for eval): {output_dir}")
|
| 433 |
+
return 0
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
if __name__ == "__main__":
|
| 437 |
+
raise SystemExit(main())
|
ICL/LV/code/SFT/check_kshot_ret_ans.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Check whether the model outputs <RET> or <ANS> under different shot settings.
|
| 5 |
+
|
| 6 |
+
0-shot: only query image + question.
|
| 7 |
+
K-shot (K>=1): after the model outputs <RET>, append 1 shot (image+description) and
|
| 8 |
+
ask again. If it still outputs <RET>, append another shot, and so on.
|
| 9 |
+
We record whether it outputs <RET>/<ANS> at each step.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import random
|
| 16 |
+
import re
|
| 17 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from PIL import Image
|
| 21 |
+
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
TAG_RE = re.compile(r"(<ANS>|<RET>)")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _extract_tag(text: str) -> Optional[str]:
|
| 28 |
+
match = TAG_RE.search(text)
|
| 29 |
+
return match.group(1) if match else None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _resolve_path(root: str, maybe_rel: str) -> str:
|
| 33 |
+
if os.path.isabs(maybe_rel):
|
| 34 |
+
return maybe_rel
|
| 35 |
+
return os.path.normpath(os.path.join(root, maybe_rel))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _split_user_text_with_images(
|
| 39 |
+
text: str, image_paths: List[str]
|
| 40 |
+
) -> Tuple[List[Dict[str, str]], List[str]]:
|
| 41 |
+
parts = text.split("<image>")
|
| 42 |
+
content: List[Dict[str, str]] = []
|
| 43 |
+
used: List[str] = []
|
| 44 |
+
for i, part in enumerate(parts):
|
| 45 |
+
part = part.strip()
|
| 46 |
+
if part:
|
| 47 |
+
content.append({"type": "text", "text": part})
|
| 48 |
+
if i < len(parts) - 1:
|
| 49 |
+
if not image_paths:
|
| 50 |
+
raise ValueError("用户文本里 <image> 数量 > images 列表长度")
|
| 51 |
+
img_path = image_paths.pop(0)
|
| 52 |
+
used.append(img_path)
|
| 53 |
+
content.append({"type": "image", "image": img_path})
|
| 54 |
+
return content, used
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _append_human_turn(
|
| 58 |
+
*,
|
| 59 |
+
messages: List[Dict[str, Any]],
|
| 60 |
+
pil_images: List[Image.Image],
|
| 61 |
+
image_root: str,
|
| 62 |
+
text: str,
|
| 63 |
+
images_all: List[str],
|
| 64 |
+
image_cursor: int,
|
| 65 |
+
) -> int:
|
| 66 |
+
n_placeholders = text.count("<image>")
|
| 67 |
+
img_paths = images_all[image_cursor : image_cursor + n_placeholders]
|
| 68 |
+
if len(img_paths) != n_placeholders:
|
| 69 |
+
raise ValueError("images 列表长度 < <image> 占位符数量")
|
| 70 |
+
|
| 71 |
+
user_content, used_paths = _split_user_text_with_images(text, img_paths.copy())
|
| 72 |
+
for p in used_paths:
|
| 73 |
+
p = _resolve_path(image_root, p) if not os.path.isabs(p) else p
|
| 74 |
+
if not os.path.exists(p):
|
| 75 |
+
raise FileNotFoundError(p)
|
| 76 |
+
with Image.open(p) as img:
|
| 77 |
+
pil_images.append(img.convert("RGB"))
|
| 78 |
+
|
| 79 |
+
messages.append({"role": "user", "content": user_content})
|
| 80 |
+
return image_cursor + n_placeholders
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _append_shot_turn(
|
| 84 |
+
*,
|
| 85 |
+
messages: List[Dict[str, Any]],
|
| 86 |
+
pil_images: List[Image.Image],
|
| 87 |
+
image_root: str,
|
| 88 |
+
image_path: str,
|
| 89 |
+
description: str,
|
| 90 |
+
) -> None:
|
| 91 |
+
img_path = _resolve_path(image_root, image_path)
|
| 92 |
+
if not os.path.exists(img_path):
|
| 93 |
+
raise FileNotFoundError(img_path)
|
| 94 |
+
with Image.open(img_path) as im:
|
| 95 |
+
pil_images.append(im.convert("RGB"))
|
| 96 |
+
messages.append(
|
| 97 |
+
{
|
| 98 |
+
"role": "user",
|
| 99 |
+
"content": [
|
| 100 |
+
{"type": "image", "image": img_path},
|
| 101 |
+
{"type": "text", "text": f"Description: {description}"},
|
| 102 |
+
],
|
| 103 |
+
}
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _build_base_messages(obj: Dict[str, Any], image_root: str) -> Tuple[List[Dict[str, Any]], List[Image.Image]]:
|
| 108 |
+
conversations = obj.get("conversations")
|
| 109 |
+
if not isinstance(conversations, list) or not conversations:
|
| 110 |
+
raise ValueError("样本缺少 conversations")
|
| 111 |
+
images_rel = obj.get("images") or []
|
| 112 |
+
if not isinstance(images_rel, list):
|
| 113 |
+
raise ValueError("images 字段不是 list")
|
| 114 |
+
|
| 115 |
+
messages: List[Dict[str, Any]] = []
|
| 116 |
+
pil_images: List[Image.Image] = []
|
| 117 |
+
image_cursor = 0
|
| 118 |
+
|
| 119 |
+
# Use the first human turn as query prompt.
|
| 120 |
+
human = None
|
| 121 |
+
for t in conversations:
|
| 122 |
+
if t.get("from") == "human":
|
| 123 |
+
human = t
|
| 124 |
+
break
|
| 125 |
+
if human is None:
|
| 126 |
+
raise ValueError("没有 human turn")
|
| 127 |
+
|
| 128 |
+
image_cursor = _append_human_turn(
|
| 129 |
+
messages=messages,
|
| 130 |
+
pil_images=pil_images,
|
| 131 |
+
image_root=image_root,
|
| 132 |
+
text=str(human.get("value", "")),
|
| 133 |
+
images_all=images_rel,
|
| 134 |
+
image_cursor=image_cursor,
|
| 135 |
+
)
|
| 136 |
+
return messages, pil_images
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _pick_shots_from_pool(
|
| 140 |
+
pool: List[Dict[str, str]],
|
| 141 |
+
k: int,
|
| 142 |
+
rng: random.Random,
|
| 143 |
+
exclude_image: Optional[str],
|
| 144 |
+
) -> List[Dict[str, str]]:
|
| 145 |
+
if k <= 0 or not pool:
|
| 146 |
+
return []
|
| 147 |
+
cand = [p for p in pool if p.get("image") != exclude_image]
|
| 148 |
+
if not cand:
|
| 149 |
+
cand = pool
|
| 150 |
+
if len(cand) >= k:
|
| 151 |
+
return rng.sample(cand, k=k)
|
| 152 |
+
return [rng.choice(cand) for _ in range(k)]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def main() -> int:
|
| 156 |
+
ap = argparse.ArgumentParser(description="Check <RET>/<ANS> outputs under 0/1/2/3-shot settings.")
|
| 157 |
+
ap.add_argument("--model", required=True, help="HF model dir")
|
| 158 |
+
ap.add_argument(
|
| 159 |
+
"--data",
|
| 160 |
+
default="/workspace/M3IT_new/ICL_eval/vqa/eval_sharegpt_100.jsonl",
|
| 161 |
+
help="Prompt-only eval jsonl (has conversations/images/shots).",
|
| 162 |
+
)
|
| 163 |
+
ap.add_argument("--image-root", default="/workspace/M3IT_new/ICL_eval")
|
| 164 |
+
ap.add_argument("--num-samples", type=int, default=20)
|
| 165 |
+
ap.add_argument("--seed", type=int, default=42)
|
| 166 |
+
ap.add_argument("--k-list", default="0,1,2,3", help="Comma-separated shot counts to report.")
|
| 167 |
+
ap.add_argument("--max-new-tokens", type=int, default=128)
|
| 168 |
+
ap.add_argument("--device", default="cuda:0")
|
| 169 |
+
ap.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16")
|
| 170 |
+
ap.add_argument("--print-samples", action="store_true", help="Print each sample input/output.")
|
| 171 |
+
args = ap.parse_args()
|
| 172 |
+
|
| 173 |
+
rng = random.Random(args.seed)
|
| 174 |
+
k_list = [int(x.strip()) for x in args.k_list.split(",") if x.strip()]
|
| 175 |
+
if not k_list:
|
| 176 |
+
raise ValueError("k-list is empty")
|
| 177 |
+
max_k = max(k_list)
|
| 178 |
+
|
| 179 |
+
# Load dataset
|
| 180 |
+
data: List[Dict[str, Any]] = []
|
| 181 |
+
with open(args.data, "r", encoding="utf-8") as f:
|
| 182 |
+
for line in f:
|
| 183 |
+
line = line.strip()
|
| 184 |
+
if not line:
|
| 185 |
+
continue
|
| 186 |
+
data.append(json.loads(line))
|
| 187 |
+
if not data:
|
| 188 |
+
raise ValueError("empty data")
|
| 189 |
+
|
| 190 |
+
# Build global shot pool
|
| 191 |
+
pool: List[Dict[str, str]] = []
|
| 192 |
+
for obj in data:
|
| 193 |
+
shots = obj.get("shots") or []
|
| 194 |
+
if isinstance(shots, list):
|
| 195 |
+
for s in shots:
|
| 196 |
+
if not isinstance(s, dict):
|
| 197 |
+
continue
|
| 198 |
+
img = s.get("image")
|
| 199 |
+
desc = s.get("description")
|
| 200 |
+
if isinstance(img, str) and isinstance(desc, str) and img and desc:
|
| 201 |
+
pool.append({"image": img, "description": desc})
|
| 202 |
+
if not pool:
|
| 203 |
+
print("[WARN] shot pool is empty, k-shot tests may be skipped")
|
| 204 |
+
|
| 205 |
+
# Sample records
|
| 206 |
+
samples = rng.sample(data, k=min(args.num_samples, len(data)))
|
| 207 |
+
|
| 208 |
+
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
|
| 209 |
+
device = torch.device(args.device)
|
| 210 |
+
processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
|
| 211 |
+
model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 212 |
+
args.model, dtype=dtype, trust_remote_code=True
|
| 213 |
+
).to(device)
|
| 214 |
+
model.eval()
|
| 215 |
+
|
| 216 |
+
summary: Dict[int, Dict[str, int]] = {
|
| 217 |
+
k: {"RET": 0, "ANS": 0, "NONE": 0, "REACHED": 0} for k in k_list
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
for obj in samples:
|
| 221 |
+
uid = obj.get("id") or obj.get("uid") or "unknown"
|
| 222 |
+
query = obj.get("query") or {}
|
| 223 |
+
query_image = query.get("image")
|
| 224 |
+
|
| 225 |
+
messages, pil_images = _build_base_messages(obj, args.image_root)
|
| 226 |
+
|
| 227 |
+
# Pre-sample shots for this sample (use first N as we go).
|
| 228 |
+
shots_all = _pick_shots_from_pool(pool, max_k, rng, exclude_image=query_image)
|
| 229 |
+
if max_k > 0 and len(shots_all) < max_k:
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
step = 0 # step 0 = query only
|
| 233 |
+
while True:
|
| 234 |
+
prompt = processor.apply_chat_template(
|
| 235 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 236 |
+
)
|
| 237 |
+
inputs = processor(text=prompt, images=pil_images, padding=True, return_tensors="pt")
|
| 238 |
+
inputs = {k2: v.to(device) for k2, v in inputs.items()}
|
| 239 |
+
|
| 240 |
+
with torch.inference_mode():
|
| 241 |
+
out_ids = model.generate(
|
| 242 |
+
**inputs, do_sample=False, max_new_tokens=args.max_new_tokens
|
| 243 |
+
)
|
| 244 |
+
in_len = int(inputs["input_ids"].shape[1])
|
| 245 |
+
pred = processor.batch_decode(out_ids[:, in_len:], skip_special_tokens=True)[0].strip()
|
| 246 |
+
tag = _extract_tag(pred)
|
| 247 |
+
|
| 248 |
+
# Append model output to the conversation
|
| 249 |
+
messages.append({"role": "assistant", "content": [{"type": "text", "text": pred}]})
|
| 250 |
+
|
| 251 |
+
if step in summary:
|
| 252 |
+
summary[step]["REACHED"] += 1
|
| 253 |
+
if tag == "<RET>":
|
| 254 |
+
summary[step]["RET"] += 1
|
| 255 |
+
elif tag == "<ANS>":
|
| 256 |
+
summary[step]["ANS"] += 1
|
| 257 |
+
else:
|
| 258 |
+
summary[step]["NONE"] += 1
|
| 259 |
+
|
| 260 |
+
if args.print_samples:
|
| 261 |
+
print("=" * 80)
|
| 262 |
+
print(f"uid={uid} | step={step} | pred_tag={tag}")
|
| 263 |
+
for m in messages:
|
| 264 |
+
role = m.get("role")
|
| 265 |
+
if role == "user":
|
| 266 |
+
parts = []
|
| 267 |
+
for c in m.get("content", []):
|
| 268 |
+
if c.get("type") == "text":
|
| 269 |
+
parts.append(c.get("text", ""))
|
| 270 |
+
elif c.get("type") == "image":
|
| 271 |
+
parts.append(f"<image> {c.get('image','')}")
|
| 272 |
+
print("[输入]")
|
| 273 |
+
print("\n".join([p for p in parts if p]).strip())
|
| 274 |
+
elif role == "assistant":
|
| 275 |
+
parts = []
|
| 276 |
+
for c in m.get("content", []):
|
| 277 |
+
if c.get("type") == "text":
|
| 278 |
+
parts.append(c.get("text", ""))
|
| 279 |
+
print("[输入-助手]")
|
| 280 |
+
print("\n".join([p for p in parts if p]).strip())
|
| 281 |
+
print("[输出]")
|
| 282 |
+
print(pred)
|
| 283 |
+
|
| 284 |
+
# Stop if not <RET> or reached max_k shots
|
| 285 |
+
if tag != "<RET>":
|
| 286 |
+
break
|
| 287 |
+
if step >= max_k:
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
+
# Append next shot and ask again.
|
| 291 |
+
shot = shots_all[step] if step < len(shots_all) else None
|
| 292 |
+
if not shot:
|
| 293 |
+
break
|
| 294 |
+
_append_shot_turn(
|
| 295 |
+
messages=messages,
|
| 296 |
+
pil_images=pil_images,
|
| 297 |
+
image_root=args.image_root,
|
| 298 |
+
image_path=shot["image"],
|
| 299 |
+
description=shot["description"],
|
| 300 |
+
)
|
| 301 |
+
# Ask for decision again after each shot.
|
| 302 |
+
messages.append({"role": "user", "content": [{"type": "text", "text": "Action:"}]})
|
| 303 |
+
step += 1
|
| 304 |
+
|
| 305 |
+
print("=== summary ===")
|
| 306 |
+
for k in k_list:
|
| 307 |
+
s = summary[k]
|
| 308 |
+
reached = s["REACHED"]
|
| 309 |
+
if reached == 0:
|
| 310 |
+
print(f"k={k}: no samples")
|
| 311 |
+
continue
|
| 312 |
+
print(
|
| 313 |
+
f"k={k} | RET={s['RET']} | ANS={s['ANS']} | NONE={s['NONE']} | reached={reached}"
|
| 314 |
+
)
|
| 315 |
+
return 0
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
if __name__ == "__main__":
|
| 319 |
+
raise SystemExit(main())
|
ICL/LV/code/SFT/cuda-keyring_1.1-1_all.deb
ADDED
|
Binary file (4.33 kB). View file
|
|
|
ICL/LV/code/SFT/prepare_dataset.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
预先生成数据集缓存
|
| 4 |
+
运行一次后,训练时直接加载缓存,避免超时
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import pickle
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# 添加当前目录到路径
|
| 13 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
|
| 15 |
+
from config import get_config
|
| 16 |
+
from dataset import SFTDataset
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
config = get_config()
|
| 21 |
+
|
| 22 |
+
cache_path = Path(config.training.output_dir) / ".dataset_cache.pkl"
|
| 23 |
+
ready_flag = Path(config.training.output_dir) / ".dataset_ready"
|
| 24 |
+
|
| 25 |
+
# 创建输出目录
|
| 26 |
+
os.makedirs(config.training.output_dir, exist_ok=True)
|
| 27 |
+
|
| 28 |
+
print("=" * 60)
|
| 29 |
+
print("预生成 SFT 数据集缓存")
|
| 30 |
+
print("=" * 60)
|
| 31 |
+
print(f"数据目录: {config.data.sft_data_dir}")
|
| 32 |
+
print(f"缓存路径: {cache_path}")
|
| 33 |
+
print("=" * 60)
|
| 34 |
+
|
| 35 |
+
# 加载数据集
|
| 36 |
+
print("\n开始加载数据集...")
|
| 37 |
+
dataset = SFTDataset(config.data, split="train")
|
| 38 |
+
|
| 39 |
+
print(f"\n数据集加载完成!共 {len(dataset)} 个样本")
|
| 40 |
+
|
| 41 |
+
# 保存缓存
|
| 42 |
+
print(f"\n保存缓存到: {cache_path}")
|
| 43 |
+
with open(cache_path, "wb") as f:
|
| 44 |
+
pickle.dump(dataset.samples, f)
|
| 45 |
+
|
| 46 |
+
# 创建就绪标记
|
| 47 |
+
ready_flag.touch()
|
| 48 |
+
|
| 49 |
+
print(f"就绪标记: {ready_flag}")
|
| 50 |
+
print("\n" + "=" * 60)
|
| 51 |
+
print("缓存生成完成!现在可以运行 bash run_train.sh")
|
| 52 |
+
print("=" * 60)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
main()
|
ICL/LV/code/adapters/gemma3_adapter.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from adapters._runners.gemma3_infer import Gemma3Runner
|
| 7 |
+
except Exception:
|
| 8 |
+
Gemma3Runner = None # type: ignore
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Adapter:
|
| 12 |
+
def __init__(self, model_path: str):
|
| 13 |
+
if Gemma3Runner is None:
|
| 14 |
+
raise RuntimeError('Gemma3Runner unavailable. Ensure gemma3-code is on PYTHONPATH or install its runner.')
|
| 15 |
+
self.runner = Gemma3Runner(model_path)
|
| 16 |
+
|
| 17 |
+
def generate_from_segments(self, segs: List[Dict[str, str]], *,
|
| 18 |
+
temperature: float, top_p: float, max_new_tokens: int) -> str:
|
| 19 |
+
gen = getattr(self.runner, 'generate_from_qwen_segs', None)
|
| 20 |
+
if gen is None:
|
| 21 |
+
raise RuntimeError('Gemma3Runner missing generate_from_qwen_segs')
|
| 22 |
+
return gen(segs, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def create(model_path: str) -> Adapter:
|
| 26 |
+
return Adapter(model_path)
|
| 27 |
+
|
ICL/LV/code/adapters/qwen3vl_adapter.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from adapters._runners.qwen3_vl_infer import Qwen3VLRunner
|
| 7 |
+
except Exception:
|
| 8 |
+
Qwen3VLRunner = None # type: ignore
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Adapter:
|
| 12 |
+
def __init__(self, model_path: str):
|
| 13 |
+
if Qwen3VLRunner is None:
|
| 14 |
+
raise RuntimeError('Qwen3VLRunner unavailable. Ensure QWEN3VL-code is on PYTHONPATH or install its runner.')
|
| 15 |
+
self.runner = Qwen3VLRunner(model_path)
|
| 16 |
+
|
| 17 |
+
def generate_from_segments(self, segs: List[Dict[str, str]], *,
|
| 18 |
+
temperature: float, top_p: float, max_new_tokens: int) -> str:
|
| 19 |
+
gen = getattr(self.runner, 'generate_from_segments', None)
|
| 20 |
+
if gen is None:
|
| 21 |
+
raise RuntimeError('Qwen3VLRunner missing generate_from_segments')
|
| 22 |
+
return gen(segs, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def create(model_path: str) -> Adapter:
|
| 26 |
+
return Adapter(model_path)
|
| 27 |
+
|
ICL/LV/code/attn map/attn map/attn map/__pycache__/token_attention_utils.cpython-313.pyc
ADDED
|
Binary file (2.57 kB). View file
|
|
|