Lekr0 commited on
Commit
8765573
·
verified ·
1 Parent(s): 5de3d77

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ICL/DAPO/verl-recipe/char_count/README.md +59 -0
  2. ICL/DAPO/verl-recipe/char_count/create_dataset.py +198 -0
  3. ICL/DAPO/verl-recipe/char_count/reward_function.py +34 -0
  4. ICL/DAPO/verl-recipe/char_count/train_grpo.sh +45 -0
  5. ICL/DAPO/verl-recipe/char_count/train_sft.sh +97 -0
  6. ICL/DAPO/verl-recipe/collabllm/README.md +74 -0
  7. ICL/DAPO/verl-recipe/collabllm/utils.py +280 -0
  8. ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_8b_base_npu.sh +138 -0
  9. ICL/DAPO/verl-recipe/deepeyes/deepeyes.py +408 -0
  10. ICL/DAPO/verl-recipe/fault_recover/async_llm.py +84 -0
  11. ICL/DAPO/verl-recipe/flash_rl_ascend/README.md +121 -0
  12. ICL/DAPO/verl-recipe/flowrl/README.md +182 -0
  13. ICL/DAPO/verl-recipe/flowrl/__init__.py +17 -0
  14. ICL/DAPO/verl-recipe/flowrl/flowrl_fsdp_worker.py +495 -0
  15. ICL/DAPO/verl-recipe/flowrl/main_flowrl.py +185 -0
  16. ICL/DAPO/verl-recipe/flowrl/run_flowrl_qwen2.5_7b.sh +134 -0
  17. ICL/DAPO/verl-recipe/infigui-g1/README.md +56 -0
  18. ICL/DAPO/verl-recipe/langgraph_agent/__init__.py +13 -0
  19. ICL/DAPO/verl-recipe/langgraph_agent/chat_model.py +393 -0
  20. ICL/DAPO/verl-recipe/langgraph_agent/react_agent_loop.py +188 -0
  21. ICL/DAPO/verl-recipe/langgraph_agent/test_react_agent_loop.py +202 -0
  22. ICL/DAPO/verl-recipe/minicpmo/rl_dataset.py +571 -0
  23. ICL/DAPO/verl-recipe/prime/__init__.py +13 -0
  24. ICL/DAPO/verl-recipe/prime/prime_core_algos.py +147 -0
  25. ICL/DAPO/verl-recipe/prime/run_prime_qwen_code.sh +61 -0
  26. ICL/DAPO/verl-recipe/r1/run_r1_distill_qwen.sh +33 -0
  27. ICL/DAPO/verl-recipe/r1_ascend/Dockerfile.vllm_ascend.mindspeed.deepseekV3 +82 -0
  28. ICL/DAPO/verl-recipe/r1_ascend/README.md +119 -0
  29. ICL/DAPO/verl-recipe/r1_ascend/README_zh.md +119 -0
  30. ICL/DAPO/verl-recipe/r1_ascend/ray_start_grpo_npu.sh +82 -0
  31. ICL/DAPO/verl-recipe/r1_ascend/vllm_rollout_spmd.py +347 -0
  32. ICL/DAPO/verl-recipe/rep_exp/README.md +71 -0
  33. ICL/DAPO/verl-recipe/rep_exp/eval.sh +83 -0
  34. ICL/DAPO/verl-recipe/rep_exp/main_rep_exp.py +483 -0
  35. ICL/DAPO/verl-recipe/rep_exp/metric_utils.py +382 -0
  36. ICL/DAPO/verl-recipe/rep_exp/model_merge.sh +6 -0
  37. ICL/DAPO/verl-recipe/rep_exp/plot_pass_at_k.py +241 -0
  38. ICL/DAPO/verl-recipe/rep_exp/rep_exp_trainer.py +739 -0
  39. ICL/DAPO/verl-recipe/spin/core_algos.py +206 -0
  40. ICL/DAPO/verl-recipe/spin/main_spin.py +168 -0
  41. ICL/DAPO/verl-recipe/spin/spin_trainer.py +1312 -0
  42. ICL/LV/code/README.md +66 -0
  43. ICL/LV/code/SFT/__pycache__/dataset.cpython-310.pyc +0 -0
  44. ICL/LV/code/SFT/build_icl_eval_sharegpt.py +437 -0
  45. ICL/LV/code/SFT/check_kshot_ret_ans.py +319 -0
  46. ICL/LV/code/SFT/cuda-keyring_1.1-1_all.deb +0 -0
  47. ICL/LV/code/SFT/prepare_dataset.py +56 -0
  48. ICL/LV/code/adapters/gemma3_adapter.py +27 -0
  49. ICL/LV/code/adapters/qwen3vl_adapter.py +27 -0
  50. ICL/LV/code/attn map/attn map/attn map/__pycache__/token_attention_utils.cpython-313.pyc +0 -0
ICL/DAPO/verl-recipe/char_count/README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Char Count
2
+ ## Introduction
3
+ Char count is a simple NLP task. We create it for beginners to grasp the idea of RLVR. The task can be trained using a tiny model (e.g., https://huggingface.co/HuggingFaceTB/SmolLM2-135M) on a consumer GPU with only 8GB.
4
+
5
+ ## Problem formulation
6
+ The prompt is: "How many {char} are there in {word}?". In order for LLM to better answer this question, we create SFT dataset with intermediate steps. For example,
7
+
8
+ ```text
9
+ Question: How many n are there in n-i-n-e?
10
+ Answer:
11
+ n = n
12
+ i != n
13
+ n = n
14
+ e != n
15
+ \boxed{2}
16
+ ```
17
+
18
+ Note that
19
+ - We add a dash between each individual char to make the task easier because each individual char will be tokenized to the same token by most tokenizer.
20
+ - In the SFT dataset, we create a CoT by listing all the individual chars and whether it equals to the target. In the end, it outputs the final answer inside the box.
21
+ - The task can be verified.
22
+ - The word is not always meaningful. Each char is sampled uniformly from a to z. We make the total length and the answer uniformly distributed within a range.
23
+
24
+ ## Scripts
25
+ Installation
26
+
27
+ ```bash
28
+ pip install verl==0.6.1
29
+ ```
30
+
31
+
32
+ To create the dataset, run
33
+ ```bash
34
+ python3 create_dataset.py
35
+ ```
36
+ We create a train set and a val set. Both of them are used of SFT and RL. You can specify the total number of data, min/max length and data path.
37
+
38
+ To run the SFT
39
+ ```bash
40
+ BACKEND=fsdp bash train_sft.sh # use fsdp
41
+ BACKEND=megatron bash train_sft.sh # use megatron
42
+ ```
43
+ We train SFT for 1 epoch. After 1 epoch, the validation score is around 0.435.
44
+
45
+ Merge checkpoint trained from SFT
46
+ ```bash
47
+ # sft
48
+ export CKPT_PATH=$HOME/experiments/char_count/models/sft/fsdp/global_step_140
49
+ python3 -m verl.model_merger merge --backend fsdp --local_dir $CKPT_PATH --target_dir $CKPT_PATH/huggingface/
50
+ # megatron
51
+ export CKPT_PATH=$HOME/experiments/char_count/models/sft/megatron/global_step_140
52
+ python3 -m verl.model_merger merge --backend megatron --local_dir $CKPT_PATH --target_dir $CKPT_PATH/huggingface/
53
+ ```
54
+
55
+ To run GRPO
56
+ ```bash
57
+ bash train_grpo.sh
58
+ ```
59
+ We train GRPO for 2 epochs. After 2 epochs, the validation score is around 0.6.
ICL/DAPO/verl-recipe/char_count/create_dataset.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Task description:
17
+ Given a random word and a random char, count the number of occurrence of char in the word.
18
+
19
+ Create CoT dataset that split the word into separate char. Then list the char and count the occurrence.
20
+
21
+ The word set comes from shakespeare
22
+ """
23
+
24
+ import os.path
25
+ import random
26
+
27
+ prompt_template = "How many {} are there in word {}?"
28
+
29
+
30
+ def generate_random_char():
31
+ return chr(97 + random.randint(0, 25))
32
+
33
+
34
+ def create_prompt_response(min_length=3, max_length=5):
35
+ # randomly generate a length
36
+ word_length = random.randint(min_length, max_length)
37
+ # randomly generate a target count number. This makes the target number
38
+ target_count_number = random.randint(1, word_length)
39
+
40
+ char_lst = []
41
+ # generate the word
42
+ # step 1: generate the target word
43
+ target_char = generate_random_char()
44
+
45
+ for _ in range(target_count_number):
46
+ char_lst.append(target_char)
47
+
48
+ # step 2: generate other words
49
+ for _ in range(word_length - target_count_number):
50
+ while True:
51
+ char = generate_random_char()
52
+ if char != target_char:
53
+ char_lst.append(char)
54
+ break
55
+
56
+ # step 3: random permute char_lst
57
+ random.shuffle(char_lst)
58
+
59
+ word = "-".join(char_lst)
60
+
61
+ prompt = prompt_template.format(target_char, word)
62
+ final_answer = []
63
+
64
+ # cot
65
+ number = 0
66
+ for i, char in enumerate(char_lst):
67
+ cot = f"{char}"
68
+ if char != target_char:
69
+ cot += " != "
70
+ else:
71
+ cot += " = "
72
+ number += 1
73
+ cot += f"{target_char}."
74
+
75
+ final_answer.append(cot)
76
+
77
+ conclusion = f"\\boxed{{{number}}} {target_char} in {word}."
78
+
79
+ final_answer.append(conclusion)
80
+
81
+ final_answer = "\n".join(final_answer)
82
+
83
+ return prompt, final_answer
84
+
85
+
86
+ if __name__ == "__main__":
87
+ import argparse
88
+
89
+ parser = argparse.ArgumentParser()
90
+ parser.add_argument("--total_number", type=int, default=10000)
91
+ parser.add_argument("--min_length", type=int, default=5)
92
+ parser.add_argument("--max_length", type=int, default=20)
93
+ parser.add_argument("--data_path", type=str, default="~/data/char_count")
94
+
95
+ args = vars(parser.parse_args())
96
+
97
+ total_number = args["total_number"]
98
+ min_length = args["min_length"]
99
+ max_length = args["max_length"]
100
+ data_path = args["data_path"]
101
+ data_path = os.path.expanduser(data_path)
102
+
103
+ full_output = []
104
+ for _ in range(total_number):
105
+ output = create_prompt_response(min_length=min_length, max_length=max_length)
106
+ full_output.append(output)
107
+
108
+ # random reorder
109
+ random.shuffle(full_output)
110
+
111
+ # split for train and test
112
+ train_split_len = int(0.9 * len(full_output))
113
+ train_outputs = full_output[:train_split_len]
114
+ test_output = full_output[train_split_len:]
115
+
116
+ sft_train_dataset = {"messages": []}
117
+
118
+ for o in train_outputs:
119
+ messages = [
120
+ {"role": "user", "content": o[0]},
121
+ {"role": "assistant", "content": o[1]},
122
+ ]
123
+
124
+ sft_train_dataset["messages"].append(messages)
125
+
126
+ sft_test_dataset = {"messages": []}
127
+
128
+ for o in test_output:
129
+ messages = [
130
+ {"role": "user", "content": o[0]},
131
+ {"role": "assistant", "content": o[1]},
132
+ ]
133
+ sft_test_dataset["messages"].append(messages)
134
+
135
+ import pandas as pd
136
+
137
+ sft_train_dataset = pd.DataFrame(data=sft_train_dataset)
138
+ sft_test_dataset = pd.DataFrame(data=sft_test_dataset)
139
+
140
+ folder = os.path.join(data_path, "sft")
141
+
142
+ os.makedirs(folder, exist_ok=True)
143
+
144
+ sft_train_dataset.to_parquet(os.path.join(folder, "train.parquet"))
145
+ sft_test_dataset.to_parquet(os.path.join(folder, "test.parquet"))
146
+
147
+ # build RL dataset
148
+ rl_train_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []}
149
+
150
+ rl_test_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []}
151
+
152
+ from verl.utils.reward_score.math_reward import last_boxed_only_string, remove_boxed
153
+
154
+ for o in train_outputs:
155
+ prompt = o[0]
156
+ response = o[1]
157
+ prompt_with_template = [
158
+ {
159
+ "role": "user",
160
+ "content": prompt,
161
+ }
162
+ ]
163
+
164
+ rl_train_dataset["prompt"].append(prompt_with_template)
165
+ rl_train_dataset["data_source"].append("char_count")
166
+ rl_train_dataset["ability"].append("other")
167
+ rl_train_dataset["reward_model"].append(
168
+ {"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))}
169
+ )
170
+ rl_train_dataset["extra_info"].append({"response": response})
171
+
172
+ for o in test_output:
173
+ prompt = o[0]
174
+ response = o[1]
175
+ prompt_with_template = [
176
+ {
177
+ "role": "user",
178
+ "content": prompt,
179
+ }
180
+ ]
181
+
182
+ rl_test_dataset["prompt"].append(prompt_with_template)
183
+ rl_test_dataset["data_source"].append("char_count")
184
+ rl_test_dataset["ability"].append("other")
185
+ rl_test_dataset["reward_model"].append(
186
+ {"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))}
187
+ )
188
+ rl_test_dataset["extra_info"].append({"response": response})
189
+
190
+ rl_train_dataset = pd.DataFrame(data=rl_train_dataset)
191
+ rl_test_dataset = pd.DataFrame(data=rl_test_dataset)
192
+
193
+ folder = os.path.join(data_path, "rl")
194
+
195
+ os.makedirs(folder, exist_ok=True)
196
+
197
+ rl_train_dataset.to_parquet(os.path.join(folder, "train.parquet"))
198
+ rl_test_dataset.to_parquet(os.path.join(folder, "test.parquet"))
ICL/DAPO/verl-recipe/char_count/reward_function.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Reward function
17
+ """
18
+
19
+ from verl.utils.reward_score import math_reward
20
+
21
+
22
+ def char_count_reward_function(data_source, solution_str, ground_truth, extra_info=None):
23
+ try:
24
+ last_boxed_string = math_reward.last_boxed_only_string(solution_str)
25
+ if last_boxed_string is None:
26
+ return 0
27
+ solution = math_reward.remove_boxed(last_boxed_string)
28
+ if solution == ground_truth:
29
+ return 1
30
+ else:
31
+ return 0
32
+ except Exception:
33
+ print(ground_truth, solution_str)
34
+ return 0
ICL/DAPO/verl-recipe/char_count/train_grpo.sh ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+
4
+ python3 -m verl.trainer.main_ppo \
5
+ algorithm.adv_estimator=grpo \
6
+ data.train_files=$HOME/data/char_count/rl/train.parquet \
7
+ data.val_files=$HOME/data/char_count/rl/test.parquet \
8
+ data.train_batch_size=128 \
9
+ data.max_prompt_length=128 \
10
+ data.max_response_length=128 \
11
+ data.filter_overlong_prompts=False \
12
+ data.truncation='error' \
13
+ actor_rollout_ref.model.path=$HOME/experiments/char_count/models/sft/megatron/global_step_140/huggingface \
14
+ actor_rollout_ref.actor.optim.lr=1e-6 \
15
+ actor_rollout_ref.model.use_remove_padding=True \
16
+ actor_rollout_ref.actor.ppo_mini_batch_size=16 \
17
+ actor_rollout_ref.actor.use_dynamic_bsz=True \
18
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5000 \
19
+ actor_rollout_ref.actor.use_kl_loss=False \
20
+ actor_rollout_ref.actor.kl_loss_coef=0.0 \
21
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
22
+ actor_rollout_ref.actor.entropy_coeff=0 \
23
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
24
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
25
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
26
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
27
+ actor_rollout_ref.rollout.name=vllm \
28
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
29
+ actor_rollout_ref.rollout.n=8 \
30
+ actor_rollout_ref.rollout.enforce_eager=True \
31
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
32
+ algorithm.use_kl_in_reward=False \
33
+ trainer.critic_warmup=0 \
34
+ trainer.logger='["console","tensorboard"]' \
35
+ trainer.project_name='verl_example' \
36
+ trainer.experiment_name='smol135m_grpo-1128a1' \
37
+ trainer.val_before_train=True \
38
+ trainer.n_gpus_per_node=1 \
39
+ trainer.nnodes=1 \
40
+ trainer.save_freq=-1 \
41
+ trainer.test_freq=5 \
42
+ trainer.total_epochs=5 \
43
+ trainer.use_legacy_worker_impl=disable \
44
+ custom_reward_function.path=./reward_function.py \
45
+ custom_reward_function.name=char_count_reward_function
ICL/DAPO/verl-recipe/char_count/train_sft.sh ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -xeuo pipefail
3
+
4
+ ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"}
5
+
6
+ TRAIN_FILES=${TRAIN_FILES:-$HOME/data/char_count/sft/train.parquet}
7
+ TEST_FILES=${TEST_FILES:-$HOME/data/char_count/sft/test.parquet}
8
+
9
+ backend=${BACKEND:-fsdp}
10
+
11
+ project_name=char_count-sft
12
+
13
+ RESUME_MODE=auto
14
+ MODEL_ID=${MODEL_ID:-HuggingFaceTB/SmolLM2-135M-Instruct}
15
+
16
+ SP_SIZE=${SP_SIZE:-1}
17
+ FSDP_SIZE=${FSDP_SIZE:-1}
18
+ FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp2"}
19
+
20
+ TP_SIZE=${TP_SIZE:-1}
21
+ PP_SIZE=${PP_SIZE:-1}
22
+ VPP_SIZE=${VPP_SIZE:-null}
23
+ CP_SIZE=${CP_SIZE:-1}
24
+
25
+ PAD_MODE=${PAD_MODE:-no_padding}
26
+
27
+ USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
28
+
29
+ FSDP_ENGINE_CONFIG="\
30
+ engine=${backend} \
31
+ optim=${backend} \
32
+ optim.lr=2e-5 \
33
+ optim.lr_warmup_steps_ratio=0.01 \
34
+ optim.weight_decay=0.1 \
35
+ optim.betas="[0.9,0.95]" \
36
+ optim.clip_grad=1.0 \
37
+ optim.min_lr_ratio=0.1 \
38
+ optim.warmup_style=cosine \
39
+ engine.ulysses_sequence_parallel_size=${SP_SIZE} \
40
+ engine.strategy=${FSDP_STRATEGY} \
41
+ engine.fsdp_size=${FSDP_SIZE}"
42
+
43
+
44
+ MEGATRON_ENGINE_CONFIG="\
45
+ engine=${backend} \
46
+ optim=${backend} \
47
+ optim.lr=2e-5 \
48
+ optim.lr_warmup_steps_ratio=0.01 \
49
+ optim.weight_decay=0.1 \
50
+ optim.betas="[0.9,0.95]" \
51
+ optim.clip_grad=1.0 \
52
+ optim.lr_warmup_init=0 \
53
+ optim.lr_decay_style=cosine \
54
+ optim.min_lr=2e-6 \
55
+ engine.tensor_model_parallel_size=${TP_SIZE} \
56
+ engine.pipeline_model_parallel_size=${PP_SIZE} \
57
+ engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
58
+ engine.context_parallel_size=${CP_SIZE} \
59
+ engine.use_mbridge=False"
60
+
61
+ if [ "$backend" = "fsdp" ]; then
62
+ ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
63
+ echo "Using fsdp engine"
64
+ exp_name=char_count-sft-SmolLM2-135M-Instruct-fsdp
65
+ else
66
+ ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
67
+ echo "Using megatron engine"
68
+ exp_name=char_count-sft-SmolLM2-135M-Instruct-megatron
69
+ fi
70
+
71
+ CKPT_HOME=${CKPT_HOME:-$HOME/experiments/char_count/models/sft/$backend}
72
+ mkdir -p "${CKPT_HOME}"
73
+
74
+ torchrun --standalone --nnodes=1 --nproc-per-node=${NUM_TRAINERS:-1} \
75
+ ${ENTRYPOINT} \
76
+ data.train_files="${TRAIN_FILES}" \
77
+ data.train_batch_size=64 \
78
+ data.val_files="${TEST_FILES}" \
79
+ data.max_length=256 \
80
+ data.pad_mode=${PAD_MODE} \
81
+ data.truncation=error \
82
+ data.use_dynamic_bsz=True \
83
+ data.max_token_len_per_gpu=1792 \
84
+ data.messages_key=messages \
85
+ model.path=$MODEL_ID \
86
+ model.use_remove_padding=${USE_REMOVE_PADDING} \
87
+ ${ENGINE_CONFIG} \
88
+ trainer.test_freq=-1 \
89
+ trainer.save_freq=70 \
90
+ trainer.logger=['console'] \
91
+ trainer.project_name="${project_name}" \
92
+ trainer.experiment_name="${exp_name}" \
93
+ trainer.total_epochs=1 \
94
+ trainer.default_local_dir="${CKPT_HOME}" \
95
+ trainer.resume_mode=${RESUME_MODE} \
96
+ trainer.max_ckpt_to_keep=5 \
97
+ checkpoint.save_contents=[model,optimizer,extra]
ICL/DAPO/verl-recipe/collabllm/README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CollabLLM
2
+
3
+ This repository implements [CollabLLM](https://arxiv.org/pdf/2502.00640) (ICML 2025) using the verl framework. For the original implementation, see the [CollabLLM repository](https://github.com/Wuyxin/collabllm).
4
+
5
+
6
+ CollabLLM is a method for training language models to collaborate effectively in multi-turn conversations. This implementation adapts the original imlpementation to work with the Verl training framework.
7
+
8
+ ## Quick start
9
+
10
+ ### 0. Environment
11
+ Make sure the required packages for `verl` are installed. Additionally, install `litellm` and export the required API keys. The API model will be used for user simulators and, optionally, LLM Judges (see the Configuration section below).
12
+
13
+ ### 1. Prepare Your Dataset
14
+
15
+ First, process your dataset using the provided script:
16
+
17
+ ```bash
18
+ python process_dataset.py --dataset <> ... --dataset_type <sft or rl>
19
+ ```
20
+
21
+
22
+ **Requirements:**
23
+ - Input: A Hugging Face multiturn dataset. Existing datasets: `collabllm/collabllm-multiturn-$DATASET`, with `DATASET` in one of [`math-hard(-large)`, `medium(-large)`, `bigcodebench(-large)`] (*-large are the datasets used in the CollabLLM paper)
24
+ - Example format: See [collabllm-multiturn-math-hard](https://huggingface.co/datasets/collabllm/collabllm-multiturn-math-hard)
25
+ - To generate your own dataset: Use [build_dataset.py](https://github.com/Wuyxin/collabllm/blob/main/scripts/engine/build_dataset.py) from the original CollabLLM repository
26
+
27
+ *Note: Check `process_dataset.py` for example commands and usage.*
28
+
29
+ ### 2. Train Your Model
30
+
31
+ **(Optional) For Supervised Fine-Tuning (SFT):**
32
+ ```bash
33
+ bash train_sft_collabllm.sh
34
+ ```
35
+
36
+ **For Reinforcement Learning (RL):**
37
+
38
+ ```bash
39
+ bash train_rl_collabllm.sh
40
+ ```
41
+
42
+ The RL script shows an example to train CollabLLM on `math-hard-large`.
43
+
44
+ - The config to sample future conversations are in `recipe/collabllm/config/collabllm_interaction_config.yaml`.
45
+ - The Multiturn-aware Reward is aggregated from these three conversational-level rewards:
46
+
47
+ ```
48
+ +reward_model.reward_kwargs.metric_weights.accuracy=1 \
49
+ +reward_model.reward_kwargs.metric_weights.interactivity=1 \
50
+ +reward_model.reward_kwargs.metric_weights.token_amount=-0.0001 \
51
+ ```
52
+
53
+ You can remove, add, or modify the weights depending on your task. A list of implemented metrics you can already add are under `recipe/collabllm/metrics`. For example, on `medium-large`, you can replace `accuracy` with `bleu_score` via
54
+ ```
55
+ +reward_model.reward_kwargs.metric_weights.bleu_score=1
56
+ ```
57
+ which will instead apply bleu score on the sampled future conversations.
58
+
59
+ ## Configuration
60
+ Read [doc](https://verl.readthedocs.io/en/latest/) for detailed configurations.
61
+
62
+ ## Citation
63
+ If you find CollabLLM useful in your research, please cite the following:
64
+
65
+ ```bibtex
66
+ @inproceedings{collabllm2025,
67
+ title={CollabLLM: From Passive Responders to Active Collaborators},
68
+ author={Shirley Wu and Michel Galley and Baolin Peng and Hao Cheng and
69
+ Gavin Li and Yao Dou and Weixin Cai and James Zou and
70
+ Jure Leskovec and Jianfeng Gao},
71
+ booktitle={International Conference on Machine Learning (ICML)},
72
+ year={2025}
73
+ }
74
+ ```
ICL/DAPO/verl-recipe/collabllm/utils.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 CollabLLM team and/or its affiliates
2
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import logging
16
+ import os
17
+ import re
18
+
19
+ logger = logging.getLogger(__file__)
20
+ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
21
+
22
+
23
+ def parse_messages(messages, strip_sys_prompt=True):
24
+ """
25
+ Args:
26
+ messages: List[dict]
27
+ List of dictionaries with keys 'role' and 'content'
28
+ Example: messages = [{'role': 'user', 'content': 'Hello!'},
29
+ {'role': 'assistant', 'content': 'Hi!'}, ...]
30
+ """
31
+ if messages is None:
32
+ return ""
33
+
34
+ if strip_sys_prompt:
35
+ messages = strip_system_prompt(messages)
36
+
37
+ chat = "\n".join(f"**{m.role.capitalize()}**: {m.content}" for m in messages)
38
+
39
+ return chat
40
+
41
+
42
+ def strip_system_prompt(messages):
43
+ """
44
+ Args:
45
+ messages: List[dict]
46
+ List of dictionaries with keys 'role' and 'content'
47
+ Example: messages = [{'role': 'user', 'content': 'Hello!'},
48
+ {'role': 'assistant', 'content': 'Hi!'}, ...]
49
+ """
50
+ return [msg for msg in messages if msg.role != "system"]
51
+
52
+
53
+ def extract_json(s):
54
+ def convert_value(value):
55
+ true_values = {"true": True, "false": False, "null": None}
56
+ value_lower = value.lower()
57
+ if value_lower in true_values:
58
+ return true_values[value_lower]
59
+ try:
60
+ if "." in value or "e" in value.lower():
61
+ return float(value)
62
+ else:
63
+ return int(value)
64
+ except ValueError:
65
+ return value # Return as string if not a number
66
+
67
+ def parse_number(s, pos):
68
+ start = pos
69
+ while pos < len(s) and s[pos] in "-+0123456789.eE":
70
+ pos += 1
71
+ num_str = s[start:pos]
72
+ try:
73
+ if "." in num_str or "e" in num_str.lower():
74
+ return float(num_str), pos
75
+ else:
76
+ return int(num_str), pos
77
+ except ValueError:
78
+ logger.error(f"Invalid number at position {start}: {num_str}")
79
+ raise
80
+
81
+ def skip_whitespace(s, pos):
82
+ while pos < len(s) and s[pos] in " \t\n\r":
83
+ pos += 1
84
+ return pos
85
+
86
+ def parse_string(s, pos):
87
+ quote_char = s[pos]
88
+ assert quote_char in ('"', "'")
89
+ pos += 1
90
+ result = ""
91
+ while pos < len(s):
92
+ c = s[pos]
93
+ if c == "\\":
94
+ pos += 1
95
+ if pos >= len(s):
96
+ raise ValueError("Invalid escape sequence")
97
+ c = s[pos]
98
+ escape_sequences = {"n": "\n", "t": "\t", "r": "\r", "\\": "\\", quote_char: quote_char}
99
+ result += escape_sequences.get(c, c)
100
+ elif c == quote_char:
101
+ pos += 1
102
+ # Attempt to convert to a number if possible
103
+ converted_value = convert_value(result)
104
+ return converted_value, pos
105
+ else:
106
+ result += c
107
+ pos += 1
108
+ raise ValueError("Unterminated string")
109
+
110
+ def parse_key(s, pos):
111
+ pos = skip_whitespace(s, pos)
112
+ if s[pos] in ('"', "'"):
113
+ key, pos = parse_string(s, pos)
114
+ return key, pos
115
+ else:
116
+ raise ValueError(f"Expected string for key at position {pos}")
117
+
118
+ def parse_object(s, pos):
119
+ obj = {}
120
+ assert s[pos] == "{"
121
+ pos += 1
122
+ pos = skip_whitespace(s, pos)
123
+ while pos < len(s) and s[pos] != "}":
124
+ pos = skip_whitespace(s, pos)
125
+ key, pos = parse_key(s, pos)
126
+ pos = skip_whitespace(s, pos)
127
+ if pos >= len(s) or s[pos] != ":":
128
+ raise ValueError(f'Expected ":" at position {pos}')
129
+ pos += 1
130
+ pos = skip_whitespace(s, pos)
131
+ value, pos = parse_value(s, pos)
132
+ obj[key] = value
133
+ pos = skip_whitespace(s, pos)
134
+ if pos < len(s) and s[pos] == ",":
135
+ pos += 1
136
+ pos = skip_whitespace(s, pos)
137
+ elif pos < len(s) and s[pos] == "}":
138
+ break
139
+ elif pos < len(s) and s[pos] != "}":
140
+ raise ValueError(f'Expected "," or "}}" at position {pos}')
141
+ if pos >= len(s) or s[pos] != "}":
142
+ raise ValueError(f'Expected "}}" at position {pos}')
143
+ pos += 1
144
+ return obj, pos
145
+
146
+ def parse_array(s, pos):
147
+ lst = []
148
+ assert s[pos] == "["
149
+ pos += 1
150
+ pos = skip_whitespace(s, pos)
151
+ while pos < len(s) and s[pos] != "]":
152
+ value, pos = parse_value(s, pos)
153
+ lst.append(value)
154
+ pos = skip_whitespace(s, pos)
155
+ if pos < len(s) and s[pos] == ",":
156
+ pos += 1
157
+ pos = skip_whitespace(s, pos)
158
+ elif pos < len(s) and s[pos] == "]":
159
+ break
160
+ elif pos < len(s) and s[pos] != "]":
161
+ raise ValueError(f'Expected "," or "]" at position {pos}')
162
+ if pos >= len(s) or s[pos] != "]":
163
+ raise ValueError(f'Expected "]" at position {pos}')
164
+ pos += 1
165
+ return lst, pos
166
+
167
+ def parse_triple_quoted_string(s, pos):
168
+ if s[pos : pos + 3] == "'''":
169
+ quote_str = "'''"
170
+ elif s[pos : pos + 3] == '"""':
171
+ quote_str = '"""'
172
+ else:
173
+ raise ValueError(f"Expected triple quotes at position {pos}")
174
+ pos += 3
175
+ result = ""
176
+ while pos < len(s):
177
+ if s[pos : pos + 3] == quote_str:
178
+ pos += 3
179
+ # Attempt to convert to a number if possible
180
+ converted_value = convert_value(result)
181
+ return converted_value, pos
182
+ else:
183
+ result += s[pos]
184
+ pos += 1
185
+ raise ValueError("Unterminated triple-quoted string")
186
+
187
+ def parse_value(s, pos):
188
+ pos = skip_whitespace(s, pos)
189
+ if pos >= len(s):
190
+ raise ValueError("Unexpected end of input")
191
+ if s[pos] == "{":
192
+ return parse_object(s, pos)
193
+ elif s[pos] == "[":
194
+ return parse_array(s, pos)
195
+ elif s[pos : pos + 3] in ("'''", '"""'):
196
+ return parse_triple_quoted_string(s, pos)
197
+ elif s[pos] in ('"', "'"):
198
+ return parse_string(s, pos)
199
+ elif s[pos : pos + 4].lower() == "true":
200
+ return True, pos + 4
201
+ elif s[pos : pos + 5].lower() == "false":
202
+ return False, pos + 5
203
+ elif s[pos : pos + 4].lower() == "null":
204
+ return None, pos + 4
205
+ elif s[pos] in "-+0123456789.":
206
+ return parse_number(s, pos)
207
+ else:
208
+ raise ValueError(f"Unexpected character at position {pos}: {s[pos]}")
209
+
210
+ json_start = s.index("{")
211
+ json_end = s.rfind("}")
212
+ s = s[json_start : json_end + 1]
213
+
214
+ s = s.strip()
215
+ result, pos = parse_value(s, 0)
216
+ pos = skip_whitespace(s, pos)
217
+ if pos != len(s):
218
+ raise ValueError(f"Unexpected content at position {pos}")
219
+ return result
220
+
221
+
222
+ def remove_think_block(msg: dict):
223
+ """
224
+ remove <think>.*?</think> from content
225
+ """
226
+ if "content" in msg and isinstance(msg["content"], str):
227
+ msg["content"] = re.sub(r"<think>.*?</think>", "", msg["content"], flags=re.DOTALL).strip()
228
+ return msg
229
+
230
+
231
+ def is_valid_messages(msg: dict) -> bool:
232
+ """
233
+ check if is valid messages, including:
234
+ 1. <think> is paried with </think>
235
+ 2. is not empty inside and outside <think>
236
+ 3. is not nested, and at most one <think> block is allowed.
237
+ 4. can not be empty if remove ending "<|im_end|>"
238
+ """
239
+ content = msg.get("content")
240
+ if not isinstance(content, str):
241
+ return True
242
+
243
+ # Base case: empty or whitespace-only content is invalid.
244
+ if not content.strip():
245
+ return False
246
+
247
+ num_think_open = content.count("<think>")
248
+ num_think_close = content.count("</think>")
249
+
250
+ # Rule 1: Check for paired tags.
251
+ if num_think_open != num_think_close:
252
+ return False
253
+
254
+ # Rule 3: Allow at most one think block.
255
+ if num_think_open > 1:
256
+ return False
257
+
258
+ # Case 1: No <think> blocks.
259
+ if num_think_open == 0:
260
+ visible_content = content
261
+ # Case 2: Exactly one <think> block.
262
+ else:
263
+ # Rule 2: Check for empty content inside the think block.
264
+ match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
265
+ if not match or not match.group(1).strip():
266
+ return False
267
+
268
+ # The "visible" content is what's outside the think block.
269
+ visible_content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL)
270
+
271
+ visible_content = visible_content.strip()
272
+
273
+ # Rule 4 & 2 (outside): Check if visible content is empty after handling <|im_end|>.
274
+ if visible_content.endswith("<|im_end|>"):
275
+ visible_content = visible_content[: -len("<|im_end|>")]
276
+
277
+ if not visible_content.strip():
278
+ return False
279
+
280
+ return True
ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_8b_base_npu.sh ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ project_name='DAPO'
3
+ exp_name='DAPO-Qwen3-8B-Base'
4
+
5
+ adv_estimator=grpo
6
+
7
+ use_kl_in_reward=False
8
+ kl_coef=0.0
9
+ use_kl_loss=False
10
+ kl_loss_coef=0.0
11
+
12
+ clip_ratio_low=0.2
13
+ clip_ratio_high=0.28
14
+
15
+ max_prompt_length=$((1024 * 2))
16
+ max_response_length=$((1024 * 20))
17
+ enable_overlong_buffer=True
18
+ overlong_buffer_len=$((1024 * 4))
19
+ overlong_penalty_factor=1.0
20
+
21
+ loss_agg_mode="token-mean"
22
+
23
+ enable_filter_groups=False
24
+ filter_groups_metric=acc
25
+ max_num_gen_batches=10
26
+ train_prompt_bsz=16
27
+ gen_prompt_bsz=$((train_prompt_bsz * 3))
28
+ n_resp_per_prompt=16
29
+ train_prompt_mini_bsz=1
30
+
31
+ # Ray
32
+ RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
33
+ WORKING_DIR=${WORKING_DIR:-"${PWD}"}
34
+ RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
35
+ NNODES=${NNODES:-1}
36
+ # Paths
37
+ RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
38
+ MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-8B-Base"}
39
+ CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
40
+ TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
41
+ TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
42
+
43
+ # Algorithm
44
+ temperature=1.0
45
+ top_p=1.0
46
+ top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
47
+
48
+ # Performance Related Parameter
49
+ sp_size=2
50
+ use_dynamic_bsz=True
51
+ actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))
52
+ infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size))
53
+ offload=True
54
+ gen_tp=2
55
+
56
+ ray job submit --runtime-env="${RUNTIME_ENV}" \
57
+ -- python3 -m recipe.dapo.main_dapo \
58
+ data.train_files="${TRAIN_FILE}" \
59
+ data.val_files="${TEST_FILE}" \
60
+ data.prompt_key=prompt \
61
+ data.truncation='left' \
62
+ data.max_prompt_length=${max_prompt_length} \
63
+ data.max_response_length=${max_response_length} \
64
+ data.gen_batch_size=${gen_prompt_bsz} \
65
+ data.train_batch_size=${train_prompt_bsz} \
66
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
67
+ algorithm.adv_estimator=${adv_estimator} \
68
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
69
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
70
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
71
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
72
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
73
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
74
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
75
+ algorithm.filter_groups.enable=${enable_filter_groups} \
76
+ algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
77
+ algorithm.filter_groups.metric=${filter_groups_metric} \
78
+ actor_rollout_ref.model.use_remove_padding=True \
79
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
80
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
81
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
82
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
83
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
84
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
85
+ actor_rollout_ref.model.path="${MODEL_PATH}" \
86
+ +actor_rollout_ref.model.override_config.attention_dropout=0. \
87
+ +actor_rollout_ref.model.override_config.embd_pdrop=0. \
88
+ +actor_rollout_ref.model.override_config.resid_pdrop=0. \
89
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
90
+ actor_rollout_ref.actor.optim.lr=1e-6 \
91
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
92
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
93
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
94
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
95
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
96
+ actor_rollout_ref.actor.entropy_coeff=0 \
97
+ actor_rollout_ref.actor.grad_clip=1.0 \
98
+ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
99
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
100
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.90 \
101
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
102
+ actor_rollout_ref.rollout.enable_chunked_prefill=False \
103
+ actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
104
+ actor_rollout_ref.rollout.temperature=${temperature} \
105
+ actor_rollout_ref.rollout.top_p=${top_p} \
106
+ actor_rollout_ref.rollout.top_k="${top_k}" \
107
+ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
108
+ actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \
109
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
110
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
111
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
112
+ actor_rollout_ref.rollout.name=vllm \
113
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
114
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
115
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
116
+ reward_model.reward_manager=dapo \
117
+ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
118
+ reward_model.overlong_buffer.len=${overlong_buffer_len} \
119
+ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
120
+ trainer.logger=['console'] \
121
+ trainer.project_name="${project_name}" \
122
+ trainer.experiment_name="${exp_name}" \
123
+ trainer.n_gpus_per_node=8 \
124
+ trainer.nnodes="${NNODES}" \
125
+ trainer.val_before_train=False \
126
+ trainer.test_freq=10 \
127
+ trainer.save_freq=20 \
128
+ trainer.total_epochs=1 \
129
+ trainer.total_training_steps=100 \
130
+ trainer.default_local_dir="${CKPTS_DIR}" \
131
+ trainer.resume_mode=auto \
132
+ data.shuffle=False \
133
+ actor_rollout_ref.actor.use_torch_compile=False \
134
+ actor_rollout_ref.ref.use_torch_compile=False \
135
+ actor_rollout_ref.actor.entropy_checkpointing=True \
136
+ actor_rollout_ref.ref.entropy_checkpointing=True \
137
+ actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
138
+ actor_rollout_ref.ref.fsdp_config.forward_prefetch=True
ICL/DAPO/verl-recipe/deepeyes/deepeyes.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import io
15
+ import logging
16
+ import os
17
+ import random
18
+ import re
19
+
20
+ import requests
21
+ from openai import OpenAI
22
+ from PIL import Image
23
+
24
+ import verl.utils.torch_functional as verl_F
25
+ from verl.utils.dataset.rl_dataset import RLHFDataset
26
+ from verl.utils.model import compute_position_id_with_mask
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ openai_api_key = "EMPTY"
31
+ openai_api_base = os.environ.get("LLM_AS_A_JUDGE_BASE", "http://10.1.100.71:18901/v1")
32
+
33
+ client = OpenAI(
34
+ api_key=openai_api_key,
35
+ base_url=openai_api_base,
36
+ )
37
+
38
+ model_name = ""
39
+ if openai_api_base:
40
+ try:
41
+ response = requests.get(f"{openai_api_base}/models")
42
+ response.raise_for_status()
43
+ models = response.json()
44
+ if models.get("data"):
45
+ model_name = models["data"][0]["id"]
46
+ else:
47
+ logger.warning("No models found at the specified API base for reward scoring.")
48
+ except (requests.exceptions.RequestException, KeyError, IndexError) as e:
49
+ logger.warning(f"Failed to get model from {openai_api_base}: {e}. Reward scoring will be disabled.")
50
+
51
+
52
+ class CustomRLHFDataset(RLHFDataset):
53
+ def __getitem__(self, item):
54
+ """
55
+ Note that we also return the raw_input_ids so that it can be combined with other chat template
56
+ """
57
+ row_dict: dict = self.dataframe[item]
58
+ row_dict[self.prompt_key] = [
59
+ {
60
+ "role": "system",
61
+ # We don't need tool description, because custom_chat_template will add it.
62
+ "content": (
63
+ "You are a helpful assistant. You can call functions to assist with the user query. "
64
+ "Important: You must call only one function at a time. After each function call, "
65
+ "wait for the execution result before making the next function call if needed."
66
+ ),
67
+ },
68
+ {
69
+ "role": "user",
70
+ "content": row_dict[self.prompt_key][1]["content"],
71
+ },
72
+ ]
73
+
74
+ images = []
75
+ row_dict_images = row_dict.get(self.image_key, None)
76
+ if row_dict_images:
77
+ images = [Image.open(io.BytesIO(image["bytes"])) for image in row_dict_images]
78
+ messages = self._build_messages(row_dict)
79
+
80
+ if self.processor is not None:
81
+ raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
82
+ model_inputs = self.processor(text=[raw_prompt], images=images, return_tensors="pt")
83
+
84
+ input_ids = model_inputs.pop("input_ids")
85
+ attention_mask = model_inputs.pop("attention_mask")
86
+
87
+ if "second_per_grid_ts" in model_inputs:
88
+ model_inputs.pop("second_per_grid_ts")
89
+
90
+ else:
91
+ raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
92
+ model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
93
+ input_ids = model_inputs.pop("input_ids")
94
+ attention_mask = model_inputs.pop("attention_mask")
95
+
96
+ input_ids, attention_mask = verl_F.postprocess_data(
97
+ input_ids=input_ids,
98
+ attention_mask=attention_mask,
99
+ max_length=self.max_prompt_length,
100
+ pad_token_id=self.tokenizer.pad_token_id,
101
+ left_pad=True,
102
+ truncation=self.truncation,
103
+ )
104
+
105
+ if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
106
+ from verl.models.transformers.qwen2_vl import get_rope_index
107
+
108
+ position_ids = [
109
+ get_rope_index(
110
+ self.processor,
111
+ input_ids=input_ids[0],
112
+ image_grid_thw=model_inputs.get("image_grid_thw"),
113
+ video_grid_thw=model_inputs.get("video_grid_thw"),
114
+ second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
115
+ attention_mask=attention_mask[0],
116
+ )
117
+ ] # (1, 3, seq_len)
118
+
119
+ else:
120
+ position_ids = compute_position_id_with_mask(attention_mask)
121
+
122
+ row_dict["input_ids"] = input_ids[0]
123
+ row_dict["attention_mask"] = attention_mask[0]
124
+ row_dict["position_ids"] = position_ids[0]
125
+
126
+ raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
127
+ if len(raw_prompt_ids) > self.max_prompt_length:
128
+ if self.truncation == "left":
129
+ raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
130
+ elif self.truncation == "right":
131
+ raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
132
+ elif self.truncation == "middle":
133
+ left_half = self.max_prompt_length // 2
134
+ right_half = self.max_prompt_length - left_half
135
+ raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]
136
+ elif self.truncation == "error":
137
+ raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")
138
+
139
+ row_dict["raw_prompt_ids"] = raw_prompt_ids
140
+ # encode prompts without chat template
141
+ if self.return_raw_chat:
142
+ row_dict["raw_prompt"] = messages
143
+
144
+ # get prompts with chat template
145
+ if self.return_full_prompt:
146
+ row_dict["full_prompts"] = raw_prompt # array of strings
147
+
148
+ # add index for each prompt
149
+ index = row_dict.get("extra_info", {}).get("index", 0)
150
+ tools_kwargs = {
151
+ "image_zoom_in_tool": {
152
+ "create_kwargs": {"image": images[0]},
153
+ # "execute_kwargs": {},
154
+ # "calc_reward_kwargs": {},
155
+ # "release_kwargs": {},
156
+ }
157
+ }
158
+ row_dict["index"] = index
159
+ row_dict["tools_kwargs"] = tools_kwargs
160
+ row_dict["agent_name"] = "tool_agent"
161
+ return row_dict
162
+
163
+
164
+ def compute_score(data_source: str, solution_str: str, ground_truth: str, extra_info=None) -> float:
165
+ """
166
+ Compute reward score for model solutions with robust handling of various formats.
167
+
168
+ Returns a weighted combination of:
169
+ - Accuracy reward (0.8 weight): Whether the answer is semantically correct
170
+ - Format reward (0.2 weight): Whether the output follows expected format
171
+ - Tool reward (1.2 weight): Whether tools were used when answer is correct
172
+ """
173
+
174
+ # Initialize tracking variables
175
+ is_format_error = False
176
+
177
+ # 1. Check <think> tag format
178
+ count_think_1 = solution_str.count("<think>")
179
+ count_think_2 = solution_str.count("</think>")
180
+ if count_think_1 != count_think_2:
181
+ is_format_error = True
182
+
183
+ # 2. Check vision tokens (skip this since tokenizer removes special tokens)
184
+ # We'll use <tool_call> and <tool_response> instead to detect tool usage
185
+
186
+ # 3. Extract answer text with multiple fallback strategies
187
+ answer_text = ""
188
+
189
+ # Strategy 1: Try to extract from <answer> tags first
190
+ predict_no_think = (
191
+ solution_str.split("</think>")[-1].strip() if "</think>" in solution_str else solution_str.strip()
192
+ )
193
+
194
+ # Check <answer> tag format
195
+ count_answer_1 = predict_no_think.count("<answer>")
196
+ count_answer_2 = predict_no_think.count("</answer>")
197
+ if count_answer_1 != count_answer_2:
198
+ is_format_error = True
199
+
200
+ # Try to extract from <answer> tags
201
+ answer_match = re.search(r"<answer>(.*?)</answer>", predict_no_think, re.DOTALL)
202
+ if answer_match:
203
+ answer_text = answer_match.group(1).strip()
204
+ else:
205
+ # No proper <answer> tags found - this is a format error
206
+ is_format_error = True
207
+
208
+ # Strategy 2: If no <answer> tags, extract content after tool responses
209
+ # Look for pattern: <tool_response>...</tool_response>assistant\n[actual_answer]
210
+ tool_response_match = re.search(
211
+ r"</tool_response>\s*assistant\s*\n(.*?)$", predict_no_think, re.DOTALL | re.MULTILINE
212
+ )
213
+ if tool_response_match:
214
+ answer_text = tool_response_match.group(1).strip()
215
+ else:
216
+ # Strategy 3: If no tool responses, look for content after </think>
217
+ if "</think>" in solution_str:
218
+ # Remove any remaining tool-related tags and extract meaningful content
219
+ remaining_content = predict_no_think
220
+ # Remove tool calls and responses
221
+ remaining_content = re.sub(r"<tool_call>.*?</tool_call>", "", remaining_content, flags=re.DOTALL)
222
+ remaining_content = re.sub(
223
+ r"<tool_response>.*?</tool_response>", "", remaining_content, flags=re.DOTALL
224
+ )
225
+ # Remove user/assistant markers
226
+ remaining_content = re.sub(r"\b(user|assistant)\b", "", remaining_content)
227
+ answer_text = remaining_content.strip()
228
+ else:
229
+ # Strategy 4: Use the entire solution_str as fallback
230
+ answer_text = solution_str.strip()
231
+
232
+ # Clean up answer text
233
+ answer_text = answer_text.strip()
234
+
235
+ # If answer is still empty after all strategies, mark as format error
236
+ if not answer_text:
237
+ is_format_error = True
238
+ answer_text = solution_str.strip() # Use full text as last resort
239
+
240
+ # 4. Evaluate correctness using LLM judge
241
+ question_text = extra_info.get("question", "") if extra_info else ""
242
+
243
+ if not client or not model_name:
244
+ logger.warning("Reward function client not initialized or model name not found.")
245
+ return 0.0
246
+
247
+ system_prompt = (
248
+ "You are an expert evaluator. Your task is to determine if a model's answer is semantically equivalent to a "
249
+ "provided standard answer, given a specific question.\n"
250
+ "Your evaluation must be strict. The model's answer is only correct if it fully matches the meaning of the "
251
+ "standard answer.\n"
252
+ 'You must provide your final judgement as a single word: either "CORRECT" or "INCORRECT". Do not provide '
253
+ "any explanation or other text."
254
+ )
255
+
256
+ user_prompt = (
257
+ f"I will provide a question, a standard answer, and a model's answer. You must evaluate if the model's "
258
+ f"answer is correct.\n\n"
259
+ f"---\n"
260
+ f"**Example 1:**\n"
261
+ f"[Question]: Is the countertop tan or blue?\n"
262
+ f"[Standard Answer]: The countertop is tan.\n"
263
+ f"[Model's Answer]: tan\n"
264
+ f"[Your Judgement]: CORRECT\n"
265
+ f"---\n"
266
+ f"**Example 2:**\n"
267
+ f"[Question]: Is the man phone both blue and closed?\n"
268
+ f"[Standard Answer]: Yes, the man phone is both blue and closed.\n"
269
+ f"[Model's Answer]: No.\n"
270
+ f"[Your Judgement]: INCORRECT\n"
271
+ f"---\n"
272
+ f"**Task:**\n"
273
+ f"[Question]: {question_text}\n"
274
+ f"[Standard Answer]: {ground_truth}\n"
275
+ f"[Model's Answer]: {answer_text}\n"
276
+ f"[Your Judgement]:"
277
+ )
278
+
279
+ try:
280
+ chat_response = client.chat.completions.create(
281
+ model=model_name,
282
+ messages=[
283
+ {"role": "system", "content": system_prompt},
284
+ {"role": "user", "content": user_prompt},
285
+ ],
286
+ seed=random.randint(0, 1000000),
287
+ temperature=0.1, # Lower temperature for more deterministic judgement
288
+ extra_body={
289
+ "chat_template_kwargs": {"enable_thinking": False},
290
+ },
291
+ )
292
+ response = chat_response.choices[0].message.content.strip()
293
+ except Exception as e:
294
+ logger.warning(f" [WARNING] Chat completion request failed: {e}")
295
+ return 0.0
296
+
297
+ # Parse LLM judge response
298
+ if re.search(r"\bCORRECT\b", response, re.IGNORECASE):
299
+ acc_reward = 1.0
300
+ elif re.search(r"\bINCORRECT\b", response, re.IGNORECASE):
301
+ acc_reward = 0.0
302
+ else:
303
+ logger.warning(
304
+ f" [WARNING] Judgement format error. Expected 'CORRECT' or 'INCORRECT'.\n"
305
+ f"Response: '{response}'\n"
306
+ f"Model Answer: '{answer_text}'\n"
307
+ f"Ground Truth: '{ground_truth}'"
308
+ )
309
+ acc_reward = 0.0
310
+
311
+ # Penalize excessively long answers (potential judge hacking)
312
+ if len(answer_text) >= 1000:
313
+ acc_reward = 0.0
314
+ is_format_error = True
315
+
316
+ # 5. Check tool usage - look for tool_call/tool_response patterns instead of vision tokens
317
+ has_tool_usage = bool(
318
+ re.search(r"<tool_call>.*?</tool_call>", solution_str, re.DOTALL)
319
+ or re.search(r"<tool_response>.*?</tool_response>", solution_str, re.DOTALL)
320
+ )
321
+
322
+ # Tool reward: only give if tools were used AND answer is correct
323
+ tool_reward = 1.0 if has_tool_usage and acc_reward > 0.5 else 0.0
324
+
325
+ # Format reward: penalty for format errors
326
+ format_reward = -1.0 if is_format_error else 0.0
327
+
328
+ # Log debug information for problematic cases
329
+ if is_format_error or not answer_text:
330
+ logger.debug(
331
+ f"Format issue detected:\n"
332
+ f"Solution: {solution_str[:200]}...\n"
333
+ f"Extracted answer: '{answer_text}'\n"
334
+ f"Format error: {is_format_error}\n"
335
+ f"Tool usage: {has_tool_usage}"
336
+ )
337
+
338
+ # Final weighted score
339
+ final_score = 0.8 * acc_reward + 0.2 * format_reward + 1.2 * tool_reward
340
+
341
+ return final_score
342
+
343
+
344
+ if __name__ == "__main__":
345
+ # Test case 1: Original test case
346
+ predict_str = "The answer is 2 + 2 = 4 </think> <answer> right </answer> <answer> left </answer>"
347
+ ground_truth = "left"
348
+ extra_info = {
349
+ "answer": "The woman is to the left of the man who is holding the camera.",
350
+ "id": 0,
351
+ "image": "/cpfs/user/honglingyi/DATA/LLM/Vstar/gqa/images/713270.jpg",
352
+ "pred_ans": "The woman is to the right of the man who is holding the camera.",
353
+ "question": "Is the woman to the left or to the right of the man who is holding the camera?",
354
+ }
355
+ print("=== Test Case 1: Original test ===")
356
+ import time
357
+
358
+ time_start = time.time()
359
+ score = compute_score("common_reasoning", predict_str, ground_truth, extra_info)
360
+ print(f"Score: {score}")
361
+ time_end = time.time()
362
+ print(f"Time: {time_end - time_start}")
363
+
364
+ # Test case 2: Problematic case mentioned by user
365
+ problematic_solution = """<tool_call>
366
+ {"name": "image_zoom_in_tool", "arguments": {"bbox_2d": [226, 399, 265, 464], "label": "white van"}}
367
+ </tool_call>user
368
+ <tool_response>
369
+ Zoomed in on the image to the region [226, 399, 265, 464] with label white van.
370
+ </tool_response>
371
+ assistant
372
+ The white van is visible in the lower section of the image, near the diagonal road."""
373
+
374
+ problematic_ground_truth = "Yes, the white van is indeed situated in the bottom part of the picture."
375
+ problematic_extra_info = {
376
+ "question": "Is the white van in the bottom part of the picture?",
377
+ }
378
+
379
+ print("\n=== Test Case 2: Problematic case (no answer tags) ===")
380
+ print(f"Solution: {problematic_solution}")
381
+ print(f"Ground truth: {problematic_ground_truth}")
382
+
383
+ time_start = time.time()
384
+ score2 = compute_score("common_reasoning", problematic_solution, problematic_ground_truth, problematic_extra_info)
385
+ print(f"Score: {score2}")
386
+ time_end = time.time()
387
+ print(f"Time: {time_end - time_start}")
388
+
389
+ # Test case 3: Well-formatted case with tools
390
+ well_formatted_solution = """<think>
391
+ I need to use the image zoom tool to get a better look at the specific area.
392
+ </think>
393
+ <tool_call>
394
+ {"name": "image_zoom_in_tool", "arguments": {"bbox_2d": [226, 399, 265, 464], "label": "white van"}}
395
+ </tool_call>
396
+ <tool_response>
397
+ Zoomed in on the image to the region [226, 399, 265, 464] with label white van.
398
+ </tool_response>
399
+ <answer>Yes, the white van is indeed situated in the bottom part of the picture.</answer>"""
400
+
401
+ print("\n=== Test Case 3: Well-formatted case ===")
402
+ time_start = time.time()
403
+ score3 = compute_score(
404
+ "common_reasoning", well_formatted_solution, problematic_ground_truth, problematic_extra_info
405
+ )
406
+ print(f"Score: {score3}")
407
+ time_end = time.time()
408
+ print(f"Time: {time_end - time_start}")
ICL/DAPO/verl-recipe/fault_recover/async_llm.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import asyncio
4
+
5
+ import numpy as np
6
+ from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
7
+ from vllm.utils import cdiv
8
+ from vllm.v1.engine.async_llm import AsyncLLM, logger
9
+ from vllm.v1.metrics.stats import IterationStats
10
+
11
+
12
+ class AsyncFaultRecoverLLM(AsyncLLM):
13
+ def _run_output_handler(self):
14
+ """Background loop: pulls from EngineCore and pushes to AsyncStreams."""
15
+
16
+ if self.output_handler is not None:
17
+ return
18
+
19
+ # Ensure that the task doesn't have a circular ref back to the AsyncLLM
20
+ # object, or else it won't be garbage collected and cleaned up properly.
21
+ engine_core = self.engine_core
22
+ output_processor = self.output_processor
23
+ log_stats = self.log_stats
24
+ logger_manager = self.logger_manager
25
+
26
+ async def output_handler(q):
27
+ try:
28
+ while True:
29
+ # 1) Pull EngineCoreOutputs from the EngineCore.
30
+ outputs = await engine_core.get_output_async()
31
+
32
+ if q is not None:
33
+ req_info = {}
34
+ for output in outputs.outputs:
35
+ req_info[output.request_id] = {}
36
+ req_info[output.request_id]["new_token_ids"] = output.new_token_ids
37
+ req_info[output.request_id]["finished"] = output.finished
38
+ await q.put.remote(req_info)
39
+
40
+ num_outputs = len(outputs.outputs)
41
+
42
+ iteration_stats = IterationStats() if (log_stats and num_outputs) else None
43
+
44
+ # Split outputs into chunks of at most
45
+ # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
46
+ # event loop for too long.
47
+ if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
48
+ slices = (outputs.outputs,)
49
+ else:
50
+ slices = np.array_split(outputs.outputs, cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
51
+
52
+ for i, outputs_slice in enumerate(slices):
53
+ # 2) Process EngineCoreOutputs.
54
+ processed_outputs = output_processor.process_outputs(
55
+ outputs_slice, outputs.timestamp, iteration_stats
56
+ )
57
+ # NOTE: RequestOutputs are pushed to their queues.
58
+ assert not processed_outputs.request_outputs
59
+
60
+ # Allow other asyncio tasks to run between chunks
61
+ if i + 1 < len(slices):
62
+ await asyncio.sleep(0)
63
+
64
+ # 3) Abort any reqs that finished due to stop strings.
65
+ await engine_core.abort_requests_async(processed_outputs.reqs_to_abort)
66
+
67
+ # 4) Logging.
68
+ # TODO(rob): make into a coroutine and launch it in
69
+ # background thread once Prometheus overhead is non-trivial.
70
+ if logger_manager:
71
+ logger_manager.record(
72
+ engine_idx=outputs.engine_index,
73
+ scheduler_stats=outputs.scheduler_stats,
74
+ iteration_stats=iteration_stats,
75
+ )
76
+ except Exception as e:
77
+ logger.exception("AsyncLLM output_handler failed.")
78
+ output_processor.propagate_error(e)
79
+
80
+ from recipe.fault_recover.fault_manager import get_tokens_queue
81
+
82
+ tokens_queue = get_tokens_queue()
83
+
84
+ self.output_handler = asyncio.create_task(output_handler(tokens_queue))
ICL/DAPO/verl-recipe/flash_rl_ascend/README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 在线量化权重:
2
+
3
+ 介绍在昇腾设备上,使用 [Flash-RL](https://github.com/yaof20/Flash-RL) 工具,修改推理后端,通过比较 INT8 模型和 BF16 模型,对权重和激活值执行在线量化。下文以 Qwen3-30B int8 为例,在 NPU 上跑通端到端功能。
4
+
5
+ ### 环境依赖
6
+
7
+ ##
8
+ | PyTorch版本 | torch_npu版本 | CANN版本 | Python版本 |
9
+ | ------------ |-----------| ---------- | ---------- |
10
+ | 2.7.1 | 2.7.1 | 8.5.0 | Python3.10 |
11
+
12
+
13
+ #### 1、安装 vllm 和 vllm-ascend
14
+ ```bash
15
+ # vllm==0.10.1
16
+ git clone https://github.com/vllm-project/vllm.git
17
+ cd vllm
18
+ git checkout e03940762b43812fccd3c214bda60201cff9d16a
19
+ pip install -r requirements/build.txt
20
+ VLLM_TARGET_DEVICE=empty pip install -v .
21
+ cd ..
22
+
23
+ # vllm-ascend==0.10.1
24
+ git clone https://github.com/vllm-project/vllm-ascend.git
25
+ cd vllm-ascend
26
+ git checkout 7e16b4a7cdb15723c63c1c0efe58672a056eace8
27
+ pip install -r requirements.txt
28
+ export COMPILE_CUSTOM_KERNELS=1
29
+ python setup.py install
30
+ cd ..
31
+
32
+ # 源码安装transformers
33
+ git clone -b v4.57.6 https://github.com/huggingface/transformers.git
34
+ cd transformers
35
+ pip install -e .
36
+ ```
37
+
38
+ #### 2、安装 MindSpeed 与 Megatron
39
+ ```bash
40
+ # MindSpeed
41
+ git clone https://gitcode.com/Ascend/MindSpeed.git
42
+ cd MindSpeed
43
+ git checkout 1cdd0abd75e40936ad31721c092f57c695dd72c4
44
+ pip install -e .
45
+ cd ..
46
+
47
+ # Megatron
48
+ pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1
49
+ ```
50
+
51
+ #### 3、安装 verl
52
+ ```bash
53
+ # verl
54
+ git clone https://github.com/volcengine/verl.git
55
+ cd verl
56
+ pip install -e .
57
+ cd ..
58
+ ```
59
+
60
+ ### 使用步骤:
61
+
62
+ #### 1、安装包:
63
+
64
+ ```
65
+ pip install flash-llm-rl # need to be installed in all nodes in multi-node training
66
+ ```
67
+
68
+ #### 2、打patch
69
+
70
+ 安装 FlashRL 后,默认采用自动 patch,推荐改用手动方式,减少过程中的错误:
71
+
72
+ 1. 在 `verl/verl/__init__.py` 文件中添加 `import flash_rl`;
73
+ 2. 在 shell 脚本中添加 `flashrl cleanup`,这将禁用自动 patch;
74
+
75
+ #### 3、生成性能分析文件
76
+
77
+ 具体来说,profile 文件会比较 bf16 模型和 int8 模型,以确定如何对更新后的模型执行在线量化:
78
+
79
+ ```
80
+ flashrl profile -m Qwen3-30B-A3B -q Qwen3-30B-A3B-w8a8 -o ${PROFILE_PATH:-"$HOME/profile.30b.pt"} --fn int8
81
+ ```
82
+
83
+ `-m` 参数后是 bf16 模型路径,`-q` 参数后是 int8 模型路径,`-o` 参数后是生成文件路径;
84
+ [RedHatAI](https://huggingface.co/RedHatAI/collections) 提供了各种量化模型;
85
+
86
+ #### 4、生成配置文件
87
+
88
+ 通过以下命令生成 yaml 配置文件,供 patch 程序使用:
89
+
90
+ ```
91
+ flashrl setup -m Qwen3-30B-A3B-w8a8 -p $HOME/profile.30b.pt --fn int8 -o ${CONFIG_PATH:-"$HOME/.flashrl_config.30b.yaml"}
92
+ ```
93
+
94
+ `-m` 参数后是 int8 模型路径,`-p` 参数后是 profile 文件路径,`-o` 参数后是生成文件路径;
95
+
96
+ (可选)为了缩小 rollout 生成和梯度计算之间的差距,FlashRL 提供了在 DP 工作线程间以混合方式进行 16 位和 8 位 rollout 生成的功能。具体来说,运行以下命令会将第二个配置附加到现有的 yaml 配置文件中。
97
+
98
+ ```
99
+ flashrl setup -a --fn bf16 -o ${CONFIG_PATH:-"$HOME/.flashrl_config.30b.yaml"}
100
+ ```
101
+
102
+ #### 5、开始训练
103
+
104
+ 脚本中添加以下环境变量:
105
+
106
+ ```
107
+ # 打印详细日志,查看是否 patch 成功:
108
+ export FLASHRL_LOGGING_LEVEL=DEBUG
109
+ # 指定配置文件:
110
+ export FLASHRL_CONFIG=$HOME/.flashrl_config.30b.yaml
111
+ # 强制 lm-head 使用 bf16,减小精度损失:
112
+ export FLASHRL_LMHEAD_FP32=1
113
+ ```
114
+
115
+ 以上步骤已在 `test_qwen3-30b_int8_npu.sh` 提供实例,修改脚本中的模型路径即可自动执行,有具体问题可根据上述步骤排查;
116
+
117
+ 在 `run.sh` 文件中补充机器 IP、网络接口,运行以下命令启动训练:
118
+
119
+ ```
120
+ bash ./run.sh
121
+ ```
ICL/DAPO/verl-recipe/flowrl/README.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center" style="color:#1976D2; font-size:42px; font-weight:bold; margin-bottom:0;">
2
+ FlowRL
3
+ </h1>
4
+
5
+ <p align="center" style="color:#42A5F5; font-size:16px; margin-top:0;">
6
+ Matching Reward Distributions via Flow Balance
7
+ </p>
8
+ <p align="center" style="color:#42A5F5; font-size:15px; margin-top:4px;">
9
+ <a href="https://arxiv.org/abs/2509.15207" target="_blank">📄 arXiv Paper</a> |
10
+ <a href="https://huggingface.co/papers/2509.15207" target="_blank">🤗 #1 Paper of the Day</a>
11
+ </p>
12
+ <p align="center" style="color:#42A5F5; font-size:14px; margin-top:4px;">
13
+ <a href="https://x.com/RoverHM/status/1969113890878259518" target="_blank">𝕏 Post 1</a> |
14
+ <a href="https://x.com/zdhnarsil/status/1969049940774023428" target="_blank">𝕏 Post 2</a> |
15
+ <a href="https://x.com/_akhaliq/status/1968901977376505929" target="_blank">𝕏 Post 3</a> |
16
+ <a href="https://x.com/zhu_xuekai/status/1968942580197941563" target="_blank">𝕏 Post 4</a>
17
+ </p>
18
+
19
+ <p align="center">
20
+ <img src="figures/flowrl.png" alt="FlowRL Overview" width="95%"/>
21
+ </p>
22
+
23
+ ## Table of Contents
24
+
25
+ - [FlowRL Objective](#flowrl-objective)
26
+ - [Trained Models & Experiment Logs](#trained-models--experiment-logs)
27
+ - [Quick Start](#quick-start)
28
+ - [Option 1: Original Paper Reproduction (verl 0.4.0)](#option-1-original-paper-reproduction-verl-040--recommended)
29
+ - [Step 1: Installation](#step-1-installation)
30
+ - [Step 2: Data Preparation](#step-2-data-preparation)
31
+ - [Step 3: Model Preparation](#step-3-model-preparation)
32
+ - [Step 4: Training Scripts](#step-4-training-scripts)
33
+ - [Option 2: Latest verl Recipe FlowRL](#option-3-latest-verl-recipe-flowrl)
34
+ - [Step 1: Prepare Data and Model](#step-1-prepare-data-and-model)
35
+ - [Step 2: Run Training](#step-2-run-training)
36
+ - [Option 3: Implement FlowRL Yourself](#option-4-implement-flowrl-yourself)
37
+ - [Testing](#testing)
38
+ - [Citation](#citation)
39
+
40
+ ## FlowRL Objective
41
+
42
+ $$
43
+ \mathcal{L}_{\text{FlowRL}} = w \cdot \left( \log Z_{\phi}(x) + \frac{1}{|y|} \log \pi_{\theta}(y \mid x) - \beta \hat{r}(x, y) - \frac{1}{|y|} \log \pi_{\text{ref}}(y \mid x) \right)^2
44
+ $$
45
+
46
+ FlowRL is a flow-balanced reinforcement learning method that matches full reward distributions instead of maximizing rewards, promoting diverse exploration and generalizable reasoning trajectories in LLMs.
47
+
48
+ ## Trained Models & Experiment Logs
49
+
50
+ | Base Model | Domain | WandB Logs | Hugging Face Model |
51
+ |-------|--------|------------|-------------------|
52
+ | Qwen2.5-7B | Math | [🔗 View Run](https://wandb.ai/xuekaizhu0/FlowRL/runs/pa62rs4x?nw=nwuserxuekaizhu0) | [🤗 Model](https://huggingface.co/xuekai/FlowRL-Qwen2.5-7B-math) |
53
+ | DeepSeek-7B | Code | [🔗 View Run](https://wandb.ai/xuekaizhu0/FlowRL/runs/wbw72gdv?nw=nwuserxuekaizhu0) | [🤗 Model](https://huggingface.co/xuekai/FlowRL-DeepSeek-7B-code) |
54
+ | Qwen2.5-32B | Math | - | [🤗 Model](https://huggingface.co/xuekai/FlowRL-Qwen2.5-32B-math) |
55
+
56
+ ## Quick Start
57
+
58
+ There are three ways to use FlowRL:
59
+
60
+ ---
61
+
62
+ **⭐ We recommend using Option 1 as the default choice.** Since verl updates frequently, the newest versions may have unstable factors such as training and inference mismatches. Option 1 uses verl 0.4.0, which is stable and has been thoroughly tested with our paper results.
63
+
64
+ ---
65
+
66
+ ### Option 1: Original Paper Reproduction (verl 0.4.0) ⭐ Recommended
67
+
68
+ For exact reproduction of results from the paper, use the original repository with verl 0.4.0:
69
+
70
+ 👉 **Original Code:** [https://github.com/Xuekai-Zhu/FlowRL](https://github.com/Xuekai-Zhu/FlowRL)
71
+
72
+ #### Step 1: Installation
73
+
74
+ Install [verl](https://github.com/volcengine/verl) first before using FlowRL.
75
+
76
+ #### Step 2: Data Preparation
77
+
78
+ ```bash
79
+ # Option A: Download our pre-processed datasets directly
80
+ bash preprocess/down_load_dataset.sh
81
+ # Move data to default directory
82
+ mv data/xuekai/flowrl-data-collection/math_data data/math_data
83
+ mv data/xuekai/flowrl-data-collection/code_data data/code_data
84
+ ```
85
+
86
+ ```bash
87
+ # Option B: Process data from original sources
88
+ # For detailed processing instructions, see data/README.md
89
+ ```
90
+
91
+ #### Step 3: Model Preparation
92
+
93
+ For Math Tasks: `Qwen/Qwen2.5-7B` (default in script) ; `Qwen/Qwen2.5-32B`
94
+
95
+ For Code Tasks: `deepseek-ai/DeepSeek-R1-Distill-Qwen-7B`
96
+
97
+ ```bash
98
+ # Download default model (Qwen2.5-7B for math)
99
+ bash preprocess/down_load_model.sh
100
+
101
+ # For other models, modify MODEL_NAME in the script before running
102
+ ```
103
+
104
+ #### Step 4: Training Scripts
105
+
106
+ ```bash
107
+ cd verl_FlowRL
108
+
109
+ # For 7B math training
110
+ bash command/training/math/flowrl_7B_math.sh
111
+
112
+ # For 32B math training
113
+ bash command/training/math/flowrl_32B_math.sh
114
+
115
+ # For 7B code training
116
+ bash command/training/code/flowrl_7B_code.sh
117
+ ```
118
+ ----
119
+ ### Option 2: Latest verl Recipe FlowRL
120
+
121
+ For running FlowRL using the latest verl framework:
122
+
123
+ **Latest verl:**
124
+
125
+ - verl recipe: [https://github.com/volcengine/verl/tree/main/recipe/flowrl](https://github.com/volcengine/verl/tree/main/recipe/flowrl)
126
+
127
+ #### Step 1: Prepare Data and Model
128
+
129
+ ```bash
130
+ # Prepare dataset
131
+ bash recipe/flowrl/prepare/prepare_data.sh
132
+
133
+ # Prepare model
134
+ bash recipe/flowrl/prepare/prepare_model.sh
135
+ ```
136
+
137
+ #### Step 2: Run Training
138
+
139
+ ```bash
140
+ # Train FlowRL with Qwen2.5-7B
141
+ bash recipe/flowrl/run_flowrl_qwen2.5_7b.sh
142
+ ```
143
+ ----
144
+ ### Option 3: Implement FlowRL Yourself
145
+
146
+ If you want to implement FlowRL in your own codebase, we provide a detailed implementation guide:
147
+
148
+ 📖 **[FlowRL Implementation Guide](FLOWRL_SIMPLE_GUIDE.md)**
149
+
150
+ This guide walks you through the key components and steps needed to integrate FlowRL into your existing training pipeline.
151
+
152
+ ## Testing
153
+
154
+ After training your FlowRL models, you can evaluate them using the following commands:
155
+
156
+ ```bash
157
+ cd verl_Test
158
+
159
+ # First merge the model
160
+ bash command/eval/merge_model.sh
161
+
162
+ # For math testing
163
+ bash command/eval/math/flowrl_math_test.sh
164
+
165
+ # For code testing
166
+ bash command/eval/code/flowrl_code_test.sh
167
+ ```
168
+
169
+ **Reference:** For verl v0.5.0.dev merge model script, see [merge_model.sh](https://github.com/Xuekai-Zhu/verl_FlowRL/blob/flowrl-v0.5.0.dev/recipe/flowrl/eval/merge_model.sh)
170
+
171
+ ## Citation
172
+
173
+ If you think this repo helps you, please kindly consider citing our paper:
174
+
175
+ ```bibtex
176
+ @article{zhu2025flowrl,
177
+ title={FlowRL: Matching Reward Distributions for LLM Reasoning},
178
+ author={Zhu, Xuekai and Cheng, Daixuan and Zhang, Dinghuai and Li, Hengli and Zhang, Kaiyan and Jiang, Che and Sun, Youbang and Hua, Ermo and Zuo, Yuxin and Lv, Xingtai and others},
179
+ journal={arXiv preprint arXiv:2509.15207},
180
+ year={2025}
181
+ }
182
+ ```
ICL/DAPO/verl-recipe/flowrl/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """FlowRL recipe package."""
16
+
17
+ __all__ = []
ICL/DAPO/verl-recipe/flowrl/flowrl_fsdp_worker.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """FlowRL FSDP Worker that uses FlowRLActor instead of standard DPActor."""
16
+
17
+ import logging
18
+ import os
19
+ import warnings
20
+
21
+ import torch
22
+ import torch.distributed
23
+ from peft import LoraConfig, TaskType, get_peft_model
24
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
25
+
26
+ try:
27
+ # for torch 2.5+
28
+ from torch.distributed.tensor import DTensor
29
+ except ImportError:
30
+ from torch.distributed._tensor import DTensor
31
+
32
+ from recipe.flowrl.flowrl_actor import FlowRLActor, ProjZModule
33
+
34
+ from verl.models.transformers.monkey_patch import apply_monkey_patch
35
+ from verl.single_controller.base.decorator import Dispatch, register
36
+ from verl.utils import hf_processor, hf_tokenizer
37
+ from verl.utils.activation_offload import enable_activation_offloading
38
+ from verl.utils.config import omega_conf_to_dataclass
39
+ from verl.utils.device import (
40
+ get_device_id,
41
+ get_torch_device,
42
+ set_expandable_segments,
43
+ )
44
+ from verl.utils.fsdp_utils import (
45
+ CPUOffloadPolicy,
46
+ MixedPrecisionPolicy,
47
+ apply_fsdp2,
48
+ collect_lora_params,
49
+ fsdp2_load_full_state_dict,
50
+ get_fsdp_wrap_policy,
51
+ get_init_weight_context_manager,
52
+ get_shard_placement_fn,
53
+ init_fn,
54
+ load_fsdp_model_to_gpu,
55
+ offload_fsdp_model_to_cpu,
56
+ replace_lora_wrapper,
57
+ )
58
+ from verl.utils.memory_utils import aggressive_empty_cache
59
+ from verl.utils.model import convert_weight_keys
60
+ from verl.utils.profiler import log_gpu_memory_usage
61
+ from verl.utils.py_functional import convert_to_regular_types
62
+ from verl.workers.config import FSDPEngineConfig
63
+ from verl.workers.fsdp_workers import ActorRolloutRefWorker, get_sharding_strategy, get_vl_model_vision_tower
64
+
65
+ logger = logging.getLogger(__file__)
66
+ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
67
+
68
+
69
+ class FlowRLActorRolloutRefWorker(ActorRolloutRefWorker):
70
+ """
71
+ FlowRL version of ActorRolloutRefWorker that uses FlowRLActor.
72
+
73
+ This worker adds FlowRL-specific modifications:
74
+ - ProjZModule for log Z estimation (added in _build_model_optimizer)
75
+ - FlowRLActor with trajectory balance loss (replaces standard DPActor)
76
+ """
77
+
78
+ def _build_model_optimizer(
79
+ self,
80
+ model_path,
81
+ fsdp_config: FSDPEngineConfig,
82
+ optim_config,
83
+ override_model_config,
84
+ use_remove_padding=False,
85
+ use_fused_kernels=False,
86
+ enable_gradient_checkpointing=False,
87
+ trust_remote_code=False,
88
+ use_liger=False,
89
+ role="actor",
90
+ enable_activation_offload=False,
91
+ ):
92
+ from torch import optim
93
+ from torch.distributed.fsdp import CPUOffload, MixedPrecision
94
+ from transformers import (
95
+ AutoConfig,
96
+ AutoModel,
97
+ AutoModelForCausalLM,
98
+ AutoModelForImageTextToText,
99
+ AutoModelForVision2Seq,
100
+ )
101
+
102
+ from verl.utils.model import get_generation_config, print_model_size, update_model_config
103
+ from verl.utils.torch_dtypes import PrecisionType
104
+
105
+ assert role in ["actor", "ref"]
106
+
107
+ log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger)
108
+ local_path = model_path
109
+
110
+ # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
111
+ # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
112
+ self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
113
+ self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code)
114
+
115
+ if self.config.model.get("custom_chat_template", None) is not None:
116
+ if self.processor is not None:
117
+ self.processor.chat_template = self.config.model.custom_chat_template
118
+ else:
119
+ self.tokenizer.chat_template = self.config.model.custom_chat_template
120
+
121
+ vllm_dtype = PrecisionType.to_dtype(self.config.rollout.dtype)
122
+ torch_dtype = fsdp_config.get("model_dtype", None)
123
+ if torch_dtype is None:
124
+ torch_dtype = torch.float32 if self._is_actor else vllm_dtype
125
+ else:
126
+ torch_dtype = PrecisionType.to_dtype(torch_dtype)
127
+
128
+ # override model kwargs
129
+ actor_model_config = AutoConfig.from_pretrained(
130
+ local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2"
131
+ )
132
+ # TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53
133
+ # which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids
134
+ # Maybe support Ulysses in VisionAttention in the future and remove this patch
135
+ if self.ulysses_sequence_parallel_size > 1 and hasattr(actor_model_config, "vision_config"):
136
+ actor_model_config.vision_config._attn_implementation = "eager"
137
+
138
+ # patch for kimi-vl
139
+ if getattr(actor_model_config, "model_type", None) == "kimi_vl":
140
+ actor_model_config.text_config.topk_method = "greedy"
141
+
142
+ self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)
143
+
144
+ override_config_kwargs = {
145
+ "bos_token_id": self.tokenizer.bos_token_id,
146
+ "eos_token_id": self.tokenizer.eos_token_id,
147
+ "pad_token_id": self.tokenizer.pad_token_id,
148
+ }
149
+ override_config_kwargs.update(override_model_config)
150
+ update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
151
+ if self.rank == 0:
152
+ print(f"Model config after override: {actor_model_config}")
153
+
154
+ # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang
155
+ init_context = get_init_weight_context_manager(
156
+ use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh
157
+ )
158
+
159
+ with init_context(), warnings.catch_warnings():
160
+ warnings.simplefilter("ignore")
161
+ has_remote_code = hasattr(actor_model_config, "auto_map") and any(
162
+ actor_model_config.architectures[0] in val for val in actor_model_config.auto_map.values()
163
+ )
164
+ if has_remote_code:
165
+ auto_class = next(
166
+ k for k, v in actor_model_config.auto_map.items() if actor_model_config.architectures[0] in v
167
+ )
168
+ match auto_class:
169
+ case "AutoModelForVision2Seq":
170
+ actor_module_class = AutoModelForVision2Seq
171
+ case "AutoModelForCausalLM":
172
+ actor_module_class = AutoModelForCausalLM
173
+ case "AutoModelForImageTextToText":
174
+ actor_module_class = AutoModelForImageTextToText
175
+ case _:
176
+ actor_module_class = AutoModel
177
+ else:
178
+ if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys():
179
+ actor_module_class = AutoModelForVision2Seq
180
+ elif type(actor_model_config) in AutoModelForCausalLM._model_mapping.keys():
181
+ actor_module_class = AutoModelForCausalLM
182
+ elif type(actor_model_config) in AutoModelForImageTextToText._model_mapping.keys():
183
+ actor_module_class = AutoModelForImageTextToText
184
+ else:
185
+ actor_module_class = AutoModel
186
+
187
+ actor_module = actor_module_class.from_pretrained(
188
+ pretrained_model_name_or_path=local_path,
189
+ torch_dtype=torch_dtype,
190
+ config=actor_model_config,
191
+ trust_remote_code=trust_remote_code,
192
+ )
193
+
194
+ # ==== FlowRL: inject ProjZ BEFORE FSDP wrap ====
195
+ if role == "actor" and self._is_actor:
196
+ n_dim = actor_module.config.hidden_size
197
+ proj_layers = getattr(self.config.actor, "proj_layer", 3)
198
+ actor_module.add_module("proj_z", ProjZModule(n_dim, num_layers=proj_layers))
199
+
200
+ if self.rank == 0:
201
+ print(f"[FlowRL] Added proj_z (layers={proj_layers}, hidden={n_dim}) BEFORE FSDP wrap")
202
+ # ===============================================
203
+
204
+ # Apply Liger kernel to the model if use_liger is set to True
205
+ if use_liger:
206
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance
207
+
208
+ _apply_liger_kernel_to_instance(model=actor_module)
209
+
210
+ fused_kernel_options = self.config.model.get("fused_kernel_options", None)
211
+ fused_kernels_backend = (
212
+ fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None
213
+ )
214
+
215
+ apply_monkey_patch(
216
+ model=actor_module,
217
+ use_remove_padding=use_remove_padding,
218
+ ulysses_sp_size=self.ulysses_sequence_parallel_size,
219
+ use_fused_kernels=use_fused_kernels,
220
+ fused_kernels_backend=fused_kernels_backend,
221
+ )
222
+
223
+ # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
224
+ actor_module.to(torch_dtype)
225
+
226
+ if enable_gradient_checkpointing:
227
+ actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
228
+ if self._is_lora:
229
+ print("Applying LoRA to actor module")
230
+ actor_module.enable_input_require_grads()
231
+ # Convert config to regular Python types before creating PEFT model
232
+ lora_config = {
233
+ "task_type": TaskType.CAUSAL_LM,
234
+ "r": self.config.model.lora_rank,
235
+ "lora_alpha": self.config.model.lora_alpha,
236
+ "target_modules": convert_to_regular_types(self.config.model.target_modules),
237
+ "exclude_modules": convert_to_regular_types(self.config.model.exclude_modules),
238
+ "bias": "none",
239
+ }
240
+ actor_module = get_peft_model(actor_module, LoraConfig(**lora_config))
241
+
242
+ self.use_orig_params = fsdp_config.get("use_orig_params", False)
243
+ if self.config.actor.get("freeze_vision_tower", False):
244
+ vision_tower = get_vl_model_vision_tower(actor_module)
245
+ if vision_tower is not None:
246
+ vision_tower.requires_grad_(False)
247
+ self.use_orig_params = True
248
+ if self.rank == 0:
249
+ print("[actor model] Vision tower is set to not trainable.")
250
+ else:
251
+ if self.rank == 0:
252
+ print("[actor model] No vision tower found.")
253
+
254
+ torch.distributed.barrier()
255
+
256
+ if self.rank == 0:
257
+ print_model_size(actor_module)
258
+
259
+ log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger)
260
+
261
+ # We wrap FSDP for rollout as well
262
+ mixed_precision_config = fsdp_config.get("mixed_precision", None)
263
+ if mixed_precision_config is not None:
264
+ param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
265
+ reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
266
+ buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
267
+ else:
268
+ param_dtype = PrecisionType.to_dtype(self.config.actor.get("dtype", "bfloat16"))
269
+ reduce_dtype = torch.float32
270
+ buffer_dtype = torch.float32
271
+
272
+ mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
273
+
274
+ auto_wrap_policy = get_fsdp_wrap_policy(
275
+ module=actor_module,
276
+ config=fsdp_config.get("wrap_policy", None),
277
+ is_lora=self.config.model.get("lora_rank", 0) > 0,
278
+ )
279
+
280
+ if self._is_rollout and self.config.rollout.name == "hf":
281
+ # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma
282
+ auto_wrap_policy = None
283
+
284
+ if self.rank == 0:
285
+ print(f"wrap_policy: {auto_wrap_policy}")
286
+
287
+ fsdp_mesh = self.device_mesh
288
+ sharding_strategy = get_sharding_strategy(fsdp_mesh)
289
+
290
+ # TODO: add transformer policy
291
+ # We force reference policy to use CPUOffload to save memory.
292
+ # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation
293
+ cpu_offload = None if role == "actor" else CPUOffload(offload_params=True)
294
+ fsdp_strategy = self.config.actor.strategy
295
+ if fsdp_strategy == "fsdp":
296
+ actor_module_fsdp = FSDP(
297
+ actor_module,
298
+ cpu_offload=cpu_offload,
299
+ param_init_fn=init_fn,
300
+ auto_wrap_policy=auto_wrap_policy,
301
+ device_id=get_device_id(),
302
+ sharding_strategy=sharding_strategy, # zero3
303
+ mixed_precision=mixed_precision,
304
+ sync_module_states=True,
305
+ device_mesh=self.device_mesh,
306
+ use_orig_params=self.use_orig_params,
307
+ forward_prefetch=fsdp_config.get("forward_prefetch", False),
308
+ )
309
+ elif fsdp_strategy == "fsdp2":
310
+ assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
311
+ mp_policy = MixedPrecisionPolicy(
312
+ param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
313
+ )
314
+ if role == "actor" and fsdp_config.offload_policy:
315
+ cpu_offload = CPUOffloadPolicy(pin_memory=True)
316
+ self._is_offload_param = False
317
+ self._is_offload_optimizer = False
318
+ else:
319
+ cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True)
320
+
321
+ fsdp_kwargs = {
322
+ "mesh": fsdp_mesh,
323
+ "mp_policy": mp_policy,
324
+ "offload_policy": cpu_offload,
325
+ "reshard_after_forward": fsdp_config.reshard_after_forward,
326
+ "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]),
327
+ }
328
+ full_state = actor_module.state_dict()
329
+ apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config)
330
+ fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload)
331
+ actor_module_fsdp = actor_module
332
+ else:
333
+ raise NotImplementedError(f"not implement {fsdp_strategy}")
334
+
335
+ if enable_activation_offload:
336
+ enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing)
337
+
338
+ log_gpu_memory_usage(f"After {role} FSDP init", logger=logger)
339
+
340
+ # TODO: add more optimizer args into config
341
+ if role == "actor" and optim_config is not None:
342
+ from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
343
+
344
+ actor_optimizer = optim.AdamW(
345
+ actor_module_fsdp.parameters(),
346
+ lr=optim_config.lr,
347
+ betas=optim_config.get("betas", (0.9, 0.999)),
348
+ weight_decay=optim_config.get("weight_decay", 1e-2),
349
+ )
350
+
351
+ total_steps = optim_config.get("total_training_steps", 0)
352
+ num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
353
+ warmup_style = optim_config.get("warmup_style", "constant")
354
+ min_lr_ratio = optim_config.get("min_lr_ratio", 0.0)
355
+ num_cycles = optim_config.get("num_cycles", 0.5)
356
+ if num_warmup_steps < 0:
357
+ num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0)
358
+ num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
359
+
360
+ if self.rank == 0:
361
+ print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
362
+
363
+ if warmup_style == "constant":
364
+ actor_lr_scheduler = get_constant_schedule_with_warmup(
365
+ optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps
366
+ )
367
+ elif warmup_style == "cosine":
368
+ actor_lr_scheduler = get_cosine_schedule_with_warmup(
369
+ optimizer=actor_optimizer,
370
+ num_warmup_steps=num_warmup_steps,
371
+ num_training_steps=total_steps,
372
+ min_lr_ratio=min_lr_ratio,
373
+ num_cycles=num_cycles,
374
+ )
375
+ else:
376
+ raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
377
+
378
+ log_gpu_memory_usage(f"After {role} optimizer init", logger=logger)
379
+ else:
380
+ actor_optimizer = None
381
+ actor_lr_scheduler = None
382
+
383
+ return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
384
+
385
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
386
+ def init_model(self):
387
+ """Override init_model to use FlowRLActor instead of DataParallelPPOActor."""
388
+ # Call parent's init_model to set up the FSDP model (with proj_z already added)
389
+ super().init_model()
390
+
391
+ # Replace the actor with FlowRLActor if this worker is an actor
392
+ if self._is_actor:
393
+ if self.rank == 0:
394
+ print("[FlowRL] Replacing DataParallelPPOActor with FlowRLActor")
395
+
396
+ # Convert actor config to dataclass
397
+ actor_cfg = omega_conf_to_dataclass(self.config.actor)
398
+
399
+ # Create FlowRLActor with trajectory balance loss
400
+ self.actor = FlowRLActor(
401
+ config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer
402
+ )
403
+
404
+ async def rollout_mode(self):
405
+ """
406
+ Override rollout_mode to filter out proj_z parameters before syncing to vLLM.
407
+
408
+ FlowRL's proj_z module is only needed during training for estimating log Z.
409
+ It should not be loaded into vLLM since vLLM is only used for rollout generation.
410
+ """
411
+ aggressive_empty_cache(force_sync=True)
412
+
413
+ log_gpu_memory_usage("Before load_fsdp_model_to_gpu", logger=logger)
414
+ if self._is_offload_param:
415
+ load_fsdp_model_to_gpu(self.actor_module_fsdp)
416
+ log_gpu_memory_usage("After load_fsdp_model_to_gpu", logger=logger)
417
+
418
+ peft_config = None
419
+ peft_model = getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
420
+ if hasattr(peft_model, "peft_config"): # LoRA
421
+ peft_config = peft_model.peft_config.get("default", None)
422
+ params = collect_lora_params(
423
+ module=self.actor_module_fsdp,
424
+ layered_summon=self.config.rollout.get("layered_summon", False),
425
+ base_sync_done=self.base_sync_done,
426
+ )
427
+ if not self.base_sync_done:
428
+ params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()}
429
+ else:
430
+ params = self.actor_module_fsdp.state_dict()
431
+
432
+ # ==== FlowRL: Filter out proj_z parameters ====
433
+ params = {k: v for k, v in params.items() if not k.startswith("proj_z")}
434
+ num_proj_z_filtered = len([k for k in self.actor_module_fsdp.state_dict().keys() if k.startswith("proj_z")])
435
+ if num_proj_z_filtered > 0 and self.rank == 0:
436
+ print(f"[FlowRL] Filtered {num_proj_z_filtered} proj_z parameters before syncing to vLLM")
437
+ # ===============================================
438
+
439
+ params = convert_weight_keys(
440
+ params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
441
+ )
442
+
443
+ # Special handling for LoRA with sleep_level=2:
444
+ if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2:
445
+ base_model_params = collect_lora_params(
446
+ module=self.actor_module_fsdp,
447
+ layered_summon=self.layered_summon,
448
+ base_sync_done=False,
449
+ )
450
+ base_model_params = {replace_lora_wrapper(k, peft_config): v for k, v in base_model_params.items()}
451
+ # Filter proj_z from base model params as well
452
+ base_model_params = {k: v for k, v in base_model_params.items() if not k.startswith("proj_z")}
453
+ base_model_params = convert_weight_keys(
454
+ base_model_params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
455
+ )
456
+
457
+ log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger)
458
+ if self._is_offload_param:
459
+ offload_fsdp_model_to_cpu(self.actor_module_fsdp)
460
+ log_gpu_memory_usage("After offload_fsdp_model_to_cpu", logger=logger)
461
+
462
+ set_expandable_segments(False)
463
+
464
+ if peft_config is not None and self.base_sync_done:
465
+ per_tensor_param = params
466
+ else:
467
+ device = get_device_id()
468
+ per_tensor_param = (
469
+ (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
470
+ for name, param in params.items()
471
+ )
472
+
473
+ if self.config.rollout.free_cache_engine:
474
+ await self.rollout.resume(tags=["weights"])
475
+ log_gpu_memory_usage("After resume weights", logger=logger)
476
+
477
+ if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2:
478
+ per_tensor_base_params = (
479
+ (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
480
+ for name, param in base_model_params.items()
481
+ )
482
+ await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False)
483
+ del base_model_params, per_tensor_base_params
484
+
485
+ await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
486
+ log_gpu_memory_usage("After update_weights", logger=logger)
487
+ del params, per_tensor_param
488
+ aggressive_empty_cache(force_sync=True)
489
+ if self.config.rollout.free_cache_engine:
490
+ await self.rollout.resume(tags=["kv_cache"])
491
+ log_gpu_memory_usage("After resume kv_cache", logger=logger)
492
+
493
+ self.base_sync_done = True
494
+ self.torch_random_states = get_torch_device().get_rng_state()
495
+ get_torch_device().set_rng_state(self.gen_random_states)
ICL/DAPO/verl-recipe/flowrl/main_flowrl.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Main training script for FlowRL algorithm."""
17
+
18
+ import os
19
+ import socket
20
+
21
+ import hydra
22
+ import ray
23
+ from omegaconf import OmegaConf
24
+
25
+ from verl.trainer.ppo.reward import load_reward_manager
26
+ from verl.utils.device import is_cuda_available
27
+
28
+
29
+ @hydra.main(config_path="config", config_name="flowrl_trainer", version_base=None)
30
+ def main(config):
31
+ run_flowrl(config)
32
+
33
+
34
+ def run_flowrl(config) -> None:
35
+ if not ray.is_initialized():
36
+ # this is for local ray cluster
37
+ default_runtime_env = {
38
+ "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
39
+ }
40
+ ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
41
+ runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
42
+ runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
43
+ ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
44
+ print(f"ray init kwargs: {ray_init_kwargs}")
45
+ ray.init(**OmegaConf.to_container(ray_init_kwargs))
46
+
47
+ try:
48
+ if (
49
+ is_cuda_available
50
+ and config.global_profiler.tool == "nsys"
51
+ and OmegaConf.select(config.global_profiler, "steps") is not None
52
+ and len(OmegaConf.select(config.global_profiler, "steps")) > 0
53
+ ):
54
+ nsight_options = OmegaConf.to_container(
55
+ config.global_profiler.global_tool_config.nsys.controller_nsight_options
56
+ )
57
+ runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
58
+ else:
59
+ runner = TaskRunner.remote()
60
+ ray.get(runner.run.remote(config))
61
+ finally:
62
+ if ray.is_initialized():
63
+ ray.shutdown()
64
+
65
+
66
+ @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
67
+ class TaskRunner:
68
+ def run(self, config):
69
+ # print initial config
70
+ from pprint import pprint
71
+
72
+ from omegaconf import OmegaConf
73
+
74
+ from verl.utils.fs import copy_to_local
75
+
76
+ print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
77
+
78
+ pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
79
+ OmegaConf.resolve(config)
80
+
81
+ # download the checkpoint from hdfs
82
+ local_path = copy_to_local(config.actor_rollout_ref.model.path)
83
+
84
+ # instantiate tokenizer
85
+ from verl.utils import hf_processor, hf_tokenizer
86
+
87
+ tokenizer = hf_tokenizer(local_path)
88
+ processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
89
+
90
+ from verl.single_controller.ray import RayWorkerGroup
91
+
92
+ # define worker classes
93
+ if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
94
+ assert config.critic.strategy in {"fsdp", "fsdp2"}
95
+
96
+ # Use FlowRL custom worker instead of standard worker
97
+ from recipe.flowrl.flowrl_fsdp_worker import FlowRLActorRolloutRefWorker
98
+
99
+ from verl.workers.fsdp_workers import CriticWorker # , ActorRolloutRefWorker
100
+
101
+ ActorRolloutRefWorker = FlowRLActorRolloutRefWorker
102
+ ray_worker_group_cls = RayWorkerGroup
103
+
104
+ elif config.actor_rollout_ref.actor.strategy == "megatron":
105
+ assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
106
+ from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
107
+
108
+ ray_worker_group_cls = RayWorkerGroup
109
+
110
+ else:
111
+ raise NotImplementedError
112
+
113
+ from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
114
+
115
+ role_worker_mapping = {
116
+ Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
117
+ Role.Critic: ray.remote(CriticWorker),
118
+ }
119
+
120
+ global_pool_id = "global_pool"
121
+ resource_pool_spec = {
122
+ global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
123
+ }
124
+ mapping = {
125
+ Role.ActorRollout: global_pool_id,
126
+ Role.Critic: global_pool_id,
127
+ }
128
+
129
+ # we should adopt a multi-source reward function here
130
+ # - for rule-based rm, we directly call a reward score
131
+ # - for model-based rm, we call a model
132
+ # - for code related prompt, we send to a sandbox if there are test cases
133
+ # - finally, we combine all the rewards together
134
+ # - The reward type depends on the tag of the data
135
+ if config.reward_model.enable:
136
+ if config.reward_model.strategy in {"fsdp", "fsdp2"}:
137
+ from verl.workers.fsdp_workers import RewardModelWorker
138
+ elif config.reward_model.strategy == "megatron":
139
+ from verl.workers.megatron_workers import RewardModelWorker
140
+ else:
141
+ raise NotImplementedError
142
+ role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
143
+ mapping[Role.RewardModel] = global_pool_id
144
+
145
+ # reference model
146
+ if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
147
+ role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
148
+ mapping[Role.RefPolicy] = global_pool_id
149
+
150
+ reward_fn = load_reward_manager(
151
+ config,
152
+ tokenizer,
153
+ 0,
154
+ max_resp_len=config.data.max_response_length,
155
+ overlong_buffer_cfg=config.reward_model.overlong_buffer,
156
+ )
157
+
158
+ # Note that we always use function-based RM for validation
159
+ val_reward_fn = load_reward_manager(
160
+ config,
161
+ tokenizer,
162
+ 1,
163
+ max_resp_len=config.data.max_response_length,
164
+ overlong_buffer_cfg=config.reward_model.overlong_buffer,
165
+ )
166
+ resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
167
+
168
+ from recipe.flowrl.flowrl_ray_trainer import RayFlowRLTrainer
169
+
170
+ trainer = RayFlowRLTrainer(
171
+ config=config,
172
+ tokenizer=tokenizer,
173
+ processor=processor,
174
+ role_worker_mapping=role_worker_mapping,
175
+ resource_pool_manager=resource_pool_manager,
176
+ ray_worker_group_cls=ray_worker_group_cls,
177
+ reward_fn=reward_fn,
178
+ val_reward_fn=val_reward_fn,
179
+ )
180
+ trainer.init_workers()
181
+ trainer.fit()
182
+
183
+
184
+ if __name__ == "__main__":
185
+ main()
ICL/DAPO/verl-recipe/flowrl/run_flowrl_qwen2.5_7b.sh ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -xeuo pipefail
3
+
4
+ project_name='FlowRL'
5
+ exp_name='FlowRL-Qwen2.5-7B'
6
+
7
+ # Algorithm settings
8
+ adv_estimator=grpo
9
+
10
+ # KL settings (ref policy needed for FlowRL, but KL penalty disabled)
11
+ use_kl_in_reward=False # Enable ref policy for ref_log_prob (needed for FlowRL loss)
12
+ kl_coef=0.0
13
+ use_kl_loss=True
14
+ kl_loss_coef=0.0
15
+
16
+ # Clip parameters
17
+ clip_ratio_low=0.2
18
+ clip_ratio_high=0.28
19
+
20
+ # Sequence lengths
21
+ max_prompt_length=$((1024 * 2))
22
+ max_response_length=$((1024 * 8))
23
+
24
+ # Overlong buffer for very long responses
25
+ enable_overlong_buffer=True
26
+ overlong_buffer_len=$((1024 * 4))
27
+ overlong_penalty_factor=1.0
28
+
29
+ # Batch sizes
30
+ train_prompt_bsz=512
31
+ gen_prompt_bsz=$((train_prompt_bsz * 3))
32
+ n_resp_per_prompt=8
33
+ train_prompt_mini_bsz=32
34
+
35
+ # Checkpoint saving frequency (-1 to disable periodic saves)
36
+ save_freq=-1
37
+
38
+ # Ray
39
+ RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
40
+ WORKING_DIR=${WORKING_DIR:-"${PWD}"}
41
+ RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
42
+ NNODES=${NNODES:-1}
43
+
44
+ # Paths
45
+ MODEL_PATH=${MODEL_PATH:-"${WORKING_DIR}/downloads/models/Qwen/Qwen2.5-7B"}
46
+ CKPTS_DIR=${CKPTS_DIR:-"${WORKING_DIR}/outputs/ckpts/${project_name}/${exp_name}"}
47
+ TRAIN_FILE=${TRAIN_FILE:-"${WORKING_DIR}/downloads/data/dapo-math-17k.parquet"}
48
+ TEST_FILE=${TEST_FILE:-"${WORKING_DIR}/downloads/data/aime-2024.parquet"}
49
+
50
+ # Sampling
51
+ temperature=1.0
52
+ top_p=1.0
53
+ top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
54
+ val_top_p=0.7
55
+
56
+ # Performance Related Parameter
57
+ n_gpus=8
58
+ sp_size=1
59
+ use_dynamic_bsz=True
60
+ actor_ppo_max_token_len=$((max_prompt_length + max_response_length))
61
+ infer_ppo_max_token_len=$((max_prompt_length + max_response_length))
62
+ offload=False
63
+ gen_tp=1
64
+
65
+
66
+ python3 -m recipe.flowrl.main_flowrl \
67
+ data.train_files="${TRAIN_FILE}" \
68
+ data.val_files="${TEST_FILE}" \
69
+ data.prompt_key=prompt \
70
+ data.truncation='left' \
71
+ data.max_prompt_length=${max_prompt_length} \
72
+ data.max_response_length=${max_response_length} \
73
+ data.gen_batch_size=${gen_prompt_bsz} \
74
+ data.train_batch_size=${train_prompt_bsz} \
75
+ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
76
+ algorithm.adv_estimator=${adv_estimator} \
77
+ algorithm.use_kl_in_reward=${use_kl_in_reward} \
78
+ algorithm.kl_ctrl.kl_coef=${kl_coef} \
79
+ actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
80
+ actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
81
+ actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
82
+ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
83
+ actor_rollout_ref.actor.clip_ratio_c=10.0 \
84
+ actor_rollout_ref.model.use_remove_padding=True \
85
+ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
86
+ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
87
+ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
88
+ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
89
+ actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
90
+ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
91
+ actor_rollout_ref.model.path="${MODEL_PATH}" \
92
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
93
+ actor_rollout_ref.actor.optim.lr=1e-6 \
94
+ actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
95
+ actor_rollout_ref.actor.optim.warmup_style='constant' \
96
+ actor_rollout_ref.actor.optim.weight_decay=0.1 \
97
+ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
98
+ actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
99
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
100
+ actor_rollout_ref.actor.entropy_coeff=0 \
101
+ actor_rollout_ref.actor.grad_clip=1.0 \
102
+ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
103
+ actor_rollout_ref.rollout.calculate_log_probs=True \
104
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
105
+ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
106
+ actor_rollout_ref.rollout.enable_chunked_prefill=True \
107
+ actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
108
+ actor_rollout_ref.rollout.temperature=${temperature} \
109
+ actor_rollout_ref.rollout.top_p=${top_p} \
110
+ actor_rollout_ref.rollout.top_k="${top_k}" \
111
+ actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
112
+ actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
113
+ actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
114
+ actor_rollout_ref.rollout.val_kwargs.do_sample=True \
115
+ actor_rollout_ref.rollout.val_kwargs.n=1 \
116
+ actor_rollout_ref.rollout.name=vllm \
117
+ actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
118
+ actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
119
+ actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
120
+ reward_model.reward_manager=dapo \
121
+ reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
122
+ reward_model.overlong_buffer.len=${overlong_buffer_len} \
123
+ reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
124
+ trainer.logger='["console","wandb"]' \
125
+ trainer.project_name="${project_name}" \
126
+ trainer.experiment_name="${exp_name}" \
127
+ trainer.n_gpus_per_node=${n_gpus} \
128
+ trainer.nnodes="${NNODES}" \
129
+ trainer.val_before_train=True \
130
+ trainer.test_freq=10 \
131
+ trainer.save_freq=${save_freq} \
132
+ trainer.total_epochs=1 \
133
+ trainer.default_local_dir="${CKPTS_DIR}" \
134
+ trainer.resume_mode=auto
ICL/DAPO/verl-recipe/infigui-g1/README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Recipe for InfiGUI-G1
2
+
3
+ This directory contains the official implementation for the paper [InfiGUI-G1: Advancing GUI Grounding with Adaptive Exploration Policy Optimization](https://arxiv.org/abs/2508.05731).
4
+
5
+ This work introduces Adaptive Exploration Policy Optimization (AEPO), a policy optimization framework designed to enhance GUI grounding in Multimodal Large Language Models (MLLMs). AEPO improves exploration efficiency by employing a multi-answer generation strategy and a theoretically grounded Adaptive Exploration Reward (AER) function. This approach effectively addresses the challenge of semantic alignment in complex GUI grounding tasks.
6
+
7
+ We provide training scripts for both 3B and 7B models, configured for a single machine with 8 GPUs by default.
8
+
9
+ ## Environment Setup
10
+
11
+ Please follow the main environment setup guide for `verl`.
12
+
13
+ The provided scripts use the following Docker image: `verlai/verl:app-verl0.5-transformers4.55.4-sglang0.4.10.post2-mcore0.13.0-te2.2`
14
+
15
+ ## Data Preparation
16
+
17
+ Before starting the training, you need to download the example dataset. This dataset is a filtered version of [omniact](https://huggingface.co/datasets/Writer/omniact), containing only grounding tasks and excluding easy samples.
18
+
19
+ The data is hosted on the Hugging Face. You can download it using the `huggingface-cli`:
20
+
21
+ ```bash
22
+ huggingface-cli download --repo-type dataset --resume-download InfiX-ai/omniact_grounding_filtered --local-dir data/omniact_grounding_filtered
23
+ ```
24
+
25
+ This command will download the training and validation parquet files into the `data/omniact_grounding_filtered` directory, which is the default path used by the scripts.
26
+
27
+ ## Training
28
+
29
+ We provide scripts to train the 3B and 7B models. Please run them from the root directory of `verl`.
30
+
31
+ - **Train the 3B model:**
32
+
33
+ ```bash
34
+ bash recipe/infigui-g1/run_3b.sh
35
+ ```
36
+
37
+ - **Train the 7B model:**
38
+
39
+ ```bash
40
+ bash recipe/infigui-g1/run_7b.sh
41
+ ```
42
+
43
+ ## Using Custom Data
44
+
45
+ If you wish to train on your own dataset, please format your data to match the structure of the example files located in `data/omniact_grounding_filtered`.
46
+
47
+ Once your data is ready, you need to update the data path arguments in the training script.
48
+
49
+ In `run_3b.sh` or `run_7b.sh`, modify the following lines:
50
+
51
+ ```bash
52
+ data.train_files=./path/to/your/train_data.parquet \
53
+ data.val_files=./path/to/your/val_data.parquet \
54
+ ```
55
+
56
+ Replace the paths with the location of your custom data files.
ICL/DAPO/verl-recipe/langgraph_agent/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
ICL/DAPO/verl-recipe/langgraph_agent/chat_model.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Ref: https://python.langchain.com/docs/how_to/custom_chat_model/
16
+ """
17
+
18
+ import asyncio
19
+ import json
20
+ import logging
21
+ import os
22
+ import uuid
23
+ from typing import Any, Optional
24
+
25
+ from langchain_core.language_models import BaseChatModel
26
+ from langchain_core.language_models.base import LanguageModelInput
27
+ from langchain_core.messages import (
28
+ AIMessage,
29
+ BaseMessage,
30
+ convert_to_openai_messages,
31
+ )
32
+ from langchain_core.messages.tool import InvalidToolCall, ToolCall
33
+ from langchain_core.outputs import ChatGeneration, ChatResult
34
+ from langchain_core.runnables import Runnable, RunnableConfig
35
+ from langchain_core.tools import StructuredTool
36
+ from langchain_core.utils.function_calling import convert_to_openai_tool
37
+ from pydantic import Field
38
+
39
+ from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, AsyncLLMServerManager
40
+ from verl.experimental.agent_loop.tool_parser import ToolParser
41
+ from verl.experimental.agent_loop.utils import add_generation_prompt_for_gpt_oss, format_gpt_oss_tool_response_manually
42
+
43
+ logger = logging.getLogger(__file__)
44
+ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
45
+
46
+
47
+ class MaxTokenExceededError(Exception):
48
+ """Indicate that history chat messages + tool message exceeds LLM max_tokens."""
49
+
50
+ pass
51
+
52
+
53
+ class ChatModel(BaseChatModel):
54
+ model_name: str = Field(alias="model")
55
+ """The name of the model"""
56
+
57
+ client: AsyncLLMServerManager
58
+ """AsyncLLM server manager"""
59
+
60
+ tokenizer: Any
61
+ """Tokenizer for the model"""
62
+
63
+ max_tokens: int
64
+ """Max tokens to generate"""
65
+
66
+ tool_parser: str = "hermes"
67
+ """Tool parser for the model"""
68
+
69
+ max_parallel_calls: int = 1
70
+ """Max parallel tool calls"""
71
+
72
+ temperature: float = 1.0
73
+ """Temperature for sampling"""
74
+
75
+ top_p: float = 1.0
76
+ """Top p for sampling"""
77
+
78
+ repetition_penalty: float = 1.0
79
+ """Repetition penalty for sampling"""
80
+
81
+ def bind_tools(self, tools, **kwargs) -> Runnable[LanguageModelInput, BaseMessage]:
82
+ """Bind tools to the model.
83
+
84
+ Args:
85
+ tools: Sequence of tools to bind to the model.
86
+
87
+ Returns:
88
+ A Runnable that returns a message.
89
+ """
90
+ formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools]
91
+
92
+ # used to remove system prompt prefix when encoding tool response
93
+ system_prompt = self.tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)
94
+ kwargs["system_prompt"] = system_prompt
95
+
96
+ return self.bind(tools=formatted_tools, **kwargs)
97
+
98
+ def with_structured_output(
99
+ self,
100
+ schema: dict | type,
101
+ *,
102
+ include_raw: bool = False,
103
+ **kwargs: Any,
104
+ ) -> Runnable[LanguageModelInput, dict | BaseChatModel]:
105
+ """Ref: https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/"""
106
+ raise NotImplementedError
107
+
108
+ def _generate(
109
+ self,
110
+ messages: list[BaseMessage],
111
+ stop: Optional[list[str]] = None,
112
+ **kwargs: Any,
113
+ ) -> ChatResult:
114
+ raise NotImplementedError
115
+
116
+ async def _agenerate(
117
+ self,
118
+ messages: list[BaseMessage],
119
+ stop: Optional[list[str]] = None,
120
+ **kwargs: Any,
121
+ ) -> ChatResult:
122
+ """Asynchronously generate chat completion message.
123
+
124
+ Args:
125
+ messages (list[BaseMessage]): List of list of messages.
126
+ stop (Optional[list[str]], optional): Stop words to use when generating. Model output is cut off at the
127
+ first occurrence of any of these substrings. Defaults to None.
128
+
129
+ Returns:
130
+ ChatResult: Chat result.
131
+ """
132
+ request_id, prompt_ids, response_mask = await self._preprocess(messages, **kwargs)
133
+
134
+ sampling_params = {
135
+ "temperature": self.temperature,
136
+ "top_p": self.top_p,
137
+ "repetition_penalty": self.repetition_penalty,
138
+ }
139
+ if "sampling_params" in kwargs:
140
+ sampling_params.update(kwargs["sampling_params"])
141
+
142
+ output = await self.client.generate(
143
+ request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params
144
+ )
145
+
146
+ message = await self._postprocess(request_id, prompt_ids, response_mask, output.token_ids, **kwargs)
147
+ generation = ChatGeneration(message=message)
148
+ return ChatResult(generations=[generation])
149
+
150
+ @property
151
+ def _llm_type(self) -> str:
152
+ """Get the type of language model used by this chat model."""
153
+ return self.model_name
154
+
155
+ async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple[str, list[int], list[int]]:
156
+ """Preprocess messages for chat completion.
157
+
158
+ To ensure strong consistency with policy model, AsyncLLM server generate response with token in token out
159
+ instead of messages list.
160
+
161
+ But all agent frameworks use messages list to represent chat history. To mitigate the gap, we store trajectory
162
+ (prompt_ids, response_mask) in lastest AIMessage.response_metadata.
163
+
164
+ 1. Encode ToolMessage to token ids.
165
+ 2. Retrieve trajectory (prompt_ids, response_mask) from lastest AIMessage.response_metadata.
166
+ 3. Append ToolMessage token ids to prompt_ids, and append 0 to response_mask.
167
+
168
+ Ref: https://python.langchain.com/docs/concepts/chat_history/
169
+
170
+ Args:
171
+ messages (list[BaseMessage]): List of messages.
172
+
173
+ Returns:
174
+ tuple[str, list[int], list[int]]: Request id, prompt ids, response mask.
175
+ """
176
+ # messages: [system], human, ai, human|tool, ai, human|tool, ...
177
+ assert messages[-1].type in ["human", "tool"], (
178
+ f"Last message must be human or tool, but got {messages[-1].type}"
179
+ )
180
+ loop = asyncio.get_running_loop()
181
+
182
+ # Case 1: initial chat completion: [system], human
183
+ if messages[-1].type == "human" and (len(messages) == 1 or messages[-2].type != "ai"):
184
+ prompt_ids = await loop.run_in_executor(
185
+ None,
186
+ lambda: self.tokenizer.apply_chat_template(
187
+ convert_to_openai_messages(messages),
188
+ tools=kwargs.get("tools"),
189
+ add_generation_prompt=True,
190
+ tokenize=True,
191
+ ),
192
+ )
193
+ return str(uuid.uuid4()), prompt_ids, []
194
+
195
+ # Case 2: follow up chat completion with tool/human response: [system], human, ai, human|tool, ...
196
+ for i in range(len(messages) - 1, -1, -1):
197
+ if messages[i].type == "ai":
198
+ break
199
+ assert "prompt_ids" in messages[i].response_metadata, "Last message must have prompt_ids in response_metadata"
200
+ assert "response_mask" in messages[i].response_metadata, (
201
+ "Last message must have response_mask in response_metadata"
202
+ )
203
+
204
+ # encode tool response
205
+ tool_responses = convert_to_openai_messages(messages[i + 1 :])
206
+ if self.tool_parser == "hermes":
207
+ tool_response_ids = await loop.run_in_executor(
208
+ None,
209
+ lambda messages=tool_responses: self.tokenizer.apply_chat_template(
210
+ messages, add_generation_prompt=True, tokenize=True
211
+ ),
212
+ )
213
+ tool_response_ids = tool_response_ids[len(kwargs["system_prompt"]) :]
214
+ elif self.tool_parser == "gpt-oss":
215
+ # Format tool responses manually
216
+ # since gpt-oss chat template requires tool call messages to parse tool response messages
217
+ # we need to format the tool response messages manually
218
+ tool_response_texts = []
219
+ for tool_msg in tool_responses:
220
+ if tool_msg["role"] == "tool":
221
+ # Use tool message's name if available (for multiple tool calls)
222
+ actual_tool_name = tool_msg.get("name", "unknown")
223
+ if actual_tool_name == "unknown":
224
+ logger.error(f"actual_tool_name: {actual_tool_name}")
225
+ formatted = format_gpt_oss_tool_response_manually(tool_msg["content"], actual_tool_name)
226
+ tool_response_texts.append(formatted)
227
+
228
+ # Tokenize the manually formatted tool responses
229
+ tool_response_text = "".join(tool_response_texts)
230
+ # need to add generation tokens for gpt-oss manually since add_generation_prompt is True
231
+ tool_response_text = add_generation_prompt_for_gpt_oss(tool_response_text)
232
+ logger.debug(f"tool_response_text: {tool_response_text}")
233
+
234
+ tool_response_ids = await loop.run_in_executor(
235
+ None, lambda: self.tokenizer.encode(tool_response_text, add_special_tokens=False)
236
+ )
237
+ else:
238
+ raise ValueError(f"Unsupported tool parser: {self.tool_parser}")
239
+
240
+ # stop generation if response length exceeds max response length
241
+ if len(messages[i].response_metadata["response_mask"]) + len(tool_response_ids) >= self.max_tokens:
242
+ raise MaxTokenExceededError(f"Max response length {self.max_tokens} exceeded")
243
+
244
+ # append tool response to prompt
245
+ request_id = messages[i].response_metadata.pop("request_id")
246
+ prompt_ids = messages[i].response_metadata.pop("prompt_ids")
247
+ response_mask = messages[i].response_metadata.pop("response_mask")
248
+ prompt_ids += tool_response_ids
249
+ response_mask += [0] * len(tool_response_ids)
250
+
251
+ return request_id, prompt_ids, response_mask
252
+
253
+ async def _postprocess(
254
+ self, request_id: str, prompt_ids: list[int], response_mask: list[int], response_ids: list[int], **kwargs: Any
255
+ ) -> AIMessage:
256
+ """Postprocess response_ids when chat completion is done.
257
+
258
+ 1. Decode response_ids, parse tool calls to AIMessage.
259
+ 2. Append response_ids to prompt_ids, and append 1 to response_mask.
260
+ 3. Store trajectory (prompt_ids, response_mask) in AIMessage.response_metadata.
261
+
262
+ Args:
263
+ request_id (str): Unique request id.
264
+ prompt_ids (list[int]): Input prompt token ids in this chat completion.
265
+ response_mask (list[int]): Response mask before this chat completion.
266
+ response_ids (list[int]): LLM generated token ids in this chat completion.
267
+
268
+ Returns:
269
+ AIMessage: Postprocessed message.
270
+ """
271
+ prompt_ids += response_ids
272
+ response_mask += [1] * len(response_ids)
273
+
274
+ tool_parser = ToolParser.get_tool_parser(self.tool_parser, self.tokenizer)
275
+ content, function_calls = await tool_parser.extract_tool_calls(response_ids)
276
+
277
+ tool_calls, invalid_tool_calls = [], []
278
+
279
+ for function_call in function_calls:
280
+ error = None
281
+ try:
282
+ args = json.loads(function_call.arguments)
283
+ if not isinstance(args, dict):
284
+ error = f"Tool arguments must be a JSON object, got {type(args).__name__}"
285
+ except json.JSONDecodeError as e:
286
+ error = f"Invalid JSON tool arguments: {e}"
287
+
288
+ if error:
289
+ logger.warning(error)
290
+ invalid_tool_calls.append(
291
+ InvalidToolCall(
292
+ name=function_call.name,
293
+ args=function_call.arguments,
294
+ id=str(uuid.uuid4()),
295
+ error=error,
296
+ )
297
+ )
298
+ else:
299
+ tool_calls.append(
300
+ ToolCall(
301
+ name=function_call.name,
302
+ args=args,
303
+ id=str(uuid.uuid4()),
304
+ )
305
+ )
306
+
307
+ message = AIMessage(
308
+ content=content,
309
+ tool_calls=tool_calls[: self.max_parallel_calls],
310
+ invalid_tool_calls=invalid_tool_calls[: self.max_parallel_calls],
311
+ response_metadata={
312
+ "request_id": request_id,
313
+ "prompt_ids": prompt_ids,
314
+ "response_mask": response_mask,
315
+ },
316
+ )
317
+ return message
318
+
319
+
320
+ class TruncateStructuredTool(StructuredTool):
321
+ """Structured tool with response truncation."""
322
+
323
+ tool_response_truncate_side: str
324
+ """truncate side of tool response: left, middle, right"""
325
+
326
+ max_tool_response_length: int
327
+ """max length of tool response"""
328
+
329
+ async def _arun(
330
+ self,
331
+ *args: Any,
332
+ config: RunnableConfig,
333
+ **kwargs: Any,
334
+ ) -> Any:
335
+ tool_response = await super()._arun(*args, config=config, **kwargs)
336
+ tool_response = str(tool_response)
337
+
338
+ if len(tool_response) > self.max_tool_response_length:
339
+ if self.tool_response_truncate_side == "left":
340
+ tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)"
341
+ elif self.tool_response_truncate_side == "right":
342
+ tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :]
343
+ else:
344
+ length = self.max_tool_response_length // 2
345
+ tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:]
346
+
347
+ return tool_response
348
+
349
+
350
+ def convert_to_agent_output(messages: list[BaseMessage], response_length: int) -> AgentLoopOutput:
351
+ """Convert messages to AgentLoopOutput.
352
+
353
+ Args:
354
+ messages (List[BaseMessage]): List of messages, last message must be assistant
355
+ with response_metadata containing `prompt_ids` and `response_mask`.
356
+ response_length (int): Max length of response.
357
+
358
+ Returns:
359
+ AgentLoopOutput: agent loop output trajectory used for training.
360
+ """
361
+ # skip last tool calls
362
+ for i in range(len(messages) - 1, -1, -1):
363
+ if messages[i].type != "tool":
364
+ break
365
+ last_message = messages[i]
366
+ assert last_message.type == "ai", f"Last message must be assistant, but got {last_message.type}"
367
+ assert "prompt_ids" in last_message.response_metadata, "Last message must have prompt_ids in response_metadata"
368
+ assert "response_mask" in last_message.response_metadata, (
369
+ "Last message must have response_mask in response_metadata"
370
+ )
371
+
372
+ num_turns = 0
373
+ for i in range(len(messages)):
374
+ if messages[i].type == "system":
375
+ continue
376
+ # parallel tool calls are in single turn
377
+ if i == 0 or messages[i].type != messages[i - 1].type:
378
+ num_turns += 1
379
+
380
+ prompt_ids = last_message.response_metadata["prompt_ids"]
381
+ response_mask = last_message.response_metadata["response_mask"]
382
+
383
+ response_ids = prompt_ids[-len(response_mask) :]
384
+ prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]
385
+
386
+ output = AgentLoopOutput(
387
+ prompt_ids=prompt_ids,
388
+ response_ids=response_ids[:response_length],
389
+ response_mask=response_mask[:response_length],
390
+ num_turns=num_turns,
391
+ metrics={},
392
+ )
393
+ return output
ICL/DAPO/verl-recipe/langgraph_agent/react_agent_loop.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ LangGraph React Agent Loop.
16
+
17
+ This implementation is exact same as `ToolAgentLoop`.
18
+
19
+ Ref: https://langchain-ai.github.io/langgraph/tutorials/workflows/
20
+ """
21
+
22
+ import logging
23
+ from typing import Any, Literal
24
+
25
+ from langchain_core.messages import AIMessage
26
+ from langchain_core.runnables import RunnableConfig
27
+ from langgraph.graph import END, MessagesState, StateGraph
28
+ from langgraph.prebuilt import ToolNode
29
+ from recipe.langgraph_agent.chat_model import (
30
+ ChatModel,
31
+ MaxTokenExceededError,
32
+ convert_to_agent_output,
33
+ )
34
+
35
+ from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ async def call_model(state: MessagesState, config: RunnableConfig):
41
+ model = config["configurable"]["model"]
42
+ sampling_params = config["configurable"]["sampling_params"]
43
+ try:
44
+ message = await model.ainvoke(state["messages"], sampling_params=sampling_params)
45
+ return {"messages": [message]}
46
+ except MaxTokenExceededError:
47
+ # last message is ToolMessage
48
+ return {"messages": []}
49
+
50
+
51
+ def should_continue(state: MessagesState, config: RunnableConfig) -> Literal["tools", END]:
52
+ # Safely extract max_assistant_turns from config
53
+ max_assistant_turns = None
54
+ try:
55
+ if config and "configurable" in config:
56
+ max_assistant_turns = config["configurable"].get("max_assistant_turns")
57
+ except Exception as e:
58
+ logger.warning(f"Failed to extract max_assistant_turns from config: {e}")
59
+
60
+ num_assistant_turns = 0
61
+ for message in state["messages"]:
62
+ if message.type == "ai":
63
+ num_assistant_turns += 1
64
+
65
+ last_message = state["messages"][-1]
66
+
67
+ # LLM call failed, e.g: max response length exceeded
68
+ if last_message.type == "tool":
69
+ return END
70
+
71
+ # max assistant turns exceeded
72
+ # Use a reasonable default limit (25) if max_assistant_turns is not set
73
+ # This prevents infinite loops
74
+ effective_max_turns = max_assistant_turns if max_assistant_turns is not None else 25
75
+ if num_assistant_turns >= effective_max_turns:
76
+ return END
77
+
78
+ # no tool calls
79
+ if not getattr(last_message, "tool_calls", None):
80
+ return END
81
+
82
+ return "tools"
83
+
84
+
85
+ class ReactAgentLoop(AgentLoopBase):
86
+ # Recursion limit calculation constants
87
+ DEFAULT_MAX_ASSISTANT_TURNS = 25
88
+ MIN_RECURSION_LIMIT = 50
89
+ NODES_PER_TURN = 2 # Each AI turn involves agent + tools nodes
90
+ RECURSION_LIMIT_SAFETY_FACTOR = 1.5 # 50% buffer for edge cases
91
+
92
+ @classmethod
93
+ def init_class(cls, config, tokenizer, **kwargs):
94
+ if cls._class_initialized:
95
+ return
96
+ cls._class_initialized = True
97
+ print("Performing class-level ReactAgentLoop initialization")
98
+
99
+ # build graph
100
+ cls.graph = cls.build_graph()
101
+
102
+ @classmethod
103
+ def build_graph(cls) -> StateGraph:
104
+ workflow = StateGraph(MessagesState)
105
+
106
+ workflow.add_node("agent", call_model)
107
+ workflow.add_node("tools", ToolNode(cls.tools))
108
+ workflow.set_entry_point("agent")
109
+ workflow.add_conditional_edges(
110
+ "agent",
111
+ should_continue,
112
+ {
113
+ "tools": "tools",
114
+ END: END,
115
+ },
116
+ )
117
+
118
+ workflow.add_edge("tools", "agent")
119
+ graph = workflow.compile()
120
+ return graph
121
+
122
+ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
123
+ messages = list(kwargs["raw_prompt"])
124
+
125
+ model_path = self.config.actor_rollout_ref.model.path
126
+ model_name = "/".join(model_path.split("/")[-2:])
127
+
128
+ rollout = self.config.actor_rollout_ref.rollout
129
+ model = ChatModel(
130
+ model=model_name,
131
+ client=self.server_manager,
132
+ tokenizer=self.tokenizer,
133
+ max_tokens=rollout.response_length,
134
+ max_parallel_calls=rollout.multi_turn.max_parallel_calls,
135
+ tool_parser=rollout.multi_turn.format,
136
+ )
137
+
138
+ model = model.bind_tools(self.tools, tool_choice="any")
139
+
140
+ # Calculate recursion_limit dynamically based on max_assistant_turns
141
+ max_assistant_turns = (
142
+ rollout.multi_turn.max_assistant_turns
143
+ if rollout.multi_turn.max_assistant_turns
144
+ else self.DEFAULT_MAX_ASSISTANT_TURNS
145
+ )
146
+
147
+ # Formula: nodes_per_turn * max_turns * safety_buffer, with minimum threshold
148
+ recursion_limit = max(
149
+ self.MIN_RECURSION_LIMIT,
150
+ int(max_assistant_turns * self.NODES_PER_TURN * self.RECURSION_LIMIT_SAFETY_FACTOR),
151
+ )
152
+ logger.info(f"Configured recursion_limit={recursion_limit} (max_assistant_turns={max_assistant_turns})")
153
+
154
+ config = {
155
+ "configurable": {
156
+ "model": model,
157
+ "sampling_params": sampling_params,
158
+ "max_user_turns": rollout.multi_turn.max_user_turns,
159
+ "max_assistant_turns": rollout.multi_turn.max_assistant_turns,
160
+ },
161
+ "recursion_limit": recursion_limit,
162
+ }
163
+
164
+ # TODO: how to handle multiple trajectories in an graph invocation?
165
+ # Each graph node may has its own LLM calls and state, e.g:
166
+ # https://github.com/google-gemini/gemini-fullstack-langgraph-quickstart
167
+ try:
168
+ state = await self.graph.ainvoke(input={"messages": messages}, config=config)
169
+ except Exception as e:
170
+ logger.error(f"Agent loop execution failed: {type(e).__name__}: {e}")
171
+ logger.error("Falling back to a minimal dummy trajectory.")
172
+
173
+ # Fallback to a minimal assistant message so that
174
+ # convert_to_agent_output and downstream padding logic
175
+ # can still run without crashing.
176
+ dummy_id = 0
177
+ fallback_message = AIMessage(
178
+ content="[Agent execution failed - no valid trajectory]",
179
+ response_metadata={
180
+ "request_id": "fallback",
181
+ "prompt_ids": [dummy_id, dummy_id],
182
+ "response_mask": [1],
183
+ },
184
+ )
185
+ state = {"messages": [fallback_message]}
186
+
187
+ output = convert_to_agent_output(state["messages"], rollout.response_length)
188
+ return output
ICL/DAPO/verl-recipe/langgraph_agent/test_react_agent_loop.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import os
16
+
17
+ import numpy as np
18
+ import pytest
19
+ import ray
20
+ from langchain_core.tools import tool
21
+ from omegaconf import DictConfig
22
+ from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop
23
+ from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager
24
+
25
+ from verl.protocol import DataProto
26
+ from verl.utils import hf_tokenizer
27
+
28
+
29
+ @pytest.fixture
30
+ def init_config() -> DictConfig:
31
+ from hydra import compose, initialize_config_dir
32
+
33
+ with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
34
+ config = compose(config_name="ppo_trainer")
35
+ model_path = "Qwen/Qwen2.5-1.5B-Instruct"
36
+ config.actor_rollout_ref.model.path = model_path
37
+ config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
38
+ config.actor_rollout_ref.rollout.mode = "async"
39
+ config.actor_rollout_ref.rollout.prompt_length = 4096
40
+ config.actor_rollout_ref.rollout.response_length = 4096
41
+ config.actor_rollout_ref.rollout.n = 4
42
+ config.actor_rollout_ref.rollout.agent.num_workers = 2
43
+
44
+ config.actor_rollout_ref.actor.use_dynamic_bsz = True
45
+ # test sleep/wake_up with fsdp offload
46
+ config.actor_rollout_ref.actor.fsdp_config.param_offload = True
47
+ config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
48
+
49
+ return config
50
+
51
+
52
+ @tool(parse_docstring=True)
53
+ def get_current_temperature(location: str, unit: str = "celsius"):
54
+ """Get current temperature at a location.
55
+
56
+ Args:
57
+ location: The location to get the temperature for, in the format "City, State, Country".
58
+ unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
59
+
60
+ Returns:
61
+ the temperature, the location, and the unit in a dict
62
+ """
63
+ print(f"[DEBUG] get_current_temperature: {location}, {unit}")
64
+ return {
65
+ "temperature": 26.1,
66
+ "location": location,
67
+ "unit": unit,
68
+ }
69
+
70
+
71
+ @tool(parse_docstring=True)
72
+ def get_temperature_date(location: str, date: str, unit: str = "celsius"):
73
+ """Get temperature at a location and date.
74
+
75
+ Args:
76
+ location: The location to get the temperature for, in the format "City, State, Country".
77
+ date: The date to get the temperature for, in the format "Year-Month-Day".
78
+ unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"])
79
+
80
+ Returns:
81
+ the temperature, the location, the date and the unit in a dict
82
+ """
83
+ print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}")
84
+ return {
85
+ "temperature": 25.9,
86
+ "location": location,
87
+ "date": date,
88
+ "unit": unit,
89
+ }
90
+
91
+
92
+ class TestReactAgentLoop(ReactAgentLoop):
93
+ @classmethod
94
+ def init_class(cls, config, tokenizer, **kwargs):
95
+ # TODO: find better way to configure tools
96
+ cls.tools = [get_current_temperature, get_temperature_date]
97
+ super().init_class(config, tokenizer, **kwargs)
98
+
99
+
100
+ def test_react_agent(init_config):
101
+ ray.init(
102
+ runtime_env={
103
+ "env_vars": {
104
+ "TOKENIZERS_PARALLELISM": "true",
105
+ "NCCL_DEBUG": "WARN",
106
+ "VLLM_LOGGING_LEVEL": "INFO",
107
+ "VLLM_USE_V1": "1",
108
+ }
109
+ }
110
+ )
111
+
112
+ # =========================== 1. Init rollout manager ===========================
113
+ agent_loop_config = [
114
+ {
115
+ "_target_": "recipe.langgraph_agent.test_react_agent_loop.TestReactAgentLoop",
116
+ "name": "react_agent",
117
+ },
118
+ ]
119
+ agent_loop_config_path = "/tmp/agent_loop_config.json"
120
+ with open(agent_loop_config_path, "w") as f:
121
+ json.dump(agent_loop_config, f)
122
+
123
+ n = 2
124
+ init_config.actor_rollout_ref.rollout.n = n
125
+ # init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path
126
+ init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2
127
+ init_config.actor_rollout_ref.rollout.agent.agent_loop_config_path = agent_loop_config_path
128
+ agent_loop_manager = init_agent_loop_manager(init_config)
129
+
130
+ # =========================== 2. Generate sequences ===========================
131
+ raw_prompts = [
132
+ [
133
+ {"role": "user", "content": "How are you?"},
134
+ ],
135
+ [
136
+ {"role": "user", "content": "What's the temperature in Los Angeles now?"},
137
+ ],
138
+ [
139
+ {"role": "user", "content": "What's the temperature in New York now?"},
140
+ ],
141
+ [
142
+ {
143
+ "role": "system",
144
+ "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n"
145
+ "Current Date: 2024-09-30",
146
+ },
147
+ {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"},
148
+ ],
149
+ ]
150
+ batch = DataProto(
151
+ non_tensor_batch={
152
+ "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object),
153
+ "agent_name": np.array(["react_agent"] * len(raw_prompts)),
154
+ "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)),
155
+ "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)),
156
+ },
157
+ )
158
+ batch = batch.repeat(n)
159
+ result = agent_loop_manager.generate_sequences(prompts=batch)
160
+ assert len(result) == len(raw_prompts) * n
161
+
162
+ # Check turns
163
+ num_turns = result.non_tensor_batch["__num_turns__"]
164
+ print(f"num_turns: {num_turns}")
165
+ for i in range(len(num_turns)):
166
+ if i // n == 0:
167
+ # [user, assistant]
168
+ assert num_turns[i] == 2
169
+ else:
170
+ # [user, assistant, tool, assistant]
171
+ assert num_turns[i] == 4
172
+
173
+ # Check response_mask
174
+ tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path)
175
+ responses = result.batch["responses"]
176
+ response_mask = result.batch["response_mask"]
177
+ attention_mask = result.batch["attention_mask"]
178
+ assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}"
179
+ response_length = response_mask.size(1)
180
+
181
+ for i in range(len(responses)):
182
+ # response with tool response
183
+ valid_tokens = responses[i][attention_mask[i][-response_length:].bool()]
184
+ response_with_obs = tokenizer.decode(valid_tokens)
185
+
186
+ # response without tool response
187
+ valid_tokens = responses[i][response_mask[i].bool()]
188
+ response_without_obs = tokenizer.decode(valid_tokens)
189
+
190
+ assert "<tool_response>" not in response_without_obs, (
191
+ f"found <tool_response> in response: {response_without_obs}"
192
+ )
193
+ assert "</tool_response>" not in response_without_obs, (
194
+ f"found </tool_response> in response: {response_without_obs}"
195
+ )
196
+ print("=========================")
197
+ print(response_with_obs)
198
+ print("---")
199
+ print(response_without_obs)
200
+
201
+ print("Test passed!")
202
+ ray.shutdown()
ICL/DAPO/verl-recipe/minicpmo/rl_dataset.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2023-2024 SGLang Team
3
+ # Copyright 2025 ModelBest Inc. and/or its affiliates
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import copy
18
+ import logging
19
+ import math
20
+ import os
21
+ import re
22
+ from typing import Optional
23
+
24
+ import datasets
25
+ import torch
26
+ from omegaconf import DictConfig, ListConfig
27
+ from PIL import Image
28
+ from torch.utils.data import Dataset
29
+ from torchvision import transforms
30
+ from transformers import PreTrainedTokenizer, ProcessorMixin
31
+
32
+ import verl.utils.torch_functional as verl_F
33
+ from verl.utils.dataset.vision_utils import process_image
34
+ from verl.utils.model import compute_position_id_with_mask
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ def build_transform():
40
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
41
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
42
+ return transforms.Compose(
43
+ [
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
46
+ ]
47
+ )
48
+
49
+
50
+ def build_image_bound(input_ids, tokenizer, new_schema=True, logger=None):
51
+ if new_schema:
52
+ start_cond = (input_ids == tokenizer.im_start_id) | (input_ids == tokenizer.slice_start_id)
53
+ end_cond = (input_ids == tokenizer.im_end_id) | (input_ids == tokenizer.slice_end_id)
54
+ else:
55
+ start_cond = input_ids == tokenizer.im_start_id
56
+ end_cond = input_ids == tokenizer.im_end_id
57
+ image_start_tokens = torch.where(start_cond)[0]
58
+ image_start_tokens += 1
59
+ image_end_tokens = torch.where(end_cond)[0]
60
+ if len(image_start_tokens) != len(image_end_tokens):
61
+ logger.error("image start token != image end tokens")
62
+ raise Exception("image start token != image end tokens")
63
+ if len(image_start_tokens) > 0:
64
+ image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)])
65
+ else:
66
+ image_bound = []
67
+ return image_bound
68
+
69
+
70
+ def preprocess(
71
+ images_dict,
72
+ conversations,
73
+ tokenizer,
74
+ transform,
75
+ query_nums=64,
76
+ slice_config=None,
77
+ llm_type=None,
78
+ patch_size=14,
79
+ batch_vision=False,
80
+ max_length=2048,
81
+ truncation="error",
82
+ apply_chat_template_kwargs=None,
83
+ logger=None,
84
+ ):
85
+ """
86
+ single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation
87
+ """
88
+ conversations = copy.deepcopy(conversations)
89
+ assert conversations[0]["role"] == "user", "the first role must be user"
90
+
91
+ if slice_config is not None:
92
+ assert isinstance(slice_config, dict)
93
+ assert "patch_size" in slice_config
94
+ assert "max_slice_nums" in slice_config
95
+ assert "scale_resolution" in slice_config
96
+ default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
97
+ new_schema = False
98
+ use_image_id = False
99
+ if llm_type == "qwen":
100
+ new_schema = True
101
+ use_image_id = True
102
+ image_placeholder_dict = {}
103
+ images = []
104
+ image_id_cnt = 0
105
+ for img_name, image in images_dict.items():
106
+ if slice_config:
107
+ source_image, patches, best_grid = slice_image(
108
+ image,
109
+ slice_config["max_slice_nums"],
110
+ slice_config["scale_resolution"],
111
+ slice_config["patch_size"],
112
+ )
113
+ images.append(source_image)
114
+ image_placeholder = default_image_placeholder
115
+ if len(patches) > 0:
116
+ for i in range(len(patches)):
117
+ for j in range(len(patches[0])):
118
+ images.append(patches[i][j])
119
+ if use_image_id:
120
+ image_placeholder = (
121
+ f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder
122
+ )
123
+ image_id_cnt += 1
124
+ image_placeholder += get_grid_placeholder(tokenizer, best_grid, query_nums, new_schema=new_schema)
125
+ image_placeholder_dict[img_name] = image_placeholder
126
+ else:
127
+ images.append(image)
128
+ if use_image_id:
129
+ image_placeholder = f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder
130
+ image_id_cnt += 1
131
+ else:
132
+ image_placeholder = default_image_placeholder
133
+ image_placeholder_dict[img_name] = image_placeholder
134
+
135
+ images = [transform(i) for i in images]
136
+
137
+ if len(images_dict) == 1 and "<image>" in images_dict:
138
+ if "<image>" in conversations[0]["content"]:
139
+ conversations[0]["content"] = conversations[0]["content"].replace("<image>", image_placeholder)
140
+ else:
141
+ conversations[0]["content"] = image_placeholder + "\n" + conversations[0]["content"]
142
+ else:
143
+ pattern = r"<image_\d+>"
144
+ new_conversations = []
145
+ for conversation in conversations:
146
+ content = conversation["content"]
147
+ parts = re.split(f"({pattern})", content)
148
+ for i, part in enumerate(parts):
149
+ if not part.strip():
150
+ continue
151
+ if re.match(pattern, part):
152
+ if part in image_placeholder_dict:
153
+ parts[i] = image_placeholder_dict[part]
154
+ else:
155
+ raise Exception(f"not found {part} in image dict")
156
+ conversation["content"] = "\n".join(parts)
157
+ new_conversations.append(conversation)
158
+ conversations = new_conversations
159
+
160
+ # TODO change role in conversation for different llm
161
+ prompt_with_chat_template = tokenizer.apply_chat_template(
162
+ conversations, add_generation_prompt=True, tokenize=False, **(apply_chat_template_kwargs or {})
163
+ )
164
+
165
+ input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
166
+ prompt=prompt_with_chat_template,
167
+ tokenizer=tokenizer,
168
+ max_length=max_length,
169
+ pad_token_id=tokenizer.pad_token_id,
170
+ left_pad=True,
171
+ truncation=truncation,
172
+ )
173
+ position_ids = compute_position_id_with_mask(attention_mask)
174
+ image_bound = build_image_bound(input_ids[0], tokenizer, new_schema, logger)
175
+
176
+ input_dict = {
177
+ "input_ids": input_ids[0],
178
+ "attention_mask": attention_mask[0],
179
+ "position_ids": position_ids[0],
180
+ "image_bound": image_bound,
181
+ }
182
+
183
+ if batch_vision:
184
+ tgt_sizes = []
185
+ reshape_images = []
186
+ for image in images:
187
+ H, W = image.shape[1:]
188
+ reshape_image = reshape_by_patch(image, patch_size)
189
+ reshape_images.append(reshape_image)
190
+ tgt_sizes.append([H // patch_size, W // patch_size])
191
+ if tgt_sizes:
192
+ tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32)
193
+
194
+ input_dict["pixel_values"] = reshape_images
195
+ input_dict["tgt_sizes"] = tgt_sizes
196
+
197
+ else:
198
+ input_dict["pixel_values"] = images
199
+ input_dict["tgt_sizes"] = []
200
+
201
+ return input_dict
202
+
203
+
204
+ def slice_image(image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):
205
+ original_size = image.size
206
+ original_width, original_height = original_size
207
+ log_ratio = math.log(original_width / original_height)
208
+ ratio = original_width * original_height / (scale_resolution * scale_resolution)
209
+ multiple = min(math.ceil(ratio), max_slice_nums)
210
+
211
+ source_image = None
212
+ best_grid = None
213
+ patches = []
214
+
215
+ if multiple <= 1 or never_split:
216
+ # dont need to slice, upsample
217
+ best_size = find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)
218
+ source_image = image.resize(best_size, Image.Resampling.BICUBIC)
219
+ else:
220
+ candidate_split_grids_nums = []
221
+ for i in [multiple - 1, multiple, multiple + 1]:
222
+ if i == 1 or i > max_slice_nums:
223
+ continue
224
+ candidate_split_grids_nums.append(i)
225
+
226
+ # source image, down-sampling and ensure divided by patch_size
227
+ best_resize = find_best_resize(original_size, scale_resolution, patch_size)
228
+ source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
229
+ candidate_grids = []
230
+
231
+ # find best grid
232
+ for split_grids_nums in candidate_split_grids_nums:
233
+ m = 1
234
+ while m <= split_grids_nums:
235
+ if split_grids_nums % m == 0:
236
+ candidate_grids.append([m, split_grids_nums // m])
237
+ m += 1
238
+
239
+ best_grid = [1, 1]
240
+ min_error = float("inf")
241
+ for grid in candidate_grids:
242
+ error = abs(log_ratio - math.log(grid[0] / grid[1]))
243
+ if error < min_error:
244
+ best_grid = grid
245
+ min_error = error
246
+
247
+ refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, allow_upscale=True)
248
+
249
+ refine_image = image.resize(refine_size, Image.Resampling.BICUBIC)
250
+ patches = split_to_patches(refine_image, best_grid)
251
+
252
+ return source_image, patches, best_grid
253
+
254
+
255
+ def ensure_divide(length, patch_size):
256
+ return max(round(length / patch_size) * patch_size, patch_size)
257
+
258
+
259
+ def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False):
260
+ width, height = original_size
261
+ if (width * height > scale_resolution * scale_resolution) or allow_upscale:
262
+ r = width / height
263
+ height = int(scale_resolution / math.sqrt(r))
264
+ width = int(height * r)
265
+ best_width = ensure_divide(width, patch_size)
266
+ best_height = ensure_divide(height, patch_size)
267
+ return (best_width, best_height)
268
+
269
+
270
+ def get_refine_size(original_size, grid, scale_resolution, patch_size, allow_upscale=False):
271
+ width, height = original_size
272
+ grid_x, grid_y = grid
273
+
274
+ refine_width = ensure_divide(width, grid_x)
275
+ refine_height = ensure_divide(height, grid_y)
276
+
277
+ grid_width = refine_width / grid_x
278
+ grid_height = refine_height / grid_y
279
+
280
+ best_grid_size = find_best_resize(
281
+ (grid_width, grid_height),
282
+ scale_resolution,
283
+ patch_size,
284
+ allow_upscale=allow_upscale,
285
+ )
286
+
287
+ refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
288
+
289
+ return refine_size
290
+
291
+
292
+ def split_to_patches(image, grid):
293
+ patches = []
294
+ width, height = image.size
295
+ grid_x = int(width / grid[0])
296
+ grid_y = int(height / grid[1])
297
+
298
+ for i in range(0, height, grid_y):
299
+ images = []
300
+ for j in range(0, width, grid_x):
301
+ box = (j, i, j + grid_x, i + grid_y)
302
+ patch = image.crop(box)
303
+ images.append(patch)
304
+ patches.append(images)
305
+
306
+ return patches
307
+
308
+
309
+ def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False):
310
+ if new_schema:
311
+ image_placeholder = tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end
312
+ else:
313
+ image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
314
+
315
+ cols = grid[0]
316
+ rows = grid[1]
317
+ slices = []
318
+ for i in range(rows):
319
+ lines = []
320
+ for j in range(cols):
321
+ lines.append(image_placeholder)
322
+ slices.append("".join(lines))
323
+ if new_schema:
324
+ slice_placeholder = "\n".join(slices)
325
+ else:
326
+ slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end
327
+ return slice_placeholder
328
+
329
+
330
+ def reshape_by_patch(image_tensor, patch_size):
331
+ """
332
+ :param image_tensor: shape [3, H, W]
333
+ :param patch_size:
334
+ :return: [3, patch_size, HW/patch_size]
335
+ """
336
+ patches = torch.nn.functional.unfold(image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size))
337
+
338
+ patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
339
+ patches = patches.permute(0, 1, 3, 2).reshape(image_tensor.size(0), patch_size, -1)
340
+ return patches
341
+
342
+
343
+ def init_minicpmo_config(processor, config):
344
+ """Initialize MiniCPM-o specific configuration"""
345
+ minicpmo_config = {
346
+ "transform": build_transform(),
347
+ "patch_size": config.get("patch_size", 14),
348
+ "query_nums": config.get("query_nums", 64),
349
+ "slice_config": config.get(
350
+ "slice_config", {"max_slice_nums": 9, "patch_size": config.get("patch_size", 14), "scale_resolution": 448}
351
+ ),
352
+ "llm_type": config.get("llm_type", "qwen"),
353
+ "batch_vision": config.get("batch_vision", True),
354
+ }
355
+ return minicpmo_config
356
+
357
+
358
+ def process_minicpmo_data(
359
+ row_dict,
360
+ messages,
361
+ tokenizer,
362
+ minicpmo_config,
363
+ image_key,
364
+ max_prompt_length,
365
+ truncation,
366
+ apply_chat_template_kwargs,
367
+ logger,
368
+ ):
369
+ """Process data for MiniCPM-o model"""
370
+ if len(row_dict[image_key]) == 1:
371
+ multi_modal_data = {}
372
+ image = process_image(row_dict.pop(image_key)[0])
373
+ multi_modal_data["image"] = [image]
374
+ images_dict = {"<image>": image}
375
+ else:
376
+ raise NotImplementedError
377
+
378
+ model_inputs = preprocess(
379
+ images_dict,
380
+ messages,
381
+ tokenizer,
382
+ minicpmo_config["transform"],
383
+ query_nums=minicpmo_config["query_nums"],
384
+ slice_config=minicpmo_config["slice_config"],
385
+ llm_type=minicpmo_config["llm_type"],
386
+ patch_size=minicpmo_config["patch_size"],
387
+ batch_vision=minicpmo_config["batch_vision"],
388
+ max_length=max_prompt_length,
389
+ truncation=truncation,
390
+ apply_chat_template_kwargs=apply_chat_template_kwargs,
391
+ logger=logger,
392
+ )
393
+
394
+ raw_prompt = tokenizer.apply_chat_template(
395
+ messages, add_generation_prompt=True, tokenize=False, **(apply_chat_template_kwargs or {})
396
+ )
397
+ raw_prompt = raw_prompt.replace("<image>", "(<image>./</image>)")
398
+
399
+ return model_inputs, multi_modal_data, raw_prompt
400
+
401
+
402
+ class RLHFDataset(Dataset):
403
+ """
404
+ Load and preprocess RLHF data from Parquet files.
405
+
406
+ - Caches files locally.
407
+ - Reads into a HuggingFace Dataset and tokenizes prompts.
408
+ - Optionally handles images/videos via a ProcessorMixin.
409
+ - Filters prompts over a max length.
410
+ - Supports resuming from checkpoints.
411
+
412
+ Args:
413
+ data_files (str or list): Path(s) to Parquet file(s).
414
+ tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs.
415
+ config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc.
416
+ processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos.
417
+ """
418
+
419
+ def __init__(
420
+ self,
421
+ data_files: str | list[str],
422
+ tokenizer: PreTrainedTokenizer,
423
+ config: DictConfig,
424
+ processor: Optional[ProcessorMixin] = None,
425
+ ):
426
+ if not isinstance(data_files, list | ListConfig):
427
+ data_files = [data_files]
428
+
429
+ self.data_files = copy.deepcopy(data_files)
430
+ self.original_data_files = copy.deepcopy(data_files) # use for resume
431
+ self.tokenizer = tokenizer
432
+ self.processor = processor
433
+ self.config = config
434
+
435
+ self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf"))
436
+ self.prompt_key = config.get("prompt_key", "prompt")
437
+ self.image_key = config.get("image_key", "images")
438
+ self.video_key = config.get("video_key", "videos")
439
+ self.max_prompt_length = config.get("max_prompt_length", 1024)
440
+ self.return_raw_chat = config.get("return_raw_chat", False)
441
+ self.return_full_prompt = config.get("return_full_prompt", False)
442
+ self.truncation = config.get("truncation", "error")
443
+ self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)
444
+ self.apply_chat_template_kwargs = config.get("apply_chat_template_kwargs", {})
445
+
446
+ self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4))
447
+ self.num_workers = min(self.num_workers, os.cpu_count())
448
+ self.use_shm = config.get("use_shm", False)
449
+ self.chat_template_func = config.get("chat_template_func", None)
450
+ self.need_tools_kwargs = config.get("need_tools_kwargs", False)
451
+ self.filter_prompts = config.get("filter_prompts", True)
452
+ self.serialize_dataset = False
453
+ self.minicpmo_config = init_minicpmo_config(self.processor, config)
454
+ self._download()
455
+ self._read_files_and_tokenize()
456
+
457
+ def _download(self, use_origin_parquet=False):
458
+ from verl.utils.fs import copy_to_local
459
+
460
+ data_files = self.data_files if not use_origin_parquet else self.original_data_files
461
+ for i, parquet_file in enumerate(data_files):
462
+ self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm)
463
+
464
+ def _read_files_and_tokenize(self):
465
+ dataframes = []
466
+ for parquet_file in self.data_files:
467
+ # read parquet files and cache
468
+ dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
469
+ dataframes.append(dataframe)
470
+ self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)
471
+
472
+ print(f"dataset len: {len(self.dataframe)}")
473
+
474
+ def resume_dataset_state(self):
475
+ self.serialize_dataset = not hasattr(self, "original_data_files")
476
+ # resume dataframe if not it's serialized in data.pt
477
+ if not self.serialize_dataset:
478
+ self._download(use_origin_parquet=True) # download and resume from original parquet files
479
+ self._read_files_and_tokenize()
480
+ else:
481
+ print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance")
482
+
483
+ def __len__(self):
484
+ return len(self.dataframe)
485
+
486
+ def _build_messages(self, example: dict):
487
+ return example.pop(self.prompt_key)
488
+
489
+ def __getitem__(self, item):
490
+ """
491
+ Note that we also return the raw_input_ids so that it can be combined with other chat template
492
+ """
493
+ row_dict: dict = self.dataframe[item]
494
+ messages = self._build_messages(row_dict)
495
+ model_inputs = {}
496
+
497
+ if self.processor is not None:
498
+ model_inputs, multi_modal_data, raw_prompt = process_minicpmo_data(
499
+ row_dict,
500
+ messages,
501
+ self.tokenizer,
502
+ self.minicpmo_config,
503
+ self.image_key,
504
+ self.max_prompt_length,
505
+ self.truncation,
506
+ self.apply_chat_template_kwargs,
507
+ logger,
508
+ )
509
+ input_ids = model_inputs.pop("input_ids")
510
+ attention_mask = model_inputs.pop("attention_mask")
511
+ position_ids = model_inputs.pop("position_ids")
512
+
513
+ # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
514
+ row_dict["multi_modal_data"] = multi_modal_data
515
+ row_dict["multi_modal_inputs"] = dict(model_inputs)
516
+ else:
517
+ raw_prompt = self.tokenizer.apply_chat_template(
518
+ messages, add_generation_prompt=True, tokenize=False, **self.apply_chat_template_kwargs
519
+ )
520
+ model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
521
+ input_ids = model_inputs.pop("input_ids")
522
+ attention_mask = model_inputs.pop("attention_mask")
523
+ position_ids = compute_position_id_with_mask(attention_mask)
524
+
525
+ row_dict["input_ids"] = input_ids
526
+ row_dict["attention_mask"] = attention_mask
527
+ row_dict["position_ids"] = position_ids
528
+
529
+ raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
530
+ if len(raw_prompt_ids) > self.max_prompt_length:
531
+ if self.truncation == "left":
532
+ raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
533
+ elif self.truncation == "right":
534
+ raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
535
+ elif self.truncation == "middle":
536
+ left_half = self.max_prompt_length // 2
537
+ right_half = self.max_prompt_length - left_half
538
+ raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:]
539
+ elif self.truncation == "error":
540
+ raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")
541
+
542
+ row_dict["raw_prompt_ids"] = raw_prompt_ids
543
+ # encode prompts without chat template
544
+ if self.return_raw_chat:
545
+ row_dict["raw_prompt"] = messages
546
+
547
+ # get prompts with chat template
548
+ if self.return_full_prompt:
549
+ row_dict["full_prompts"] = raw_prompt # array of strings
550
+
551
+ # add index for each prompt
552
+ index = row_dict.get("extra_info", {}).get("index", 0)
553
+ tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {})
554
+ interaction_kwargs = row_dict.get("extra_info", {}).get("interaction_kwargs", {})
555
+ need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs)
556
+ if need_tools_kwargs and not tools_kwargs:
557
+ logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"])
558
+ row_dict["index"] = index
559
+ row_dict["tools_kwargs"] = tools_kwargs
560
+ row_dict["interaction_kwargs"] = interaction_kwargs
561
+ return row_dict
562
+
563
+ def __getstate__(self):
564
+ if not self.serialize_dataset:
565
+ state = self.__dict__.copy()
566
+
567
+ if "dataframe" in state:
568
+ del state["dataframe"]
569
+ return state
570
+
571
+ return self.__dict__.copy()
ICL/DAPO/verl-recipe/prime/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 PRIME team and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
ICL/DAPO/verl-recipe/prime/prime_core_algos.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 PRIME team and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ import verl
18
+ import verl.utils.torch_functional as verl_F
19
+
20
+
21
+ def compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Tensor, n_samples, config):
22
+ # calculate rloo reward on different reward sources, and sum again
23
+ def masked_rloo(reward_tensor_original, mask_tensor):
24
+ reward_tensor = reward_tensor_original.clone()
25
+ reward_tensor[~mask_tensor] = 0
26
+ for start_pos in range(0, reward_tensor.shape[0], n_samples):
27
+ cur_rewards_mean = torch.cat(
28
+ [
29
+ reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean(dim=0, keepdim=True)
30
+ for pos in range(start_pos, start_pos + n_samples)
31
+ ],
32
+ dim=0,
33
+ )
34
+ cur_rewards_sum = cur_rewards_mean.sum()
35
+ cur_reward_baseline = cur_rewards_sum / (n_samples - 1)
36
+ reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] = (
37
+ reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]]
38
+ * (n_samples / (n_samples - 1))
39
+ - cur_reward_baseline
40
+ )
41
+
42
+ return reward_tensor
43
+
44
+ reward_tensors = []
45
+
46
+ with torch.no_grad():
47
+ if "rm_scores" in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.0:
48
+ reward_tensor = data.batch["rm_scores"]
49
+ reward_mask = response_mask.bool()
50
+
51
+ reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef)
52
+
53
+ if "acc" in data.batch.keys() and config.algorithm.reward_gt_coef != 0.0:
54
+ reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32)
55
+ reward_mask = torch.zeros_like(response_mask, dtype=torch.bool)
56
+
57
+ prompt_ids = data.batch["prompts"]
58
+ prompt_length = prompt_ids.shape[-1]
59
+ valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(-1)
60
+
61
+ reward_mask[
62
+ torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
63
+ valid_response_length - 1,
64
+ ] = True
65
+ reward_tensor[
66
+ torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device),
67
+ valid_response_length - 1,
68
+ ] = data.batch["acc"]
69
+
70
+ reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef)
71
+
72
+ final_reward_tensor = sum(reward_tensors)
73
+
74
+ returns = (final_reward_tensor * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
75
+
76
+ advantages = returns.clone()
77
+ advantages = verl_F.masked_whiten(advantages, response_mask)
78
+
79
+ return advantages, returns
80
+
81
+
82
+ def compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta):
83
+ cur_scores = ((token_level_scores * response_mask).sum(dim=1) * beta).sigmoid()
84
+ cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc)
85
+ return cur_dpo_loss
86
+
87
+
88
+ def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode="none"):
89
+ # we always assume that the BoN size equals n_samples
90
+ # mode1: use acc as rm
91
+ # mode2: use Q as rm
92
+ cur_Q = (token_level_scores * response_mask).sum(dim=1) * beta
93
+ other_Q = torch.zeros_like(cur_Q)
94
+ for i in range(token_level_scores.shape[0]):
95
+ Q_chosen = Q_bc[i][acc_bc[i] < acc[i]] if acc[i] > 0 else Q_bc[i][acc_bc[i] > acc[i]]
96
+ if len(Q_chosen) > 0:
97
+ other_Q[i] = Q_chosen.mean() * beta
98
+ else:
99
+ other_Q[i] = 0
100
+ dpo_loss = -torch.log(torch.sigmoid((cur_Q - other_Q) * ((acc > 0).float() * 2 - 1)))
101
+ if bon_mode == "none":
102
+ dpo_loss = dpo_loss.mean()
103
+ else:
104
+ weight = torch.zeros_like(dpo_loss)
105
+ n_samples = acc_bc.shape[1]
106
+ if bon_mode == "bon_rm":
107
+ for i in range(token_level_scores.shape[0]):
108
+ weight[i] = n_samples * torch.pow((Q_bc[i] * beta <= cur_Q[i]).float().mean(), n_samples - 1)
109
+ elif bon_mode == "bon_acc":
110
+ for i in range(token_level_scores.shape[0]):
111
+ weight[i] = n_samples * torch.pow((acc_bc[i] <= acc[i]).float().mean(), n_samples - 1)
112
+ else:
113
+ raise NotImplementedError
114
+ dpo_loss = (dpo_loss * weight).sum()
115
+
116
+ return dpo_loss
117
+
118
+
119
+ def compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples):
120
+ dpo_acc = []
121
+ for start_id in range(0, token_level_scores.shape[0], n_samples):
122
+ cur_scores = (
123
+ token_level_scores[start_id : start_id + n_samples] * response_mask[start_id : start_id + n_samples]
124
+ ).sum(dim=1)
125
+
126
+ def get_upper_triangle(tensor_x):
127
+ diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0)
128
+ upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1)
129
+ return diff_matrix[upper_tri_indices]
130
+
131
+ cur_acc_diff = get_upper_triangle(acc[start_id : start_id + n_samples]) # in range [-1,1]
132
+ cur_score_diff = get_upper_triangle(cur_scores) # in R
133
+ cur_score_prediction = (cur_score_diff > 0).float() # in [0,1]
134
+ if cur_acc_diff.abs().sum() == 0:
135
+ cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5
136
+ else:
137
+ cur_acc = (
138
+ ((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * cur_acc_diff.abs()
139
+ ).sum() / cur_acc_diff.abs().sum()
140
+
141
+ dpo_acc.append(cur_acc.unsqueeze(0))
142
+
143
+ return torch.cat(dpo_acc, dim=0).mean()
144
+
145
+
146
+ def compute_dpo_abs_accuracy(token_level_scores, acc, response_mask, n_samples):
147
+ return (torch.sign((token_level_scores * response_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean()
ICL/DAPO/verl-recipe/prime/run_prime_qwen_code.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -x
2
+
3
+
4
+ # download from https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data
5
+ code_train_path=$HOME/data/code/train.parquet
6
+ code_test_path=$HOME/data/code/test.parquet
7
+
8
+ train_files="['$code_train_path']"
9
+ test_files="['$code_test_path']"
10
+
11
+ model_path=PRIME-RL/Eurus-2-7B-SFT
12
+ # model_path=Qwen/Qwen2.5-0.5B-Instruct
13
+
14
+ python3 -m recipe.prime.main_prime \
15
+ data.train_files="$train_files" \
16
+ data.val_files="$test_files" \
17
+ data.train_batch_size=64 \
18
+ data.val_batch_size=6312 \
19
+ data.max_prompt_length=1024 \
20
+ data.max_response_length=3072 \
21
+ data.filter_overlong_prompts=True \
22
+ data.filter_accuracy=True \
23
+ data.accuracy_lower_bound=0.2 \
24
+ data.accuracy_upper_bound=0.8 \
25
+ data.oversample_factor=4 \
26
+ actor_rollout_ref.model.path=$model_path \
27
+ actor_rollout_ref.actor.optim.lr=5e-7 \
28
+ actor_rollout_ref.model.use_remove_padding=True \
29
+ actor_rollout_ref.actor.ppo_mini_batch_size=64 \
30
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
31
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
32
+ actor_rollout_ref.actor.fsdp_config.param_offload=True \
33
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
34
+ actor_rollout_ref.actor.use_kl_loss=False \
35
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
36
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
37
+ actor_rollout_ref.rollout.name=vllm \
38
+ actor_rollout_ref.rollout.n=4 \
39
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
40
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
41
+ algorithm.adv_estimator=rloo \
42
+ algorithm.use_kl_in_reward=True \
43
+ algorithm.kl_penalty=kl \
44
+ algorithm.kl_ctrl.kl_coef=0.001 \
45
+ reward_model.model.path=$model_path \
46
+ reward_model.micro_batch_size_per_gpu=1 \
47
+ reward_model.model.update=before \
48
+ reward_model.model.beta_train=0.05 \
49
+ reward_model.model.optim.lr=1e-6 \
50
+ reward_model.model.optim.grad_clip=10.0 \
51
+ reward_model.model.input_tokenizer=null \
52
+ reward_model.mini_batch_size=64 \
53
+ trainer.val_before_train=False \
54
+ trainer.logger='["console","wandb"]' \
55
+ trainer.project_name='prime_example' \
56
+ trainer.experiment_name='Eurus-2-7B-SFT-code' \
57
+ trainer.n_gpus_per_node=8 \
58
+ trainer.nnodes=1 \
59
+ trainer.save_freq=64 \
60
+ trainer.test_freq=64 \
61
+ trainer.total_epochs=15 $@
ICL/DAPO/verl-recipe/r1/run_r1_distill_qwen.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL_PATH=Qwen/DeepSeek-R1-Distill-Qwen-1.5B
2
+ DATA_PATH=/workspace/datasets/r1_bench
3
+
4
+ # Eval Data Process
5
+ python3 -m recipe.r1.data_process \
6
+ --local_dir $DATA_PATH \
7
+ --tasks all
8
+
9
+ # Generation
10
+ python3 -m verl.trainer.main_generation \
11
+ trainer.nnodes=1 \
12
+ trainer.n_gpus_per_node=8 \
13
+ data.path=$DATA_PATH/test.parquet \
14
+ data.prompt_key=prompt \
15
+ data.batch_size=1024 \
16
+ data.n_samples=8 \
17
+ data.output_path=$DATA_PATH/test-output-8.parquet \
18
+ model.path=$MODEL_PATH \
19
+ rollout.temperature=0.6 \
20
+ rollout.top_p=0.95 \
21
+ rollout.prompt_length=1024 \
22
+ rollout.response_length=32768 \
23
+ rollout.tensor_model_parallel_size=1 \
24
+ rollout.gpu_memory_utilization=0.9 \
25
+ rollout.max_num_batched_tokens=65536
26
+
27
+ # Evaluation
28
+ python3 -m recipe.r1.main_eval \
29
+ data.path=$DATA_PATH/test-output-8.parquet \
30
+ data.prompt_key=prompt \
31
+ data.response_key=responses \
32
+ custom_reward_function.path=recipe/r1/reward_score.py \
33
+ custom_reward_function.name=reward_func
ICL/DAPO/verl-recipe/r1_ascend/Dockerfile.vllm_ascend.mindspeed.deepseekV3 ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ FROM quay.io/ascend/cann:8.2.rc1-a3-openeuler22.03-py3.11
16
+
17
+ ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple"
18
+ ARG COMPILE_CUSTOM_KERNELS=1
19
+
20
+ # Define environments
21
+ ENV DEBIAN_FRONTED=noninteractive
22
+ ENV COMPILE_CUSTOM_KERNELS=${COMPILE_CUSTOM_KERNELS}
23
+
24
+ RUN yum install -y patch
25
+
26
+ WORKDIR /workspace
27
+
28
+ RUN pip config set global.index-url ${PIP_INDEX_URL}
29
+
30
+ # Install torch and torch-npu
31
+ RUN python3 -m pip install torch==2.5.1 torch-npu==2.5.1.post1
32
+
33
+ # Compile/Install apex
34
+ RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh && \
35
+ source /usr/local/Ascend/nnal/atb/set_env.sh && \
36
+ source /usr/local/Ascend/nnal/asdsip/set_env.sh && \
37
+ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \
38
+ git clone -b master https://gitcode.com/ascend/apex.git && \
39
+ cd apex/ && bash scripts/build.sh --python=3.11 && \
40
+ cd apex/dist/ && \
41
+ python3 -m pip install --upgrade apex-0.1+ascend-*.whl
42
+
43
+ # verl
44
+ RUN git clone https://github.com/volcengine/verl.git
45
+
46
+ # MindSpeed
47
+ RUN git clone https://gitcode.com/Ascend/MindSpeed.git && \
48
+ cd MindSpeed && \
49
+ git checkout f6688 && \
50
+ pip install -r requirements.txt && \
51
+ cp -r mindspeed ../verl
52
+
53
+ # Install vLLM
54
+ RUN git clone https://github.com/vllm-project/vllm.git && \
55
+ cd vllm && \
56
+ git checkout v0.9.1 && \
57
+ cp -r vllm ../verl
58
+ # In x86, triton will be installed by vllm. But in Ascend, triton doesn't work correctly. we need to uninstall it.
59
+ RUN VLLM_TARGET_DEVICE="empty" python3 -m pip install -e /workspace/vllm/ --extra-index https://download.pytorch.org/whl/cpu/ && \
60
+ python3 -m pip uninstall -y triton && \
61
+ python3 -m pip cache purge
62
+
63
+ # Install vllm-ascend
64
+ RUN git clone https://github.com/vllm-project/vllm-ascend.git && \
65
+ cd vllm-ascend && \
66
+ git checkout 8c7bc45 && \
67
+ cp -r vllm_ascend ../verl
68
+
69
+ # Append `libascebd_hal.so` path (devlib) to LD_LIBRARY_PATH
70
+ RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh && \
71
+ source /usr/local/Ascend/nnal/atb/set_env.sh && \
72
+ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \
73
+ export CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/usr/include/c++/12:/usr/include/c++/12/`uname -i`-openEuler-linux && \
74
+ python3 -m pip install -v -e /workspace/vllm-ascend/ --exists-action=i --extra-index https://download.pytorch.org/whl/cpu/ && \
75
+ python3 -m pip cache purge
76
+
77
+ # Install modelscope (for fast download) and ray (for multinode) and Megatron-LM and others
78
+ RUN python3 -m pip install modelscope ray cache purge "transformers<4.54.0" mathruler cbor2 && \
79
+ pip install pybase64 fastapi zmq uvicorn openai msgspec blake3 py-cpuinfo gguf openai-harmony && \
80
+ pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1
81
+
82
+ CMD ["/bin/bash"]
ICL/DAPO/verl-recipe/r1_ascend/README.md ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepSeek-R1-Zero on Ascend NPU
2
+ This recipe provides a sample for fine-tuning the Deepseek-V3-Base model using Reinforcement Learning from Human Feedback (RLHF) on Ascend NPUs, specifically utilizing the GRPO algorithm with rule-based rewards on the deepscaler dataset.
3
+
4
+ ## Implementation Details
5
+ To implement RL training for the DeepSeek model on Ascend NPUs, this example includes the following key code additions and modifications:
6
+ - We implemented a simple rule-based reward function in `deepscaler.py`, referencing `verl/utils/reward_score/gsm8k.py`.
7
+ - We provided a dataset file conversion script `json_to_parquet.py`, which adds a template to the prompts to stimulate model thinking during the data file format conversion.
8
+ - Due to potential incomplete memory offloading during sleep operations for vLLM on NPUs, we added patches to manually handle the offloading and onloading of the rollout model and KVcache on NPUs. The related code is in `vllm_rollout_spmd.py` and `megatron_workers.py`.
9
+ - To enable vLLM to utilize all ranks for expert parallelism, support for vLLM's data parallelism was necessary. For this purpose, we added patches to construct the correct data parallel communication group. The related code is in `vllm_parallel_state.py` and `vllm_rollout_spmd.py`. Additionally, the `VLLM_DP_SIZE` environment variable must be correctly set to `world_size / vllm_tp_size`.
10
+ - The MindSpeed training framework for NPUs invalidates torch.compile to avoid compilation failures during training, but this prevents its use for accelerating inference. To resolve this, we added patches that allow compilation during inference but not during training. The related code is in `megatron_workers.py`.
11
+ - During RL training, multiple KV cache scheduling operations in vLLM on NPUs could lead to inconsistent memory allocation causing memory trampling. The fix for this issue is patched in `engine_core.py`.
12
+
13
+ By searching globally for `# NPU-ADAPTATION`, you can see the actual changes made by the patch code.
14
+
15
+ For more technical details, please refer to [the Technical Report (in Chinese)](https://gitcode.com/cann/cann-recipes-train/blob/master/docs/deepseek/deepseek_rl_train_optimization.md).
16
+
17
+ ## Training Details
18
+ ### Hyperparameters
19
+ This example fine-tunes the DeepSeek-671B Base model on the deepscaler dataset using a combination of simple format rewards and answer accuracy rewards. The key hyperparameters are as follows:
20
+
21
+ | iteration | learning rate | global batchsize | n_samples | temperature | kl-coef | prompt_max_len | response_max_len | rule reward | reward model |
22
+ |:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
23
+ | 70 | 1e-6 (constant) | 512 | 16 | 1.0 | 0.001 | 1024 | 2048 | format + acc | - |
24
+
25
+ ### Resource Allocation and Performance
26
+ This recipe was trained on an Ascend Atlas 800T A3 hyper-node server, utilizing 128 A3 NPUs, which is equivalent to 256 accelerator ranks. The specific deployment strategy is as follows:
27
+
28
+ | Rollout Deployment | Actor Deployment | Reference Deployment | Offload Strategy |
29
+ |:----:|:----:|:----:|:----:|
30
+ | TP2 EP256 | EP32 PP8 | Same as Actor | Full offload, optimizer utilizes the [Mindspeed Swap Optimizer feature](https://gitee.com/ascend/MindSpeed/blob/master/docs/features/swap-optimizer.md) |
31
+
32
+ The performance metrics for one training step are shown below (throughput varies with the model's response length during training):
33
+
34
+ | step | prompt_len_mean | response_len_mean | timing_step (s) | throughput (tps/A3) | timing_gen (s) | timing_reward (s) | timing_old_prob (s) | timing_ref_prob (s) | timing_update (s) |
35
+ |:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
36
+ | 2 | 175.1 | 1385.0 | 1044.8 | 95.5 | 482.2 | 20.4 | 105.5 | 92.7 | 342.9 |
37
+
38
+ ### Training Metrics
39
+ <div align="center">
40
+ <img src="./figures/rewards.png" width="33%" />
41
+ <img src="./figures/response_len.png" width="33%" />
42
+ <img src="./figures/val_score.png" width="33%" />
43
+ </div>
44
+
45
+ ## Quick Start
46
+
47
+ ### Environment Setup
48
+ For setting up the Ascend NPU environment for verl, please refer to [ascend_quick_start.rst (in Chinese)](../../docs/ascend_tutorial/ascend_quick_start.rst).
49
+
50
+ Alternatively, you can use the provided Dockerfile to build the project's runtime environment locally: `docker build -f Dockerfile.vllm_ascend.mindspeed.deepseekV3 -t REPOSITORY:TAG ./`
51
+
52
+ Prepare the source code with the following steps:
53
+ ```bash
54
+ # Clone verl
55
+ git clone https://github.com/volcengine/verl.git
56
+
57
+ # Clone and setup vLLM (v0.9.1)
58
+ git clone https://github.com/vllm-project/vllm.git
59
+ cd vllm
60
+ git checkout v0.9.1
61
+ cp -r vllm ../verl
62
+ cd ..
63
+
64
+ # Clone and setup vLLM-Ascend (commit 8c7bc45)
65
+ git clone https://github.com/vllm-project/vllm-ascend.git
66
+ cd vllm-ascend
67
+ git checkout 8c7bc45
68
+ cp -r vllm_ascend ../verl
69
+ cd ..
70
+
71
+ # Clone and setup MindSpeed (commit f6688)
72
+ git clone https://gitcode.com/Ascend/MindSpeed.git
73
+ cd MindSpeed
74
+ git checkout f6688
75
+ cp -r mindspeed ../verl
76
+ cd ..
77
+
78
+ # Install Megatron-LM.core and other dependencies
79
+ pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1
80
+ pip install mathruler
81
+ ```
82
+
83
+ ### Prepare the Training Dataset
84
+ This example uses the deepscaler dataset. Prepare it as follows:
85
+ - Download the dataset [JSON file](https://huggingface.co/datasets/agentica-org/DeepScaleR-Preview-Dataset/blob/main/deepscaler.json).
86
+ - Generate the `train.parquet` and `test.parquet` files and place them in the `./data/deepscaler` directory:
87
+ ```bash
88
+ # Execute from the verl project directory
89
+ python recipe/r1_ascend/json_to_parquet.py --output_dir ./data/deepscaler --json_path path/to/deepscaler.json --train_data_ratio 0.9
90
+ ```
91
+
92
+ The processed prompts used during training will include a specific template, for example: `A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Put your final answer within \boxed{}. <|User|>{problem}<|Assistant|>`
93
+
94
+ ### Prepare Model Weights
95
+ Prepare the DeepSeek-V3-Base model weights as follows:
96
+ - Place the model configuration files (excluding the weights) into the `./DeepSeek-V3-hf` directory. The `config.json` file needs to be replaced to remove quantization and MTP configurations. Refer to [this link (in Chinese)](https://gitcode.com/cann/cann-recipes-train/blob/master/rl_train/deepseek/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87) for details.
97
+ - Download the FP8 model weights from [HuggingFace](https://huggingface.co/deepseek-ai/DeepSeek-V3-Base) or [ModelScope](https://www.modelscope.cn/models/deepseek-ai/DeepSeek-V3-Base). Ensure the target disk has over 650GB of free space.
98
+ - Convert the FP8 weights to BF16 weights. Refer to [this link (in Chinese)](https://gitcode.com/cann/cann-recipes-train/blob/master/rl_train/deepseek/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87) for instructions. This step requires over 1300GB of free space on the target disk.
99
+
100
+ This example uses pre-sharded distributed weights. Therefore, the following weight sharding step is also required:
101
+ - The distributed weights will be stored in `ckpts/DeepseekV3-dist-ckpts`.
102
+ - Use the script `verl/scripts/converter_hf_to_mcore.py` to shard the original BF16 weights into distributed weights. In practice, we found that 2TB of CPU RAM was insufficient for sharding the 671B model. Therefore, we adapted this script for expert parallelism and performed the weight sharding using a distributed strategy of EP8 PP8 across 64 NPUs.
103
+
104
+ ### Other Code Modifications
105
+ In practice, to achieve the above results for on-policy RL training, we need to replace the `old_log_prob = data["old_log_probs"]` code in `verl/workers/actor/megatron_actor.py` with:
106
+
107
+ ```python
108
+ on_policy = self.config.ppo_epochs == 1
109
+ if on_policy:
110
+ old_log_prob = log_prob.detach() # guarantee exact numerical equality
111
+ else:
112
+ old_log_prob = data["old_log_probs"]
113
+ ```
114
+
115
+ ### Execute RL Fine-tuning
116
+ ```bash
117
+ # Start the RL fine-tuning for DeepSeekV3 from the verl directory
118
+ bash ./recipe/r1_ascend/ray_start_grpo_npu.sh
119
+ ```
ICL/DAPO/verl-recipe/r1_ascend/README_zh.md ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepSeek-R1-Zero on Ascend NPU
2
+ 本recipe是基于Deepseek-V3-Base模型在NPU上进行RLHF后训练的样例,基于GRPO与规则奖励,使用deepscaler数据集。
3
+
4
+ ## 实现细节
5
+ 为了在Ascend NPU上实现DeepSeek模型的RL训练,本样例中补充了一些代码,如下所示:
6
+ - 我们参考`verl/utils/reward_score/gsm8k.py`,在`deepscaler.py`中实现了一个简单的规则奖励函数。
7
+ - 我们提供了数据集文件转换脚本`json_to_parquet.py`,在数据文件格式转换的同时给prompt增加了激发模型思考的模板。
8
+ - NPU上vLLM的sleep可能存在内存卸载不干净的问题,因此添加了一些patch,手动实现NPU上Rollout模型与KVcache的卸载与加载。相关代码在`vllm_rollout_spmd.py`以及 `megatron_workers.py`中。
9
+ - 为了实现vLLM利用所有卡进行专家并行,需要支持vLLM的数据并行。为此添加了一些patch构建正确的DP通信域。相关代码在`vllm_parallel_state.py`以及`vllm_rollout_spmd.py`中。此外还需要正确配置`VLLM_DP_SIZE`环境变量为`world_size / vllm_tp_size`。
10
+ - NPU的MindSpeed训练框架会将torch.compile无效化来规避训练侧的compile失败,但这会使推理侧无法利用torch.compile加速。为了解决该问题,本样例添加了一些patch,使推理时可以compile,训练时不compile。相关代码`megatron_workers.py`中。
11
+ - RL训练过程中,NPU上vLLM多次KVcache调度可能引发申请内存不一致导致内存踩踏问题,修复patch在`engine_core.py`中。
12
+
13
+ 通过全局搜索`# NPU-ADAPTATION`,可以看到patch代码所做的实际改动。
14
+
15
+ 更多技术细节可参考[技术报告](https://gitcode.com/cann/cann-recipes-train/blob/master/docs/deepseek/deepseek_rl_train_optimization.md)。
16
+
17
+ ## 训练细节
18
+ ### 训练超参
19
+
20
+ 本样例基于DeepSeek-671B Base模型在deepscaler数据集上训练,使用简单的格式奖励和结果准确率奖励,训练超参如下:
21
+
22
+ | 迭代 | 学习率 | gbs | 采样数 | 温度 | kl-coef | 输入长度 | 输出长度 | 规则奖励 | 奖励模型 |
23
+ |:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
24
+ | 70 | 1e-6 (constant) | 512 | 16 | 1.0 | 0.001 | 1024 | 2048 | format + acc | - |
25
+
26
+ ### 训练资源与性能
27
+ 本样例在昇腾Atlas 800T A3超节点服务器上进行训练,使用了128张A3 NPU,等效于256张加速卡。具体的部署方式如下:
28
+
29
+ | Rollout部署 | Actor部署 | Reference部署 | Offload策略 |
30
+ |:----:|:----:|:----:|:----:|
31
+ | TP2 EP256 | EP32 PP8 | 同Actor | 全offload,优化器使用[Mindspeed卸载特性](https://gitee.com/ascend/MindSpeed/blob/master/docs/features/swap-optimizer.md) |
32
+
33
+ 得到一步的训练性能如下(吞吐会随着训练中模型输出长度变化而改变):
34
+ | step | 平均问题长度 | 平均回复长度 | 单步总耗时(s) | 吞吐(tps/A3) | gen耗时(s) | reward耗时(s) | old_prob耗时(s) | ref_prob耗时(s) | update耗时(s) |
35
+ |:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|
36
+ | 2 | 175.1 | 1385.0 | 1044.8 | 95.5 | 482.2 | 20.4 | 105.5 | 92.7 | 342.9 |
37
+
38
+ ### 训练过程记录
39
+ <div align="center">
40
+ <img src="./figures/rewards.png" width="33%" />
41
+ <img src="./figures/response_len.png" width="33%" />
42
+ <img src="./figures/val_score.png" width="33%" />
43
+ </div>
44
+
45
+ ## 快速开始
46
+
47
+ ### 环境准备
48
+ verl上的NPU环境准备,可参考[ascend_quick_start.rst](../../docs/ascend_tutorial/ascend_quick_start.rst)进行配置。
49
+
50
+ 此外,也可使用我们提供的Dockerfile在本地构建项目运行环境:`docker build -f Dockerfile.vllm_ascend.mindspeed.deepseekV3 -t REPOSITORY:TAG ./`
51
+
52
+ 本样准备源码的步骤如下:
53
+ ```bash
54
+ # verl
55
+ git clone https://github.com/volcengine/verl.git
56
+
57
+ # vLLM (v0.9.1)
58
+ git clone https://github.com/vllm-project/vllm.git
59
+ cd vllm
60
+ git checkout v0.9.1
61
+ cp -r vllm ../verl
62
+ cd ..
63
+
64
+ # vLLM-Ascend (v0.9.1-dev)
65
+ git clone https://github.com/vllm-project/vllm-ascend.git
66
+ cd vllm-ascend
67
+ git checkout 8c7bc45
68
+ cp -r vllm_ascend ../verl
69
+ cd ..
70
+
71
+ # MindSpeed (commit-id: f6688)
72
+ git clone https://gitcode.com/Ascend/MindSpeed.git
73
+ cd MindSpeed
74
+ git checkout f6688
75
+ cp -r mindspeed ../verl
76
+ cd ..
77
+
78
+ # Megatron-LM.core and others
79
+ pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1
80
+ pip install mathruler
81
+ ```
82
+
83
+ ### 准备训练数据集
84
+ 本样例使用deepscaler数据集。准备方式如下:
85
+ - 下载数据集[json文件](https://huggingface.co/datasets/agentica-org/DeepScaleR-Preview-Dataset/blob/main/deepscaler.json)。
86
+ - 获取`train.parquet`与`test.parquet`文件并放入`./data/deepscaler`路径:
87
+
88
+ ```bash
89
+ # 在verl项目目录执行
90
+ python recipe/r1_ascend/json_to_parquet.py --output_dir ./data/deepscaler --json_path path/to/deepscaler.json --train_data_ratio 0.9
91
+ ```
92
+
93
+ 训练中经过处理的prompt将包含模板,例如:`A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Put your final answer within \boxed{}. <|User|>{problem}<|Assistant|>`
94
+
95
+ ### 准备模型权重
96
+ DeepSeek-V3-Base模型权重准备步骤如下:
97
+ - 需要将模型配置相关文件(不含权重)放入`./DeepSeek-V3-hf`目录,并且`config.json`需要进行替换以去除量化和MTP。该步骤可参考[此链接](https://gitcode.com/cann/cann-recipes-train/blob/master/rl_train/deepseek/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87)。
98
+ - 模型FP8权重下载:[HuggingFace地址](https://huggingface.co/deepseek-ai/DeepSeek-V3-Base),[ModelScope地址](https://www.modelscope.cn/models/deepseek-ai/DeepSeek-V3-Base)。此步骤需要目录所在磁盘有650GB以上空间。
99
+ - 将FP8权重转为BF16权重,可参考[此链接](https://gitcode.com/cann/cann-recipes-train/blob/master/rl_train/deepseek/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87)。此步骤需要目录所在磁盘有1300GB以上空间。
100
+
101
+ 本样例使用了预先切分的分布式权重,因此还要执行以下的切分权重操作:
102
+ - 分布式权重需存储至`ckpts/DeepseekV3-dist-ckpts`。
103
+ - 使用`verl/scripts/converter_hf_to_mcore.py`对原始的BF16权重切分得到分布式权重。实践中我们发现2T的CPU内存不足以完成671B模型的权重切分处理,为此我们对该脚本进行了专家并行的适配,并在64块NPU上用EP8 PP8分布式策略对权重进行了切分。
104
+
105
+ ### 其他代码修改
106
+ 实践中为了得到以上on-policy训练的结果,我们将 `verl/workers/actor/megatron_actor.py` 中的代码段 `old_log_prob = data["old_log_probs"]` 替换为如下代码:
107
+ ```python
108
+ on_policy = self.config.ppo_epochs == 1
109
+ if on_policy:
110
+ old_log_prob = log_prob.detach() # 确保二者数值完全相等
111
+ else:
112
+ old_log_prob = data["old_log_probs"]
113
+ ```
114
+
115
+ ### 执行RL后训练
116
+ ```bash
117
+ # verl目录下启动DeepSeekV3的RL后训练
118
+ bash ./recipe/r1_ascend/ray_start_grpo_npu.sh
119
+ ```
ICL/DAPO/verl-recipe/r1_ascend/ray_start_grpo_npu.sh ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ray stop --force
2
+
3
+ export RAY_DEDUP_LOGS=0 # 0: disable ray's log folding 1: enable ray's log folding
4
+ export HYDRA_FULL_ERROR=1 # display the accurate error stack
5
+
6
+ ulimit -n 32768
7
+ mkdir logs
8
+
9
+ NNODES=16 # number of nodes
10
+ NPUS_PER_NODE=16 # the number of npus for each node
11
+ export WORLD_SIZE=$(($NNODES*$NPUS_PER_NODE))
12
+
13
+ RAY_START_PORT=6766
14
+ RAY_DASHBOARD_PORT=8260
15
+
16
+ MASTER_ADDR="IP FOR MASTER NODE" # modify it to correspond to the IP of the master node
17
+ SOCKET_IFNAME="SOCKET IFNAME FOR CURRENT NODE" # modify it to the communication network card of the current node
18
+ # obtain the current node IP
19
+ CURRENT_IP=$(ifconfig $SOCKET_IFNAME | grep -Eo 'inet (addr:)?([0-9]{1,3}\.){3}[0-9]{1,3}' | awk '{print $NF}')
20
+ export MASTER_PORT=29444
21
+ export HCCL_IF_BASE_PORT=64247
22
+ export TP_SOCKET_IFNAME=$SOCKET_IFNAME
23
+ export HCCL_SOCKET_IFNAME=$SOCKET_IFNAME
24
+ export GLOO_SOCKET_IFNAME=$SOCKET_IFNAME
25
+
26
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
27
+ export PYTORCH_NPU_ALLOC_CONF="expandable_segments:True"
28
+ export TASK_QUEUE_ENABLE=2 # enable level2 optimization of the sent queue of the ascend operator
29
+ export HCCL_BUFFSIZE=300 # the buffer size of HCCL
30
+
31
+ export HCCL_CONNECT_TIMEOUT=600
32
+ export HCCL_EXEC_TIMEOUT=600
33
+
34
+ export ASCEND_LAUNCH_BLOCKING=0 # debug usage, which seriously affects performance after use, but the error stack is accurate
35
+
36
+ export VLLM_USE_V1=1 # use the V1 engine of vLLM
37
+ export VLLM_ENABLE_GRAPH_MODE=1 # enable vLLM graph mode
38
+ export HCCL_OP_EXPANSION_MODE=AIV # enable the communication mode of AIV
39
+ export VLLM_ENABLE_MC2=1 # enable MC2 communication
40
+ export VLLM_DP_SIZE=128 # configure the DP size of vLLM, this is related to the vLLM instance num
41
+
42
+ # under the configuration of the vLLM log level of INFO, enable this configuration, print the time of prefill and decode
43
+ export VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE=0
44
+
45
+ if [ "$MASTER_ADDR" = "$CURRENT_IP" ]; then
46
+ # the master node starts
47
+ ray start --head --port=$RAY_START_PORT --dashboard-host=0.0.0.0 --node-ip-address=$CURRENT_IP --dashboard-port=$RAY_DASHBOARD_PORT --resources='{"NPU": '$NPUS_PER_NODE'}'
48
+
49
+ while true; do
50
+ ray_status_output=$(ray status)
51
+ npu_count=$(echo "$ray_status_output" | grep -oP '(?<=/)\d+\.\d+(?=\s*NPU)' | head -n 1)
52
+ npu_count_int=$(echo "$npu_count" | awk '{print int($1)}')
53
+ device_count=$((npu_count_int / $NPUS_PER_NODE))
54
+
55
+ # determine whether device_count is equal to NNODES
56
+ if [ "$device_count" -eq "$NNODES" ]; then
57
+ echo "Ray cluster is ready with $device_count devices (from $npu_count NPU resources), starting Python script."
58
+ ray status
59
+ bash ./recipe/r1_ascend/run_deepseekv3_671b_grpo_megatron_npu.sh
60
+ break
61
+ else
62
+ echo "Waiting for Ray to allocate $NNODES devices. Current device count: $device_count"
63
+ sleep 5
64
+ fi
65
+ done
66
+ else
67
+ # the child node attempts to register ray with the master node until successful
68
+ while true; do
69
+ # try to connect to the Ray cluster
70
+ ray start --address="$MASTER_ADDR:$RAY_START_PORT" --resources='{"NPU": '$NPUS_PER_NODE'}' --node-ip-address=$CURRENT_IP
71
+
72
+ # check if the connection is successful
73
+ ray status
74
+ if [ $? -eq 0 ]; then
75
+ echo "Successfully connected to the Ray cluster!"
76
+ break
77
+ else
78
+ echo "Failed to connect to the Ray cluster. Retrying in 5 seconds..."
79
+ sleep 5
80
+ fi
81
+ done
82
+ fi
ICL/DAPO/verl-recipe/r1_ascend/vllm_rollout_spmd.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # Adapted from
17
+ # https://github.com/volcengine/verl/blob/main/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
18
+
19
+ import logging
20
+ import os
21
+ from typing import Generator
22
+
23
+ import torch
24
+ import torch.distributed
25
+ from omegaconf import ListConfig
26
+ from torch.distributed.device_mesh import DeviceMesh
27
+ from vllm import LLM, SamplingParams
28
+ from vllm.config import CompilationConfig, CompilationLevel
29
+
30
+ from verl.third_party.vllm import VLLM_SLEEP_LEVEL
31
+ from verl.utils.device import get_device_name
32
+ from verl.utils.memory_utils import aggressive_empty_cache
33
+ from verl.workers.config import HFModelConfig, RolloutConfig
34
+ from verl.workers.rollout.vllm_rollout import vLLMRollout as vLLMRolloutBase
35
+
36
+ logger = logging.getLogger(__file__)
37
+ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
38
+
39
+
40
+ class vLLMRollout(vLLMRolloutBase):
41
+ def __init__(
42
+ self,
43
+ config: RolloutConfig,
44
+ model_config: HFModelConfig,
45
+ device_mesh: DeviceMesh,
46
+ ):
47
+ self.config = config
48
+ self.model_config = model_config
49
+ self.device_mesh = device_mesh
50
+ # NPU-ADAPTATION: import vLLM-Ascend patch
51
+ from recipe.r1_ascend import engine_core # noqa: F401
52
+ from vllm_ascend.patch import (
53
+ platform, # noqa: F401
54
+ worker, # noqa: F401
55
+ )
56
+ # NPU-ADAPTATION END
57
+
58
+ if config.layered_summon:
59
+ self.sleep_level = 1
60
+ else:
61
+ self.sleep_level = VLLM_SLEEP_LEVEL
62
+
63
+ model_path = model_config.local_path
64
+ tokenizer = model_config.tokenizer
65
+ model_hf_config = model_config.hf_config
66
+ trust_remote_code = model_config.trust_remote_code
67
+ self.lora_kwargs = (
68
+ {"enable_lora": True, "max_loras": 1, "max_lora_rank": model_config.lora_rank}
69
+ if model_config.lora_rank > 0
70
+ else {}
71
+ )
72
+
73
+ tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1)
74
+ assert tensor_parallel_size <= torch.distributed.get_world_size(), (
75
+ "tensor parallel size should be less than or equal to the world size"
76
+ )
77
+ max_num_batched_tokens = self.config.get("max_num_batched_tokens", 8192)
78
+
79
+ # NPU-ADAPTATION: VLLM_DP_SIZE is configured, the DP communication domain needs to be explicitly initialized
80
+ if int(os.environ.get("VLLM_DP_SIZE", "1")) > 1:
81
+ from recipe.r1_ascend.vllm_parallel_state import init_parallel_state
82
+
83
+ init_parallel_state(tensor_parallel_size)
84
+ # NPU-ADAPTATION END
85
+
86
+ rope_scaling_config = getattr(model_hf_config, "rope_scaling", None)
87
+ if not rope_scaling_config:
88
+ max_position_embeddings = None
89
+ if hasattr(model_hf_config, "max_position_embeddings"):
90
+ max_position_embeddings = model_hf_config.max_position_embeddings
91
+ elif hasattr(model_hf_config, "llm_config") and hasattr(
92
+ model_hf_config.llm_config, "max_position_embeddings"
93
+ ):
94
+ max_position_embeddings = model_hf_config.llm_config.max_position_embeddings
95
+ elif hasattr(model_hf_config, "text_config") and hasattr(
96
+ model_hf_config.text_config, "max_position_embeddings"
97
+ ):
98
+ max_position_embeddings = model_hf_config.text_config.max_position_embeddings
99
+ if max_position_embeddings is None:
100
+ raise ValueError("max_position_embeddings not found in model_hf_config")
101
+ assert max_position_embeddings >= config.prompt_length + config.response_length, (
102
+ "model context length should be greater than total sequence length"
103
+ )
104
+ else:
105
+ # handle type where there's a length extend factor
106
+ # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support
107
+ # for using yarn as an example
108
+ rope_scaling_factor = rope_scaling_config.get("factor", 1.0)
109
+
110
+ assert (
111
+ model_hf_config.max_position_embeddings * rope_scaling_factor
112
+ >= config.prompt_length + config.response_length
113
+ ), (
114
+ "model context length should be greater than total sequence length, "
115
+ + f"got rope_scaling_factor={rope_scaling_factor} and "
116
+ + f"max_position_embeddings={model_hf_config.max_position_embeddings}"
117
+ )
118
+
119
+ max_model_len = int(config.max_model_len or config.prompt_length + config.response_length)
120
+
121
+ load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format
122
+
123
+ # copy it to avoid secretly modifying the engine config
124
+ engine_kwargs = config.get("engine_kwargs", {}).get("vllm", {}) or {}
125
+
126
+ # For each vLLM engine parameter,
127
+ # - `None` means not setting it, so we pop it, and leave it to vLLM default value
128
+ # (which can vary across different vLLM versions);
129
+ # - Otherwise it's the desired value we want to explicitly set.
130
+ engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None}
131
+ if config.get("limit_images", None): # support for multi-image data
132
+ engine_kwargs["limit_mm_per_prompt"] = {"image": config.get("limit_images")}
133
+
134
+ compilation_config = {}
135
+
136
+ cudagraph_capture_sizes = config.get("cudagraph_capture_sizes")
137
+ # enforce_eager must be False to use cudagraph
138
+ if not config.enforce_eager and cudagraph_capture_sizes:
139
+ if isinstance(cudagraph_capture_sizes, ListConfig):
140
+ compilation_config["compilation_config"] = CompilationConfig(
141
+ level=CompilationLevel.PIECEWISE, cudagraph_capture_sizes=cudagraph_capture_sizes
142
+ )
143
+ else:
144
+ logger.warning(f"cudagraph_capture_sizes must be a list, but got {cudagraph_capture_sizes}")
145
+
146
+ VLLM_ENABLE_GRAPGH_MODE = int(os.environ.get("VLLM_ENABLE_GRAPH_MODE", "0"))
147
+ self.inference_engine = LLM(
148
+ model=model_path,
149
+ # NPU-ADAPTATION: Enable inference EP and disable sleep mode.
150
+ enable_sleep_mode=False,
151
+ enable_expert_parallel=True,
152
+ # NPU-ADAPTATION END
153
+ tensor_parallel_size=tensor_parallel_size,
154
+ distributed_executor_backend="external_launcher",
155
+ dtype=config.dtype,
156
+ enforce_eager=config.enforce_eager,
157
+ gpu_memory_utilization=config.gpu_memory_utilization,
158
+ disable_custom_all_reduce=True,
159
+ skip_tokenizer_init=False,
160
+ max_model_len=max_model_len,
161
+ max_num_seqs=config.max_num_seqs,
162
+ load_format=load_format,
163
+ disable_log_stats=config.disable_log_stats,
164
+ max_num_batched_tokens=max_num_batched_tokens,
165
+ enable_chunked_prefill=config.enable_chunked_prefill,
166
+ enable_prefix_caching=False,
167
+ trust_remote_code=trust_remote_code,
168
+ seed=config.get("seed", 0),
169
+ # NPU-ADAPTATION: Enable graph mode and configure the parameters.
170
+ additional_config={
171
+ "torchair_graph_config": {
172
+ "enabled": VLLM_ENABLE_GRAPGH_MODE,
173
+ "use_cached_graph": False,
174
+ "graph_batch_sizes_init": False,
175
+ "graph_batch_sizes": [config.max_num_seqs],
176
+ "enable_multistream_mla": False,
177
+ "enable_multistream_moe": False,
178
+ "enable_view_optimize": False,
179
+ "enable_kv_nz": False,
180
+ "enable_frozen_parameter": False,
181
+ },
182
+ "ascend_scheduler_config": {
183
+ "enabled": True,
184
+ },
185
+ "refresh": True,
186
+ },
187
+ # NPU-ADAPTATION END
188
+ **compilation_config,
189
+ **self.lora_kwargs,
190
+ **engine_kwargs,
191
+ )
192
+ # NPU-ADAPTATION: Weight onload and offload, and initialization configurations such as kv_cache.
193
+ self.model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.get_model()
194
+ self.kv_cache_configs = None
195
+ self.cpu_model = {}
196
+ self.gpu_buffers = None
197
+ for name, params in self.model.named_parameters():
198
+ self.cpu_model[name] = torch.empty_like(params, device="cpu")
199
+ # NPU-ADAPTATION END
200
+
201
+ kwargs = dict(
202
+ n=1,
203
+ logprobs=0, # can be set to 0 and let actor to recompute
204
+ max_tokens=config.response_length,
205
+ repetition_penalty=config.get("repetition_penalty", 1.0),
206
+ )
207
+
208
+ kwargs["detokenize"] = False
209
+
210
+ # supporting adding any sampling params from the config file
211
+ for k in config.keys():
212
+ if hasattr(SamplingParams(), str(k)) and k != "seed":
213
+ kwargs[k] = config.get(k)
214
+ kwargs["n"] = 1 # already repeat in ray_trainer
215
+ logger.info(f"vllm sampling kwargs: {kwargs}")
216
+ self.sampling_params = SamplingParams(**kwargs)
217
+
218
+ self.pad_token_id = tokenizer.pad_token_id
219
+
220
+ # NPU-ADAPTATION: Weight onload and offload, kv_cache init and free function
221
+ # NOTE: Due to potential incomplete memory offloading during sleep operations for vLLM on NPUs, we add
222
+ # patches to manually handle the off/on loading of the rollout model and KVcache on NPUs.
223
+ def init_cache_engine(self):
224
+ if os.environ["VLLM_USE_V1"] == "1":
225
+ worker = self.inference_engine.llm_engine.model_executor.driver_worker.worker
226
+ if not worker.model_runner.kv_caches:
227
+ # v1 use explicit initialization method
228
+ self.inference_engine.llm_engine.engine_core.engine_core.model_executor.initialize_from_config(
229
+ self.inference_engine.llm_engine.engine_core.engine_core.kv_cache_configs
230
+ )
231
+ self.inference_engine.llm_engine.reset_prefix_cache()
232
+ else:
233
+ if self.inference_engine.llm_engine.model_executor.driver_worker.worker.cache_engine is None:
234
+ self.inference_engine.llm_engine.model_executor.driver_worker.worker._init_cache_engine()
235
+
236
+ def onload_model_weights(self):
237
+ self.gpu_buffers = {}
238
+ for name, param in self.model.named_parameters():
239
+ self.gpu_buffers[name] = torch.empty_like(param, device=get_device_name())
240
+ for name, param in self.model.named_parameters():
241
+ param.data = self.gpu_buffers[name]
242
+
243
+ def offload_model_weights(self):
244
+ for name, params in self.model.named_parameters():
245
+ params.data = self.cpu_model[name]
246
+ if hasattr(self.model.model.layers[0].self_attn, "mla_attn"):
247
+ for i in range(self.model.model.start_layer, self.model.model.end_layer):
248
+ mla = self.model.model.layers[i].self_attn.mla_attn.impl
249
+ if hasattr(mla, "w_kc"):
250
+ mla.w_kc = None
251
+ mla.w_vc = None
252
+ if hasattr(mla, "W_UV"):
253
+ mla.W_UV = None
254
+ mla.W_UK_T = None
255
+
256
+ self.gpu_buffers = None
257
+ aggressive_empty_cache()
258
+
259
+ def free_cache_engine(self):
260
+ if os.environ["VLLM_USE_V1"] == "1":
261
+ worker = self.inference_engine.llm_engine.model_executor.driver_worker.worker
262
+ ctx = worker.model_runner.vllm_config.compilation_config.static_forward_context
263
+ else:
264
+ compilation_config = self.inference_engine.llm_engine.model_executor.driver_worker.worker.compilation_config
265
+ ctx = compilation_config.static_forward_context
266
+ from vllm.attention import AttentionType
267
+
268
+ layer_need_kv_cache = []
269
+ for layer_name in ctx:
270
+ if hasattr(ctx[layer_name], "attn_type") and ctx[layer_name].attn_type in (
271
+ AttentionType.DECODER,
272
+ AttentionType.ENCODER_DECODER,
273
+ ):
274
+ layer_need_kv_cache.append(layer_name)
275
+
276
+ pipeline_parallel_size = self.inference_engine.llm_engine.vllm_config.parallel_config.pipeline_parallel_size
277
+ for layer_name in layer_need_kv_cache:
278
+ kv_cache = []
279
+ for _ in range(pipeline_parallel_size):
280
+ kv_cache.append(torch.tensor([]))
281
+ ctx[layer_name].kv_cache = kv_cache
282
+ if os.environ["VLLM_USE_V1"] == "1":
283
+ worker = self.inference_engine.llm_engine.model_executor.driver_worker.worker
284
+
285
+ worker.model_runner.kv_caches = []
286
+ else:
287
+ self.inference_engine.llm_engine.model_executor.driver_worker.worker.cache_engine = None
288
+ self.inference_engine.llm_engine.model_executor.driver_worker.worker.gpu_cache = None
289
+
290
+ if hasattr(self.model.model.layers[0].self_attn, "attn"):
291
+ for i in range(self.model.model.start_layer, self.model.model.end_layer):
292
+ attn_impl = self.model.model.layers[i].self_attn.attn.impl
293
+ if hasattr(attn_impl, "key_cache"):
294
+ attn_impl.key_cache = None
295
+ attn_impl.value_cache = None
296
+
297
+ aggressive_empty_cache()
298
+
299
+ def _process_mla(self, load_weight=False):
300
+ for i in range(self.model.model.start_layer, self.model.model.end_layer):
301
+ mla = self.model.model.layers[i].self_attn.mla_attn.impl
302
+ if hasattr(mla, "w_kc"):
303
+ mla.w_kc = None
304
+ mla.w_vc = None
305
+ if hasattr(mla, "W_UV"):
306
+ mla.W_UV = None
307
+ mla.W_UK_T = None
308
+ if load_weight:
309
+ mla.process_weights_after_loading(None)
310
+
311
+ async def resume(self, tags: list[str]):
312
+ """Resume rollout weights or kv cache in NPU memory.
313
+
314
+ Args:
315
+ tags: weights or kv_cache.
316
+ """
317
+ if not self.config.free_cache_engine:
318
+ return
319
+
320
+ if "weights" in tags:
321
+ self.onload_model_weights()
322
+ elif "kv_cache" in tags:
323
+ self.init_cache_engine()
324
+
325
+ async def release(self):
326
+ """Release weights and kv cache in NPU memory."""
327
+ if not self.config.free_cache_engine:
328
+ return
329
+
330
+ self.free_cache_engine()
331
+ self.offload_model_weights()
332
+
333
+ if hasattr(self.model.model.layers[0].self_attn, "mla_attn"):
334
+ self._process_mla()
335
+
336
+ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs):
337
+ """Update the weights of the rollout model.
338
+
339
+ Args:
340
+ weights: A generator that yields the name of the weight tensor and the tensor itself.
341
+ """
342
+ await super().update_weights(weights, **kwargs)
343
+
344
+ if hasattr(self.model.model.layers[0].self_attn, "mla_attn"):
345
+ self._process_mla(load_weight=True)
346
+
347
+ # NPU-ADAPTATION END
ICL/DAPO/verl-recipe/rep_exp/README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # Representation-Based Exploration for Language Models: <br> From Test-Time to Post-Training
4
+
5
+ [📄 arXiv](https://arxiv.org/abs/2510.11686) &nbsp; &nbsp; [🌐 Website](https://rep-exp.github.io) &nbsp; &nbsp; [🐦 Twitter / X ](https://x.com/JensTuyls/status/1978244454617128993)
6
+
7
+ </div>
8
+
9
+ ## Installation 🔌
10
+
11
+ Install the following commit of verl:
12
+ ```
13
+ pip install verl@git+https://github.com/volcengine/verl.git@b9bd00efba253ea90072555c45692054cf703de2
14
+ ```
15
+
16
+ The only other package to install is scikit-learn, which we'll use for applying a sparse projection.
17
+ ```bash
18
+ pip install scikit-learn
19
+ ```
20
+
21
+ ## Running the Experiments 🚀
22
+
23
+ You can reproduce or extend our experiments by running the following commands:
24
+
25
+ ```bash
26
+ # General format
27
+ sh rep_exp/train_elliptical.sh $TASK $SPARSE_DIM $BETA $SEED
28
+
29
+ # MATH
30
+ sh rep_exp/train_elliptical.sh math 32 0.01 42
31
+
32
+ # GSM8K
33
+ sh rep_exp/train_elliptical.sh gsm8k 32 0.01 42
34
+
35
+ # DAPO-WITH-AIME
36
+ sh rep_exp/train_elliptical.sh dapo-with-aime24 128 0.01 42
37
+ ```
38
+ where `$TASK` is the task name, `$SPARSE_DIM` is the sparse dimension, `$BETA` is the beta parameter, and `$SEED` is the seed.
39
+
40
+ ## Evaluation 📊
41
+ Once done training, you can evaluate the model on the test set by following two steps.
42
+ 1. Merge the model checkpoint.
43
+
44
+ This is necessary because the model checkpoint is saved in multiple shards (depending on the nubmer of GPUs), and we need to merge them into a single checkpoint.
45
+
46
+ ```bash
47
+ sh rep_exp/model_merge.sh /path/to/global_step_X/actor # where X is the global step of the checkpoint with the best pass@1 on dev
48
+ ```
49
+
50
+ 2. Evaluate the merged model.
51
+
52
+ ```bash
53
+ sh rep_exp/eval.sh $TASK /path/to/global_step_X/actor/hf #where X is the global step of the checkpoint with the best pass@1 on dev
54
+ ```
55
+
56
+ The results should be in a folder named `eval` and saved as a JSON file.
57
+
58
+ ## Citation 📝
59
+
60
+ ```bibtex
61
+ @article{tuyls2025representation,
62
+ title={Representation-Based Exploration for Language Models: From Test-Time to Post-Training},
63
+ author={Tuyls, Jens and Foster, Dylan J and Krishnamurthy, Akshay and Ash, Jordan T},
64
+ journal={arXiv preprint arXiv:2510.11686},
65
+ year={2025}
66
+ }
67
+ ```
68
+
69
+ ## Contact 📬
70
+
71
+ If you have any questions or suggestions, feel free to reach out at [jtuyls@princeton.edu](mailto:jtuyls@princeton.edu).
ICL/DAPO/verl-recipe/rep_exp/eval.sh ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK=${1} # math, gsm8k, dapo-with-aime24
2
+
3
+ # Custom model path for evaluation after training
4
+ MODEL_PATH=${2} # /path/to/global_step_X/actor/hf, where X is the global step of the checkpoint with the best pass@1 on dev
5
+
6
+ # If you want to evaluate the base model before training
7
+ # MODEL_PATH=Qwen/Qwen2.5-7B-Instruct
8
+
9
+ train_path=$HOME/data/${TASK}/train.parquet
10
+ train_files="['$train_path']"
11
+ CHECKPOINT_SAVE_CONTENTS='["model"]'
12
+
13
+ if [ ${TASK} == "dapo-with-aime24" ]; then
14
+ MAX_PROMPT_LENGTH=$((1024 * 2))
15
+ MAX_RESPONSE_LENGTH=$((1024 * 8))
16
+ MAX_NUM_BATCHED_TOKENS=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))
17
+ test_path=$HOME/data/${TASK}/dev.parquet
18
+ else
19
+ MAX_PROMPT_LENGTH=1024
20
+ MAX_RESPONSE_LENGTH=1024
21
+ MAX_NUM_BATCHED_TOKENS=8192
22
+ test_path=$HOME/data/${TASK}/test.parquet
23
+ fi
24
+
25
+ test_files="['$test_path']"
26
+
27
+ # If you're on a cluster with no internet access, set to OFFLINE=True
28
+ OFFLINE=False
29
+
30
+ PYTHONUNBUFFERED=1 WANDB_MODE=disabled TRANSFORMERS_OFFLINE=${OFFLINE} python3 -u -m rep_exp.main_rep_exp \
31
+ algorithm.adv_estimator=grpo \
32
+ data.train_files="$train_files" \
33
+ data.val_files="$test_files" \
34
+ data.train_batch_size=1024 \
35
+ data.max_prompt_length=$MAX_PROMPT_LENGTH \
36
+ data.max_response_length=$MAX_RESPONSE_LENGTH \
37
+ data.filter_overlong_prompts=True \
38
+ data.truncation='error' \
39
+ data.val_batch_size=128 \
40
+ actor_rollout_ref.model.path="$MODEL_PATH" \
41
+ actor_rollout_ref.actor.checkpoint.save_contents=$CHECKPOINT_SAVE_CONTENTS \
42
+ actor_rollout_ref.actor.optim.lr=1e-6 \
43
+ actor_rollout_ref.model.use_remove_padding=True \
44
+ actor_rollout_ref.actor.ppo_mini_batch_size=256 \
45
+ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
46
+ actor_rollout_ref.actor.kl_loss_coef=0.0 \
47
+ actor_rollout_ref.actor.kl_loss_type=low_var_kl \
48
+ actor_rollout_ref.actor.entropy_coeff=0 \
49
+ actor_rollout_ref.actor.ppo_epochs=1 \
50
+ actor_rollout_ref.model.enable_gradient_checkpointing=True \
51
+ actor_rollout_ref.actor.fsdp_config.param_offload=False \
52
+ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
53
+ actor_rollout_ref.rollout.mode=sync \
54
+ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
55
+ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
56
+ actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_BATCHED_TOKENS \
57
+ actor_rollout_ref.rollout.name=vllm \
58
+ actor_rollout_ref.rollout.gpu_memory_utilization=0.45 \
59
+ actor_rollout_ref.rollout.val_kwargs.n=256 \
60
+ actor_rollout_ref.rollout.n=8 \
61
+ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
62
+ actor_rollout_ref.ref.fsdp_config.param_offload=True \
63
+ reward_model.model.path="$MODEL_PATH" \
64
+ reward_model.model.use_remove_padding=False \
65
+ reward_model.model.fsdp_config.param_offload=True \
66
+ reward_model.micro_batch_size_per_gpu=32 \
67
+ reward_model.model.input_tokenizer=null \
68
+ actor_rollout_ref.actor.use_kl_loss=False \
69
+ algorithm.use_kl_in_reward=False \
70
+ trainer.critic_warmup=0 \
71
+ trainer.logger='["console","json_eval"]' \
72
+ trainer.project_name='rep-exp' \
73
+ trainer.experiment_name="${TASK}_eval" \
74
+ trainer.n_gpus_per_node=1 \
75
+ trainer.nnodes=1 \
76
+ trainer.save_freq=-1 \
77
+ trainer.test_freq=1 \
78
+ trainer.total_epochs=100 \
79
+ trainer.val_only=True \
80
+ trainer.resume_mode=disable \
81
+ trainer.resume_from_path=''
82
+
83
+ exit 0
ICL/DAPO/verl-recipe/rep_exp/main_rep_exp.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
16
+ """
17
+
18
+ import os
19
+ import socket
20
+ import warnings
21
+
22
+ import hydra
23
+ import ray
24
+ from omegaconf import OmegaConf
25
+
26
+ from verl.experimental.dataset.sampler import AbstractSampler
27
+ from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
28
+ from verl.trainer.ppo.reward import load_reward_manager
29
+ from verl.trainer.ppo.utils import need_critic, need_reference_policy
30
+ from verl.utils.config import validate_config
31
+ from verl.utils.device import is_cuda_available
32
+ from verl.utils.import_utils import load_extern_type
33
+
34
+ from .rep_exp_trainer import RayRepExpTrainer
35
+
36
+
37
+ @hydra.main(config_path="config", config_name="rep_exp_trainer", version_base=None)
38
+ def main(config):
39
+ """Main entry point for PPO training with Hydra configuration management.
40
+
41
+ Args:
42
+ config_dict: Hydra configuration dictionary containing training parameters.
43
+ """
44
+ run_ppo(config)
45
+
46
+
47
+ # Define a function to run the PPO-like training process
48
+ def run_ppo(config, task_runner_class=None) -> None:
49
+ """Initialize Ray cluster and run distributed PPO training process.
50
+
51
+ Args:
52
+ config: Training configuration object containing all necessary parameters
53
+ for distributed PPO training including Ray initialization settings,
54
+ model paths, and training hyperparameters.
55
+ task_runner_class: For recipe to change TaskRunner.
56
+ """
57
+ # Check if Ray is not initialized
58
+ if not ray.is_initialized():
59
+ # Initialize Ray with a local cluster configuration
60
+ # Set environment variables in the runtime environment to control tokenizer parallelism,
61
+ # NCCL debug level, VLLM logging level, and allow runtime LoRA updating
62
+ # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration
63
+ default_runtime_env = get_ppo_ray_runtime_env()
64
+ ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
65
+ runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
66
+
67
+ if config.transfer_queue.enable:
68
+ # Add runtime environment variables for transfer queue
69
+ runtime_env_vars = runtime_env_kwargs.get("env_vars", {})
70
+ runtime_env_vars["TRANSFER_QUEUE_ENABLE"] = "1"
71
+ runtime_env_kwargs["env_vars"] = runtime_env_vars
72
+
73
+ runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
74
+ ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
75
+ print(f"ray init kwargs: {ray_init_kwargs}")
76
+ ray.init(**OmegaConf.to_container(ray_init_kwargs))
77
+
78
+ if task_runner_class is None:
79
+ task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head
80
+
81
+ # Create a remote instance of the TaskRunner class, and
82
+ # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
83
+ if (
84
+ is_cuda_available
85
+ and config.global_profiler.tool == "nsys"
86
+ and config.global_profiler.get("steps") is not None
87
+ and len(config.global_profiler.get("steps", [])) > 0
88
+ ):
89
+ from verl.utils.import_utils import is_nvtx_available
90
+
91
+ assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'"
92
+ nsight_options = OmegaConf.to_container(
93
+ config.global_profiler.global_tool_config.nsys.controller_nsight_options
94
+ )
95
+ runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote()
96
+ else:
97
+ runner = task_runner_class.remote()
98
+ ray.get(runner.run.remote(config))
99
+
100
+ # [Optional] get the path of the timeline trace file from the configuration, default to None
101
+ # This file is used for performance analysis
102
+ timeline_json_file = config.ray_kwargs.get("timeline_json_file", None)
103
+ if timeline_json_file:
104
+ ray.timeline(filename=timeline_json_file)
105
+
106
+
107
+ class TaskRunner:
108
+ """Ray remote class for executing distributed PPO training tasks.
109
+
110
+ This class encapsulates the main training logic and runs as a Ray remote actor
111
+ to enable distributed execution across multiple nodes and GPUs.
112
+
113
+ Attributes:
114
+ role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes
115
+ mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation
116
+ """
117
+
118
+ def __init__(self):
119
+ self.role_worker_mapping = {}
120
+ self.mapping = {}
121
+
122
+ def add_actor_rollout_worker(self, config):
123
+ """Add actor rollout worker based on the actor strategy."""
124
+ from verl.single_controller.ray import RayWorkerGroup
125
+ from verl.trainer.ppo.ray_trainer import Role
126
+
127
+ use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
128
+
129
+ # use new model engine implementation
130
+ if use_legacy_worker_impl == "disable":
131
+ from verl.workers.engine_workers import ActorRolloutRefWorker
132
+
133
+ actor_rollout_cls = ActorRolloutRefWorker
134
+ ray_worker_group_cls = RayWorkerGroup
135
+ # NOTE: In new model engine, ref policy and actor rollout are in same ActorRolloutRefWorker,
136
+ # while in legacy model engine, ref policy is in a separate ActorRolloutRefWorker.
137
+ if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
138
+ role = Role.ActorRolloutRef
139
+ else:
140
+ role = Role.ActorRollout
141
+ self.role_worker_mapping[role] = ray.remote(actor_rollout_cls)
142
+ self.mapping[role] = "global_pool"
143
+ return actor_rollout_cls, ray_worker_group_cls
144
+
145
+ if config.actor_rollout_ref.rollout.mode == "sync":
146
+ warnings.warn("spmd rollout mode is deprecated and will be removed in v0.6.2", stacklevel=2)
147
+
148
+ if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
149
+ from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
150
+
151
+ actor_rollout_cls = (
152
+ AsyncActorRolloutRefWorker
153
+ if config.actor_rollout_ref.rollout.mode == "async"
154
+ else ActorRolloutRefWorker
155
+ )
156
+ ray_worker_group_cls = RayWorkerGroup
157
+
158
+ elif config.actor_rollout_ref.actor.strategy == "megatron":
159
+ from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
160
+
161
+ actor_rollout_cls = (
162
+ AsyncActorRolloutRefWorker
163
+ if config.actor_rollout_ref.rollout.mode == "async"
164
+ else ActorRolloutRefWorker
165
+ )
166
+ ray_worker_group_cls = RayWorkerGroup
167
+
168
+ else:
169
+ raise NotImplementedError
170
+
171
+ self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls)
172
+ self.mapping[Role.ActorRollout] = "global_pool"
173
+ return actor_rollout_cls, ray_worker_group_cls
174
+
175
+ def add_critic_worker(self, config):
176
+ """Add critic worker to role mapping."""
177
+ if config.critic.strategy in {"fsdp", "fsdp2"}:
178
+ use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
179
+ if use_legacy_worker_impl in ["auto", "enable"]:
180
+ from verl.workers.fsdp_workers import CriticWorker
181
+ elif use_legacy_worker_impl == "disable":
182
+ from verl.workers.roles import CriticWorker
183
+
184
+ print("Using new worker implementation")
185
+ else:
186
+ raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
187
+
188
+ elif config.critic.strategy == "megatron":
189
+ from verl.workers.megatron_workers import CriticWorker
190
+
191
+ else:
192
+ raise NotImplementedError
193
+
194
+ from verl.trainer.ppo.ray_trainer import Role
195
+
196
+ self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker)
197
+ self.mapping[Role.Critic] = "global_pool"
198
+
199
+ def init_resource_pool_mgr(self, config):
200
+ """Initialize resource pool manager."""
201
+
202
+ global_pool_id = "global_pool"
203
+ resource_pool_spec = {
204
+ global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
205
+ }
206
+ # TODO Here you can use the new registration method to support dynamic registration of roles
207
+ if config.reward_model.enable_resource_pool:
208
+ if config.reward_model.n_gpus_per_node <= 0:
209
+ raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0")
210
+ if config.reward_model.nnodes <= 0:
211
+ raise ValueError("config.reward_model.nnodes must be greater than 0")
212
+
213
+ reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes
214
+ resource_pool_spec["reward_pool"] = reward_pool
215
+
216
+ from verl.trainer.ppo.ray_trainer import ResourcePoolManager
217
+
218
+ resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping)
219
+ return resource_pool_manager
220
+
221
+ def add_reward_model_worker(self, config):
222
+ """Add reward model worker if enabled."""
223
+ from verl.trainer.ppo.ray_trainer import Role
224
+
225
+ if config.reward_model.enable:
226
+ use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
227
+ if use_legacy_worker_impl in ["auto", "enable"]:
228
+ if config.reward_model.strategy in {"fsdp", "fsdp2"}:
229
+ if config.reward_model.elliptical:
230
+ from .workers.elliptical_reward_model_worker import (
231
+ EllipticalRewardModelWorker as RewardModelWorker,
232
+ )
233
+ else:
234
+ from verl.workers.fsdp_workers import RewardModelWorker
235
+ elif config.reward_model.strategy == "megatron":
236
+ from verl.workers.megatron_workers import RewardModelWorker
237
+ else:
238
+ raise NotImplementedError
239
+ elif use_legacy_worker_impl == "disable":
240
+ from verl.workers.roles import RewardModelWorker
241
+
242
+ print("Using new worker implementation")
243
+ else:
244
+ raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
245
+
246
+ self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
247
+ if config.reward_model.enable_resource_pool:
248
+ self.mapping[Role.RewardModel] = "reward_pool"
249
+ else:
250
+ self.mapping[Role.RewardModel] = "global_pool"
251
+
252
+ def add_ref_policy_worker(self, config, ref_policy_cls):
253
+ """Add reference policy worker if KL loss or KL reward is used."""
254
+ from verl.trainer.ppo.ray_trainer import Role
255
+
256
+ # Ref policy has been fused into ActorRolloutRefWorker in new model engine,
257
+ # we don't need to add a separate ref policy worker goup.
258
+ use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
259
+ if use_legacy_worker_impl == "disable":
260
+ return
261
+
262
+ if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
263
+ self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls)
264
+ self.mapping[Role.RefPolicy] = "global_pool"
265
+
266
+ def run(self, config):
267
+ """Execute the main PPO training workflow.
268
+
269
+ This method sets up the distributed training environment, initializes
270
+ workers, datasets, and reward functions, then starts the training process.
271
+
272
+ Args:
273
+ config: Training configuration object containing all parameters needed
274
+ for setting up and running the PPO training process.
275
+ """
276
+ # Print the initial configuration. `resolve=True` will evaluate symbolic values.
277
+ from pprint import pprint
278
+
279
+ from omegaconf import OmegaConf
280
+
281
+ from verl.utils.fs import copy_to_local
282
+
283
+ print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
284
+ pprint(OmegaConf.to_container(config, resolve=True))
285
+ OmegaConf.resolve(config)
286
+
287
+ actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config)
288
+ self.add_critic_worker(config)
289
+
290
+ # We should adopt a multi-source reward function here:
291
+ # - for rule-based rm, we directly call a reward score
292
+ # - for model-based rm, we call a model
293
+ # - for code related prompt, we send to a sandbox if there are test cases
294
+ # finally, we combine all the rewards together
295
+ # The reward type depends on the tag of the data
296
+ self.add_reward_model_worker(config)
297
+
298
+ # Add a reference policy worker if KL loss or KL reward is used.
299
+ self.add_ref_policy_worker(config, actor_rollout_cls)
300
+
301
+ # validate config
302
+ validate_config(
303
+ config=config,
304
+ use_reference_policy=need_reference_policy(self.role_worker_mapping),
305
+ use_critic=need_critic(config),
306
+ )
307
+
308
+ # Download the checkpoint from HDFS to the local machine.
309
+ # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on
310
+ local_path = copy_to_local(
311
+ config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
312
+ )
313
+
314
+ # Instantiate the tokenizer and processor.
315
+ from verl.utils import hf_processor, hf_tokenizer
316
+
317
+ trust_remote_code = config.data.get("trust_remote_code", False)
318
+ tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
319
+ # Used for multimodal LLM, could be None
320
+ processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
321
+
322
+ # Make sure the elliptical reward manager is registered
323
+ from .reward_manager.elliptical_reward_manager import EllipticalRewardManager # noqa: F401
324
+
325
+ # Load the reward manager for training and validation.
326
+ reward_manager_name = config.reward_model.get("reward_manager", "naive")
327
+ reward_fn = load_reward_manager(
328
+ config,
329
+ tokenizer,
330
+ num_examine=0,
331
+ **config.reward_model.get("reward_kwargs", {}).get(reward_manager_name, {}),
332
+ )
333
+ val_reward_fn = load_reward_manager(
334
+ config,
335
+ tokenizer,
336
+ num_examine=1,
337
+ **config.reward_model.get("reward_kwargs", {}).get(reward_manager_name, {}),
338
+ )
339
+
340
+ resource_pool_manager = self.init_resource_pool_mgr(config)
341
+
342
+ from verl.utils.dataset.rl_dataset import collate_fn
343
+
344
+ # Create training and validation datasets.
345
+ train_dataset = create_rl_dataset(
346
+ config.data.train_files,
347
+ config.data,
348
+ tokenizer,
349
+ processor,
350
+ is_train=True,
351
+ max_samples=config.data.get("train_max_samples", -1),
352
+ )
353
+ val_dataset = create_rl_dataset(
354
+ config.data.val_files,
355
+ config.data,
356
+ tokenizer,
357
+ processor,
358
+ is_train=False,
359
+ max_samples=config.data.get("val_max_samples", -1),
360
+ )
361
+ train_sampler = create_rl_sampler(config.data, train_dataset)
362
+
363
+ # Initialize the PPO trainer.
364
+ trainer = RayRepExpTrainer(
365
+ config=config,
366
+ tokenizer=tokenizer,
367
+ processor=processor,
368
+ role_worker_mapping=self.role_worker_mapping,
369
+ resource_pool_manager=resource_pool_manager,
370
+ ray_worker_group_cls=ray_worker_group_cls,
371
+ reward_fn=reward_fn,
372
+ val_reward_fn=val_reward_fn,
373
+ train_dataset=train_dataset,
374
+ val_dataset=val_dataset,
375
+ collate_fn=collate_fn,
376
+ train_sampler=train_sampler,
377
+ )
378
+ # Initialize the workers of the trainer.
379
+ trainer.init_workers()
380
+
381
+ # Start the training process.
382
+ trainer.fit()
383
+
384
+
385
+ def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True, max_samples: int = -1):
386
+ """Create a dataset.
387
+
388
+ Arguments:
389
+ data_paths: List of paths to data files.
390
+ data_config: The data config.
391
+ tokenizer (Tokenizer): The tokenizer.
392
+ processor (Processor): The processor.
393
+
394
+ Returns:
395
+ dataset (Dataset): The dataset.
396
+ """
397
+ from torch.utils.data import Dataset
398
+
399
+ from verl.utils.dataset.rl_dataset import RLHFDataset
400
+
401
+ # Check if a custom dataset class is specified in the data configuration
402
+ # and if the path to the custom class is provided
403
+ if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None:
404
+ # Dynamically load the custom dataset class
405
+ dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
406
+ # Verify that the custom dataset class inherits from torch.utils.data.Dataset
407
+ if not issubclass(dataset_cls, Dataset):
408
+ raise TypeError(
409
+ f"The custom dataset class '{data_config.custom_cls.name}' from "
410
+ f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset"
411
+ )
412
+ elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train:
413
+ # If a data generation strategy is specified, use the DynamicGenDataset class
414
+ from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset
415
+
416
+ dataset_cls = DynamicGenDataset
417
+ print("Using DynamicGenDataset for data generation.")
418
+ else:
419
+ # Use the default RLHFDataset class if no custom class is specified
420
+ dataset_cls = RLHFDataset
421
+ print(f"Using dataset class: {dataset_cls.__name__}")
422
+
423
+ # Instantiate the dataset using the determined dataset class
424
+ dataset = dataset_cls(
425
+ data_files=data_paths,
426
+ tokenizer=tokenizer,
427
+ processor=processor,
428
+ config=data_config,
429
+ max_samples=max_samples,
430
+ )
431
+
432
+ return dataset
433
+
434
+
435
+ def create_rl_sampler(data_config, dataset):
436
+ """Create a sampler for the dataset.
437
+
438
+ Arguments:
439
+ data_config: The data config.
440
+ dataset (Dataset): The dataset.
441
+
442
+ Returns:
443
+ sampler (Sampler): The sampler.
444
+ """
445
+ import torch
446
+ from torch.utils.data import SequentialSampler
447
+
448
+ # torch.utils.data.RandomSampler could not recover properly
449
+ from torchdata.stateful_dataloader.sampler import RandomSampler
450
+
451
+ if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None:
452
+ curriculum_class = load_extern_type(
453
+ data_config.sampler.class_path,
454
+ data_config.sampler.class_name,
455
+ )
456
+ sampler = curriculum_class(
457
+ data_source=dataset,
458
+ data_config=data_config,
459
+ )
460
+ assert isinstance(sampler, AbstractSampler)
461
+ assert data_config.get("dataloader_num_workers", 8) == 0, (
462
+ "If using curriculum, num_workers must be 0 to prevent data caching. "
463
+ "If the dataloader caches data before the batch is done the "
464
+ "curriculum sampler won't have the opportunity to reorder it. "
465
+ )
466
+
467
+ # Use a sampler to facilitate checkpoint resumption.
468
+ # If shuffling is enabled in the data configuration, create a random sampler.
469
+ elif data_config.shuffle:
470
+ train_dataloader_generator = torch.Generator()
471
+ seed = data_config.get("seed")
472
+ if seed is not None:
473
+ train_dataloader_generator.manual_seed(seed)
474
+ sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
475
+ else:
476
+ # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order.
477
+ sampler = SequentialSampler(data_source=dataset)
478
+
479
+ return sampler
480
+
481
+
482
+ if __name__ == "__main__":
483
+ main()
ICL/DAPO/verl-recipe/rep_exp/metric_utils.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Metrics related to the RepExp trainer.
16
+ """
17
+
18
+ from collections import defaultdict
19
+ from functools import partial
20
+ from typing import Any
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from verl import DataProto
26
+ from verl.trainer.ppo.metric_utils import _compute_response_info, bootstrap_metric, calc_maj_val
27
+
28
+
29
+ def _compute_three_case_stats(data: DataProto, extrinsic_reward_tensor: torch.Tensor) -> dict:
30
+ """
31
+ Compute the fraction of samples that have no rollouts correct, some rollouts correct, and all rollouts correct.
32
+
33
+ Args:
34
+ data (DataProto): The data proto containing the batch data.
35
+ extrinsic_reward_tensor (torch.Tensor): The extrinsic reward tensor.
36
+
37
+ Returns:
38
+ dict[str, float]: A dictionary containing the fraction of samples that have no rollouts correct,
39
+ some rollouts correct, and all rollouts correct.
40
+ """
41
+ no_rollouts_correct = 0
42
+ some_rollouts_correct = 0
43
+ all_rollouts_correct = 0
44
+
45
+ visited_uids = set()
46
+ for uid in data.non_tensor_batch["uid"]:
47
+ if uid in visited_uids:
48
+ continue
49
+
50
+ visited_uids.add(uid)
51
+ mask = torch.from_numpy(data.non_tensor_batch["uid"] == uid)
52
+
53
+ # Split into three cases
54
+ if extrinsic_reward_tensor[mask].sum() == 0:
55
+ no_rollouts_correct += 1
56
+ elif extrinsic_reward_tensor[mask].sum() == mask.sum():
57
+ all_rollouts_correct += 1
58
+ elif extrinsic_reward_tensor[mask].sum() > 0 and extrinsic_reward_tensor[mask].sum() < mask.sum():
59
+ some_rollouts_correct += 1
60
+ else:
61
+ raise ValueError(f"Invalid extrinsic reward tensor: {extrinsic_reward_tensor[mask].sum()}")
62
+
63
+ # Sanity checks
64
+ assert len(visited_uids) == no_rollouts_correct + some_rollouts_correct + all_rollouts_correct
65
+
66
+ return {
67
+ "no_rollouts_correct_frac": no_rollouts_correct / len(visited_uids),
68
+ "some_rollouts_correct_frac": some_rollouts_correct / len(visited_uids),
69
+ "all_rollouts_correct_frac": all_rollouts_correct / len(visited_uids),
70
+ }
71
+
72
+
73
+ def compute_data_metrics(batch: DataProto, use_critic: bool = True, elliptical: bool = False) -> dict[str, Any]:
74
+ """
75
+ Computes various metrics from a batch of data for PPO training.
76
+
77
+ This function calculates metrics related to scores, rewards, advantages, returns, values,
78
+ and sequence lengths from a batch of data. It provides statistical information (mean, max, min)
79
+ for each metric category.
80
+
81
+ Args:
82
+ batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc.
83
+ use_critic: Whether to include critic-specific metrics. Defaults to True.
84
+ elliptical: Whether to include elliptical-specific metrics. Defaults to False.
85
+
86
+ Returns:
87
+ A dictionary of metrics including:
88
+ - critic/score/mean, max, min: Statistics about sequence scores
89
+ - critic/rewards/mean, max, min: Statistics about sequence rewards
90
+ - critic/advantages/mean, max, min: Statistics about advantages
91
+ - critic/returns/mean, max, min: Statistics about returns
92
+ - critic/values/mean, max, min: Statistics about critic values (if use_critic=True)
93
+ - critic/vf_explained_var: Explained variance of the value function (if use_critic=True)
94
+ - response_length/mean, max, min, clip_ratio: Statistics about response lengths
95
+ - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths
96
+ - num_turns/mean, max, min: Statistics about the number of multi-turn conversations
97
+ """
98
+ sequence_score = batch.batch["token_level_scores"].sum(-1)
99
+ sequence_reward = batch.batch["token_level_rewards"].sum(-1)
100
+
101
+ if elliptical:
102
+ sequence_intrinsic_reward = batch.non_tensor_batch["intrinsic_reward"].sum(-1)
103
+ sequence_beta_scaled_intrinsic_reward = batch.non_tensor_batch["beta_scaled_intrinsic_reward"].sum(-1)
104
+ sequence_extrinsic_reward = batch.non_tensor_batch["extrinsic_reward"].sum(-1)
105
+ sequence_total_reward = batch.non_tensor_batch["total_reward"].sum(-1)
106
+ sequence_raw_bonuses = batch.non_tensor_batch["raw_bonuses"].sum(-1)
107
+
108
+ three_case_stats = _compute_three_case_stats(batch, batch.non_tensor_batch["extrinsic_reward"])
109
+
110
+ advantages = batch.batch["advantages"]
111
+ returns = batch.batch["returns"]
112
+
113
+ max_response_length = batch.batch["responses"].shape[-1]
114
+
115
+ prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
116
+ response_mask = batch.batch["response_mask"].bool()
117
+
118
+ max_prompt_length = prompt_mask.size(-1)
119
+
120
+ response_info = _compute_response_info(batch)
121
+ prompt_length = response_info["prompt_length"]
122
+ response_length = response_info["response_length"]
123
+
124
+ aborted_mask = (response_length == 0).bool()
125
+ non_aborted_mask = ~aborted_mask
126
+
127
+ non_aborted_sequence_score = sequence_score[non_aborted_mask]
128
+ non_aborted_sequence_reward = sequence_reward[non_aborted_mask]
129
+
130
+ score_mean = torch.mean(non_aborted_sequence_score).detach().item()
131
+ score_max = torch.max(non_aborted_sequence_score).detach().item()
132
+ score_min = torch.min(non_aborted_sequence_score).detach().item()
133
+
134
+ reward_mean = torch.mean(non_aborted_sequence_reward).detach().item()
135
+ reward_max = torch.max(non_aborted_sequence_reward).detach().item()
136
+ reward_min = torch.min(non_aborted_sequence_reward).detach().item()
137
+
138
+ valid_adv = torch.masked_select(advantages, response_mask)
139
+ valid_returns = torch.masked_select(returns, response_mask)
140
+
141
+ if use_critic:
142
+ values = batch.batch["values"]
143
+ valid_values = torch.masked_select(values, response_mask)
144
+ return_diff_var = torch.var(valid_returns - valid_values)
145
+ return_var = torch.var(valid_returns)
146
+
147
+ # Aborted samples and non-aborted response length statistics
148
+ # response_length_non_aborted/*: statistics computed on non-aborted samples only
149
+ aborted_ratio = torch.mean(aborted_mask.float()).detach().item()
150
+
151
+ non_aborted_response_length = response_length[non_aborted_mask]
152
+ if non_aborted_response_length.numel() > 0:
153
+ non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item()
154
+ non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item()
155
+ non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item()
156
+ non_aborted_response_length_clip_ratio = (
157
+ torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item()
158
+ )
159
+ else:
160
+ raise ValueError("All samples are aborted, this should not happen.")
161
+
162
+ metrics = {
163
+ # score
164
+ "critic/score/mean": score_mean,
165
+ "critic/score/max": score_max,
166
+ "critic/score/min": score_min,
167
+ # reward
168
+ "critic/rewards/mean": reward_mean,
169
+ "critic/rewards/max": reward_max,
170
+ "critic/rewards/min": reward_min,
171
+ # adv
172
+ "critic/advantages/mean": torch.mean(valid_adv).detach().item(),
173
+ "critic/advantages/max": torch.max(valid_adv).detach().item(),
174
+ "critic/advantages/min": torch.min(valid_adv).detach().item(),
175
+ # returns
176
+ "critic/returns/mean": torch.mean(valid_returns).detach().item(),
177
+ "critic/returns/max": torch.max(valid_returns).detach().item(),
178
+ "critic/returns/min": torch.min(valid_returns).detach().item(),
179
+ **(
180
+ {
181
+ # values
182
+ "critic/values/mean": torch.mean(valid_values).detach().item(),
183
+ "critic/values/max": torch.max(valid_values).detach().item(),
184
+ "critic/values/min": torch.min(valid_values).detach().item(),
185
+ # vf explained var
186
+ "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
187
+ }
188
+ if use_critic
189
+ else {}
190
+ ),
191
+ **(
192
+ {
193
+ # raw bonuses
194
+ "critic/raw_bonuses/mean": np.mean(sequence_raw_bonuses).item(),
195
+ "critic/raw_bonuses/max": np.max(sequence_raw_bonuses).item(),
196
+ "critic/raw_bonuses/min": np.min(sequence_raw_bonuses).item(),
197
+ "critic/raw_bonuses/std": np.std(sequence_raw_bonuses).item(),
198
+ # intrinsic_reward
199
+ "critic/intrinsic_reward/mean": np.mean(sequence_intrinsic_reward).item(),
200
+ "critic/intrinsic_reward/max": np.max(sequence_intrinsic_reward).item(),
201
+ "critic/intrinsic_reward/min": np.min(sequence_intrinsic_reward).item(),
202
+ "critic/intrinsic_reward/std": np.std(sequence_intrinsic_reward).item(),
203
+ # beta_scaled_intrinsic_reward
204
+ "critic/beta_scaled_intrinsic_reward/mean": np.mean(sequence_beta_scaled_intrinsic_reward).item(),
205
+ "critic/beta_scaled_intrinsic_reward/max": np.max(sequence_beta_scaled_intrinsic_reward).item(),
206
+ "critic/beta_scaled_intrinsic_reward/min": np.min(sequence_beta_scaled_intrinsic_reward).item(),
207
+ "critic/beta_scaled_intrinsic_reward/std": np.std(sequence_beta_scaled_intrinsic_reward).item(),
208
+ # extrinsic_reward
209
+ "critic/extrinsic_reward/mean": np.mean(sequence_extrinsic_reward).item(),
210
+ "critic/extrinsic_reward/max": np.max(sequence_extrinsic_reward).item(),
211
+ "critic/extrinsic_reward/min": np.min(sequence_extrinsic_reward).item(),
212
+ "critic/extrinsic_reward/std": np.std(sequence_extrinsic_reward).item(),
213
+ # three_case_stats
214
+ "critic/extrinsic_reward/no_rollouts_correct_frac": three_case_stats["no_rollouts_correct_frac"],
215
+ "critic/extrinsic_reward/some_rollouts_correct_frac": three_case_stats["some_rollouts_correct_frac"],
216
+ "critic/extrinsic_reward/all_rollouts_correct_frac": three_case_stats["all_rollouts_correct_frac"],
217
+ # total_reward
218
+ "critic/total_reward/mean": np.mean(sequence_total_reward).item(),
219
+ "critic/total_reward/max": np.max(sequence_total_reward).item(),
220
+ "critic/total_reward/min": np.min(sequence_total_reward).item(),
221
+ "critic/total_reward/std": np.std(sequence_total_reward).item(),
222
+ }
223
+ if elliptical
224
+ else {}
225
+ ),
226
+ # response length
227
+ "response_length/mean": torch.mean(response_length).detach().item(),
228
+ "response_length/max": torch.max(response_length).detach().item(),
229
+ "response_length/min": torch.min(response_length).detach().item(),
230
+ "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
231
+ .detach()
232
+ .item(),
233
+ # response length (non-aborted only)
234
+ # These statistics exclude aborted samples to avoid skew from zeros
235
+ "response_length_non_aborted/mean": non_aborted_response_length_mean,
236
+ "response_length_non_aborted/max": non_aborted_response_length_max,
237
+ "response_length_non_aborted/min": non_aborted_response_length_min,
238
+ "response_length_non_aborted/clip_ratio": non_aborted_response_length_clip_ratio,
239
+ # aborted ratio
240
+ # Fraction of samples whose response length is zero
241
+ "response/aborted_ratio": aborted_ratio,
242
+ # prompt length
243
+ "prompt_length/mean": torch.mean(prompt_length).detach().item(),
244
+ "prompt_length/max": torch.max(prompt_length).detach().item(),
245
+ "prompt_length/min": torch.min(prompt_length).detach().item(),
246
+ "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
247
+ }
248
+
249
+ # multi-turn conversation
250
+ if "__num_turns__" in batch.non_tensor_batch:
251
+ num_turns = batch.non_tensor_batch["__num_turns__"]
252
+ metrics["num_turns/min"] = num_turns.min()
253
+ metrics["num_turns/max"] = num_turns.max()
254
+ metrics["num_turns/mean"] = num_turns.mean()
255
+
256
+ if "tool_call_counts" in batch.non_tensor_batch:
257
+ tool_call_counts = batch.non_tensor_batch["tool_call_counts"]
258
+ metrics["tool_call_counts/min"] = tool_call_counts.min()
259
+ metrics["tool_call_counts/max"] = tool_call_counts.max()
260
+ metrics["tool_call_counts/mean"] = tool_call_counts.mean()
261
+
262
+ return metrics
263
+
264
+
265
+ def comb_estimator(n: int, c: int, k: int) -> float:
266
+ """Calculates 1 - comb(n - c, k) / comb(n, k)."""
267
+ if n - c < k:
268
+ return 1.0
269
+ return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
270
+
271
+
272
+ def process_validation_metrics(
273
+ data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42
274
+ ) -> dict[str, dict[str, dict[str, float]]]:
275
+ """
276
+ Process validation metrics into a structured format with statistical analysis.
277
+
278
+ This function organizes validation metrics by data source and prompt, then computes
279
+ various statistical measures including means, standard deviations, best/worst values,
280
+ and majority voting results. It also performs bootstrap sampling to estimate statistics
281
+ for different sample sizes.
282
+
283
+ Args:
284
+ data_sources: List of data source identifiers for each sample.
285
+ sample_uids: List of sample uids corresponding to each sample.
286
+ infos_dict: Dictionary mapping variable names to lists of values for each sample.
287
+ seed: Random seed for bootstrap sampling. Defaults to 42.
288
+
289
+ Returns:
290
+ A nested dictionary with the structure:
291
+ {
292
+ data_source: {
293
+ variable_name: {
294
+ metric_name: value
295
+ }
296
+ }
297
+ }
298
+
299
+ Where metric_name includes:
300
+ - "mean@N": Mean value across N samples
301
+ - "std@N": Standard deviation across N samples
302
+ - "best@N/mean": Mean of the best values in bootstrap samples of size N
303
+ - "best@N/std": Standard deviation of the best values in bootstrap samples
304
+ - "worst@N/mean": Mean of the worst values in bootstrap samples
305
+ - "worst@N/std": Standard deviation of the worst values in bootstrap samples
306
+ - "maj@N/mean": Mean of majority voting results in bootstrap samples (if "pred" exists)
307
+ - "maj@N/std": Standard deviation of majority voting results (if "pred" exists)
308
+
309
+ Example:
310
+ >>> data_sources = ["source1", "source1", "source2"]
311
+ >>> sample_uids = ["uid1", "uid1", "uid2"]
312
+ >>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]}
313
+ >>> result = process_validation_metrics(data_sources, sample_uids, infos_dict)
314
+ >>> # result will contain statistics for each data source and variable
315
+ """
316
+ # Group metrics by data source, prompt and variable
317
+ data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
318
+ for sample_idx, data_source in enumerate(data_sources):
319
+ uid = sample_uids[sample_idx]
320
+ var2vals = data_src2uid2var2vals[data_source][uid]
321
+ for var_name, var_vals in infos_dict.items():
322
+ var2vals[var_name].append(var_vals[sample_idx])
323
+
324
+ # Calculate metrics for each group
325
+ data_src2uid2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
326
+ for data_source, uid2var2vals in data_src2uid2var2vals.items():
327
+ for uid, var2vals in uid2var2vals.items():
328
+ for var_name, var_vals in var2vals.items():
329
+ if isinstance(var_vals[0], str):
330
+ continue
331
+
332
+ metric = {}
333
+ n_resps = len(var_vals)
334
+ metric[f"mean@{n_resps}"] = np.mean(var_vals)
335
+ metric["pass@1/mean"] = comb_estimator(n_resps, np.sum(var_vals), 1)
336
+
337
+ if n_resps > 1:
338
+ metric[f"std@{n_resps}"] = np.std(var_vals)
339
+
340
+ ns = []
341
+ n = 2
342
+ while n < n_resps:
343
+ ns.append(n)
344
+ n *= 2
345
+ ns.append(n_resps)
346
+
347
+ for n in ns:
348
+ # [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(
349
+ # data=var_vals, subset_size=n, reduce_fns=[np.max, np.min], seed=seed
350
+ # )
351
+ # metric[f"best@{n}/mean"], metric[f"best@{n}/std"] = bon_mean, bon_std
352
+ # metric[f"worst@{n}/mean"], metric[f"worst@{n}/std"] = won_mean, won_std
353
+ metric[f"pass@{n}/mean"] = comb_estimator(n_resps, np.sum(var_vals), n)
354
+ if var2vals.get("pred", None) is not None:
355
+ vote_data = [
356
+ {"val": val, "pred": pred} for val, pred in zip(var_vals, var2vals["pred"], strict=True)
357
+ ]
358
+ [(maj_n_mean, maj_n_std)] = bootstrap_metric(
359
+ data=vote_data,
360
+ subset_size=n,
361
+ reduce_fns=[partial(calc_maj_val, vote_key="pred", val_key="val")],
362
+ seed=seed,
363
+ )
364
+ metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std
365
+
366
+ data_src2uid2var2metric[data_source][uid][var_name] = metric
367
+
368
+ # Aggregate metrics across uids
369
+ data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
370
+ for data_source, uid2var2metric in data_src2uid2var2metric.items():
371
+ for uid, var2metric in uid2var2metric.items():
372
+ for var_name, metric in var2metric.items():
373
+ for metric_name, metric_val in metric.items():
374
+ data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val)
375
+
376
+ data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
377
+ for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items():
378
+ for var_name, metric2uid_vals in var2metric2uid_vals.items():
379
+ for metric_name, uid_vals in metric2uid_vals.items():
380
+ data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals)
381
+
382
+ return data_src2var2metric2val
ICL/DAPO/verl-recipe/rep_exp/model_merge.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ CHECKPOINT_PATH=${1} # /path/to/global_step_X/actor, where X is the global step of the checkpoint with the best pass@1 on dev
2
+
3
+ python3 -m verl.model_merger merge \
4
+ --backend fsdp \
5
+ --local_dir $CHECKPOINT_PATH \
6
+ --target_dir $CHECKPOINT_PATH/hf
ICL/DAPO/verl-recipe/rep_exp/plot_pass_at_k.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Code to plot the pass@k results for the RepExp RL training results.
16
+ """
17
+
18
+ import json
19
+ import os
20
+ from collections import defaultdict
21
+
22
+ import matplotlib.pyplot as plt
23
+ import numpy as np
24
+ import scipy.stats as stats
25
+ import seaborn as sns
26
+ from matplotlib.lines import Line2D
27
+
28
+ # Content configuration
29
+ EVAL_FOLDER = "./eval"
30
+ TASKS = ["math"] # ["math", "gsm8k", "dapo-with-aime24"]
31
+ SEEDS = [41, 42, 43]
32
+ ALGORITHMS = ["elliptical"] # ["grpo", "elliptical", "untrained", "unlikely"]
33
+ LOG_AXES = True
34
+
35
+ # Plot configuration
36
+ FACE_COLOR = "#F7F7FF"
37
+ MARKER = "o"
38
+ LINEWIDTH = 1.275
39
+ MARKERSIZE = 6
40
+ MARKEREDGEWIDTH = 0.9
41
+ LABEL_FONT_SIZE = 10
42
+ TITLE_FONT_SIZE = 11
43
+ TICK_LABEL_FONT_SIZE = 8
44
+ LEGEND_FONT_SIZE = 8
45
+
46
+ TASK_TO_NICE_NAME = {
47
+ "math": "MATH",
48
+ "gsm8k": "GSM8K",
49
+ "dapo-with-aime24": "AIME 2024",
50
+ "countdown-4": "Countdown",
51
+ }
52
+
53
+ ALGO_TO_COLOR = {
54
+ "grpo": sns.color_palette("deep")[-1],
55
+ "untrained": sns.color_palette("deep")[7],
56
+ "elliptical": sns.color_palette("colorblind")[2],
57
+ "unlikely": sns.color_palette("deep")[1],
58
+ }
59
+
60
+ ALGO_TO_NICE_NAME = {
61
+ "grpo": "GRPO",
62
+ "untrained": "Base Model",
63
+ "elliptical": r"RepExp (ours)",
64
+ "unlikely": "Unlikeliness",
65
+ }
66
+
67
+
68
+ def process_data(data: list[dict[str, float]], algorithm: str) -> tuple[dict[int, float], dict[int, float]]:
69
+ """
70
+ Process the pass@k data generated by a given algorithm.
71
+
72
+ Args:
73
+ data (List[Dict]): The data to process.
74
+ algorithm (str): Algorithm that generated the data.
75
+
76
+ Returns:
77
+ Tuple[Dict[int, float], Dict[int, float]]:
78
+ pass_at_k - The mean pass@k values.
79
+ pass_at_k_sem - The standard error of the pass@k values.
80
+ """
81
+ pass_at_k = defaultdict(list)
82
+ for d in data:
83
+ for key, v in d.items():
84
+ for k in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
85
+ if key.endswith(f"reward/pass@{k}/mean"):
86
+ pass_at_k[k].append(v)
87
+
88
+ # NOTE: we only use a single seed for untrained since there is only one checkpoint for it
89
+ if algorithm != "untrained":
90
+ for k in pass_at_k.keys():
91
+ assert len(pass_at_k[k]) == len(SEEDS)
92
+
93
+ pass_at_k_sem = {k: stats.sem(v) for k, v in pass_at_k.items()} if algorithm != "untrained" else None
94
+ pass_at_k = {k: np.mean(v) for k, v in pass_at_k.items()}
95
+
96
+ return pass_at_k, pass_at_k_sem
97
+
98
+
99
+ def main():
100
+ # Get all top-level folders in EVAL_FOLDER
101
+ eval_folders = os.listdir(EVAL_FOLDER)
102
+
103
+ # Figure setup
104
+ sns.set_style("whitegrid")
105
+ fig, axs = plt.subplots(1, len(TASKS), figsize=(3 * len(TASKS), 3))
106
+
107
+ for i, task in enumerate(TASKS):
108
+ ax = axs[i] if len(TASKS) > 1 else axs
109
+ algo_to_xs = {}
110
+ algo_to_ys = {}
111
+
112
+ for algorithm in ALGORITHMS:
113
+ # Get all eval folders for the current task and algorithm
114
+ folders = [f for f in eval_folders if f.startswith(f"{task}_{algorithm}")]
115
+ if len(folders) == 0:
116
+ continue
117
+
118
+ data = []
119
+ for folder in folders:
120
+ if algorithm == "untrained":
121
+ with open(os.path.join(EVAL_FOLDER, folder, "eval.json")) as f:
122
+ data.append(json.load(f))
123
+ else:
124
+ # walk all files recursively in folder
125
+ for root, dirs, files in os.walk(os.path.join(EVAL_FOLDER, folder)):
126
+ for file in files:
127
+ if file.endswith("eval.json"):
128
+ with open(os.path.join(root, file)) as f:
129
+ data.append(json.load(f))
130
+ break
131
+
132
+ pass_at_k, pass_at_k_sem = process_data(data, algorithm)
133
+
134
+ xs = np.array(list(pass_at_k.keys()))
135
+ ys = np.array([pass_at_k[k] for k in xs])
136
+ algo_to_xs[algorithm] = xs
137
+ algo_to_ys[algorithm] = ys
138
+
139
+ # Plot the current task - algorithm data
140
+ ax.plot(
141
+ xs,
142
+ ys,
143
+ color=ALGO_TO_COLOR[algorithm],
144
+ label=algorithm,
145
+ markeredgecolor=FACE_COLOR,
146
+ marker=MARKER,
147
+ linewidth=LINEWIDTH,
148
+ markersize=MARKERSIZE,
149
+ markeredgewidth=MARKEREDGEWIDTH,
150
+ alpha=1.0 if algorithm != "untrained" else 0.8,
151
+ )
152
+
153
+ # Plot the standard error in shaded bands
154
+ if algorithm != "untrained":
155
+ sems = np.array([pass_at_k_sem[k] for k in xs])
156
+ ax.fill_between(xs, ys - sems, ys + sems, alpha=0.2, color=ALGO_TO_COLOR[algorithm])
157
+
158
+ # Set y-axis limits
159
+ if task == "math":
160
+ y_min = 0.7
161
+ ax.set_ylim(top=0.95, bottom=y_min)
162
+ elif task == "gsm8k":
163
+ y_min = 0.925
164
+ ax.set_ylim(top=0.995, bottom=y_min)
165
+ elif task == "dapo-with-aime24":
166
+ y_min = 0.1
167
+ ax.set_ylim(bottom=y_min, top=0.63)
168
+
169
+ # Set x-axis limits
170
+ if LOG_AXES:
171
+ ax.set_xlim(left=2 ** (-0.2), right=2 ** (8.2))
172
+ else:
173
+ ax.set_xlim(left=-10, right=266)
174
+
175
+ # Set x-axis scale and ticks
176
+ if LOG_AXES:
177
+ ax.set_xscale("log", base=2)
178
+ x_ticks = [2**i for i in range(int(np.log2(max(xs))) + 1)]
179
+ x_tick_labels = [f"$2^{{{i}}}$" for i in range(int(np.log2(max(xs))) + 1)]
180
+ else:
181
+ # set every 64
182
+ x_ticks = [1, 32, 64, 96, 128, 160, 192, 224, 256]
183
+ x_tick_labels = ["1", "32", "64", "96", "128", "160", "192", "224", "256"]
184
+ ax.set_xticks(x_ticks, x_tick_labels)
185
+
186
+ # Set axes labels
187
+ ax.set_xlabel("k", fontsize=LABEL_FONT_SIZE)
188
+ if i == 0:
189
+ ax.set_ylabel("Pass@k", fontsize=LABEL_FONT_SIZE)
190
+
191
+ # Set title
192
+ ax.set_title(f"{TASK_TO_NICE_NAME[task]}", fontsize=TITLE_FONT_SIZE)
193
+
194
+ # Set font size for tick labels
195
+ for _label in ax.get_xticklabels():
196
+ _label.set_fontsize(TICK_LABEL_FONT_SIZE)
197
+ for _label in ax.get_yticklabels():
198
+ _label.set_fontsize(TICK_LABEL_FONT_SIZE)
199
+
200
+ # Create legend handles
201
+ legend_handles = [
202
+ Line2D(
203
+ [0],
204
+ [0],
205
+ color=ALGO_TO_COLOR[algo],
206
+ marker=MARKER,
207
+ linestyle="-",
208
+ linewidth=LINEWIDTH,
209
+ markersize=MARKERSIZE,
210
+ markeredgewidth=MARKEREDGEWIDTH,
211
+ markeredgecolor=FACE_COLOR,
212
+ label=ALGO_TO_NICE_NAME[algo],
213
+ )
214
+ for algo in ALGORITHMS
215
+ ]
216
+
217
+ # Create legend
218
+ legend = fig.legend(
219
+ handles=legend_handles,
220
+ loc="lower center",
221
+ ncol=len(ALGORITHMS),
222
+ bbox_to_anchor=(0.5, -0.07),
223
+ fontsize=LEGEND_FONT_SIZE,
224
+ )
225
+
226
+ plt.tight_layout()
227
+
228
+ os.makedirs("figures", exist_ok=True)
229
+ # Save figure
230
+ plt.savefig(
231
+ os.path.join("figures", f"rl_pass_at_k_{TASKS}_{'' if LOG_AXES else '_linear_axes'}.pdf"),
232
+ bbox_extra_artists=(legend,),
233
+ bbox_inches="tight",
234
+ )
235
+
236
+ # Close figure
237
+ plt.close()
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()
ICL/DAPO/verl-recipe/rep_exp/rep_exp_trainer.py ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2023-2024 SGLang Team
3
+ # Copyright 2025 ModelBest Inc. and/or its affiliates
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ PPO Trainer with Ray-based single controller.
18
+ This trainer supports model-agonistic model initialization with huggingface
19
+ """
20
+
21
+ import json
22
+ import os
23
+ import uuid
24
+ from collections import defaultdict
25
+ from copy import deepcopy
26
+ from pprint import pprint
27
+
28
+ import numpy as np
29
+ import ray
30
+ import torch
31
+ from omegaconf import OmegaConf
32
+ from tqdm import tqdm
33
+
34
+ from verl import DataProto
35
+ from verl.experimental.dataset.sampler import AbstractCurriculumSampler
36
+ from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
37
+ from verl.single_controller.ray import RayClassWithInitArgs
38
+ from verl.single_controller.ray.base import create_colocated_worker_cls
39
+ from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
40
+ from verl.trainer.ppo.metric_utils import (
41
+ compute_throughout_metrics,
42
+ compute_timing_metrics,
43
+ )
44
+ from verl.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask
45
+ from verl.trainer.ppo.reward import compute_reward, compute_reward_async
46
+ from verl.trainer.ppo.utils import Role
47
+ from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi
48
+ from verl.utils.config import omega_conf_to_dataclass
49
+ from verl.utils.debug import marked_timer
50
+ from verl.utils.metric import reduce_metrics
51
+ from verl.utils.rollout_skip import RolloutSkip
52
+
53
+ from .metric_utils import compute_data_metrics, process_validation_metrics
54
+
55
+
56
+ class RayRepExpTrainer(RayPPOTrainer):
57
+ """Distributed RepExp trainer using Ray for scalable reinforcement learning.
58
+
59
+ See RayPPOTrainer parent class for more details.
60
+ """
61
+
62
+ def _save_checkpoint(self):
63
+ super()._save_checkpoint()
64
+
65
+ # Write best metric to global steps
66
+ local_best_metric_to_global_step = os.path.join(
67
+ self.config.trainer.default_local_dir, "best_metric_to_global_step.json"
68
+ )
69
+ with open(local_best_metric_to_global_step, "w") as f:
70
+ json.dump(self.best_dev_pass_at_k_to_global_step, f)
71
+
72
+ def _update_best_pass_at(self, val_metrics: dict[str, float], pass_at_k: int) -> bool:
73
+ """
74
+ Save checkpoint if the validation metrics are the best.
75
+
76
+ Args:
77
+ val_metrics: The validation metrics.
78
+ pass_at_k: The pass@k to use for determining whether to save the checkpoint.
79
+ """
80
+ for k in val_metrics.keys():
81
+ if k.endswith(f"reward/pass@{pass_at_k}/mean"):
82
+ if val_metrics[k] > self.best_dev_pass_at_k[pass_at_k]:
83
+ self.best_dev_pass_at_k[pass_at_k] = val_metrics[k]
84
+ self.best_dev_pass_at_k_to_global_step[pass_at_k] = self.global_steps
85
+ return True
86
+
87
+ return False
88
+
89
+ def _validate(self):
90
+ data_source_lst = []
91
+ reward_extra_infos_dict: dict[str, list] = defaultdict(list)
92
+
93
+ # Lists to collect samples for the table
94
+ sample_inputs = []
95
+ sample_outputs = []
96
+ sample_gts = []
97
+ sample_scores = []
98
+ sample_turns = []
99
+ sample_uids = []
100
+
101
+ for test_data in tqdm(self.val_dataloader, desc="Validating ..."):
102
+ test_batch = DataProto.from_single_dict(test_data)
103
+
104
+ if "uid" not in test_batch.non_tensor_batch:
105
+ test_batch.non_tensor_batch["uid"] = np.array(
106
+ [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object
107
+ )
108
+
109
+ # repeat test batch
110
+ test_batch = test_batch.repeat(
111
+ repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
112
+ )
113
+
114
+ # we only do validation on rule-based rm
115
+ if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
116
+ return {}
117
+
118
+ # Store original inputs
119
+ input_ids = test_batch.batch["input_ids"]
120
+ # TODO: Can we keep special tokens except for padding tokens?
121
+ input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
122
+ sample_inputs.extend(input_texts)
123
+ sample_uids.extend(test_batch.non_tensor_batch["uid"])
124
+
125
+ ground_truths = [
126
+ item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch
127
+ ]
128
+ sample_gts.extend(ground_truths)
129
+
130
+ test_gen_batch = self._get_gen_batch(test_batch)
131
+ test_gen_batch.meta_info = {
132
+ "eos_token_id": self.tokenizer.eos_token_id,
133
+ "pad_token_id": self.tokenizer.pad_token_id,
134
+ "recompute_log_prob": False,
135
+ "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
136
+ "validate": True,
137
+ "global_steps": self.global_steps,
138
+ }
139
+ print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")
140
+
141
+ # pad to be divisible by dp_size
142
+ size_divisor = (
143
+ self.actor_rollout_wg.world_size
144
+ if not self.async_rollout_mode
145
+ else self.config.actor_rollout_ref.rollout.agent.num_workers
146
+ )
147
+ test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)
148
+ if not self.async_rollout_mode:
149
+ test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
150
+ else:
151
+ test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
152
+
153
+ # unpad
154
+ test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
155
+
156
+ print("validation generation end")
157
+
158
+ # Store generated outputs
159
+ output_ids = test_output_gen_batch.batch["responses"]
160
+ output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
161
+ sample_outputs.extend(output_texts)
162
+
163
+ test_batch = test_batch.union(test_output_gen_batch)
164
+ test_batch.meta_info["validate"] = True
165
+
166
+ # evaluate using reward_function
167
+ if self.val_reward_fn is None:
168
+ raise ValueError("val_reward_fn must be provided for validation.")
169
+ result = self.val_reward_fn(test_batch, return_dict=True)
170
+ reward_tensor = result["reward_tensor"]
171
+ scores = reward_tensor.sum(-1).cpu().tolist()
172
+ sample_scores.extend(scores)
173
+
174
+ reward_extra_infos_dict["reward"].extend(scores)
175
+ if "reward_extra_info" in result:
176
+ for key, lst in result["reward_extra_info"].items():
177
+ reward_extra_infos_dict[key].extend(lst)
178
+
179
+ # collect num_turns of each prompt
180
+ if "__num_turns__" in test_batch.non_tensor_batch:
181
+ sample_turns.append(test_batch.non_tensor_batch["__num_turns__"])
182
+
183
+ data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))
184
+
185
+ self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
186
+
187
+ # dump generations
188
+ val_data_dir = self.config.trainer.get("validation_data_dir", None)
189
+ if val_data_dir:
190
+ self._dump_generations(
191
+ inputs=sample_inputs,
192
+ outputs=sample_outputs,
193
+ gts=sample_gts,
194
+ scores=sample_scores,
195
+ reward_extra_infos_dict=reward_extra_infos_dict,
196
+ dump_path=val_data_dir,
197
+ )
198
+
199
+ for key_info, lst in reward_extra_infos_dict.items():
200
+ assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"
201
+
202
+ data_sources = np.concatenate(data_source_lst, axis=0)
203
+
204
+ data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict)
205
+ metric_dict = {}
206
+ for data_source, var2metric2val in data_src2var2metric2val.items():
207
+ core_var = "acc" if "acc" in var2metric2val else "reward"
208
+ for var_name, metric2val in var2metric2val.items():
209
+ n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
210
+ for metric_name, metric_val in metric2val.items():
211
+ if (
212
+ (var_name == core_var)
213
+ and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"])
214
+ and (f"@{n_max}" in metric_name)
215
+ ):
216
+ metric_sec = "val-core"
217
+ else:
218
+ metric_sec = "val-aux"
219
+ pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
220
+ metric_dict[pfx] = metric_val
221
+
222
+ if len(sample_turns) > 0:
223
+ sample_turns = np.concatenate(sample_turns)
224
+ metric_dict["val-aux/num_turns/min"] = sample_turns.min()
225
+ metric_dict["val-aux/num_turns/max"] = sample_turns.max()
226
+ metric_dict["val-aux/num_turns/mean"] = sample_turns.mean()
227
+
228
+ return metric_dict
229
+
230
+ def init_workers(self):
231
+ """Initialize distributed training workers using Ray backend.
232
+
233
+ Creates:
234
+ 1. Ray resource pools from configuration
235
+ 2. Worker groups for each role (actor, critic, etc.)
236
+ """
237
+ self.resource_pool_manager.create_resource_pool()
238
+
239
+ self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
240
+ val_only = self.config.trainer.get("val_only", False)
241
+
242
+ # create actor and rollout
243
+ actor_role = Role.ActorRolloutRef if Role.ActorRolloutRef in self.role_worker_mapping else Role.ActorRollout
244
+ if self.hybrid_engine:
245
+ resource_pool = self.resource_pool_manager.get_resource_pool(actor_role)
246
+ actor_rollout_cls = RayClassWithInitArgs(
247
+ cls=self.role_worker_mapping[actor_role],
248
+ config=self.config.actor_rollout_ref,
249
+ role=str(actor_role),
250
+ )
251
+ self.resource_pool_to_cls[resource_pool][str(actor_role)] = actor_rollout_cls
252
+ else:
253
+ raise NotImplementedError
254
+
255
+ # create critic
256
+ if self.use_critic and not val_only:
257
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
258
+ critic_cfg = omega_conf_to_dataclass(self.config.critic)
259
+ critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
260
+ self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls
261
+
262
+ # create reference policy if needed
263
+ if self.use_reference_policy and not val_only:
264
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
265
+ ref_policy_cls = RayClassWithInitArgs(
266
+ self.role_worker_mapping[Role.RefPolicy],
267
+ config=self.config.actor_rollout_ref,
268
+ role=str(Role.RefPolicy),
269
+ )
270
+ self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls
271
+
272
+ # create a reward model if reward_fn is None
273
+ if self.use_rm and not val_only:
274
+ # we create a RM here
275
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
276
+ rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
277
+ self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls
278
+
279
+ # initialize WorkerGroup
280
+ # NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
281
+ # you should not use `create_colocated_worker_cls`.
282
+ # Instead, directly pass different resource pool to different worker groups.
283
+ # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
284
+ all_wg = {}
285
+ wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
286
+ if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
287
+ wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
288
+ if OmegaConf.select(self.config.global_profiler, "steps") is not None:
289
+ wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps")
290
+ # Only require nsight worker options when tool is nsys
291
+ if OmegaConf.select(self.config.global_profiler, "tool") == "nsys":
292
+ assert (
293
+ OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
294
+ is not None
295
+ ), "worker_nsight_options must be set when using nsys with profile_steps"
296
+ wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
297
+ OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
298
+ )
299
+ wg_kwargs["device_name"] = self.device_name
300
+
301
+ for resource_pool, class_dict in self.resource_pool_to_cls.items():
302
+ worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
303
+ wg_dict = self.ray_worker_group_cls(
304
+ resource_pool=resource_pool,
305
+ ray_cls_with_init=worker_dict_cls,
306
+ **wg_kwargs,
307
+ )
308
+ spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
309
+ all_wg.update(spawn_wg)
310
+
311
+ if self.use_critic:
312
+ self.critic_wg = all_wg[str(Role.Critic)]
313
+ self.critic_wg.init_model()
314
+
315
+ if self.use_reference_policy and not self.ref_in_actor:
316
+ if str(Role.RefPolicy) in all_wg:
317
+ self.ref_policy_wg = all_wg[str(Role.RefPolicy)]
318
+ self.ref_policy_wg.init_model()
319
+ else:
320
+ # Model engine: ActorRolloutRefWorker
321
+ assert str(Role.ActorRolloutRef) in all_wg, f"{all_wg.keys()=}"
322
+ self.ref_policy_wg = all_wg[str(Role.ActorRolloutRef)]
323
+
324
+ self.rm_wg = None
325
+ # initalization of rm_wg will be deprecated in the future
326
+ if self.use_rm:
327
+ self.rm_wg = all_wg[str(Role.RewardModel)]
328
+ self.rm_wg.init_model()
329
+
330
+ # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
331
+ self.actor_rollout_wg = all_wg[str(actor_role)]
332
+ self.actor_rollout_wg.init_model()
333
+
334
+ # create async rollout manager and request scheduler
335
+ self.async_rollout_mode = False
336
+ if self.config.actor_rollout_ref.rollout.mode == "async":
337
+ from verl.experimental.agent_loop import AgentLoopManager
338
+
339
+ self.async_rollout_mode = True
340
+ self.async_rollout_manager = AgentLoopManager(
341
+ config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg
342
+ )
343
+
344
+ def fit(self):
345
+ """
346
+ The training loop of PPO.
347
+ The driver process only need to call the compute functions of the worker group through RPC
348
+ to construct the PPO dataflow.
349
+ The light-weight advantage computation is done on the driver process.
350
+ """
351
+ from omegaconf import OmegaConf
352
+
353
+ from .utils.tracking import Tracking
354
+
355
+ logger = Tracking(
356
+ project_name=self.config.trainer.project_name,
357
+ experiment_name=self.config.trainer.experiment_name,
358
+ default_backend=self.config.trainer.logger,
359
+ config=OmegaConf.to_container(self.config, resolve=True),
360
+ )
361
+
362
+ # global vars to track during training
363
+ self.global_steps = 0
364
+
365
+ self.best_dev_pass_at_k = {
366
+ 1: 0,
367
+ }
368
+ self.best_dev_pass_at_k_to_global_step = {
369
+ 1: 0,
370
+ }
371
+
372
+ # load checkpoint before doing anything
373
+ self._load_checkpoint()
374
+
375
+ current_epoch = self.global_steps // len(self.train_dataloader)
376
+
377
+ # perform validation before training
378
+ # currently, we only support validation using the reward_function.
379
+ if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
380
+ val_metrics = self._validate()
381
+ assert val_metrics, f"{val_metrics=}"
382
+
383
+ # Initialize the best validation metrics for pass@k before training
384
+ self._update_best_pass_at(val_metrics, 1)
385
+ val_metrics["best/pass@1"] = self.best_dev_pass_at_k[1]
386
+
387
+ pprint(f"Initial validation metrics: {val_metrics}")
388
+ logger.log(data=val_metrics, step=self.global_steps)
389
+
390
+ if self.config.trainer.get("val_only", False):
391
+ return
392
+
393
+ if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
394
+ rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
395
+ rollout_skip.wrap_generate_sequences()
396
+
397
+ # add tqdm
398
+ progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
399
+
400
+ # we start from step 1
401
+ self.global_steps += 1
402
+ last_val_metrics = None
403
+ self.max_steps_duration = 0
404
+
405
+ prev_step_profile = False
406
+ curr_step_profile = (
407
+ self.global_steps in self.config.global_profiler.steps
408
+ if self.config.global_profiler.steps is not None
409
+ else False
410
+ )
411
+ next_step_profile = False
412
+
413
+ for epoch in range(current_epoch, self.config.trainer.total_epochs):
414
+ for batch_dict in self.train_dataloader:
415
+ metrics = {}
416
+ timing_raw = {}
417
+
418
+ with marked_timer("start_profile", timing_raw):
419
+ self._start_profiling(
420
+ not prev_step_profile and curr_step_profile
421
+ if self.config.global_profiler.profile_continuous_steps
422
+ else curr_step_profile
423
+ )
424
+ batch: DataProto = DataProto.from_single_dict(batch_dict)
425
+ batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
426
+
427
+ # add uid to batch
428
+ batch.non_tensor_batch["uid"] = np.array(
429
+ [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object
430
+ )
431
+
432
+ gen_batch = self._get_gen_batch(batch)
433
+
434
+ # pass global_steps to trace
435
+ gen_batch.meta_info["global_steps"] = self.global_steps
436
+ gen_batch_output = gen_batch.repeat(
437
+ repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
438
+ )
439
+
440
+ is_last_step = self.global_steps >= self.total_training_steps
441
+ with marked_timer("step", timing_raw):
442
+ # generate a batch
443
+ with marked_timer("gen", timing_raw, color="red"):
444
+ if not self.async_rollout_mode:
445
+ gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
446
+ else:
447
+ gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
448
+
449
+ timing_raw.update(gen_batch_output.meta_info["timing"])
450
+ gen_batch_output.meta_info.pop("timing", None)
451
+
452
+ if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
453
+ if self.reward_fn is None:
454
+ raise ValueError("A reward_fn is required for REMAX advantage estimation.")
455
+
456
+ with marked_timer("gen_max", timing_raw, color="purple"):
457
+ gen_baseline_batch = deepcopy(gen_batch)
458
+ gen_baseline_batch.meta_info["do_sample"] = False
459
+ if not self.async_rollout_mode:
460
+ gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
461
+ else:
462
+ gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
463
+ batch = batch.union(gen_baseline_output)
464
+ # compute reward model score on batch
465
+ rm_scores = None
466
+ if self.use_rm and "rm_scores" not in batch.batch.keys():
467
+ rm_scores = self.rm_wg.compute_rm_score(batch)
468
+ batch = batch.union(rm_scores)
469
+ reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn)
470
+ reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
471
+
472
+ keys_to_pop = set(gen_baseline_output.batch.keys())
473
+ if rm_scores is not None:
474
+ keys_to_pop.update(rm_scores.batch.keys())
475
+ batch.pop(batch_keys=list(keys_to_pop))
476
+
477
+ batch.batch["reward_baselines"] = reward_baseline_tensor
478
+
479
+ del rm_scores, gen_baseline_batch, gen_baseline_output
480
+ # repeat to align with repeated responses in rollout
481
+ batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
482
+ batch = batch.union(gen_batch_output)
483
+
484
+ if "response_mask" not in batch.batch.keys():
485
+ batch.batch["response_mask"] = compute_response_mask(batch)
486
+ # Balance the number of valid tokens across DP ranks.
487
+ # NOTE: This usually changes the order of data in the `batch`,
488
+ # which won't affect the advantage calculation (since it's based on uid),
489
+ # but might affect the loss calculation (due to the change of mini-batching).
490
+ if self.config.trainer.balance_batch:
491
+ self._balance_batch(batch, metrics=metrics)
492
+
493
+ # compute global_valid tokens
494
+ batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
495
+
496
+ with marked_timer("reward", timing_raw, color="yellow"):
497
+ # compute reward model score
498
+ if self.use_rm and "rm_scores" not in batch.batch.keys():
499
+ if self.config.reward_model.elliptical.enable:
500
+ hidden_states = self.rm_wg.compute_hidden_states(batch)
501
+ batch = batch.union(hidden_states)
502
+ reward_tensor = self.rm_wg.compute_rm_score(batch)
503
+ else:
504
+ reward_tensor = self.rm_wg.compute_rm_score(batch)
505
+ batch = batch.union(reward_tensor)
506
+
507
+ if self.config.reward_model.launch_reward_fn_async:
508
+ future_reward = compute_reward_async.remote(
509
+ data=batch, config=self.config, tokenizer=self.tokenizer
510
+ )
511
+ else:
512
+ reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
513
+
514
+ # Operating Mode Selection:
515
+ # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ)
516
+ # - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ)
517
+ # Note: π_old computed once per data batch, serves as stable reference during mini-batch updates
518
+ rollout_corr_config = self.config.algorithm.get("rollout_correction", None)
519
+ bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False)
520
+ if bypass_recomputing_logprobs: # Use `rollout_log_probs`
521
+ from verl.trainer.ppo.rollout_corr_helper import apply_rollout_correction
522
+
523
+ apply_rollout_correction(
524
+ batch=batch,
525
+ rollout_corr_config=rollout_corr_config,
526
+ policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss,
527
+ )
528
+ else: # Recompute old_log_probs
529
+ with marked_timer("old_log_prob", timing_raw, color="blue"):
530
+ old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
531
+ entropys = old_log_prob.batch["entropys"]
532
+ response_masks = batch.batch["response_mask"]
533
+ loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
534
+ entropy_agg = agg_loss(
535
+ loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode
536
+ )
537
+ old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
538
+ metrics.update(old_log_prob_metrics)
539
+ old_log_prob.batch.pop("entropys")
540
+ batch = batch.union(old_log_prob)
541
+ if "rollout_log_probs" in batch.batch.keys():
542
+ # TODO: we may want to add diff of probs too.
543
+ from verl.utils.debug.metrics import calculate_debug_metrics
544
+
545
+ metrics.update(calculate_debug_metrics(batch))
546
+
547
+ assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}'
548
+
549
+ if self.use_reference_policy:
550
+ # compute reference log_prob
551
+ with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
552
+ if not self.ref_in_actor:
553
+ ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
554
+ else:
555
+ ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
556
+ batch = batch.union(ref_log_prob)
557
+
558
+ # compute values
559
+ if self.use_critic:
560
+ with marked_timer("values", timing_raw, color="cyan"):
561
+ values = self.critic_wg.compute_values(batch)
562
+ batch = batch.union(values)
563
+
564
+ with marked_timer("adv", timing_raw, color="brown"):
565
+ # we combine with rule-based rm
566
+ reward_extra_infos_dict: dict[str, list]
567
+ if self.config.reward_model.launch_reward_fn_async:
568
+ reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
569
+ batch.batch["token_level_scores"] = reward_tensor
570
+
571
+ if reward_extra_infos_dict:
572
+ batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
573
+
574
+ # compute rewards. apply_kl_penalty if available
575
+ if self.config.algorithm.use_kl_in_reward:
576
+ batch, kl_metrics = apply_kl_penalty(
577
+ batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
578
+ )
579
+ metrics.update(kl_metrics)
580
+ else:
581
+ batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
582
+
583
+ # Compute rollout correction: IS weights, rejection sampling, and metrics
584
+ # Only runs in decoupled mode (computes once per batch using stable π_old)
585
+ # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout
586
+ if (
587
+ rollout_corr_config is not None
588
+ and "rollout_log_probs" in batch.batch
589
+ and not bypass_recomputing_logprobs # Only in decoupled mode
590
+ ):
591
+ from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch
592
+
593
+ # Compute IS weights, apply rejection sampling, compute metrics
594
+ batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config)
595
+ # IS and off-policy metrics already have rollout_corr/ prefix
596
+ metrics.update(is_metrics)
597
+
598
+ # compute advantages, executed on the driver process
599
+ norm_adv_by_std_in_grpo = self.config.algorithm.get(
600
+ "norm_adv_by_std_in_grpo", True
601
+ ) # GRPO adv normalization factor
602
+
603
+ batch = compute_advantage(
604
+ batch,
605
+ adv_estimator=self.config.algorithm.adv_estimator,
606
+ gamma=self.config.algorithm.gamma,
607
+ lam=self.config.algorithm.lam,
608
+ num_repeat=self.config.actor_rollout_ref.rollout.n,
609
+ norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
610
+ config=self.config.algorithm,
611
+ )
612
+
613
+ # update critic
614
+ if self.use_critic:
615
+ with marked_timer("update_critic", timing_raw, color="pink"):
616
+ critic_output = self.critic_wg.update_critic(batch)
617
+ critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
618
+ metrics.update(critic_output_metrics)
619
+
620
+ # implement critic warmup
621
+ if self.config.trainer.critic_warmup <= self.global_steps:
622
+ # update actor
623
+ with marked_timer("update_actor", timing_raw, color="red"):
624
+ rollout_config = self.config.actor_rollout_ref.rollout
625
+ batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable
626
+ # TODO: Make "temperature" single source of truth from generation.
627
+ batch.meta_info["temperature"] = rollout_config.temperature
628
+ actor_output = self.actor_rollout_wg.update_actor(batch)
629
+ actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
630
+ metrics.update(actor_output_metrics)
631
+
632
+ # Log rollout generations if enabled
633
+ rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
634
+ if rollout_data_dir:
635
+ self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)
636
+
637
+ # validate
638
+ if (
639
+ self.val_reward_fn is not None
640
+ and self.config.trainer.test_freq > 0
641
+ and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
642
+ ):
643
+ with marked_timer("testing", timing_raw, color="green"):
644
+ val_metrics: dict = self._validate()
645
+
646
+ # Initialize the best validation metrics for pass@k before training
647
+ self._update_best_pass_at(val_metrics, 1)
648
+ val_metrics["best/pass@1"] = self.best_dev_pass_at_k[1]
649
+
650
+ if is_last_step:
651
+ last_val_metrics = val_metrics
652
+ metrics.update(val_metrics)
653
+
654
+ # Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
655
+ esi_close_to_expiration = should_save_ckpt_esi(
656
+ max_steps_duration=self.max_steps_duration,
657
+ redundant_time=self.config.trainer.esi_redundant_time,
658
+ )
659
+ # Check if the conditions for saving a checkpoint are met.
660
+ # The conditions include a mandatory condition (1) and
661
+ # one of the following optional conditions (2/3/4):
662
+ # 1. The save frequency is set to a positive value.
663
+ # 2. It's the last training step.
664
+ # 3. The current step number is a multiple of the save frequency.
665
+ # 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
666
+ if self.config.trainer.save_freq > 0 and (
667
+ is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration
668
+ ):
669
+ if esi_close_to_expiration:
670
+ print("Force saving checkpoint: ESI instance expiration approaching.")
671
+ with marked_timer("save_checkpoint", timing_raw, color="green"):
672
+ self._save_checkpoint()
673
+
674
+ with marked_timer("stop_profile", timing_raw):
675
+ next_step_profile = (
676
+ self.global_steps + 1 in self.config.global_profiler.steps
677
+ if self.config.global_profiler.steps is not None
678
+ else False
679
+ )
680
+ self._stop_profiling(
681
+ curr_step_profile and not next_step_profile
682
+ if self.config.global_profiler.profile_continuous_steps
683
+ else curr_step_profile
684
+ )
685
+ prev_step_profile = curr_step_profile
686
+ curr_step_profile = next_step_profile
687
+
688
+ steps_duration = timing_raw["step"]
689
+ self.max_steps_duration = max(self.max_steps_duration, steps_duration)
690
+
691
+ # training metrics
692
+ metrics.update(
693
+ {
694
+ "training/global_step": self.global_steps,
695
+ "training/epoch": epoch,
696
+ }
697
+ )
698
+ # collect metrics
699
+ metrics.update(
700
+ compute_data_metrics(
701
+ batch=batch,
702
+ use_critic=self.use_critic,
703
+ elliptical=self.config.reward_model.elliptical.enable,
704
+ )
705
+ )
706
+ metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
707
+ # TODO: implement actual tflpo and theoretical tflpo
708
+ n_gpus = self.resource_pool_manager.get_n_gpus()
709
+ metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
710
+ # Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation
711
+
712
+ # this is experimental and may be changed/removed in the future in favor of a general-purpose one
713
+ if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
714
+ self.train_dataloader.sampler.update(batch=batch)
715
+
716
+ # TODO: make a canonical logger that supports various backend
717
+ logger.log(data=metrics, step=self.global_steps)
718
+
719
+ progress_bar.update(1)
720
+ self.global_steps += 1
721
+
722
+ if (
723
+ hasattr(self.config.actor_rollout_ref.actor, "profiler")
724
+ and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory"
725
+ ):
726
+ self.actor_rollout_wg.dump_memory_snapshot(
727
+ tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}"
728
+ )
729
+
730
+ if is_last_step:
731
+ pprint(f"Final validation metrics: {last_val_metrics}")
732
+ progress_bar.close()
733
+ return
734
+
735
+ # this is experimental and may be changed/removed in the future
736
+ # in favor of a general-purpose data buffer pool
737
+ if hasattr(self.train_dataset, "on_batch_end"):
738
+ # The dataset may be changed after each training batch
739
+ self.train_dataset.on_batch_end(batch=batch)
ICL/DAPO/verl-recipe/spin/core_algos.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2023-2024 SGLang Team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+
21
+ class AdaptiveKLController:
22
+ """
23
+ Adaptive KL controller described in the paper:
24
+ https://arxiv.org/pdf/1909.08593.pdf
25
+ """
26
+
27
+ def __init__(self, init_kl_coef, target_kl, horizon):
28
+ self.value = init_kl_coef
29
+ self.target = target_kl
30
+ self.horizon = horizon
31
+
32
+ def update(self, current_kl, n_steps):
33
+ target = self.target
34
+ proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2)
35
+ mult = 1 + proportional_error * n_steps / self.horizon
36
+ self.value *= mult
37
+
38
+
39
+ class FixedKLController:
40
+ """Fixed KL controller."""
41
+
42
+ def __init__(self, kl_coef):
43
+ self.value = kl_coef
44
+
45
+ def update(self, current_kl, n_steps):
46
+ pass
47
+
48
+
49
+ def get_kl_controller(kl_ctrl):
50
+ if kl_ctrl.type == "fixed":
51
+ return FixedKLController(kl_coef=kl_ctrl.kl_coef)
52
+ elif kl_ctrl.type == "adaptive":
53
+ assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}"
54
+ return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon)
55
+ else:
56
+ raise NotImplementedError
57
+
58
+
59
+ def compute_onlinedpo_pref(
60
+ token_level_rewards: torch.Tensor,
61
+ response_mask: torch.Tensor,
62
+ ) -> torch.Tensor:
63
+ """
64
+ Computes preferences between pairs of sequences based on summed rewards
65
+ and returns a mask aligned with the interleaved batch.
66
+
67
+ Assumes inputs are interleaved: [Resp1_Prompt0, Resp2_Prompt0, Resp1_Prompt1, Resp2_Prompt1, ...]
68
+
69
+ Args:
70
+ token_level_rewards: Tensor of shape [batch_size * 2, seq_len]
71
+ response_mask: Tensor of shape [batch_size * 2, seq_len]
72
+
73
+ Returns:
74
+ torch.Tensor: A boolean mask of shape [batch_size * 2], where True indicates
75
+ the corresponding entry is the chosen response for its pair.
76
+ Example: [True, False, False, True, ...] means for prompt 0,
77
+ response 1 was chosen; for prompt 1, response 2 was chosen.
78
+ """
79
+ # print(f"---- [DEBUG] Inside compute_onlinedpo_pref ----")
80
+ if token_level_rewards.shape[0] % 2 != 0 or response_mask.shape[0] % 2 != 0:
81
+ raise ValueError(
82
+ f"Input tensor batch dimension must be even for pair comparison, got shapes: "
83
+ f"{token_level_rewards.shape}, {response_mask.shape}"
84
+ )
85
+ if token_level_rewards.shape != response_mask.shape:
86
+ raise ValueError(f"Shape mismatch between rewards {token_level_rewards.shape} and mask {response_mask.shape}")
87
+
88
+ # 1. Calculate Sequence Scores
89
+ scores = (token_level_rewards * response_mask).sum(dim=-1)
90
+ # print(f" Calculated sequence scores shape: {scores.shape}") # [batch_size * 2]
91
+
92
+ # 2. Reshape scores to group pairs: [batch_size, 2]
93
+ try:
94
+ score_pairs = scores.view(-1, 2)
95
+ except RuntimeError as e:
96
+ print(f"ERROR reshaping scores (shape {scores.shape}) into pairs: {e}")
97
+ raise e
98
+ print(f" Reshaped score pairs shape: {score_pairs.shape}") # [batch_size, 2]
99
+
100
+ # 3. Compare scores to find which index (0 or 1) is the winner within each pair
101
+ # winner_indices[i] = 0 if score_pairs[i, 0] >= score_pairs[i, 1] else 1
102
+ winner_indices = torch.argmax(score_pairs, dim=1) # 0 if first is max, 1 if second is max
103
+ # Handle ties explicitly if argmax behavior isn't guaranteed (usually picks first max)
104
+ # Alternatively: winner_mask_original = score_pairs[:, 0] >= score_pairs[:, 1]
105
+ # print(f" Winner indices shape: {winner_indices.shape}") # [batch_size]
106
+ # print(f" Number where Response 2 (index 1) is preferred: {winner_indices.sum().item()}") # Counts number of 1s
107
+
108
+ # 4. Create the final [batch_size * 2] mask
109
+ num_pairs = score_pairs.shape[0]
110
+ full_batch_size = num_pairs * 2
111
+ # Create indices for the full batch [0, 1, 2, 3, ..., N*2-1]
112
+ # full_indices = torch.arange(full_batch_size, device=scores.device)
113
+ # Create indices corresponding to the winner within each pair's original index
114
+ # E.g., if winner_indices is [0, 1, 0], pair_indices is [0, 1, 2]
115
+ # winner_global_indices = (pair_indices * 2) + winner_indices -> [ (0*2)+0, (1*2)+1, (2*2)+0 ] -> [0, 3, 4]
116
+ pair_indices = torch.arange(num_pairs, device=scores.device)
117
+ winner_global_indices = (pair_indices * 2) + winner_indices
118
+
119
+ # Create boolean mask - True at the winner's position
120
+ output_preference_mask = torch.zeros(full_batch_size, dtype=torch.bool, device=scores.device)
121
+ output_preference_mask[winner_global_indices] = True
122
+
123
+ # print(f" Output preference mask shape: {output_preference_mask.shape}") # Should be [batch_size * 2]
124
+ # print(f" Output mask True count (Chosen): {output_preference_mask.sum().item()}") # Should be batch_size
125
+ # print(f" Output mask False count (Rejected): {(~output_preference_mask).sum().item()}") # Should be batch_size
126
+ # print(f"---- [DEBUG] Exiting compute_onlinedpo_pref ----")
127
+
128
+ return output_preference_mask
129
+
130
+
131
+ def compute_online_dpo_loss(
132
+ policy_chosen_logps: torch.Tensor,
133
+ policy_rejected_logps: torch.Tensor,
134
+ reference_chosen_logps: torch.Tensor,
135
+ reference_rejected_logps: torch.Tensor,
136
+ beta: float,
137
+ label_smoothing: float = 0.0,
138
+ loss_type: str = "sigmoid",
139
+ reference_free: bool = False,
140
+ ) -> torch.Tensor:
141
+ import torch.nn.functional as F
142
+
143
+ pi_logratios = policy_chosen_logps - policy_rejected_logps
144
+ ref_logratios = reference_chosen_logps - reference_rejected_logps
145
+
146
+ if reference_free:
147
+ ref_logratios = torch.zeros_like(pi_logratios)
148
+
149
+ logits = pi_logratios - ref_logratios
150
+
151
+ if loss_type == "sigmoid":
152
+ losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing
153
+ elif loss_type == "ipo":
154
+ losses = (logits - 1 / (2 * beta)) ** 2
155
+ else:
156
+ raise ValueError(f"Unsupported loss_type: {loss_type}. Choose 'sigmoid', 'ipo', or 'hinge'.")
157
+
158
+ return losses.mean()
159
+
160
+
161
+ def get_batch_logps(
162
+ logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False
163
+ ) -> torch.FloatTensor:
164
+ """
165
+ Compute the log probabilities of the given labels under the given logits.
166
+
167
+ Args:
168
+ logits: Logits of the model (e.g., huggingface CausalLMOutputs `logits`).
169
+ Shape: (batch_size, sequence_length, vocab_size)
170
+ labels: Labels for computing the sequence log probabilities. Shape: (batch_size, sequence_length)
171
+ average_log_prob: If True, return the average log probability per sequence. Otherwise, return the sum.
172
+
173
+ Returns:
174
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given sequences.
175
+ """
176
+ if logits.shape[:-1] != labels.shape:
177
+ raise ValueError("Logits and labels must have the same shape[:-1]")
178
+
179
+ # Ensure labels are contiguous and on the same device as logits
180
+ labels = labels.contiguous().to(logits.device)
181
+ # Shift so that tokens < n predict n
182
+ shift_logits = logits[..., :-1, :].contiguous()
183
+ shift_labels = labels[..., 1:].contiguous()
184
+
185
+ # Calculate per token log probability
186
+ loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
187
+ per_token_logps = -loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
188
+ per_token_logps = per_token_logps.view(
189
+ shift_logits.size(0), shift_logits.size(1)
190
+ ) # Reshape back to (batch_size, seq_len-1)
191
+
192
+ # Create a mask for the labels that are not -100
193
+ loss_mask = shift_labels != -100
194
+
195
+ # Apply the mask to the per token log probabilities
196
+ masked_logps = per_token_logps * loss_mask
197
+
198
+ # Calculate the sum or average log probability per sequence
199
+ sequence_logps = masked_logps.sum(dim=-1)
200
+
201
+ if average_log_prob:
202
+ # Avoid division by zero for sequences with no valid tokens
203
+ num_valid_tokens = loss_mask.sum(dim=-1)
204
+ return sequence_logps / torch.clamp(num_valid_tokens, min=1)
205
+ else:
206
+ return sequence_logps
ICL/DAPO/verl-recipe/spin/main_spin.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2023-2024 SGLang Team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ import hydra
19
+ import ray
20
+ from recipe.spin.spin_trainer import RaySPINTrainer
21
+ from recipe.spin.utils import validate_config
22
+
23
+ from verl.trainer.ppo.reward import get_custom_reward_fn
24
+ from verl.trainer.ppo.utils import need_reference_policy
25
+
26
+
27
+ @hydra.main(config_path="config", config_name="spin_trainer", version_base=None)
28
+ def main(config):
29
+ run_ppo(config)
30
+
31
+
32
+ def run_ppo(config) -> None:
33
+ # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices
34
+ # isolation, will solve in the future
35
+ os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "")
36
+ if not ray.is_initialized():
37
+ # this is for local ray cluster
38
+ ray.init(
39
+ runtime_env={
40
+ "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
41
+ }
42
+ )
43
+
44
+ runner = TaskRunner.remote()
45
+ ray.get(runner.run.remote(config))
46
+
47
+
48
+ @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
49
+ class TaskRunner:
50
+ def run(self, config):
51
+ # print initial config
52
+ from pprint import pprint
53
+
54
+ from omegaconf import OmegaConf
55
+
56
+ from verl.utils.fs import copy_to_local
57
+
58
+ pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
59
+ OmegaConf.resolve(config)
60
+
61
+ # define worker classes
62
+ if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
63
+ assert config.critic.strategy in {"fsdp", "fsdp2"}
64
+ # from recipe.spin.fsdp_workers import ActorRolloutRefWorker
65
+ from recipe.spin.fsdp_workers import SPINRolloutRefWorker
66
+
67
+ from verl.single_controller.ray import RayWorkerGroup
68
+
69
+ ray_worker_group_cls = RayWorkerGroup
70
+
71
+ elif config.actor_rollout_ref.actor.strategy == "megatron":
72
+ assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
73
+ from verl.single_controller.ray import RayWorkerGroup
74
+
75
+ ray_worker_group_cls = RayWorkerGroup
76
+
77
+ else:
78
+ raise NotImplementedError
79
+
80
+ from recipe.spin.spin_trainer import ResourcePoolManager, Role
81
+
82
+ role_worker_mapping = {
83
+ # Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
84
+ Role.ActorRollout: ray.remote(SPINRolloutRefWorker),
85
+ # Role.Critic: ray.remote(CriticWorker),
86
+ }
87
+
88
+ global_pool_id = "global_pool"
89
+ resource_pool_spec = {
90
+ global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
91
+ }
92
+ mapping = {
93
+ Role.ActorRollout: global_pool_id,
94
+ # Role.Critic: global_pool_id,
95
+ }
96
+
97
+ if config.reward_model.enable:
98
+ if config.reward_model.strategy in {"fsdp", "fsdp2"}:
99
+ from recipe.spin.fsdp_workers import RewardModelWorker
100
+ elif config.reward_model.strategy == "megatron":
101
+ from verl.workers.megatron_workers import RewardModelWorker
102
+ else:
103
+ raise NotImplementedError
104
+ role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
105
+ mapping[Role.RewardModel] = global_pool_id
106
+
107
+ # use reference model
108
+ # if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
109
+ # role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
110
+ role_worker_mapping[Role.RefPolicy] = ray.remote(SPINRolloutRefWorker)
111
+ mapping[Role.RefPolicy] = global_pool_id
112
+
113
+ # validate config
114
+ validate_config(
115
+ config=config,
116
+ use_reference_policy=need_reference_policy(role_worker_mapping),
117
+ use_critic=False,
118
+ )
119
+
120
+ # download the checkpoint from hdfs
121
+ local_path = copy_to_local(config.actor_rollout_ref.model.path)
122
+
123
+ # instantiate tokenizer
124
+ from verl.utils import hf_processor, hf_tokenizer
125
+
126
+ trust_remote_code = config.data.get("trust_remote_code", False)
127
+ tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
128
+ processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none
129
+
130
+ from verl.workers.reward_manager import get_reward_manager_cls
131
+
132
+ # Note(haibin.lin): please make sure custom reward managers are imported and
133
+ # registered via `verl.workers.reward_manager.register`
134
+ reward_manager_name = config.reward_model.get("reward_manager", "naive")
135
+ reward_manager_cls = get_reward_manager_cls(reward_manager_name)
136
+
137
+ compute_score = get_custom_reward_fn(config)
138
+ reward_kwargs = dict(config.reward_model.get("reward_kwargs", {}))
139
+ reward_fn = reward_manager_cls(
140
+ tokenizer=tokenizer,
141
+ num_examine=0,
142
+ compute_score=compute_score,
143
+ reward_fn_key=config.data.reward_fn_key,
144
+ **reward_kwargs,
145
+ )
146
+
147
+ # Note that we always use function-based RM for validation
148
+ val_reward_fn = reward_manager_cls(
149
+ tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key
150
+ )
151
+ resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
152
+
153
+ trainer = RaySPINTrainer(
154
+ config=config,
155
+ tokenizer=tokenizer,
156
+ processor=processor,
157
+ role_worker_mapping=role_worker_mapping,
158
+ resource_pool_manager=resource_pool_manager,
159
+ ray_worker_group_cls=ray_worker_group_cls,
160
+ reward_fn=reward_fn,
161
+ val_reward_fn=val_reward_fn,
162
+ )
163
+ trainer.init_workers()
164
+ trainer.fit_dpo()
165
+
166
+
167
+ if __name__ == "__main__":
168
+ main()
ICL/DAPO/verl-recipe/spin/spin_trainer.py ADDED
@@ -0,0 +1,1312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2023-2024 SGLang Team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import traceback
18
+ import uuid
19
+ from collections import defaultdict
20
+ from contextlib import contextmanager
21
+ from dataclasses import dataclass, field
22
+ from pprint import pprint
23
+ from typing import Any, Optional
24
+
25
+ import numpy as np
26
+ import ray
27
+ import torch
28
+ from codetiming import Timer
29
+ from omegaconf import OmegaConf, open_dict
30
+ from recipe.spin import core_algos
31
+ from torch.utils.data import Dataset, Sampler
32
+ from torchdata.stateful_dataloader import StatefulDataLoader
33
+ from tqdm import tqdm
34
+
35
+ from verl import DataProto
36
+ from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
37
+ from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
38
+ from verl.single_controller.ray.base import create_colocated_worker_cls
39
+ from verl.trainer.ppo.metric_utils import compute_throughout_metrics, compute_timing_metrics, process_validation_metrics
40
+ from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model
41
+ from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
42
+ from verl.utils.metric import reduce_metrics
43
+ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
44
+ from verl.utils.torch_functional import masked_mean
45
+ from verl.utils.tracking import ValidationGenerationsLogger
46
+
47
+
48
+ @dataclass
49
+ class ResourcePoolManager:
50
+ """
51
+ Define a resource pool specification. Resource pool will be initialized first.
52
+ Mapping
53
+ """
54
+
55
+ resource_pool_spec: dict[str, list[int]]
56
+ mapping: dict[Role, str]
57
+ resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict)
58
+
59
+ def create_resource_pool(self):
60
+ for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
61
+ # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool
62
+ # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one.
63
+ # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different
64
+ # WorkerGroup for different models
65
+ resource_pool = RayResourcePool(
66
+ process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name
67
+ )
68
+ self.resource_pool_dict[resource_pool_name] = resource_pool
69
+
70
+ self._check_resource_available()
71
+
72
+ def get_resource_pool(self, role: Role) -> RayResourcePool:
73
+ """Get the resource pool of the worker_cls"""
74
+ return self.resource_pool_dict[self.mapping[role]]
75
+
76
+ def get_n_gpus(self) -> int:
77
+ """Get the number of gpus in this cluster."""
78
+ return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])
79
+
80
+ def _check_resource_available(self):
81
+ """Check if the resource pool can be satisfied in this ray cluster."""
82
+ node_available_resources = ray._private.state.available_resources_per_node()
83
+ node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()}
84
+
85
+ # check total required gpus can be satisfied
86
+ total_available_gpus = sum(node_available_gpus.values())
87
+ total_required_gpus = sum(
88
+ [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]
89
+ )
90
+ if total_available_gpus < total_required_gpus:
91
+ raise ValueError(
92
+ f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}"
93
+ )
94
+
95
+ # check each resource pool can be satisfied, O(#resource_pools * #nodes)
96
+ for resource_pool_name, process_on_nodes in self.resource_pool_spec.items():
97
+ num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes)
98
+ for node, available_gpus in node_available_gpus.items():
99
+ if available_gpus >= num_gpus:
100
+ node_available_gpus[node] -= num_gpus
101
+ num_nodes -= 1
102
+ if num_nodes == 0:
103
+ break
104
+ if num_nodes > 0:
105
+ raise ValueError(
106
+ f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this "
107
+ f"ray cluster"
108
+ )
109
+
110
+
111
+ def _compute_response_info(batch: DataProto) -> dict[str, Any]:
112
+ """Placeholder: Computes prompt and response lengths."""
113
+ try:
114
+ # Assuming 'prompts' and 'responses' keys exist after generation/union
115
+ prompt_len = batch.batch["prompts"].shape[1]
116
+ resp_len = batch.batch["responses"].shape[1]
117
+ # This is simplified - real implementation might use attention masks
118
+ # to get actual lengths per sample.
119
+ batch_size = batch.batch.batch_size[0]
120
+ prompt_lengths_tensor = torch.full((batch_size,), prompt_len, dtype=torch.float32, device=batch.batch.device)
121
+ response_lengths_tensor = torch.full((batch_size,), resp_len, dtype=torch.float32, device=batch.batch.device)
122
+
123
+ # Try getting actual lengths from attention mask if possible (more accurate)
124
+ if "response_mask" in batch.batch:
125
+ response_lengths_tensor = batch.batch["response_mask"].sum(dim=1).float()
126
+ # if "attention_mask" in batch.batch and "response_mask" in batch.batch:
127
+ # full_mask = batch.batch["attention_mask"]
128
+ # resp_mask = batch.batch["response_mask"]
129
+ # Infer prompt mask length based on where response mask starts or total length
130
+ # This logic depends heavily on how your masks are constructed.
131
+ # Example: prompt_lengths_tensor = full_mask.sum(dim=1).float() - response_lengths_tensor
132
+ # Fallback to using prompt shape if mask logic is complex:
133
+ prompt_lengths_tensor = torch.tensor(
134
+ [batch.batch["prompts"].shape[1]] * batch_size, dtype=torch.float32, device=batch.batch.device
135
+ )
136
+
137
+ return {
138
+ "prompt_length": prompt_lengths_tensor,
139
+ "response_length": response_lengths_tensor,
140
+ "max_response_length": resp_len,
141
+ "max_prompt_length": prompt_len, # Or from config if fixed padding
142
+ }
143
+ except KeyError as e:
144
+ print(f"Warning: Missing key in _compute_response_info: {e}. Returning defaults.")
145
+ # Return default/dummy values if keys are missing
146
+ b_size = batch.batch.batch_size[0] if batch.batch.batch_size else 1
147
+ max_resp = batch.batch.get("responses").shape[1] if batch.batch.get("responses") is not None else 0
148
+ max_prompt = batch.batch.get("prompts").shape[1] if batch.batch.get("prompts") is not None else 0
149
+ return {
150
+ "prompt_length": torch.zeros(b_size),
151
+ "response_length": torch.zeros(b_size),
152
+ "max_response_length": max_resp,
153
+ "max_prompt_length": max_prompt,
154
+ }
155
+
156
+
157
+ # --- Modified Metric Function ---
158
+ def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]:
159
+ """
160
+ Computes and returns metrics relevant for the DPO-like process.
161
+ Assumes 'batch' contains results after generation and preference marking,
162
+ potentially including 'dpo_logits', 'preferences', 'chosen_logps', etc.
163
+ Removes PPO-specific advantage/return/critic metrics.
164
+ """
165
+ print("---- [DEBUG] Computing DPO Data Metrics ----")
166
+ metrics = {}
167
+ try:
168
+ # --- Scores and Rewards (from reward_fn) ---
169
+ if "token_level_scores" in batch.batch and batch.batch["token_level_scores"] is not None:
170
+ sequence_score = batch.batch["token_level_scores"].sum(-1)
171
+ metrics.update(
172
+ {
173
+ "reward/score/mean": torch.mean(sequence_score).item(),
174
+ "reward/score/max": torch.max(sequence_score).item(),
175
+ "reward/score/min": torch.min(sequence_score).item(),
176
+ }
177
+ )
178
+ else:
179
+ print("DEBUG compute_dpo_data_metrics: 'token_level_scores' not found.")
180
+
181
+ if "token_level_rewards" in batch.batch and batch.batch["token_level_rewards"] is not None:
182
+ sequence_reward = batch.batch["token_level_rewards"].sum(-1)
183
+ metrics.update(
184
+ {
185
+ "reward/rewards/mean": torch.mean(sequence_reward).item(),
186
+ "reward/rewards/max": torch.max(sequence_reward).item(),
187
+ "reward/rewards/min": torch.min(sequence_reward).item(),
188
+ }
189
+ )
190
+ else:
191
+ print("DEBUG compute_dpo_data_metrics: 'token_level_rewards' not found.")
192
+
193
+ # --- DPO Specific Metrics (if stored previously) ---
194
+ if "dpo_logits" in batch.batch and batch.batch["dpo_logits"] is not None:
195
+ metrics["actor/dpo_logits"] = batch.batch["dpo_logits"].mean().item()
196
+ else:
197
+ print("DEBUG compute_dpo_data_metrics: 'dpo_logits' not found.")
198
+
199
+ if "chosen_logps" in batch.batch and batch.batch["chosen_logps"] is not None:
200
+ metrics["actor/chosen_logps"] = batch.batch["chosen_logps"].mean().item()
201
+ else:
202
+ print("DEBUG compute_dpo_data_metrics: 'chosen_logps' not found.")
203
+
204
+ if "rejected_logps" in batch.batch and batch.batch["rejected_logps"] is not None:
205
+ metrics["actor/rejected_logps"] = batch.batch["rejected_logps"].mean().item()
206
+ else:
207
+ print("DEBUG compute_dpo_data_metrics: 'rejected_logps' not found.")
208
+
209
+ # Add metrics based on the 'preferences' mask if available
210
+ # if "preferences" in batch.batch and batch.batch["preferences"] is not None:
211
+ # prefs_mask = batch.batch["preferences"] # Shape [batch_size * n]
212
+ # Calculate accuracy based on RM scores (assuming higher score -> True in mask)
213
+ # Requires chosen/rejected scores to be available or recalculated
214
+ # This is complex here, better calculated in the main loop or update function
215
+
216
+ # --- Length Metrics ---
217
+ response_info = _compute_response_info(batch)
218
+ prompt_length = response_info["prompt_length"]
219
+ response_length = response_info["response_length"]
220
+ max_response_length = response_info["max_response_length"]
221
+ max_prompt_length = response_info["max_prompt_length"] # Use calculated or from config
222
+
223
+ metrics.update(
224
+ {
225
+ "response_length/mean": torch.mean(response_length).item(),
226
+ "response_length/max": torch.max(response_length).item(),
227
+ "response_length/min": torch.min(response_length).item(),
228
+ "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()).item(),
229
+ "prompt_length/mean": torch.mean(prompt_length).item(),
230
+ "prompt_length/max": torch.max(prompt_length).item(),
231
+ "prompt_length/min": torch.min(prompt_length).item(),
232
+ # Prompt clip ratio might need adjustment based on how max_prompt_length is defined
233
+ "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).item(),
234
+ }
235
+ )
236
+
237
+ except KeyError as e:
238
+ print(f"ERROR in compute_dpo_data_metrics: Missing key {e}")
239
+ except Exception as e:
240
+ print(f"ERROR in compute_dpo_data_metrics: {e}")
241
+ traceback.print_exc()
242
+
243
+ print(f"---- [DEBUG] Calculated DPO Data Metrics: {list(metrics.keys())} ----")
244
+ return metrics
245
+
246
+
247
+ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"):
248
+ responses = data.batch["responses"]
249
+ response_length = responses.size(1)
250
+ token_level_scores = data.batch["token_level_scores"]
251
+ batch_size = data.batch.batch_size[0]
252
+ attention_mask = data.batch["attention_mask"]
253
+ response_mask = attention_mask[:, -response_length:]
254
+
255
+ # compute kl between ref_policy and current policy
256
+ # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled.
257
+ kld = core_algos.kl_penalty(
258
+ data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty
259
+ ) # (batch_size, response_length)
260
+ kld = kld * response_mask
261
+ beta = kl_ctrl.value
262
+
263
+ token_level_rewards = token_level_scores - beta * kld
264
+
265
+ current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence
266
+ current_kl = torch.mean(current_kl, dim=0).item()
267
+
268
+ # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837
269
+ kl_ctrl.update(current_kl=current_kl, n_steps=batch_size)
270
+ data.batch["token_level_rewards"] = token_level_rewards
271
+
272
+ metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta}
273
+
274
+ return data, metrics
275
+
276
+
277
+ def compute_response_mask(data: DataProto):
278
+ responses = data.batch["responses"]
279
+ response_length = responses.size(1)
280
+ attention_mask = data.batch["attention_mask"]
281
+ return attention_mask[:, -response_length:]
282
+
283
+
284
+ def compute_onlineDPO_pref(data: DataProto):
285
+ """
286
+ Wrapper to compute DPO preference and add it to the DataProto batch.
287
+ Includes debugging prints.
288
+ """
289
+ # print(f"\n---- [DEBUG] Entering compute_onlineDPO_pref ----")
290
+ # print(f" Input batch keys: {list(data.batch.keys())}")
291
+
292
+ # Check inputs
293
+ rewards_tensor = data.batch.get("token_level_rewards")
294
+ mask_tensor = data.batch.get("response_mask")
295
+
296
+ if rewards_tensor is None or mask_tensor is None:
297
+ print(" ERROR: Missing 'token_level_rewards' or 'response_mask' in input data!")
298
+ # Handle error case - maybe return original data or raise?
299
+ # Returning original data for now to potentially allow skipping
300
+ return data
301
+
302
+ try:
303
+ preferences = core_algos.compute_onlinedpo_pref(token_level_rewards=rewards_tensor, response_mask=mask_tensor)
304
+ # Store the result
305
+ data.batch["preferences"] = preferences
306
+
307
+ except AttributeError:
308
+ print("ERROR: Function 'compute_online_dpo_preference' not found in core_algos.py!")
309
+ # Assign dummy value or raise error
310
+ data.batch["preferences"] = None # Indicate failure
311
+ except Exception as e_pref:
312
+ print(f"ERROR during core_algos.compute_online_dpo_preference: {e_pref}")
313
+ import traceback
314
+
315
+ traceback.print_exc()
316
+ data.batch["preferences"] = None # Indicate failure
317
+
318
+ # print(f"---- [DEBUG] Exiting compute_onlineDPO_pref ----")
319
+ return data
320
+
321
+
322
+ @contextmanager
323
+ def _timer(name: str, timing_raw: dict[str, float]):
324
+ with Timer(name=name, logger=None) as timer:
325
+ yield
326
+ timing_raw[name] = timer.last
327
+
328
+
329
+ class RaySPINTrainer:
330
+ """
331
+ Note that this trainer runs on the driver process on a single CPU/GPU node.
332
+ """
333
+
334
+ # TODO: support each role have individual ray_worker_group_cls,
335
+ # i.e., support different backend of different role
336
+ def __init__(
337
+ self,
338
+ config,
339
+ tokenizer,
340
+ role_worker_mapping: dict[Role, WorkerType],
341
+ resource_pool_manager: ResourcePoolManager,
342
+ ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
343
+ processor=None,
344
+ reward_fn=None,
345
+ val_reward_fn=None,
346
+ train_dataset: Optional[Dataset] = None,
347
+ val_dataset: Optional[Dataset] = None,
348
+ collate_fn=None,
349
+ train_sampler: Optional[Sampler] = None,
350
+ device_name=None,
351
+ ):
352
+ # assert get_torch_device().is_available(), 'cuda must be available on driver'
353
+
354
+ self.tokenizer = tokenizer
355
+ self.processor = processor
356
+ self.config = config
357
+ self.reward_fn = reward_fn
358
+ self.val_reward_fn = val_reward_fn
359
+
360
+ self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
361
+ assert self.hybrid_engine, "Currently, only support hybrid engine"
362
+
363
+ if self.hybrid_engine:
364
+ assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}"
365
+
366
+ self.role_worker_mapping = role_worker_mapping
367
+ self.resource_pool_manager = resource_pool_manager
368
+ self.use_reference_policy = need_reference_policy(role_worker_mapping)
369
+ self.use_rm = need_reward_model(role_worker_mapping)
370
+ self.use_critic = False
371
+ self.ray_worker_group_cls = ray_worker_group_cls
372
+ self.validation_generations_logger = ValidationGenerationsLogger()
373
+ self.async_rollout_mode = False
374
+ self.device_name = device_name if device_name else self.config.trainer.device
375
+
376
+ # define in-reward KL control
377
+ # kl loss control currently not suppoorted
378
+ if config.algorithm.use_kl_in_reward:
379
+ self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
380
+
381
+ self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
382
+
383
+ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):
384
+ """
385
+ Creates the train and validation dataloaders.
386
+ """
387
+ # TODO: we have to make sure the batch size is divisible by the dp size
388
+ from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
389
+
390
+ if train_dataset is None:
391
+ train_dataset = create_rl_dataset(
392
+ self.config.data.train_files,
393
+ self.config.data,
394
+ self.tokenizer,
395
+ self.processor,
396
+ max_samples=self.config.data.get("train_max_samples", -1),
397
+ )
398
+ if val_dataset is None:
399
+ val_dataset = create_rl_dataset(
400
+ self.config.data.val_files,
401
+ self.config.data,
402
+ self.tokenizer,
403
+ self.processor,
404
+ max_samples=self.config.data.get("val_max_samples", -1),
405
+ )
406
+ self.train_dataset, self.val_dataset = train_dataset, val_dataset
407
+
408
+ if train_sampler is None:
409
+ train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
410
+ if collate_fn is None:
411
+ from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn
412
+
413
+ collate_fn = default_collate_fn
414
+
415
+ self.train_dataloader = StatefulDataLoader(
416
+ dataset=self.train_dataset,
417
+ batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
418
+ num_workers=self.config.data.get("dataloader_num_workers", 8),
419
+ drop_last=True,
420
+ collate_fn=collate_fn,
421
+ sampler=train_sampler,
422
+ )
423
+
424
+ val_batch_size = self.config.data.val_batch_size # Prefer config value if set
425
+ if val_batch_size is None:
426
+ val_batch_size = len(self.val_dataset)
427
+
428
+ self.val_dataloader = StatefulDataLoader(
429
+ dataset=self.val_dataset,
430
+ batch_size=val_batch_size,
431
+ num_workers=self.config.data.get("dataloader_num_workers", 8),
432
+ shuffle=False,
433
+ drop_last=False,
434
+ collate_fn=collate_fn,
435
+ )
436
+
437
+ assert len(self.train_dataloader) >= 1, "Train dataloader is empty!"
438
+ assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!"
439
+
440
+ print(
441
+ f"Size of train dataloader: {len(self.train_dataloader)}, "
442
+ f"Size of val dataloader: {len(self.val_dataloader)}"
443
+ )
444
+
445
+ total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
446
+
447
+ if self.config.trainer.total_training_steps is not None:
448
+ total_training_steps = self.config.trainer.total_training_steps
449
+
450
+ self.total_training_steps = total_training_steps
451
+ print(f"Total training steps: {self.total_training_steps}")
452
+
453
+ try:
454
+ OmegaConf.set_struct(self.config, True)
455
+ with open_dict(self.config):
456
+ if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"):
457
+ self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
458
+ if OmegaConf.select(self.config, "critic.optim"):
459
+ self.config.critic.optim.total_training_steps = total_training_steps
460
+ except Exception as e:
461
+ print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}")
462
+
463
+ def _maybe_log_val_generations(self, inputs, outputs, scores):
464
+ """Log a table of validation samples to the configured logger (wandb or swanlab)"""
465
+
466
+ generations_to_log = self.config.trainer.log_val_generations
467
+
468
+ if generations_to_log == 0:
469
+ return
470
+
471
+ import numpy as np
472
+
473
+ # Create tuples of (input, output, score) and sort by input text
474
+ samples = list(zip(inputs, outputs, scores, strict=True))
475
+ samples.sort(key=lambda x: x[0]) # Sort by input text
476
+
477
+ # Use fixed random seed for deterministic shuffling
478
+ rng = np.random.RandomState(42)
479
+ rng.shuffle(samples)
480
+
481
+ # Take first N samples after shuffling
482
+ samples = samples[:generations_to_log]
483
+
484
+ # Log to each configured logger
485
+ self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps)
486
+
487
+ def _validate(self):
488
+ data_source_lst = []
489
+ reward_extra_infos_dict: dict[str, list] = defaultdict(list)
490
+
491
+ # Lists to collect samples for the table
492
+ sample_inputs = []
493
+ sample_outputs = []
494
+ sample_scores = []
495
+
496
+ for test_data in self.val_dataloader:
497
+ test_batch = DataProto.from_single_dict(test_data)
498
+
499
+ # repeat test batch
500
+ test_batch = test_batch.repeat(
501
+ repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True
502
+ )
503
+
504
+ # we only do validation on rule-based rm
505
+ if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model":
506
+ return {}
507
+
508
+ # Store original inputs
509
+ input_ids = test_batch.batch["input_ids"]
510
+ # TODO: Can we keep special tokens except for padding tokens?
511
+ input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
512
+ sample_inputs.extend(input_texts)
513
+
514
+ batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
515
+ non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
516
+ if "multi_modal_inputs" in test_batch.non_tensor_batch:
517
+ non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"])
518
+ if "raw_prompt" in test_batch.non_tensor_batch:
519
+ non_tensor_batch_keys_to_pop.append("raw_prompt")
520
+ if "tools_kwargs" in test_batch.non_tensor_batch:
521
+ non_tensor_batch_keys_to_pop.append("tools_kwargs")
522
+ test_gen_batch = test_batch.pop(
523
+ batch_keys=batch_keys_to_pop,
524
+ non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
525
+ )
526
+
527
+ test_gen_batch.meta_info = {
528
+ "eos_token_id": self.tokenizer.eos_token_id,
529
+ "pad_token_id": self.tokenizer.pad_token_id,
530
+ "recompute_log_prob": False,
531
+ "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample,
532
+ "validate": True,
533
+ }
534
+ print(f"test_gen_batch meta info: {test_gen_batch.meta_info}")
535
+
536
+ # pad to be divisible by dp_size
537
+ test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size)
538
+ if not self.async_rollout_mode:
539
+ test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
540
+ else:
541
+ test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
542
+
543
+ # unpad
544
+ test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
545
+ print("validation generation end")
546
+
547
+ # Store generated outputs
548
+ output_ids = test_output_gen_batch.batch["responses"]
549
+ output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
550
+ sample_outputs.extend(output_texts)
551
+
552
+ test_batch = test_batch.union(test_output_gen_batch)
553
+
554
+ # evaluate using reward_function
555
+ result = self.val_reward_fn(test_batch, return_dict=True)
556
+ reward_tensor = result["reward_tensor"]
557
+ scores = reward_tensor.sum(-1).cpu().tolist()
558
+ sample_scores.extend(scores)
559
+
560
+ reward_extra_infos_dict["reward"].extend(scores)
561
+ if "reward_extra_info" in result:
562
+ for key, lst in result["reward_extra_info"].items():
563
+ reward_extra_infos_dict[key].extend(lst)
564
+
565
+ data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0]))
566
+
567
+ self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
568
+
569
+ # dump generations
570
+ val_data_dir = self.config.trainer.get("validation_data_dir", None)
571
+ if val_data_dir:
572
+ sample_gts = [
573
+ item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch
574
+ ]
575
+ self._dump_generations(
576
+ inputs=sample_inputs,
577
+ outputs=sample_outputs,
578
+ gts=sample_gts,
579
+ scores=sample_scores,
580
+ reward_extra_infos_dict=reward_extra_infos_dict,
581
+ dump_path=val_data_dir,
582
+ )
583
+
584
+ for key_info, lst in reward_extra_infos_dict.items():
585
+ assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}"
586
+
587
+ data_sources = np.concatenate(data_source_lst, axis=0)
588
+ print(f"DEBUG: Data sources shape: {data_sources.shape}") # Added Print
589
+ print(f"DEBUG: reward_extra_infos_dict keys before processing: {reward_extra_infos_dict.keys()}") # Added Print
590
+
591
+ data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict)
592
+ print(
593
+ f"DEBUG: Output of process_validation_metrics (data_src2var2metric2val): {data_src2var2metric2val}"
594
+ ) # Added Print
595
+ metric_dict = {}
596
+ for data_source, var2metric2val in data_src2var2metric2val.items():
597
+ core_var = "acc" if "acc" in var2metric2val else "reward"
598
+ for var_name, metric2val in var2metric2val.items():
599
+ n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()])
600
+ for metric_name, metric_val in metric2val.items():
601
+ if (
602
+ (var_name == core_var)
603
+ and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"])
604
+ and (f"@{n_max}" in metric_name)
605
+ ):
606
+ metric_sec = "val-core"
607
+ else:
608
+ metric_sec = "val-aux"
609
+ pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}"
610
+ metric_dict[pfx] = metric_val
611
+
612
+ return metric_dict
613
+
614
+ def init_workers(self):
615
+ """Init resource pool and worker group"""
616
+ self.resource_pool_manager.create_resource_pool()
617
+
618
+ self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
619
+
620
+ # create actor and rollout
621
+ if self.hybrid_engine:
622
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
623
+ actor_rollout_cls = RayClassWithInitArgs(
624
+ cls=self.role_worker_mapping[Role.ActorRollout],
625
+ config=self.config.actor_rollout_ref,
626
+ role="actor_rollout",
627
+ )
628
+ self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
629
+ else:
630
+ raise NotImplementedError
631
+
632
+ # create critic
633
+ if self.use_critic:
634
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
635
+ critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic)
636
+ self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
637
+
638
+ # create reference policy if needed
639
+ if self.use_reference_policy:
640
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
641
+ ref_policy_cls = RayClassWithInitArgs(
642
+ self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref"
643
+ )
644
+ self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
645
+
646
+ # create a reward model if reward_fn is None
647
+ if self.use_rm:
648
+ # we create a RM here
649
+ resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
650
+ rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
651
+ self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
652
+
653
+ # initialize WorkerGroup
654
+ # NOTE: if you want to use a different resource pool for each role, which can support different
655
+ # parallel size,
656
+ # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to
657
+ # different worker groups.
658
+ # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
659
+ all_wg = {}
660
+ self.wg_dicts = []
661
+ wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
662
+ if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
663
+ wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
664
+ wg_kwargs["device_name"] = self.device_name
665
+
666
+ for resource_pool, class_dict in self.resource_pool_to_cls.items():
667
+ worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
668
+ wg_dict = self.ray_worker_group_cls(
669
+ resource_pool=resource_pool,
670
+ ray_cls_with_init=worker_dict_cls,
671
+ **wg_kwargs,
672
+ )
673
+ spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
674
+ all_wg.update(spawn_wg)
675
+ # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699
676
+ self.wg_dicts.append(wg_dict)
677
+
678
+ if self.use_critic:
679
+ self.critic_wg = all_wg["critic"]
680
+ self.critic_wg.init_model()
681
+
682
+ if self.use_reference_policy:
683
+ self.ref_policy_wg = all_wg["ref"]
684
+ self.ref_policy_wg.init_model()
685
+
686
+ if self.use_rm:
687
+ self.rm_wg = all_wg["rm"]
688
+ self.rm_wg.init_model()
689
+
690
+ # we should create rollout at the end so that vllm can have a better estimation of kv cache memory
691
+ self.actor_rollout_wg = all_wg["actor_rollout"]
692
+ self.actor_rollout_wg.init_model()
693
+
694
+ def _save_checkpoint(self):
695
+ # path: given_path + `/global_step_{global_steps}` + `/actor`
696
+ local_global_step_folder = os.path.join(
697
+ self.config.trainer.default_local_dir, f"global_step_{self.global_steps}"
698
+ )
699
+
700
+ print(f"local_global_step_folder: {local_global_step_folder}")
701
+ actor_local_path = os.path.join(local_global_step_folder, "actor")
702
+
703
+ actor_remote_path = (
704
+ None
705
+ if self.config.trainer.default_hdfs_dir is None
706
+ else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor")
707
+ )
708
+
709
+ remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False)
710
+ if remove_previous_ckpt_in_save:
711
+ print(
712
+ "Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and "
713
+ "max_critic_ckpt_to_keep=1 instead"
714
+ )
715
+ max_actor_ckpt_to_keep = (
716
+ self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1
717
+ )
718
+ max_critic_ckpt_to_keep = (
719
+ self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1
720
+ )
721
+
722
+ self.actor_rollout_wg.save_checkpoint(
723
+ actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep
724
+ )
725
+
726
+ if self.use_critic:
727
+ critic_local_path = os.path.join(local_global_step_folder, "critic")
728
+ critic_remote_path = (
729
+ None
730
+ if self.config.trainer.default_hdfs_dir is None
731
+ else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic")
732
+ )
733
+ self.critic_wg.save_checkpoint(
734
+ critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep
735
+ )
736
+
737
+ # save dataloader
738
+ dataloader_local_path = os.path.join(local_global_step_folder, "data.pt")
739
+ dataloader_state_dict = self.train_dataloader.state_dict()
740
+ torch.save(dataloader_state_dict, dataloader_local_path)
741
+
742
+ # latest checkpointed iteration tracker (for atomic usage)
743
+ local_latest_checkpointed_iteration = os.path.join(
744
+ self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt"
745
+ )
746
+ with open(local_latest_checkpointed_iteration, "w") as f:
747
+ f.write(str(self.global_steps))
748
+
749
+ def _load_checkpoint(self):
750
+ if self.config.trainer.resume_mode == "disable":
751
+ return 0
752
+
753
+ # load from hdfs
754
+ if self.config.trainer.default_hdfs_dir is not None:
755
+ raise NotImplementedError("load from hdfs is not implemented yet")
756
+ else:
757
+ checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path
758
+ if not os.path.isabs(checkpoint_folder):
759
+ working_dir = os.getcwd()
760
+ checkpoint_folder = os.path.join(working_dir, checkpoint_folder)
761
+ global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest
762
+
763
+ # find global_step_folder
764
+ if self.config.trainer.resume_mode == "auto":
765
+ if global_step_folder is None:
766
+ print("Training from scratch")
767
+ return 0
768
+ else:
769
+ if self.config.trainer.resume_mode == "resume_path":
770
+ assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type"
771
+ assert "global_step_" in self.config.trainer.resume_from_path, (
772
+ "resume ckpt must specify the global_steps"
773
+ )
774
+ global_step_folder = self.config.trainer.resume_from_path
775
+ if not os.path.isabs(global_step_folder):
776
+ working_dir = os.getcwd()
777
+ global_step_folder = os.path.join(working_dir, global_step_folder)
778
+ print(f"Load from checkpoint folder: {global_step_folder}")
779
+ # set global step
780
+ self.global_steps = int(global_step_folder.split("global_step_")[-1])
781
+
782
+ print(f"Setting global step to {self.global_steps}")
783
+ print(f"Resuming from {global_step_folder}")
784
+
785
+ actor_path = os.path.join(global_step_folder, "actor")
786
+ critic_path = os.path.join(global_step_folder, "critic")
787
+ # load actor
788
+ self.actor_rollout_wg.load_checkpoint(
789
+ actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
790
+ )
791
+ # load critic
792
+ if self.use_critic:
793
+ self.critic_wg.load_checkpoint(
794
+ critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
795
+ )
796
+
797
+ # load dataloader,
798
+ # TODO: from remote not implemented yet
799
+ dataloader_local_path = os.path.join(global_step_folder, "data.pt")
800
+ if os.path.exists(dataloader_local_path):
801
+ dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False)
802
+ self.train_dataloader.load_state_dict(dataloader_state_dict)
803
+ else:
804
+ print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch")
805
+
806
+ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"):
807
+ """Reorder the data on single controller such that each dp rank gets similar total tokens"""
808
+ attention_mask = batch.batch["attention_mask"]
809
+ batch_size = attention_mask.shape[0]
810
+ global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
811
+ world_size = self.actor_rollout_wg.world_size
812
+ global_partition_lst = get_seqlen_balanced_partitions(
813
+ global_seqlen_lst, k_partitions=world_size, equal_size=True
814
+ )
815
+ # reorder based on index. The data will be automatically equally partitioned by dispatch function
816
+ global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
817
+ batch.reorder(global_idx)
818
+ global_balance_stats = log_seqlen_unbalance(
819
+ seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix
820
+ )
821
+ metrics.update(global_balance_stats)
822
+
823
+ def fit_dpo(self): # Renamed for clarity as standard PPO loop
824
+ """
825
+ The training loop of Online DPO using a periodically updated reference model.
826
+ The driver process calls worker groups for computation.
827
+ Advantage computation is replaced by DPO logic.
828
+ """
829
+ import traceback # Ensure traceback is imported
830
+
831
+ from omegaconf import OmegaConf
832
+
833
+ from verl.utils.tracking import Tracking
834
+
835
+ # Initialize logger
836
+ logger = None
837
+ try:
838
+ logger = Tracking(
839
+ project_name=self.config.trainer.project_name,
840
+ experiment_name=self.config.trainer.experiment_name,
841
+ default_backend=self.config.trainer.logger,
842
+ config=OmegaConf.to_container(self.config, resolve=True, throw_on_missing=False),
843
+ )
844
+ except Exception as e:
845
+ print(f"Warning: Failed to initialize logger: {e}")
846
+
847
+ self.global_steps = 0
848
+ # Load checkpoint before doing anything
849
+ loaded_step = self._load_checkpoint()
850
+ self.global_steps = loaded_step + 1 if loaded_step is not None and loaded_step > 0 else 1
851
+ print(
852
+ f"Starting Online DPO training from global step {self.global_steps}. "
853
+ f"Total steps: {self.total_training_steps}"
854
+ )
855
+ print(f"Reference model update frequency: {self.config.trainer.get('ref_update_freq', 'Not Set')}")
856
+
857
+ # Check if reference policy is configured correctly for this mode
858
+ if not self.use_reference_policy:
859
+ print(
860
+ "WARNING: 'use_reference_policy' is False. Periodic reference model update requires a "
861
+ "reference policy worker. DPO updates might fail or use incorrect logic."
862
+ )
863
+ # Consider raising an error if strict adherence is required:
864
+ # raise ValueError("Periodic reference model update requires 'use_reference_policy' to be True "
865
+ # "and a configured reference worker.")
866
+
867
+ # Perform validation before training
868
+ if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
869
+ print("Running validation before Online DPO training...")
870
+ val_metrics = self._validate()
871
+ pprint(f"Initial validation metrics: {val_metrics}")
872
+ if logger and val_metrics:
873
+ logger.log(data=val_metrics, step=max(0, self.global_steps - 1))
874
+ if self.config.trainer.get("val_only", False):
875
+ print("Validation only mode enabled. Exiting training.")
876
+ if logger and hasattr(logger, "finish"):
877
+ logger.finish()
878
+ return
879
+
880
+ # Add tqdm progress bar
881
+ progress_bar = tqdm(
882
+ total=self.total_training_steps,
883
+ initial=self.global_steps,
884
+ desc="Online DPO Training Progress",
885
+ position=0,
886
+ leave=True,
887
+ )
888
+
889
+ last_val_metrics = None
890
+ should_stop = False
891
+
892
+ for epoch in range(self.config.trainer.total_epochs):
893
+ if should_stop:
894
+ break
895
+ print(f"--- Starting Online DPO Epoch {epoch} ---")
896
+ try:
897
+ train_iterator = iter(self.train_dataloader)
898
+ except TypeError:
899
+ print("Warning: Dataloader is not iterable.")
900
+ train_iterator = self.train_dataloader # Fallback attempt
901
+
902
+ for batch_idx, batch_dict in enumerate(train_iterator):
903
+ if self.global_steps > self.total_training_steps:
904
+ should_stop = True
905
+ break
906
+
907
+ metrics = {}
908
+ timing_raw = {}
909
+ step_timer = Timer(logger=None)
910
+ ref_log_prob_computed = False # Flag to track if ref log probs were computed
911
+
912
+ try: # Outer try-except for the whole step
913
+ step_timer.start()
914
+ with _timer("step", timing_raw):
915
+ batch: DataProto = DataProto.from_single_dict(batch_dict)
916
+ current_batch_size = batch.batch.batch_size[0]
917
+ print(
918
+ f"\n[Step {self.global_steps}, Batch {batch_idx}] Processing batch size: "
919
+ f"{current_batch_size}"
920
+ )
921
+
922
+ # --- Reference Model Update ---
923
+ ref_update_freq = self.config.trainer.get("ref_update_freq", -1)
924
+ if (
925
+ self.use_reference_policy
926
+ and ref_update_freq > 0
927
+ and self.global_steps % ref_update_freq == 0
928
+ ):
929
+ print(f"\n[Step {self.global_steps}] Updating Reference Model Weights from Actor...")
930
+ try:
931
+ # --- This requires careful implementation with FSDP ---
932
+ # 1. Save actor state dict (potentially to CPU memory or disk)
933
+ # This needs to be done collectively across actor worker ranks.
934
+ # The checkpoint_manager might be adaptable, or use FSDP APIs directly.
935
+ # Example placeholder using a conceptual save/load mechanism:
936
+ actor_state_path = "/tmp/actor_state_mid" # Temporary path
937
+ self.actor_rollout_wg.save_checkpoint(actor_state_path) # Adapt save logic
938
+
939
+ # 2. Load the state dict onto the reference model worker group
940
+ # This also needs collective loading on the ref worker ranks.
941
+ self.ref_policy_wg.load_checkpoint(actor_state_path, None, True) # Adapt load logic
942
+
943
+ print(f"[Step {self.global_steps}] Reference Model Weights Updated.")
944
+ # Optionally remove the temporary state file
945
+ # os.remove(actor_state_path) # Needs rank-aware removal or shared storage
946
+
947
+ except Exception as sync_e:
948
+ print(f"ERROR during reference model sync at step {self.global_steps}: {sync_e}")
949
+ traceback.print_exc()
950
+
951
+ # Pop keys for generation
952
+ pop_batch_keys = ["input_ids", "attention_mask"]
953
+ if "position_ids" in batch.batch:
954
+ pop_batch_keys.append("position_ids")
955
+ pop_non_tensor_keys = ["raw_prompt_ids"] if "raw_prompt_ids" in batch.non_tensor_batch else []
956
+ if "multi_modal_inputs" in batch.non_tensor_batch.keys():
957
+ pop_non_tensor_keys.extend(["multi_modal_data", "multi_modal_inputs"])
958
+ original_non_tensor_data = batch.non_tensor_batch
959
+ gen_batch = batch.pop(
960
+ batch_keys=pop_batch_keys,
961
+ non_tensor_batch_keys=pop_non_tensor_keys,
962
+ )
963
+ gen_batch = gen_batch.repeat(
964
+ repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
965
+ )
966
+ # (Add Debug prints for gen_batch if needed)
967
+
968
+ # Generate sequences (chosen/rejected pairs)
969
+ with _timer("gen", timing_raw):
970
+ try:
971
+ gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
972
+ # (Add Debug prints for gen_batch_output if needed)
973
+ except Exception as gen_e:
974
+ print(f"\n!!!!!!!! ERROR DURING GENERATION (Step {self.global_steps}) !!!!!!!!")
975
+ print(gen_e)
976
+ traceback.print_exc()
977
+ print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
978
+ step_timer.stop()
979
+ continue
980
+
981
+ # Combine original prompts with generated sequences
982
+ batch.non_tensor_batch = original_non_tensor_data # Restore non-tensor data
983
+ batch.non_tensor_batch["uid"] = np.array(
984
+ [str(uuid.uuid4()) for _ in range(current_batch_size)], dtype=object
985
+ )
986
+ batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
987
+ batch = batch.union(gen_batch_output)
988
+ # (Add Debug prints after union if needed)
989
+
990
+ # Compute response mask (needed for ref logprob calc and DPO prep)
991
+ batch.batch["response_mask"] = compute_response_mask(batch)
992
+
993
+ if self.config.trainer.balance_batch:
994
+ self._balance_batch(batch, metrics=metrics)
995
+
996
+ batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
997
+
998
+ # --- Compute Log Probs for the CURRENT policy (used for KL if enabled, or ActorAsRef
999
+ # fallback) ---
1000
+ # Note: For pure DPO with external ref, this 'old_log_probs' might not be strictly needed
1001
+ # unless used for other metrics or a fallback. Keep it for now.
1002
+ with _timer("policy_log_prob", timing_raw):
1003
+ policy_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch)
1004
+ batch = batch.union(policy_log_prob_output) # Adds 'old_log_probs'
1005
+ # (Debug prints for old_log_probs)
1006
+
1007
+ # --- Compute Log Probs using the EXTERNAL Reference Model ---
1008
+ if self.use_reference_policy:
1009
+ with _timer("ref_log_prob_dpo", timing_raw):
1010
+ # print(f"---- [Step {self.global_steps}] DEBUG DPO: Calling compute_ref_log_prob ----")
1011
+ try:
1012
+ # 'batch' contains interleaved chosen/rejected sequences
1013
+ ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob(
1014
+ batch
1015
+ ) # Returns DataProto with 'ref_log_prob'
1016
+ batch = batch.union(
1017
+ ref_log_prob_output
1018
+ ) # Adds 'ref_log_prob' key [batch_size * n, seq_len]
1019
+ ref_log_prob_computed = True # Mark success
1020
+ # print(f"---- [Step {self.global_steps}] DEBUG DPO: ref_log_prob tensor shape: "
1021
+ # f"{batch.batch['ref_log_prob'].shape} ----")
1022
+ except Exception as ref_e:
1023
+ print(f"ERROR computing reference log probs at step {self.global_steps}: {ref_e}")
1024
+ traceback.print_exc()
1025
+ batch.batch["ref_log_prob"] = None # Mark as failed
1026
+ ref_log_prob_computed = False
1027
+ else:
1028
+ print(
1029
+ "Warning: Skipping external reference log prob calculation as use_reference_policy "
1030
+ "is False."
1031
+ )
1032
+ # DPO update will likely fail unless ActorAsRef logic is re-enabled in dp_actor
1033
+
1034
+ # --- Compute Rewards/Scores (used to determine preference) ---
1035
+ with _timer("reward_calc", timing_raw):
1036
+ # (Reward calculation logic using RM or reward_fn as before)
1037
+ # ... Ensure this calculates 'token_level_rewards' or similar ...
1038
+ if self.use_rm:
1039
+ reward_tensor_rm = self.rm_wg.compute_rm_score(batch)
1040
+ batch = batch.union(reward_tensor_rm) # Adds 'rm_scores'
1041
+
1042
+ reward_extra_infos_dict = {}
1043
+ try:
1044
+ if self.reward_fn is None:
1045
+ # print(f"---- [DEBUG Step {self.global_steps}] ERROR: self.reward_fn is None! "
1046
+ # f"Using dummy rewards. ----")
1047
+ # Use rm_scores if available, otherwise zeros
1048
+ reward_tensor = batch.batch.get(
1049
+ "rm_scores", torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32)
1050
+ )
1051
+ else:
1052
+ reward_result = self.reward_fn(batch, return_dict=True)
1053
+ reward_tensor = reward_result["reward_tensor"] # Final combined reward
1054
+ reward_extra_infos_dict = reward_result.get("reward_extra_info", {})
1055
+
1056
+ except Exception:
1057
+ # print(f'---- [DEBUG Step {self.global_steps}] Error in reward_fn call: {e}. '
1058
+ # f'Using dummy rewards. ----')
1059
+ traceback.print_exc()
1060
+ reward_tensor = torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32)
1061
+ reward_extra_infos_dict = {}
1062
+
1063
+ # Use 'token_level_rewards' as the key for preference calculation
1064
+ batch.batch["token_level_rewards"] = reward_tensor
1065
+ if reward_extra_infos_dict:
1066
+ batch.non_tensor_batch.update(
1067
+ {k: np.array(v) for k, v in reward_extra_infos_dict.items()}
1068
+ )
1069
+
1070
+ # --- Determine Preferences ---
1071
+ # Uses 'token_level_rewards' to determine chosen/rejected based on score
1072
+ batch = compute_onlineDPO_pref(batch) # Adds 'preferences' key
1073
+
1074
+ # --- Prepare DPO Batch ---
1075
+ dpo_update_batch_proto = None # Initialize
1076
+ with _timer("prepare_dpo_batch", timing_raw):
1077
+ try:
1078
+ if "preferences" not in batch.batch or batch.batch["preferences"] is None:
1079
+ raise ValueError("'preferences' key missing or None after compute_onlineDPO_pref.")
1080
+
1081
+ # Check if reference log probs were computed successfully (if needed)
1082
+ if self.use_reference_policy and not ref_log_prob_computed:
1083
+ raise ValueError("Reference log probs required but failed to compute.")
1084
+
1085
+ # Check required base keys
1086
+ required_keys = ["input_ids", "attention_mask", "response_mask"]
1087
+ for rk in required_keys:
1088
+ if rk not in batch.batch or batch.batch[rk] is None:
1089
+ raise KeyError(f"Required key '{rk}' missing from batch for DPO prep.")
1090
+
1091
+ preferences_mask = batch.batch["preferences"] # Shape [batch_size * n]
1092
+ not_preferences_mask = ~preferences_mask
1093
+
1094
+ # Gather Chosen/Rejected Base Tensors
1095
+ chosen_input_ids = batch.batch["input_ids"][preferences_mask]
1096
+ chosen_attention_mask = batch.batch["attention_mask"][preferences_mask]
1097
+ rejected_input_ids = batch.batch["input_ids"][not_preferences_mask]
1098
+ rejected_attention_mask = batch.batch["attention_mask"][not_preferences_mask]
1099
+ chosen_position_ids = (
1100
+ batch.batch.get("position_ids")[preferences_mask]
1101
+ if "position_ids" in batch.batch
1102
+ else None
1103
+ )
1104
+ rejected_position_ids = (
1105
+ batch.batch.get("position_ids")[not_preferences_mask]
1106
+ if "position_ids" in batch.batch
1107
+ else None
1108
+ )
1109
+
1110
+ # Create Labels
1111
+ print("WARNING: Creating DPO labels using configured max_prompt_length...")
1112
+ prompt_len = self.config.data.max_prompt_length
1113
+ chosen_labels = chosen_input_ids.clone()
1114
+ chosen_labels[:, :prompt_len] = -100
1115
+ rejected_labels = rejected_input_ids.clone()
1116
+ rejected_labels[:, :prompt_len] = -100
1117
+
1118
+ # Calculate and Gather Reference Log Probs (Sequence Level)
1119
+ if self.use_reference_policy:
1120
+ ref_log_prob_tensor = batch.batch["ref_log_prob"] # Token level [bsz * n, seq_len]
1121
+ response_mask_full = batch.batch[
1122
+ "response_mask"
1123
+ ] # Response mask [bsz * n, seq_len]
1124
+ ref_sequence_logps = (ref_log_prob_tensor * response_mask_full).sum(
1125
+ dim=-1
1126
+ ) # Sequence level [bsz * n]
1127
+ reference_chosen_logps = ref_sequence_logps[preferences_mask]
1128
+ reference_rejected_logps = ref_sequence_logps[not_preferences_mask]
1129
+ else:
1130
+ # If not using external ref, DPO needs ActorAsRef logic in dp_actor
1131
+ # We won't add the keys here, dp_actor will handle it (or fail if not modified)
1132
+ print(
1133
+ "Info: Not adding explicit reference logps to DPO batch "
1134
+ "(use_reference_policy=False)."
1135
+ )
1136
+ reference_chosen_logps = None # Explicitly None
1137
+ reference_rejected_logps = None
1138
+
1139
+ # Package Tensors
1140
+ dpo_tensors = {
1141
+ "chosen_input_ids": chosen_input_ids,
1142
+ "chosen_attention_mask": chosen_attention_mask,
1143
+ "chosen_labels": chosen_labels,
1144
+ "rejected_input_ids": rejected_input_ids,
1145
+ "rejected_attention_mask": rejected_attention_mask,
1146
+ "rejected_labels": rejected_labels,
1147
+ }
1148
+ # Conditionally add reference logps if computed
1149
+ if reference_chosen_logps is not None:
1150
+ dpo_tensors["reference_chosen_logps"] = reference_chosen_logps
1151
+ if reference_rejected_logps is not None:
1152
+ dpo_tensors["reference_rejected_logps"] = reference_rejected_logps
1153
+ # Add position ids if they exist
1154
+ if chosen_position_ids is not None:
1155
+ dpo_tensors["chosen_position_ids"] = chosen_position_ids
1156
+ if rejected_position_ids is not None:
1157
+ dpo_tensors["rejected_position_ids"] = rejected_position_ids
1158
+
1159
+ # Prepare Meta Info
1160
+ dpo_meta = {
1161
+ "dpo_beta": OmegaConf.select(self.config.algorithm, "dpo_beta", default=0.1),
1162
+ "dpo_loss_type": OmegaConf.select(
1163
+ self.config.algorithm, "dpo_loss_type", default="sigmoid"
1164
+ ),
1165
+ "dpo_label_smoothing": OmegaConf.select(
1166
+ self.config.algorithm, "dpo_label_smoothing", default=0.0
1167
+ ),
1168
+ "use_reference_policy": self.use_reference_policy,
1169
+ "reference_free": not self.use_reference_policy, # False if using external ref
1170
+ "global_step": self.global_steps,
1171
+ }
1172
+
1173
+ dpo_update_batch_proto = DataProto.from_dict(tensors=dpo_tensors, meta_info=dpo_meta)
1174
+ # print(f"---- [Step {self.global_steps}] DEBUG DPO: Prepared DPO Update Batch ----")
1175
+ # print(f" Keys: {list(dpo_update_batch_proto.batch.keys())}")
1176
+ # print(f" Meta Info: {dpo_meta}")
1177
+
1178
+ except Exception as e_prep:
1179
+ print(f"ERROR preparing DPO batch at step {self.global_steps}: {e_prep}")
1180
+ traceback.print_exc()
1181
+ dpo_update_batch_proto = None # Skip update on error
1182
+
1183
+ # --- Actor Update Step ---
1184
+ actor_output = None
1185
+ if self.config.trainer.critic_warmup <= self.global_steps and dpo_update_batch_proto:
1186
+ with _timer("update_actor", timing_raw):
1187
+ # Pass the batch containing reference log probs (if computed)
1188
+ # The modified update_actor_dpo expects them if reference_free=False
1189
+ actor_output = self.actor_rollout_wg.update_actor_dpo(dpo_update_batch_proto)
1190
+ if actor_output and "metrics" in actor_output.meta_info:
1191
+ metrics.update(reduce_metrics(actor_output.meta_info["metrics"]))
1192
+ elif dpo_update_batch_proto is None:
1193
+ print(
1194
+ f"Skipping actor update at step {self.global_steps} due to DPO batch preparation error."
1195
+ )
1196
+
1197
+ # --- Validation and Saving ---
1198
+ test_freq = OmegaConf.select(self.config.trainer, "test_freq", default=-1)
1199
+ is_last_step = self.global_steps >= self.total_training_steps
1200
+ if (
1201
+ self.val_reward_fn is not None
1202
+ and test_freq > 0
1203
+ and (is_last_step or self.global_steps % test_freq == 0)
1204
+ ):
1205
+ print(f"\nRunning DPO validation at step {self.global_steps}...")
1206
+ val_timing_raw = {}
1207
+ with _timer("testing", val_timing_raw):
1208
+ val_metrics: dict = self._validate()
1209
+ if is_last_step:
1210
+ last_val_metrics = val_metrics
1211
+ if val_metrics:
1212
+ metrics["time/validation_run"] = val_timing_raw.get("testing", 0)
1213
+ metrics.update(val_metrics)
1214
+ else:
1215
+ print("Validation skipped or returned no metrics.")
1216
+
1217
+ save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1)
1218
+ if save_freq > 0 and (is_last_step or self.global_steps % save_freq == 0):
1219
+ print(f"\nSaving DPO checkpoint at step {self.global_steps}...")
1220
+ with _timer("save_checkpoint", timing_raw):
1221
+ self._save_checkpoint() # Saves actor (and potentially critic if used elsewhere)
1222
+ metrics["time/save_checkpoint"] = timing_raw.get("save_checkpoint", 0)
1223
+
1224
+ # --- End main step timer context ---
1225
+
1226
+ # --- Metrics calculation AFTER the 'step' timer block ---
1227
+ metrics.update(compute_dpo_data_metrics(batch=batch)) # Use DPO-specific metrics
1228
+ metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
1229
+ n_gpus = self.resource_pool_manager.get_n_gpus()
1230
+ if "step" in timing_raw:
1231
+ metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
1232
+ else:
1233
+ print(
1234
+ f"Warning: 'step' key missing from timing_raw at step {self.global_steps}. "
1235
+ f"Skipping throughput."
1236
+ )
1237
+
1238
+ step_timer.stop()
1239
+ metrics["time/step"] = step_timer.last
1240
+
1241
+ # Log metrics
1242
+ log_freq = OmegaConf.select(self.config.trainer, "log_freq", default=1)
1243
+ if logger and self.global_steps % log_freq == 0:
1244
+ log_payload = metrics.copy()
1245
+ # Add learning rate to log payload
1246
+ if actor_output and "actor/lr" in metrics:
1247
+ log_payload["actor/lr"] = metrics["actor/lr"]
1248
+
1249
+ print(f"[Step {self.global_steps} DPO] Logging Step Payload Keys: {list(log_payload.keys())}")
1250
+ try:
1251
+ logger.log(data=log_payload, step=self.global_steps)
1252
+ except Exception as e:
1253
+ print(f"Logging failed at step {self.global_steps}: {e}")
1254
+
1255
+ # Update progress bar
1256
+ postfix_metrics = {
1257
+ k: f"{v:.3f}" if isinstance(v, float) else v
1258
+ for k, v in metrics.items()
1259
+ if isinstance(v, int | float)
1260
+ }
1261
+ progress_bar.set_postfix(postfix_metrics)
1262
+
1263
+ except Exception as step_e:
1264
+ print(f"\n!!!!!!!! ERROR DURING DPO Step {self.global_steps} !!!!!!!!")
1265
+ print(f"Caught Exception: {step_e}")
1266
+ traceback.print_exc()
1267
+ print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
1268
+ step_timer.stop()
1269
+ should_stop = True
1270
+ break
1271
+
1272
+ if is_last_step or should_stop:
1273
+ print(f"Stopping DPO training at step {self.global_steps}.")
1274
+ break
1275
+
1276
+ self.global_steps += 1
1277
+ progress_bar.update(1)
1278
+
1279
+ # End of epoch handling
1280
+ if hasattr(self.train_dataloader, "reset"):
1281
+ try:
1282
+ self.train_dataloader.reset()
1283
+ except Exception as e:
1284
+ print(f"Warning: Failed to reset train dataloader state: {e}")
1285
+ if should_stop:
1286
+ break
1287
+
1288
+ # --- Final cleanup and logging ---
1289
+ progress_bar.close()
1290
+ final_step = max(0, self.global_steps - 1)
1291
+ print(f"Online DPO Training finished at step {final_step}.")
1292
+ # Save final checkpoint
1293
+ save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1)
1294
+ if not self.config.trainer.get("val_only", False) and (save_freq <= 0 or final_step % save_freq != 0):
1295
+ print(f"Saving final DPO checkpoint at step {final_step}...")
1296
+ self._save_checkpoint()
1297
+
1298
+ # Final validation run
1299
+ if self.val_reward_fn and last_val_metrics is None and not self.config.trainer.get("val_only", False):
1300
+ print("Running final validation...")
1301
+ last_val_metrics = self._validate()
1302
+ if last_val_metrics and logger:
1303
+ last_val_metrics["final_validation"] = True
1304
+ try:
1305
+ logger.log(data=last_val_metrics, step=final_step)
1306
+ except Exception as e:
1307
+ print(f"[Final Val Metrics Log Error]: {e}")
1308
+
1309
+ pprint(f"Final validation metrics: {last_val_metrics}")
1310
+ if logger and hasattr(logger, "finish"):
1311
+ logger.finish()
1312
+ print("Online DPO Training Run Complete.")
ICL/LV/code/README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Unified Multi-Model VQA Codebase
2
+
3
+ 目的
4
+ - 这一套代码是“模型无关”的通用评测/数据/提示构造层;所有模型仅通过“适配器”接入。
5
+ - 通用输入是 OpenAI 扁平内容序列(image→text;示例用 [REQUEST]/[RESPONSE];查询的 [RESPONSE] 为空)。
6
+
7
+ 目录
8
+ - core/
9
+ - prompting/openai_segments.py 扁平序列构造与落盘适配
10
+ - datasets/m3it_reader.py M3IT 统一读取 & base64 图片缓存
11
+ - metrics/metrics.py Token‑F1、BERTScore‑F1 等
12
+ - eval/ 与模型无关的评测脚本(调用 adapters)
13
+ - zero_shot_vqa.py / random_k_shot_vqa.py
14
+ - eval_textual_retriever_vqa.py / eval_visual_retriever_vqa.py / eval_multimodal_retriever_vqa.py
15
+ - order 评测(统一缓存 + 独立指标脚本):
16
+ - order_eval_core.py(内部调用) / _modal_order.py(内部调用)
17
+ - eval_order_caption_bertscore.py / eval_order_caption_cider.py
18
+ - eval_order_classification_accuracy.py / eval_order_classification_f1.py
19
+ - eval_order_reasoning_accuracy.py / eval_order_reasoning_ras.py
20
+ - eval_order_vqa_bertscore.py / eval_order_vqa_tokenf1.py
21
+ - adapters/
22
+ - idefics2_adapter.py
23
+ - qwen_vl_adapter.py
24
+ - qwen3vl_adapter.py
25
+ - gemma3_adapter.py
26
+
27
+ 使用
28
+ - 例:零样本(Idefics2)
29
+ python3 -m core.eval.zero_shot_vqa \
30
+ --adapter idefics2 \
31
+ --model-path /path/to/idefics2-8b \
32
+ --dataset-root /path/to/M3IT \
33
+ --split test --total-samples 500 \
34
+ --instruction-image "C:\\Users\\you\\instruction.png" --dump-first 2
35
+
36
+ - 例:随机 few‑shot(Qwen‑VL)
37
+ python3 -m core.eval.random_k_shot_vqa \
38
+ --adapter qwen-vl \
39
+ --model-path /path/to/Qwen-VL \
40
+ --dataset-root /path/to/M3IT \
41
+ --split test --k-shots 3 --total-samples 500 \
42
+ --use-paper-instruction --instruction-image "C:\\Users\\you\\instruction.png"
43
+
44
+ - 例:模态顺序评测(以 VQA Token-F1 为例)
45
+ python3 -m core.eval.eval_order_vqa_tokenf1 \
46
+ --adapter idefics2 \
47
+ --model-path /path/to/idefics2-8b \
48
+ --dataset-root /path/to/M3IT \
49
+ --retriever-model-path /path/to/BridgeTower-or-CLIP \
50
+ --orders image-text,text-image,text-image-text \
51
+ --k-shots 3 --total-samples 500 --split val
52
+ --adapter qwen-vl \
53
+ --model-path /path/to/Qwen-VL \
54
+ --dataset-root /path/to/M3IT \
55
+ --split test --k-shots 3 --total-samples 500 \
56
+ --use-paper-instruction --instruction-image "C:\\Users\\you\\instruction.png"
57
+
58
+ 约定
59
+ - 适配器接口见 adapters/*.py:
60
+ - create(model_path: str) -> Adapter
61
+ - Adapter.generate_from_segments(segs: List[dict], temperature: float, top_p: float, max_new_tokens: int) -> str
62
+ - 可选:Adapter.generate_single(image_path: str, prompt: str, ...)
63
+
64
+ 说明
65
+ - 适配器与通用源码彻底分离;你可以只替换 adapters/xxx_adapter.py 即可对接新模型。
66
+ - Windows 路径/BASE64/data:URL 的图片由 prompting/openai_segments.py 自动兼容。
ICL/LV/code/SFT/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (8.2 kB). View file
 
ICL/LV/code/SFT/build_icl_eval_sharegpt.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Build a prompt-only ShareGPT-style eval set for deciding <RET> vs <ANS>.
4
+
5
+ Prompt format is aligned with build_icl_dataset.py:
6
+ instruction + <image> + "Question: ...\\nAction:"
7
+
8
+ But for evaluation we keep ONLY the initial human turn in `conversations` to avoid leaking labels.
9
+ Gold labels are stored outside the prompt:
10
+ - expected_first_tag: "<RET>" or "<ANS>" (NOT included in conversations)
11
+ - answer: used for offline checking (NOT included in conversations)
12
+ - shots: for RET samples only, used for the follow-up step after model outputs <RET>
13
+
14
+ Important:
15
+ - Never use train split for eval; recommend val/test/dev.
16
+ - Optionally excludes any uid already present in an existing training jsonl to avoid overlap.
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ import random
22
+ import sys
23
+ from dataclasses import dataclass
24
+ from pathlib import Path
25
+ from typing import Dict, List, Optional, Set
26
+
27
+
28
+ # Add code root to PYTHONPATH for core/ imports
29
+ CODE_ROOT = Path(__file__).resolve().parents[1]
30
+ if str(CODE_ROOT) not in sys.path:
31
+ sys.path.insert(0, str(CODE_ROOT))
32
+
33
+ from core.datasets.m3it_reader import iter_m3it_samples, load_instructions # noqa: E402
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class PoolItem:
38
+ image_path: str
39
+ description: str
40
+ subdir: str
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class QueryItem:
45
+ image_path: str
46
+ question: str
47
+ answer: str
48
+ subdir: str
49
+ uid: str
50
+
51
+
52
+ def _extract_uid(raw: Dict, fallback: str) -> str:
53
+ if isinstance(raw, dict):
54
+ for k in ("id", "image_id"):
55
+ v = raw.get(k)
56
+ if isinstance(v, (str, int)):
57
+ return str(v)
58
+ meta = raw.get("meta") if isinstance(raw.get("meta"), dict) else {}
59
+ for k in ("img_id", "id", "image_id"):
60
+ v = meta.get(k)
61
+ if isinstance(v, (str, int)):
62
+ return str(v)
63
+ return fallback
64
+
65
+
66
+ def discover_subdirs(dataset_root: Path, category: str) -> List[str]:
67
+ base = dataset_root / "data" / category
68
+ if not base.exists():
69
+ return []
70
+ out: List[str] = []
71
+ for p in sorted(base.iterdir()):
72
+ if p.is_dir():
73
+ out.append(f"{category}/{p.name}")
74
+ return out
75
+
76
+
77
+ def pick_instruction(insts: List[str], rng: random.Random) -> str:
78
+ if insts:
79
+ s = rng.choice(insts)
80
+ if isinstance(s, str) and s.strip():
81
+ return s.strip()
82
+ return "Please answer the question based on the image."
83
+
84
+
85
+ def load_exclude_uids(path: Optional[str]) -> Set[str]:
86
+ if not path:
87
+ return set()
88
+ p = Path(path)
89
+ if not p.exists():
90
+ return set()
91
+
92
+ out: Set[str] = set()
93
+ with p.open("r", encoding="utf-8") as f:
94
+ for line in f:
95
+ line = line.strip()
96
+ if not line:
97
+ continue
98
+ try:
99
+ obj = json.loads(line)
100
+ except Exception:
101
+ continue
102
+ if not isinstance(obj, dict):
103
+ continue
104
+ uid = obj.get("uid")
105
+ if isinstance(uid, (str, int)):
106
+ out.add(str(uid))
107
+ continue
108
+ sid = obj.get("id")
109
+ if isinstance(sid, (str, int)):
110
+ out.add(str(sid))
111
+ return out
112
+
113
+
114
+ def to_rel(path: str, root: Path) -> str:
115
+ try:
116
+ return str(Path(path).relative_to(root))
117
+ except Exception:
118
+ return path
119
+
120
+
121
+ def build_pool_for_subdir(
122
+ *,
123
+ dataset_root: Path,
124
+ subdir: str,
125
+ split: str,
126
+ cache_dir: Path,
127
+ target_n: int,
128
+ max_samples_scan: int,
129
+ ) -> List[PoolItem]:
130
+ items: List[PoolItem] = []
131
+ try:
132
+ iterable = iter_m3it_samples(
133
+ str(dataset_root),
134
+ subdir,
135
+ split=split,
136
+ cache_dir=str(cache_dir),
137
+ max_samples=None,
138
+ )
139
+ except FileNotFoundError:
140
+ return []
141
+
142
+ for idx, smp in enumerate(iterable):
143
+ if max_samples_scan > 0 and idx >= max_samples_scan:
144
+ break
145
+ if not smp.answers:
146
+ continue
147
+ desc = (smp.answers[0] or "").strip()
148
+ if not desc:
149
+ continue
150
+ items.append(PoolItem(smp.image_path, desc, subdir))
151
+ if target_n > 0 and len(items) >= target_n:
152
+ break
153
+ return items
154
+
155
+
156
+ def collect_query_pool(
157
+ *,
158
+ dataset_root: Path,
159
+ subdirs: List[str],
160
+ split: str,
161
+ cache_dir: Path,
162
+ exclude_uids: Set[str],
163
+ target_n: int,
164
+ seed: int,
165
+ max_samples_per_subdir: int,
166
+ ) -> List[QueryItem]:
167
+ rng = random.Random(seed)
168
+ subdirs = list(subdirs)
169
+ rng.shuffle(subdirs)
170
+
171
+ seen: Set[str] = set()
172
+ out: List[QueryItem] = []
173
+
174
+ for subdir in subdirs:
175
+ taken = 0
176
+ try:
177
+ iterable = iter_m3it_samples(
178
+ str(dataset_root),
179
+ subdir,
180
+ split=split,
181
+ cache_dir=str(cache_dir),
182
+ max_samples=None,
183
+ )
184
+ except FileNotFoundError:
185
+ continue
186
+
187
+ for i, smp in enumerate(iterable):
188
+ if max_samples_per_subdir > 0 and taken >= max_samples_per_subdir:
189
+ break
190
+ q = (smp.text or "").strip()
191
+ if not q:
192
+ continue
193
+ if not smp.answers:
194
+ continue
195
+ ans = (smp.answers[0] or "").strip()
196
+ if not ans:
197
+ continue
198
+ uid = _extract_uid(smp.raw, f"{subdir}:{i:08d}")
199
+ if uid in exclude_uids:
200
+ continue
201
+ if uid in seen:
202
+ continue
203
+ seen.add(uid)
204
+ out.append(QueryItem(smp.image_path, q, ans, subdir, uid))
205
+ taken += 1
206
+ if target_n > 0 and len(out) >= target_n:
207
+ return out
208
+ return out
209
+
210
+
211
+ def select_shots(
212
+ pool: List[PoolItem],
213
+ k: int,
214
+ rng: random.Random,
215
+ exclude_image: Optional[str] = None,
216
+ ) -> List[PoolItem]:
217
+ if not pool or k <= 0:
218
+ return []
219
+ cand = [p for p in pool if p.image_path != exclude_image]
220
+ if not cand:
221
+ cand = pool
222
+ if len(cand) >= k:
223
+ return rng.sample(cand, k=k)
224
+ return [rng.choice(cand) for _ in range(k)]
225
+
226
+
227
+ def write_jsonl(path: Path, records: List[Dict]) -> None:
228
+ path.parent.mkdir(parents=True, exist_ok=True)
229
+ with path.open("w", encoding="utf-8") as f:
230
+ for r in records:
231
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
232
+
233
+
234
+ def build_prompt_only_record(
235
+ *,
236
+ uid: str,
237
+ instruction: str,
238
+ image_rel: str,
239
+ question: str,
240
+ expected_first_tag: str,
241
+ answer: str,
242
+ category: str,
243
+ subdir: str,
244
+ shots: List[Dict],
245
+ k_shot: int,
246
+ ) -> Dict:
247
+ human = []
248
+ if instruction:
249
+ human.append(instruction.strip())
250
+ human.append("<image>")
251
+ human.append(f"Question: {question}\nAction:")
252
+ human_value = "\n".join([x for x in human if x]).strip()
253
+
254
+ return {
255
+ "id": uid,
256
+ "images": [image_rel],
257
+ "conversations": [
258
+ {"from": "human", "value": human_value},
259
+ ],
260
+ "expected_first_tag": expected_first_tag,
261
+ "answer": answer,
262
+ "k_shot": k_shot,
263
+ "shots": shots,
264
+ "category": category,
265
+ "subdir": subdir,
266
+ "instruction": instruction,
267
+ "query": {"image": image_rel, "question": question},
268
+ }
269
+
270
+
271
+ def main() -> int:
272
+ ap = argparse.ArgumentParser(description="Build prompt-only eval set (ShareGPT jsonl) for <RET>/<ANS> decision.")
273
+ ap.add_argument("--dataset-root", default="/workspace/M3IT")
274
+ ap.add_argument("--output-dir", default="/workspace/M3IT_new/ICL_eval")
275
+ ap.add_argument("--category", default="vqa")
276
+ ap.add_argument("--split", default="val", help="Never use train; recommend val/test/dev.")
277
+ ap.add_argument("--pool-split", default="val", help="Never use train; recommend val/test/dev.")
278
+ ap.add_argument("--seed", type=int, default=42)
279
+
280
+ ap.add_argument("--total", type=int, default=100)
281
+ ap.add_argument("--ret-ratio", type=float, default=0.5)
282
+ ap.add_argument("--query-pool-size", type=int, default=1000, help="How many queries to collect before sampling.")
283
+ ap.add_argument("--max-samples-per-subdir", type=int, default=2000, help="Scan cap per subdir when collecting queries.")
284
+ ap.add_argument("--pool-size-per-subdir", type=int, default=2000, help="Max pool size to build per subdir (for shots).")
285
+ ap.add_argument("--pool-scan-per-subdir", type=int, default=4000, help="Scan cap per subdir when building pools.")
286
+ ap.add_argument("--shot-k-min", type=int, default=1)
287
+ ap.add_argument("--shot-k-max", type=int, default=3)
288
+
289
+ ap.add_argument(
290
+ "--exclude-uids-from",
291
+ default="/workspace/M3IT_new/ICL/vqa/merged_shuffled_sharegpt.jsonl",
292
+ help="Optional jsonl to exclude uids/ids (to avoid overlap with training).",
293
+ )
294
+ ap.add_argument("--overwrite", action="store_true")
295
+ ap.add_argument("--output", default=None, help="Default: {output_dir}/{category}/eval_sharegpt_{total}.jsonl")
296
+ args = ap.parse_args()
297
+
298
+ if args.split.strip().lower() == "train" or args.pool_split.strip().lower() == "train":
299
+ raise ValueError("split/pool-split=train is not allowed for eval set")
300
+ if args.total <= 0:
301
+ raise ValueError("total must be > 0")
302
+ if not (0.0 <= args.ret_ratio <= 1.0):
303
+ raise ValueError("ret-ratio must be in [0, 1]")
304
+ if args.shot_k_min <= 0 or args.shot_k_max < args.shot_k_min:
305
+ raise ValueError("invalid shot-k range")
306
+
307
+ dataset_root = Path(args.dataset_root)
308
+ output_dir = Path(args.output_dir)
309
+ cache_dir = output_dir / "_image_cache"
310
+ cache_dir.mkdir(parents=True, exist_ok=True)
311
+
312
+ out_path = Path(
313
+ args.output
314
+ if args.output
315
+ else str(output_dir / args.category / f"eval_sharegpt_{args.total}.jsonl")
316
+ )
317
+ if out_path.exists() and not args.overwrite:
318
+ raise FileExistsError(f"Output exists: {out_path} (use --overwrite to replace)")
319
+
320
+ subdirs = discover_subdirs(dataset_root, args.category)
321
+ if not subdirs:
322
+ raise FileNotFoundError(f"No subdirs found under {dataset_root}/data/{args.category}")
323
+
324
+ exclude_uids = load_exclude_uids(args.exclude_uids_from)
325
+
326
+ # Load instructions once per subdir.
327
+ inst_map: Dict[str, List[str]] = {sd: load_instructions(dataset_root, sd) for sd in subdirs}
328
+
329
+ query_pool_target = max(args.total, args.query_pool_size)
330
+ queries = collect_query_pool(
331
+ dataset_root=dataset_root,
332
+ subdirs=subdirs,
333
+ split=args.split,
334
+ cache_dir=cache_dir,
335
+ exclude_uids=exclude_uids,
336
+ target_n=query_pool_target,
337
+ seed=args.seed,
338
+ max_samples_per_subdir=args.max_samples_per_subdir,
339
+ )
340
+ if len(queries) < args.total:
341
+ raise RuntimeError(f"Not enough queries collected: got {len(queries)}/{args.total}. "
342
+ f"Try increasing --max-samples-per-subdir or changing --split.")
343
+
344
+ rng = random.Random(args.seed)
345
+ ret_n = int(round(args.total * args.ret_ratio))
346
+ ret_n = max(0, min(args.total, ret_n))
347
+ ans_n = args.total - ret_n
348
+
349
+ labels = ["RET"] * ret_n + ["ANS"] * ans_n
350
+ rng.shuffle(labels)
351
+
352
+ # Pools are built lazily per subdir for RET samples only.
353
+ pool_map: Dict[str, List[PoolItem]] = {}
354
+
355
+ records: List[Dict] = []
356
+ used_uids: Set[str] = set()
357
+
358
+ for label in labels:
359
+ for _try in range(2000):
360
+ q = rng.choice(queries)
361
+ if q.uid in used_uids:
362
+ continue
363
+
364
+ inst = pick_instruction(inst_map.get(q.subdir, []), rng)
365
+ image_rel = to_rel(q.image_path, output_dir)
366
+
367
+ if label == "ANS":
368
+ used_uids.add(q.uid)
369
+ records.append(
370
+ build_prompt_only_record(
371
+ uid=q.uid,
372
+ instruction=inst,
373
+ image_rel=image_rel,
374
+ question=q.question,
375
+ expected_first_tag="<ANS>",
376
+ answer=q.answer,
377
+ category=args.category,
378
+ subdir=q.subdir,
379
+ shots=[],
380
+ k_shot=0,
381
+ )
382
+ )
383
+ break
384
+
385
+ # RET case: attach hidden shots for the follow-up step.
386
+ if q.subdir not in pool_map:
387
+ pool_map[q.subdir] = build_pool_for_subdir(
388
+ dataset_root=dataset_root,
389
+ subdir=q.subdir,
390
+ split=args.pool_split,
391
+ cache_dir=cache_dir,
392
+ target_n=args.pool_size_per_subdir,
393
+ max_samples_scan=args.pool_scan_per_subdir,
394
+ )
395
+ pool = pool_map.get(q.subdir, [])
396
+ if not pool:
397
+ continue
398
+
399
+ k = rng.randint(args.shot_k_min, args.shot_k_max)
400
+ shots_items = select_shots(pool, k, rng, exclude_image=q.image_path)
401
+ shots = [
402
+ {"image": to_rel(s.image_path, output_dir), "description": s.description}
403
+ for s in shots_items
404
+ ]
405
+ used_uids.add(q.uid)
406
+ records.append(
407
+ build_prompt_only_record(
408
+ uid=q.uid,
409
+ instruction=inst,
410
+ image_rel=image_rel,
411
+ question=q.question,
412
+ expected_first_tag="<RET>",
413
+ answer=q.answer,
414
+ category=args.category,
415
+ subdir=q.subdir,
416
+ shots=shots,
417
+ k_shot=k,
418
+ )
419
+ )
420
+ break
421
+ else:
422
+ raise RuntimeError(f"Failed to sample enough records for label={label}. "
423
+ f"Try increasing --query-pool-size or relaxing --exclude-uids-from.")
424
+
425
+ rng.shuffle(records)
426
+ write_jsonl(out_path, records)
427
+
428
+ # Lightweight summary to stdout.
429
+ ret_cnt = sum(1 for r in records if r.get("expected_first_tag") == "<RET>")
430
+ ans_cnt = len(records) - ret_cnt
431
+ print(f"[OK] wrote={len(records)} ret={ret_cnt} ans={ans_cnt} -> {out_path}")
432
+ print(f"[INFO] image_root (for eval): {output_dir}")
433
+ return 0
434
+
435
+
436
+ if __name__ == "__main__":
437
+ raise SystemExit(main())
ICL/LV/code/SFT/check_kshot_ret_ans.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Check whether the model outputs <RET> or <ANS> under different shot settings.
5
+
6
+ 0-shot: only query image + question.
7
+ K-shot (K>=1): after the model outputs <RET>, append 1 shot (image+description) and
8
+ ask again. If it still outputs <RET>, append another shot, and so on.
9
+ We record whether it outputs <RET>/<ANS> at each step.
10
+ """
11
+
12
+ import argparse
13
+ import json
14
+ import os
15
+ import random
16
+ import re
17
+ from typing import Any, Dict, List, Optional, Tuple
18
+
19
+ import torch
20
+ from PIL import Image
21
+ from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
22
+
23
+
24
+ TAG_RE = re.compile(r"(<ANS>|<RET>)")
25
+
26
+
27
+ def _extract_tag(text: str) -> Optional[str]:
28
+ match = TAG_RE.search(text)
29
+ return match.group(1) if match else None
30
+
31
+
32
+ def _resolve_path(root: str, maybe_rel: str) -> str:
33
+ if os.path.isabs(maybe_rel):
34
+ return maybe_rel
35
+ return os.path.normpath(os.path.join(root, maybe_rel))
36
+
37
+
38
+ def _split_user_text_with_images(
39
+ text: str, image_paths: List[str]
40
+ ) -> Tuple[List[Dict[str, str]], List[str]]:
41
+ parts = text.split("<image>")
42
+ content: List[Dict[str, str]] = []
43
+ used: List[str] = []
44
+ for i, part in enumerate(parts):
45
+ part = part.strip()
46
+ if part:
47
+ content.append({"type": "text", "text": part})
48
+ if i < len(parts) - 1:
49
+ if not image_paths:
50
+ raise ValueError("用户文本里 <image> 数量 > images 列表长度")
51
+ img_path = image_paths.pop(0)
52
+ used.append(img_path)
53
+ content.append({"type": "image", "image": img_path})
54
+ return content, used
55
+
56
+
57
+ def _append_human_turn(
58
+ *,
59
+ messages: List[Dict[str, Any]],
60
+ pil_images: List[Image.Image],
61
+ image_root: str,
62
+ text: str,
63
+ images_all: List[str],
64
+ image_cursor: int,
65
+ ) -> int:
66
+ n_placeholders = text.count("<image>")
67
+ img_paths = images_all[image_cursor : image_cursor + n_placeholders]
68
+ if len(img_paths) != n_placeholders:
69
+ raise ValueError("images 列表长度 < <image> 占位符数量")
70
+
71
+ user_content, used_paths = _split_user_text_with_images(text, img_paths.copy())
72
+ for p in used_paths:
73
+ p = _resolve_path(image_root, p) if not os.path.isabs(p) else p
74
+ if not os.path.exists(p):
75
+ raise FileNotFoundError(p)
76
+ with Image.open(p) as img:
77
+ pil_images.append(img.convert("RGB"))
78
+
79
+ messages.append({"role": "user", "content": user_content})
80
+ return image_cursor + n_placeholders
81
+
82
+
83
+ def _append_shot_turn(
84
+ *,
85
+ messages: List[Dict[str, Any]],
86
+ pil_images: List[Image.Image],
87
+ image_root: str,
88
+ image_path: str,
89
+ description: str,
90
+ ) -> None:
91
+ img_path = _resolve_path(image_root, image_path)
92
+ if not os.path.exists(img_path):
93
+ raise FileNotFoundError(img_path)
94
+ with Image.open(img_path) as im:
95
+ pil_images.append(im.convert("RGB"))
96
+ messages.append(
97
+ {
98
+ "role": "user",
99
+ "content": [
100
+ {"type": "image", "image": img_path},
101
+ {"type": "text", "text": f"Description: {description}"},
102
+ ],
103
+ }
104
+ )
105
+
106
+
107
+ def _build_base_messages(obj: Dict[str, Any], image_root: str) -> Tuple[List[Dict[str, Any]], List[Image.Image]]:
108
+ conversations = obj.get("conversations")
109
+ if not isinstance(conversations, list) or not conversations:
110
+ raise ValueError("样本缺少 conversations")
111
+ images_rel = obj.get("images") or []
112
+ if not isinstance(images_rel, list):
113
+ raise ValueError("images 字段不是 list")
114
+
115
+ messages: List[Dict[str, Any]] = []
116
+ pil_images: List[Image.Image] = []
117
+ image_cursor = 0
118
+
119
+ # Use the first human turn as query prompt.
120
+ human = None
121
+ for t in conversations:
122
+ if t.get("from") == "human":
123
+ human = t
124
+ break
125
+ if human is None:
126
+ raise ValueError("没有 human turn")
127
+
128
+ image_cursor = _append_human_turn(
129
+ messages=messages,
130
+ pil_images=pil_images,
131
+ image_root=image_root,
132
+ text=str(human.get("value", "")),
133
+ images_all=images_rel,
134
+ image_cursor=image_cursor,
135
+ )
136
+ return messages, pil_images
137
+
138
+
139
+ def _pick_shots_from_pool(
140
+ pool: List[Dict[str, str]],
141
+ k: int,
142
+ rng: random.Random,
143
+ exclude_image: Optional[str],
144
+ ) -> List[Dict[str, str]]:
145
+ if k <= 0 or not pool:
146
+ return []
147
+ cand = [p for p in pool if p.get("image") != exclude_image]
148
+ if not cand:
149
+ cand = pool
150
+ if len(cand) >= k:
151
+ return rng.sample(cand, k=k)
152
+ return [rng.choice(cand) for _ in range(k)]
153
+
154
+
155
+ def main() -> int:
156
+ ap = argparse.ArgumentParser(description="Check <RET>/<ANS> outputs under 0/1/2/3-shot settings.")
157
+ ap.add_argument("--model", required=True, help="HF model dir")
158
+ ap.add_argument(
159
+ "--data",
160
+ default="/workspace/M3IT_new/ICL_eval/vqa/eval_sharegpt_100.jsonl",
161
+ help="Prompt-only eval jsonl (has conversations/images/shots).",
162
+ )
163
+ ap.add_argument("--image-root", default="/workspace/M3IT_new/ICL_eval")
164
+ ap.add_argument("--num-samples", type=int, default=20)
165
+ ap.add_argument("--seed", type=int, default=42)
166
+ ap.add_argument("--k-list", default="0,1,2,3", help="Comma-separated shot counts to report.")
167
+ ap.add_argument("--max-new-tokens", type=int, default=128)
168
+ ap.add_argument("--device", default="cuda:0")
169
+ ap.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16")
170
+ ap.add_argument("--print-samples", action="store_true", help="Print each sample input/output.")
171
+ args = ap.parse_args()
172
+
173
+ rng = random.Random(args.seed)
174
+ k_list = [int(x.strip()) for x in args.k_list.split(",") if x.strip()]
175
+ if not k_list:
176
+ raise ValueError("k-list is empty")
177
+ max_k = max(k_list)
178
+
179
+ # Load dataset
180
+ data: List[Dict[str, Any]] = []
181
+ with open(args.data, "r", encoding="utf-8") as f:
182
+ for line in f:
183
+ line = line.strip()
184
+ if not line:
185
+ continue
186
+ data.append(json.loads(line))
187
+ if not data:
188
+ raise ValueError("empty data")
189
+
190
+ # Build global shot pool
191
+ pool: List[Dict[str, str]] = []
192
+ for obj in data:
193
+ shots = obj.get("shots") or []
194
+ if isinstance(shots, list):
195
+ for s in shots:
196
+ if not isinstance(s, dict):
197
+ continue
198
+ img = s.get("image")
199
+ desc = s.get("description")
200
+ if isinstance(img, str) and isinstance(desc, str) and img and desc:
201
+ pool.append({"image": img, "description": desc})
202
+ if not pool:
203
+ print("[WARN] shot pool is empty, k-shot tests may be skipped")
204
+
205
+ # Sample records
206
+ samples = rng.sample(data, k=min(args.num_samples, len(data)))
207
+
208
+ dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
209
+ device = torch.device(args.device)
210
+ processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
211
+ model = Qwen3VLForConditionalGeneration.from_pretrained(
212
+ args.model, dtype=dtype, trust_remote_code=True
213
+ ).to(device)
214
+ model.eval()
215
+
216
+ summary: Dict[int, Dict[str, int]] = {
217
+ k: {"RET": 0, "ANS": 0, "NONE": 0, "REACHED": 0} for k in k_list
218
+ }
219
+
220
+ for obj in samples:
221
+ uid = obj.get("id") or obj.get("uid") or "unknown"
222
+ query = obj.get("query") or {}
223
+ query_image = query.get("image")
224
+
225
+ messages, pil_images = _build_base_messages(obj, args.image_root)
226
+
227
+ # Pre-sample shots for this sample (use first N as we go).
228
+ shots_all = _pick_shots_from_pool(pool, max_k, rng, exclude_image=query_image)
229
+ if max_k > 0 and len(shots_all) < max_k:
230
+ continue
231
+
232
+ step = 0 # step 0 = query only
233
+ while True:
234
+ prompt = processor.apply_chat_template(
235
+ messages, tokenize=False, add_generation_prompt=True
236
+ )
237
+ inputs = processor(text=prompt, images=pil_images, padding=True, return_tensors="pt")
238
+ inputs = {k2: v.to(device) for k2, v in inputs.items()}
239
+
240
+ with torch.inference_mode():
241
+ out_ids = model.generate(
242
+ **inputs, do_sample=False, max_new_tokens=args.max_new_tokens
243
+ )
244
+ in_len = int(inputs["input_ids"].shape[1])
245
+ pred = processor.batch_decode(out_ids[:, in_len:], skip_special_tokens=True)[0].strip()
246
+ tag = _extract_tag(pred)
247
+
248
+ # Append model output to the conversation
249
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": pred}]})
250
+
251
+ if step in summary:
252
+ summary[step]["REACHED"] += 1
253
+ if tag == "<RET>":
254
+ summary[step]["RET"] += 1
255
+ elif tag == "<ANS>":
256
+ summary[step]["ANS"] += 1
257
+ else:
258
+ summary[step]["NONE"] += 1
259
+
260
+ if args.print_samples:
261
+ print("=" * 80)
262
+ print(f"uid={uid} | step={step} | pred_tag={tag}")
263
+ for m in messages:
264
+ role = m.get("role")
265
+ if role == "user":
266
+ parts = []
267
+ for c in m.get("content", []):
268
+ if c.get("type") == "text":
269
+ parts.append(c.get("text", ""))
270
+ elif c.get("type") == "image":
271
+ parts.append(f"<image> {c.get('image','')}")
272
+ print("[输入]")
273
+ print("\n".join([p for p in parts if p]).strip())
274
+ elif role == "assistant":
275
+ parts = []
276
+ for c in m.get("content", []):
277
+ if c.get("type") == "text":
278
+ parts.append(c.get("text", ""))
279
+ print("[输入-助手]")
280
+ print("\n".join([p for p in parts if p]).strip())
281
+ print("[输出]")
282
+ print(pred)
283
+
284
+ # Stop if not <RET> or reached max_k shots
285
+ if tag != "<RET>":
286
+ break
287
+ if step >= max_k:
288
+ break
289
+
290
+ # Append next shot and ask again.
291
+ shot = shots_all[step] if step < len(shots_all) else None
292
+ if not shot:
293
+ break
294
+ _append_shot_turn(
295
+ messages=messages,
296
+ pil_images=pil_images,
297
+ image_root=args.image_root,
298
+ image_path=shot["image"],
299
+ description=shot["description"],
300
+ )
301
+ # Ask for decision again after each shot.
302
+ messages.append({"role": "user", "content": [{"type": "text", "text": "Action:"}]})
303
+ step += 1
304
+
305
+ print("=== summary ===")
306
+ for k in k_list:
307
+ s = summary[k]
308
+ reached = s["REACHED"]
309
+ if reached == 0:
310
+ print(f"k={k}: no samples")
311
+ continue
312
+ print(
313
+ f"k={k} | RET={s['RET']} | ANS={s['ANS']} | NONE={s['NONE']} | reached={reached}"
314
+ )
315
+ return 0
316
+
317
+
318
+ if __name__ == "__main__":
319
+ raise SystemExit(main())
ICL/LV/code/SFT/cuda-keyring_1.1-1_all.deb ADDED
Binary file (4.33 kB). View file
 
ICL/LV/code/SFT/prepare_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 预先生成数据集缓存
4
+ 运行一次后,训练时直接加载缓存,避免超时
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import pickle
10
+ from pathlib import Path
11
+
12
+ # 添加当前目录到路径
13
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
14
+
15
+ from config import get_config
16
+ from dataset import SFTDataset
17
+
18
+
19
+ def main():
20
+ config = get_config()
21
+
22
+ cache_path = Path(config.training.output_dir) / ".dataset_cache.pkl"
23
+ ready_flag = Path(config.training.output_dir) / ".dataset_ready"
24
+
25
+ # 创建输出目录
26
+ os.makedirs(config.training.output_dir, exist_ok=True)
27
+
28
+ print("=" * 60)
29
+ print("预生成 SFT 数据集缓存")
30
+ print("=" * 60)
31
+ print(f"数据目录: {config.data.sft_data_dir}")
32
+ print(f"缓存路径: {cache_path}")
33
+ print("=" * 60)
34
+
35
+ # 加载数据集
36
+ print("\n开始加载数据集...")
37
+ dataset = SFTDataset(config.data, split="train")
38
+
39
+ print(f"\n数据集加载完成!共 {len(dataset)} 个样本")
40
+
41
+ # 保存缓存
42
+ print(f"\n保存缓存到: {cache_path}")
43
+ with open(cache_path, "wb") as f:
44
+ pickle.dump(dataset.samples, f)
45
+
46
+ # 创建就绪标记
47
+ ready_flag.touch()
48
+
49
+ print(f"就绪标记: {ready_flag}")
50
+ print("\n" + "=" * 60)
51
+ print("缓存生成完成!现在可以运行 bash run_train.sh")
52
+ print("=" * 60)
53
+
54
+
55
+ if __name__ == "__main__":
56
+ main()
ICL/LV/code/adapters/gemma3_adapter.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Dict
4
+
5
+ try:
6
+ from adapters._runners.gemma3_infer import Gemma3Runner
7
+ except Exception:
8
+ Gemma3Runner = None # type: ignore
9
+
10
+
11
+ class Adapter:
12
+ def __init__(self, model_path: str):
13
+ if Gemma3Runner is None:
14
+ raise RuntimeError('Gemma3Runner unavailable. Ensure gemma3-code is on PYTHONPATH or install its runner.')
15
+ self.runner = Gemma3Runner(model_path)
16
+
17
+ def generate_from_segments(self, segs: List[Dict[str, str]], *,
18
+ temperature: float, top_p: float, max_new_tokens: int) -> str:
19
+ gen = getattr(self.runner, 'generate_from_qwen_segs', None)
20
+ if gen is None:
21
+ raise RuntimeError('Gemma3Runner missing generate_from_qwen_segs')
22
+ return gen(segs, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens)
23
+
24
+
25
+ def create(model_path: str) -> Adapter:
26
+ return Adapter(model_path)
27
+
ICL/LV/code/adapters/qwen3vl_adapter.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Dict
4
+
5
+ try:
6
+ from adapters._runners.qwen3_vl_infer import Qwen3VLRunner
7
+ except Exception:
8
+ Qwen3VLRunner = None # type: ignore
9
+
10
+
11
+ class Adapter:
12
+ def __init__(self, model_path: str):
13
+ if Qwen3VLRunner is None:
14
+ raise RuntimeError('Qwen3VLRunner unavailable. Ensure QWEN3VL-code is on PYTHONPATH or install its runner.')
15
+ self.runner = Qwen3VLRunner(model_path)
16
+
17
+ def generate_from_segments(self, segs: List[Dict[str, str]], *,
18
+ temperature: float, top_p: float, max_new_tokens: int) -> str:
19
+ gen = getattr(self.runner, 'generate_from_segments', None)
20
+ if gen is None:
21
+ raise RuntimeError('Qwen3VLRunner missing generate_from_segments')
22
+ return gen(segs, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens)
23
+
24
+
25
+ def create(model_path: str) -> Adapter:
26
+ return Adapter(model_path)
27
+
ICL/LV/code/attn map/attn map/attn map/__pycache__/token_attention_utils.cpython-313.pyc ADDED
Binary file (2.57 kB). View file