diff --git a/ICL/DAPO/verl-recipe/.pre-commit-config.yaml b/ICL/DAPO/verl-recipe/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3fc409e2dec362589952079e2176effc958cf705 --- /dev/null +++ b/ICL/DAPO/verl-recipe/.pre-commit-config.yaml @@ -0,0 +1,8 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.14.10" + hooks: + - id: ruff + args: ["--fix", "--show-fixes", "--output-format=full"] + exclude: ^.*\.(ipynb)$ + - id: ruff-format diff --git a/ICL/DAPO/verl-recipe/CODEOWNERS b/ICL/DAPO/verl-recipe/CODEOWNERS new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ICL/DAPO/verl-recipe/LICENSE b/ICL/DAPO/verl-recipe/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d645695673349e3947e8e5ae42332d0ac3164cd7 --- /dev/null +++ b/ICL/DAPO/verl-recipe/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/ICL/DAPO/verl-recipe/README.md b/ICL/DAPO/verl-recipe/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8960f7a2de3791d5f863a7127a2be741801a406a --- /dev/null +++ b/ICL/DAPO/verl-recipe/README.md @@ -0,0 +1,47 @@ +# verl-recipe + +`verl-recipe` hosts recipes based on [verl](https://github.com/volcengine/verl) contributed by the community. + +## Usage + +`verl-recipe` can be used as a submodule of `verl`, keeping backward compatibility as `verl/recipe`: + +```bash +git clone https://github.com/verl-project/verl.git +cd verl +git submodule update --init --recursive recipe +``` + +## Available Recipes + +- [retool](https://github.com/verl-project/verl-recipe/tree/main/retool): Reinforcement Learning for Strategic Tool Use in LLMs +- [langgraph_agent](https://github.com/verl-project/verl-recipe/tree/main/langgraph_agent): A tiny example to demonstrate multi-turn rollout with [LangGraph ReactAgent](https://langchain-ai.github.io/langgraph/agents/overview/) to solve math expression. +- [spo](https://github.com/verl-project/verl-recipe/tree/main/spo): [Single-stream Policy Optimization](https://arxiv.org/abs/2509.13232). +- TBA... + +## Contribution + +### Version Specification + +Recipes are recommended to specify the verl version required, e.g., + +``` +# release version +verl==0.6.0 + +# dev version +verl@git+https://github.com/volcengine/verl.git@313dfdb2199124a37189e32e6d4a6c654379f2d4 +``` + +### Code Linting and Formatting + +To maximize flexiblility but minimize meaningless changes, we apply `pre-commit` but only force code linting and formatting with `ruff`. Use it as follows: + +```bash +pip install pre-commit +pre-commit install +# for staged changes +pre-commit run +# for all files in the repo +pre-commit run --all-files +``` diff --git a/ICL/DAPO/verl-recipe/collabllm/collabllm_agent_loop.py b/ICL/DAPO/verl-recipe/collabllm/collabllm_agent_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..d6b5d64b32f3e57c764174a91c7fcb44ab895e72 --- /dev/null +++ b/ICL/DAPO/verl-recipe/collabllm/collabllm_agent_loop.py @@ -0,0 +1,139 @@ +# Copyright 2025 CollabLLM team and/or its affiliates +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from copy import deepcopy +from typing import Any +from uuid import uuid4 + +from recipe.collabllm.utils import is_valid_messages + +from verl.experimental.agent_loop.agent_loop import AgentLoopOutput +from verl.experimental.agent_loop.tool_agent_loop import AgentData, AgentState, ToolAgentLoop +from verl.utils.rollout_trace import rollout_trace_op +from verl.workers.rollout.schemas import Message + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class CollabLLMAgentLoop(ToolAgentLoop): + @rollout_trace_op + async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: + messages = list(kwargs["raw_prompt"]) + image_data = deepcopy(kwargs.get("multi_modal_data", {}).get("image", None)) + metrics = {} + request_id = uuid4().hex + tools_kwargs = kwargs.get("tools_kwargs", {}) + + # Initialize interaction if needed + interaction = None + interaction_kwargs = {} + if self.interaction_config_file: + interaction_kwargs = kwargs["extra_info"]["interaction_kwargs"] + if "name" not in interaction_kwargs: + raise ValueError("'name' key is required in interaction_kwargs") + interaction_name = interaction_kwargs["name"] + if interaction_name not in self.interaction_map: + raise ValueError( + f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: " + f"{list(self.interaction_map.keys())}" + ) + interaction = self.interaction_map[interaction_name] + await interaction.start_interaction(request_id, **interaction_kwargs) + # Create AgentData instance to encapsulate all state + agent_data = AgentData( + messages=messages, + image_data=image_data, + metrics=metrics, + request_id=request_id, + tools_kwargs=tools_kwargs, + interaction=interaction, + interaction_kwargs=interaction_kwargs, + ) + # for collabllm, firstly generate model reponses + await self._handle_pending_state(agent_data, sampling_params) + + status = await self._handle_generating_state(agent_data, sampling_params) + + if status == AgentState.TERMINATED: + # tell reward manager to score -1 and skip future interaction + # to avoid reward hacking with incompleted message + num_repeats = 0 + else: + # then, collect interaction rollouts + num_repeats = self.config.actor_rollout_ref.rollout.multi_turn.num_repeat_rollouts + + interaction_requests = [deepcopy(agent_data) for _ in range(num_repeats)] + + # messages are only used in collabllm reward manager + messages_lst = [] + for _agent_data in interaction_requests: + if not is_valid_messages(_agent_data.messages[-1]): + break + + prev_msg_len = len(_agent_data.messages) + await self.run_agent_data_loop(_agent_data, sampling_params, AgentState.INTERACTING) + messages_lst.append([Message(**msg) for msg in _agent_data.messages]) + + if interaction.config.get("enable_log"): + print(f"Assistant: ...{messages_lst[-1][prev_msg_len - 1].content[-100:]}") + print(f"User: {messages_lst[-1][prev_msg_len].content[:100]}...") + + # Finalize output + response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :] + prompt_ids = agent_data.prompt_ids[: len(agent_data.prompt_ids) - len(agent_data.response_mask)] + multi_modal_data = {"image": agent_data.image_data} if agent_data.image_data is not None else {} + + output = AgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids[: self.response_length], + response_mask=agent_data.response_mask[: self.response_length], + multi_modal_data=multi_modal_data, + response_logprobs=agent_data.response_logprobs[: self.response_length] + if agent_data.response_logprobs + else None, + num_turns=agent_data.user_turns + agent_data.assistant_turns + 1, + metrics=agent_data.metrics, + extra_fields={ + "turn_scores": agent_data.turn_scores, + "messages": {"messages": messages_lst}, # compatiable with sglang interaction + }, + ) + return output + + async def run_agent_data_loop(self, agent_data: AgentData, sampling_params: dict[str, Any], state: AgentState): + """ + Run the agent data loop to process the agent data. + + Args: + agent_data (AgentData): The agent data to process. + sampling_params (dict[str, Any]): The sampling parameters. + state (AgentState, optional): The initial state of the agent. Defaults to None. + """ + + while state != AgentState.TERMINATED: + if state == AgentState.PENDING: + state = await self._handle_pending_state(agent_data, sampling_params) + elif state == AgentState.GENERATING: + state = await self._handle_generating_state(agent_data, sampling_params) + elif state == AgentState.PROCESSING_TOOLS: + state = await self._handle_processing_tools_state(agent_data) + elif state == AgentState.INTERACTING: + state = await self._handle_interacting_state(agent_data) + else: + logger.error(f"Invalid state: {state}") + state = AgentState.TERMINATED diff --git a/ICL/DAPO/verl-recipe/collabllm/collabllm_interation.py b/ICL/DAPO/verl-recipe/collabllm/collabllm_interation.py new file mode 100644 index 0000000000000000000000000000000000000000..bb0bdbc0298e007d14892ced1189bd0be9967809 --- /dev/null +++ b/ICL/DAPO/verl-recipe/collabllm/collabllm_interation.py @@ -0,0 +1,374 @@ +# Copyright 2024 CollabLLM Ltd. and/or its affiliates +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import copy +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from recipe.collabllm.utils import remove_think_block + +from verl.interactions.base import BaseInteraction +from verl.utils.rollout_trace import rollout_trace_op + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +TERMINATION_SIGNAL = "[[TERMINATE CHAT]]" +USER_PROMPT_TEMPLATE = """You are role-playing as a human USER interacting with an AI collaborator to complete a specific task. Your goal is to generate realistic, natural responses that a user might give in this scenario. + +## Input Information: +You will be provided with: +- Task Description: The type of task you are trying to accomplish. +- Complete Prompt or Reference Goal: This field may include the complete user request/query or a reference answer to user's request. Use this field to understand the user's intent, requirements, or what would count as a satisfactory outcome. +- Chat History: The ongoing conversation between you (as the user) and the AI + +Inputs: +<|The Start of Task Description (Not visible to the AI)|> +{task_desc} +<|The End of Task Description|> + +<|The Start of Complete Prompt or Reference Goal (Not visible to the AI)|> +{single_turn_prompt} +<|The End of Complete Prompt or Reference Goal|> + +<|The Start of Chat History|> +{chat_history} +<|The End of Chat History|> + + +## Guidelines: +- Stay in Character: Role-play as a human USER. You are NOT an AI. Maintain a consistent personality throughout the chat. +- Minimize Effort: IMPORTANT! As a user, avoid being too detailed in your responses. Provide vague or incomplete demands in the early stages of the conversation to minimize your effort. Let the AI ask for clarification rather than providing everything upfront. +- Knowledge Background: Reflect the user's knowledge level in the role-playing. If the user is less knowledgeable about a task, they might not notice incorrect statements. Ask questions that demonstrate your current understanding and areas of confusion. +- Occasionally Make Mistakes: Real-world users might misspell words, provide incorrect dates, give wrong information, or ask unclear questions. Simulate this behavior to reflect natural interactions. +- Mention Personal Preferences: Include preferences or constraints that might influence your requests or responses. For example, "I prefer short answers," "I need this done quickly," or "I like detailed comments in code." +- Goal-Oriented: Keep the chat focused on your intent. Avoid small talk or digressions. Redirect the chat back to the main objective if it starts to stray. + +## Output Format: +You should output a JSON object with three entries: +- "current_answer" (str): Briefly summerize the AI's current solution to the task. +- "thought" (str): Output your thought process as a user deciding what to say next. Consider: +1. Have you obtained a satisfactory solution from the AI? If yes, you can terminate this chat. +2. If not, what specific part of the problem or solution are you struggling with? +3. Has the AI asked you to perform a task or answer a question? If so, how should you approach it? +4. Are you noticing any patterns or potential misunderstandings that need clarification? +5. If you're stuck, how can you phrase your question to get the most helpful response while demonstrating your current understanding? +- "response" (str): Based on your thought process, respond to the AI as the user you are role-playing. Stop immediately when the user's response is completed. + +## Important Notes: +- Respond Based on Previous Messages: Your responses should be based on the context of the current chat history. Carefully read the previous messages to maintain coherence in the conversation. +- Conversation Flow: If "Current Chat History" is empty, start the conversation from scratch with an initial request. Otherwise, continue based on the existing conversation. +- Don't Copy Input Directly: Use the provided information for understanding context only. Avoid copying target queries or any provided information directly in your responses. +- Completion Signal: Use "{termination_signal}" as your response when you believe your goal has been solved or if you determine the AI cannot help further. +- Double check if the JSON object is formatted correctly. Ensure that all fields are present and properly structured. + +Remember to stay in character as a user throughout your response, and follow the instructions and guidelines carefully.""" # noqa: E501 + + +class CollabLLMInteraction(BaseInteraction): + """A demo interaction for calculating the reward of CollabLLM. + + - `start_interaction`: start a interaction instance for a trajectory. + - `generate_response`: generate the response of the assistant. + - `calculate_score`: calculate the score of the interaction. + - `finalize_interaction`: finalize the interaction instance. + """ + + def __init__(self, config: dict): + super().__init__(config) + _config = copy.deepcopy(config) + + _config.pop("enable_log", None) + + self.name = _config.pop("name") + self.user_model = _config.pop("user_model") + + self.termination_signal = _config.pop("termination_signal", TERMINATION_SIGNAL) + self.num_retries = _config.pop("num_retries", 3) + + self.user_model_kwargs = _config + + self._instance_dict = {} + + async def start_interaction( + self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs + ) -> str: + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": 0.0, + } + self.interaction_kwargs = kwargs + assert "single_turn_prompt" in kwargs, "single_turn_prompt is required in interaction_kwargs" + return instance_id + + @rollout_trace_op + async def generate_response( + self, instance_id: str, messages: list[dict[str, Any]], **kwargs + ) -> tuple[bool, str, float, dict]: + assert messages[-1]["role"] in ["system", "assistant"], ( + "Last message input to the user model must be from system or assistant role" + ) + + import litellm + + chat_history = self._parse_messages(messages, strip_sys_prompt=True) + prompt = USER_PROMPT_TEMPLATE.format( + task_desc=self.interaction_kwargs.get("task_desc", "general assistance task"), + single_turn_prompt=self.interaction_kwargs["single_turn_prompt"], + chat_history=chat_history, + termination_signal=self.termination_signal, + ) + response = "" + for i in range(self.num_retries): + try: + full_response = ( + ( + await litellm.acompletion( + model=self.user_model, + messages=[{"role": "user", "content": prompt}], + **self.user_model_kwargs, + ) + ) + .choices[0] + .message.content + ) + except litellm.RateLimitError as e: + logger.warning(f"[CollabLLMInteraction] hit RateLimitError: {e}. Retrying...") + await asyncio.sleep(max(2**i, 60)) + continue + except Exception as e: + logger.exception(f"An unexpected error occurred in CollabLLMAgentLoop: {e}") + continue + + try: + if isinstance(full_response, str): + full_response = extract_json(full_response) + except Exception as e: + logger.warning(f"[CollabLLMInteraction] Error extracting JSON: {e}. Retrying...") + continue + + if isinstance(full_response, dict): + keys = full_response.keys() + if {"current_answer", "thought", "response"}.issubset(keys): + response = full_response.pop("response") + if isinstance(response, str): + break + else: + logger.warning( + f"[CollabLLMInteraction] got an invalid response {response} full_response {full_response}. \ + Retrying..." + ) + continue + else: + logger.warning(f"[CollabLLMInteraction] Keys {keys} do not match expected keys. Retrying...") + continue + + self._instance_dict[instance_id]["response"] = response + logger.debug(f"[CollabLLMInteraction] User: {response}") + should_terminate_sequence = self.termination_signal in response + reward = 0.0 + + return should_terminate_sequence, response, reward, {} + + async def finalize_interaction(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] + + def _parse_messages(self, messages, strip_sys_prompt=True): + if messages is None: + return "" + + if strip_sys_prompt: + messages = [msg for msg in messages if msg["role"] != "system"] + + messages = [remove_think_block(msg) for msg in messages] + + chat = "\n".join(f"**{m['role'].capitalize()}**: {m['content']}" for m in messages) + + return chat + + +def extract_json(s): + def convert_value(value): + true_values = {"true": True, "false": False, "null": None} + value_lower = value.lower() + if value_lower in true_values: + return true_values[value_lower] + try: + if "." in value or "e" in value.lower(): + return float(value) + else: + return int(value) + except ValueError: + return value # Return as string if not a number + + def parse_number(s, pos): + start = pos + while pos < len(s) and s[pos] in "-+0123456789.eE": + pos += 1 + num_str = s[start:pos] + try: + if "." in num_str or "e" in num_str.lower(): + return float(num_str), pos + else: + return int(num_str), pos + except ValueError: + logger.error(f"Invalid number at position {start}: {num_str}") + raise + + def skip_whitespace(s, pos): + while pos < len(s) and s[pos] in " \t\n\r": + pos += 1 + return pos + + def parse_string(s, pos): + quote_char = s[pos] + assert quote_char in ('"', "'") + pos += 1 + result = "" + while pos < len(s): + c = s[pos] + if c == "\\": + pos += 1 + if pos >= len(s): + raise ValueError("Invalid escape sequence") + c = s[pos] + escape_sequences = {"n": "\n", "t": "\t", "r": "\r", "\\": "\\", quote_char: quote_char} + result += escape_sequences.get(c, c) + elif c == quote_char: + pos += 1 + # Attempt to convert to a number if possible + converted_value = convert_value(result) + return converted_value, pos + else: + result += c + pos += 1 + raise ValueError("Unterminated string") + + def parse_key(s, pos): + pos = skip_whitespace(s, pos) + if s[pos] in ('"', "'"): + key, pos = parse_string(s, pos) + return key, pos + else: + raise ValueError(f"Expected string for key at position {pos}") + + def parse_object(s, pos): + obj = {} + assert s[pos] == "{" + pos += 1 + pos = skip_whitespace(s, pos) + while pos < len(s) and s[pos] != "}": + pos = skip_whitespace(s, pos) + key, pos = parse_key(s, pos) + pos = skip_whitespace(s, pos) + if pos >= len(s) or s[pos] != ":": + raise ValueError(f'Expected ":" at position {pos}') + pos += 1 + pos = skip_whitespace(s, pos) + value, pos = parse_value(s, pos) + obj[key] = value + pos = skip_whitespace(s, pos) + if pos < len(s) and s[pos] == ",": + pos += 1 + pos = skip_whitespace(s, pos) + elif pos < len(s) and s[pos] == "}": + break + elif pos < len(s) and s[pos] != "}": + raise ValueError(f'Expected "," or "}}" at position {pos}') + if pos >= len(s) or s[pos] != "}": + raise ValueError(f'Expected "}}" at position {pos}') + pos += 1 + return obj, pos + + def parse_array(s, pos): + lst = [] + assert s[pos] == "[" + pos += 1 + pos = skip_whitespace(s, pos) + while pos < len(s) and s[pos] != "]": + value, pos = parse_value(s, pos) + lst.append(value) + pos = skip_whitespace(s, pos) + if pos < len(s) and s[pos] == ",": + pos += 1 + pos = skip_whitespace(s, pos) + elif pos < len(s) and s[pos] == "]": + break + elif pos < len(s) and s[pos] != "]": + raise ValueError(f'Expected "," or "]" at position {pos}') + if pos >= len(s) or s[pos] != "]": + raise ValueError(f'Expected "]" at position {pos}') + pos += 1 + return lst, pos + + def parse_triple_quoted_string(s, pos): + if s[pos : pos + 3] == "'''": + quote_str = "'''" + elif s[pos : pos + 3] == '"""': + quote_str = '"""' + else: + raise ValueError(f"Expected triple quotes at position {pos}") + pos += 3 + result = "" + while pos < len(s): + if s[pos : pos + 3] == quote_str: + pos += 3 + # Attempt to convert to a number if possible + converted_value = convert_value(result) + return converted_value, pos + else: + result += s[pos] + pos += 1 + raise ValueError("Unterminated triple-quoted string") + + def parse_value(s, pos): + pos = skip_whitespace(s, pos) + if pos >= len(s): + raise ValueError("Unexpected end of input") + if s[pos] == "{": + return parse_object(s, pos) + elif s[pos] == "[": + return parse_array(s, pos) + elif s[pos : pos + 3] in ("'''", '"""'): + return parse_triple_quoted_string(s, pos) + elif s[pos] in ('"', "'"): + return parse_string(s, pos) + elif s[pos : pos + 4].lower() == "true": + return True, pos + 4 + elif s[pos : pos + 5].lower() == "false": + return False, pos + 5 + elif s[pos : pos + 4].lower() == "null": + return None, pos + 4 + elif s[pos] in "-+0123456789.": + return parse_number(s, pos) + else: + raise ValueError(f"Unexpected character at position {pos}: {s[pos]}") + + json_start = s.index("{") + json_end = s.rfind("}") + s = s[json_start : json_end + 1] + + s = s.strip() + result, pos = parse_value(s, 0) + pos = skip_whitespace(s, pos) + if pos != len(s): + raise ValueError(f"Unexpected content at position {pos}") + return result diff --git a/ICL/DAPO/verl-recipe/collabllm/metrics/accuracy.py b/ICL/DAPO/verl-recipe/collabllm/metrics/accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..a81d1b8d18ce5ea1d815a361ca0912bdbc248e2b --- /dev/null +++ b/ICL/DAPO/verl-recipe/collabllm/metrics/accuracy.py @@ -0,0 +1,104 @@ +# Copyright 2025 CollabLLM team and/or its affiliates +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from recipe.collabllm.utils import extract_json, parse_messages + +ACCURACY_PROMPT = '''You are a helpful and meticulous evaluator. Your task is to \ +evaluate the *accuracy* of an AI model's answer to a target question. \ +You will be given the target question, the ground truth answer, and the conversation between the AI and the user. + +Provided Information: + +<|The Start of Target Question and Ground Truth Answer|> +Target Question: {single_turn_prompt} +Ground Truth Answer: {ground_truth} +<|The End of Target Question and Ground Truth Answer|> + +<|The Start of The Conversation|> +{chat_history} +<|The End of The Conversation|> + +You should determine whether the model's final response to the target question is \ +factually correct and consistent with the provided ground truth. + +Rating criteria (binary): + • 1 = Correct — the response matches the ground truth. + • 0 = Incorrect — the response contradicts or misses the ground truth. + +Output format (JSON): +{{ + "thought": "", + "accuracy": <0 or 1> +}} + +Double check if the JSON object is formatted correctly. Ensure that all fields are present and properly structured. \ +Use " or """ to wrap up the thought and use single quotes inside the "thought" field to avoid JSON escape issues. + +Your evaluation: +''' + + +async def compute_score(data_source, messages, ground_truth, extra_info, **kwargs): + # Check if litellm is available, fallback to openai if not + try: + import litellm + + use_litellm = True + except ImportError: + # litellm not found, falling back to openai + import openai + + use_litellm = False + + chat_history = parse_messages(messages, strip_sys_prompt=True) + prompt = ACCURACY_PROMPT.format( + single_turn_prompt=extra_info["interaction_kwargs"]["single_turn_prompt"], + ground_truth=ground_truth, + chat_history=chat_history, + ) + + if use_litellm: + full_response = ( + ( + await litellm.acompletion( + messages=[{"role": "user", "content": prompt}], + **kwargs, + ) + ) + .choices[0] + .message.content + ) + else: + client = openai.AsyncOpenAI() # Assumes API key is set in environment + full_response = ( + ( + await client.chat.completions.create( + messages=[{"role": "user", "content": prompt}], + **kwargs, + ) + ) + .choices[0] + .message.content + ) + + full_response = extract_json(full_response) + + assert isinstance(full_response, dict), f"Expected a dict, got {type(full_response)}" + assert {"accuracy", "thought"}.issubset(full_response.keys()), ( + f"Expected keys not found from {full_response.keys()}" + ) + + accuracy = full_response.pop("accuracy") + return float(accuracy) diff --git a/ICL/DAPO/verl-recipe/collabllm/metrics/bleu_score.py b/ICL/DAPO/verl-recipe/collabllm/metrics/bleu_score.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5df973058021cc0283043d76daa85b368f6e02 --- /dev/null +++ b/ICL/DAPO/verl-recipe/collabllm/metrics/bleu_score.py @@ -0,0 +1,115 @@ +# Copyright 2025 CollabLLM team and/or its affiliates +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nltk.translate.bleu_score import sentence_bleu +from recipe.collabllm.utils import extract_json, parse_messages + +EXTRACT_MULTITURN_COMPLETION_PROMPT = '''You are a thorough and diligent conversation analyzer. \ +Your task is to extract the final and complete version of a document that was generated during \ +a multiturn conversation between a user and a chat assistant. \ +The extracted content should reflect the final and comprehensive response provided by the assistant \ +based on the user’s request. + +You will be provided with the conversation: + +<|The Start of The Conversation|> +{chat_history} +<|The End of The Conversation|> + +Instructions for Extraction: + +1. Identify the Most Update-to-Date Contents: Review the entire conversation to identify the most updated parts \ +of the content provided by the assistant. This may include: + - Different sections of text (e.g., an essay, report, or article). + +2. Integrate Revisions: If the assistant made revisions, updates, or added sections throughout the conversation, \ +ensure that these changes are fully integrated into the final content. The goal is to extract a single, cohesive \ +output that incorporates all modifications and additions made during the conversation. For example, if the assistant \ +writes an introducation at the beginning and move on to the conclusion, the final output should include both the \ +introduction and the conclusion. + +3. Focus on Completeness: + - For text-based documents: Ensure that the extracted content is comprehensive and represents the full document \ + or section as discussed in the conversation. + +You should output a JSON object with two entries: +- "thought" (str): Output your thought process when extracting the final content. + 1. How do different parts of the conversation contribute to the final output? + 2. How do you make sure you included the most updated and complete information? + 3. How do you make sure you did not include any information that is not necessary? +- "final_completion" (str): The final and complete version of the document extracted from the conversation. + +Note: +1. If there are multiple lines, you should use triple quotes (""") to wrap the content. For example, \ + "final_completion": """first line. + second line.""" or "thought": """first line; + second line.""". +2. In the "final_completion" entry, replace all double quotes (") with single quotes (') to prevent JSON formatting \ +issues. For example, you can output "final_completion": "'Hello World' is a common phrase." + +Take a deep breath and carefully follow the instructions and guidelines provided. +''' + + +async def compute_score(data_source, messages, ground_truth, extra_info, **kwargs): + # Check if litellm is available, fallback to openai if not + try: + import litellm + + use_litellm = True + except ImportError: + # litellm not found, falling back to openai + import openai + + use_litellm = False + + chat_history = parse_messages(messages, strip_sys_prompt=True) + prompt = EXTRACT_MULTITURN_COMPLETION_PROMPT.format(chat_history=chat_history) + + if use_litellm: + full_response = ( + ( + await litellm.acompletion( + messages=[{"role": "user", "content": prompt}], + **kwargs, + ) + ) + .choices[0] + .message.content + ) + else: + client = openai.AsyncOpenAI() # Assumes API key is set in environment + full_response = ( + ( + await client.chat.completions.create( + messages=[{"role": "user", "content": prompt}], + **kwargs, + ) + ) + .choices[0] + .message.content + ) + + full_response = extract_json(full_response) + + assert isinstance(full_response, dict), f"Expected a dict, got {type(full_response)}" + assert {"final_completion", "thought"}.issubset(full_response.keys()), ( + f"Expected keys not found from {full_response.keys()}" + ) + + final_completion = full_response.pop("final_completion") + + bleu = sentence_bleu([ground_truth], final_completion) + return float(bleu) diff --git a/ICL/DAPO/verl-recipe/collabllm/metrics/interactivity.py b/ICL/DAPO/verl-recipe/collabllm/metrics/interactivity.py new file mode 100644 index 0000000000000000000000000000000000000000..a7ef69ef59660934853bda0a34335702d3f811f4 --- /dev/null +++ b/ICL/DAPO/verl-recipe/collabllm/metrics/interactivity.py @@ -0,0 +1,108 @@ +# Copyright 2025 CollabLLM team and/or its affiliates +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from recipe.collabllm.utils import extract_json, parse_messages + +INTERACTIVITY_PROMPT = '''You are a helpful and meticulous conversation evaluator. \ +Your task is to evaluate the interactivity of the responses provided by an AI assistant \ +to user questions in a given conversation: + +<|The Start of the Conversation to be Evaluated|> +{chat_history} +<|The End of the Conversation to be Evaluated|> + +You should assess the assistant's engagement, clarity, and ability to understand the user's needs. \ +Give a float number between 0 and 1. + +Scoring Criteria: +- Let U = user understanding & response clarity ∈ [0,1] + - 1.0 = Fully understands the user's intent and gives a clear answer. + - 0.7 = Mostly understands and the answer is generally clear. + - 0.3 = Partially misunderstands or the answer is hard to follow. + - 0.0 = Misunderstands the intent and gives an unclear or irrelevant answer. +- Let Q = clarification in [0,1] + - 1.0 = Asks precise, necessary clarifying questions when needed. + - 0.7 = Asks somewhat helpful but incomplete clarifications. + - 0.3 = Only asks generic questions (e.g., “Does that help?”). + - 0.0 = Asks no clarifying questions when needed. +- Let S = suggestion helpfulness in [0,1] + - 1.0 = Provides useful, actionable suggestions. + - 0.7 = Suggestions are somewhat helpful but limited. + - 0.3 = Suggestions are vague or generic. + - 0.0 = No suggestions when they would clearly help. +score = average([U, Q, S]) + +Output format (JSON): +{{ + "thought": "", + "interactivity": +}} + +Double check if the JSON object is formatted correctly. Ensure that all fields are present and properly structured. \ +Use " or """ to wrap up the thought. You should not use other triple quotes inside the "thought" field. \ +Instead you should use single quotes to avoid JSON escape issues. + +Your evaluation: +''' + + +async def compute_score(data_source, messages, ground_truth, extra_info, **kwargs): + # Check if litellm is available, fallback to openai if not + try: + import litellm + + use_litellm = True + except ImportError: + # litellm not found, falling back to openai + import openai + + use_litellm = False + + chat_history = parse_messages(messages, strip_sys_prompt=True) + prompt = INTERACTIVITY_PROMPT.format(chat_history=chat_history) + + if use_litellm: + full_response = ( + ( + await litellm.acompletion( + messages=[{"role": "user", "content": prompt}], + **kwargs, + ) + ) + .choices[0] + .message.content + ) + else: + client = openai.AsyncOpenAI() # Assumes API key is set in environment + full_response = ( + ( + await client.chat.completions.create( + messages=[{"role": "user", "content": prompt}], + **kwargs, + ) + ) + .choices[0] + .message.content + ) + + full_response = extract_json(full_response) + + assert isinstance(full_response, dict), f"Expected a dict, got {type(full_response)}" + assert {"interactivity", "thought"}.issubset(full_response.keys()), ( + f"Expected keys not found from {full_response.keys()}" + ) + + interactivity = full_response.pop("interactivity") + return float(interactivity) diff --git a/ICL/DAPO/verl-recipe/collabllm/metrics/pass_rate.py b/ICL/DAPO/verl-recipe/collabllm/metrics/pass_rate.py new file mode 100644 index 0000000000000000000000000000000000000000..df422b4b194fd00f5a3a38b083146636ebb46590 --- /dev/null +++ b/ICL/DAPO/verl-recipe/collabllm/metrics/pass_rate.py @@ -0,0 +1,138 @@ +# Copyright 2025 CollabLLM team and/or its affiliates +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from bigcodebench.eval import untrusted_check +from recipe.collabllm.utils import extract_json, parse_messages + +EXTRACT_MULTITURN_COMPLETION_PROMPT = '''You are a thorough and diligent conversation analyzer. \ +Your task is to extract the final and complete version of a code function {entry_point} that was generated \ +during a multiturn conversation between a user and a chat assistant. \ +The extracted content should reflect the final and comprehensive response provided by the \ +assistant based on the user’s request. + +You will be provided with the task and the conversation: + +<|The Start of The Task|> +{single_turn_prompt} +<|The End of The Task|> + +<|The Start of The Conversation|> +{chat_history} +<|The End of The Conversation|> + +Instructions for Extraction: + +1. Identify the Most Update-to-Date Contents: Review the entire conversation to identify the most updated parts of \ +the content provided by the assistant. This may include: + - Different parts of the code snippet, function, class, or script. + +2. Integrate Revisions: If the assistant made revisions, updates, or added sections throughout the conversation, \ +ensure that these changes are fully integrated into the final content. The goal is to extract a single, cohesive \ +output that incorporates all modifications and additions made during the conversation. For example, if the assistant \ +writes a function at the beginning and changes a part, the final output should take the modification into account. + +3. Focus on Completeness: + - For code: Extract a complete and functional code snippet, including all necessary components such as imports, \ + functions, classes, and any other essential elements. The code should be runnable, but you do not need to \ + include any testing examples including the contents after `if __name__ == "__main__":`. Only the function code \ + is required. + +You should output a JSON object with two entries: +- "thought" (str): Output your thought process when extracting the final content. + 1. How do different parts of the conversation contribute to the final output? + 2. How do you make sure you included the most updated and complete information? + 3. How do you make sure you did not include any information that is not necessary? +- "final_completion" (str): The final and complete version of the code extracted from the conversation. \ +Rename main function name for the task to {entry_point} if needed. Remove any comments wrapped by """. + +Note: +1. If there are multiple lines, you should use triple quotes (""") to wrap the content. For example, \ + "final_completion": """first line. + second line.""" or "thought": """first line; + second line.""". You should not use other triple quotes inside. +2. In the "final_completion" entry, replace all double quotes (") with single quotes (') to prevent JSON formatting \ + issues. For example, you can output "final_completion": "'Hello World' is a common phrase." + +Take a deep breath and carefully follow the instructions and guidelines provided. +''' + + +async def compute_score(data_source, messages, ground_truth, extra_info, **kwargs): + # Check if litellm is available, fallback to openai if not + try: + import litellm + + use_litellm = True + except ImportError: + # litellm not found, falling back to openai + import openai + + use_litellm = False + + chat_history = parse_messages(messages, strip_sys_prompt=True) + + prompt = EXTRACT_MULTITURN_COMPLETION_PROMPT.format( + chat_history=chat_history, + single_turn_prompt=extra_info["interaction_kwargs"]["single_turn_prompt"], + entry_point=extra_info["single_turn_metadata"]["entry_point"], + ) + + if use_litellm: + full_response = ( + ( + await litellm.acompletion( + messages=[{"role": "user", "content": prompt}], + **kwargs, + ) + ) + .choices[0] + .message.content + ) + else: + client = openai.AsyncOpenAI() # Assumes API key is set in environment + full_response = ( + ( + await client.chat.completions.create( + messages=[{"role": "user", "content": prompt}], + **kwargs, + ) + ) + .choices[0] + .message.content + ) + + full_response = extract_json(full_response) + + assert isinstance(full_response, dict), f"Expected a dict, got {type(full_response)}" + assert {"final_completion", "thought"}.issubset(full_response.keys()), ( + f"Expected keys not found from {full_response.keys()}" + ) + + final_completion = full_response.pop("final_completion") + metadata = extra_info["single_turn_metadata"] + res = untrusted_check( + final_completion, + metadata["test"], + metadata["entry_point"], + max_as_limit=300 * 1024, + max_data_limit=300 * 1024, + max_stack_limit=300 * 1024, + min_time_limit=60, + gt_time_limit=60, + ) + passed = res[0] == "pass" + + # info = res[1] # for printing extra info + return float(passed) diff --git a/ICL/DAPO/verl-recipe/collabllm/process_dataset.py b/ICL/DAPO/verl-recipe/collabllm/process_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cb04a2a9080b721101d2b2ab73eaedd4a2462e0b --- /dev/null +++ b/ICL/DAPO/verl-recipe/collabllm/process_dataset.py @@ -0,0 +1,239 @@ +# Copyright 2025 CollabLLM team and/or its affiliates +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 +""" +# available datasets: +# math-hard(-large), medium(-large), bigcodebench(-large) +# to create your own dataset, refer to https://github.com/Wuyxin/collabllm + +DATASET=math-hard-large + +python recipe/collabllm/process_dataset.py \ + --dataset collabllm/collabllm-multiturn-$DATASET \ + --local_dir $HOME/data/collabllm-$DATASET \ + --dataset_type sft + +python recipe/collabllm/process_dataset.py \ + --dataset collabllm/collabllm-multiturn-$DATASET \ + --local_dir $HOME/data/collabllm-$DATASET \ + --dataset_type rl + + +Preprocess collabllm/collabllm-multiturn-math-hard into (ground_truth, extra_info). + +- ground_truth: picked from --prefer_field (default: single_turn_completion), + falling back to --fallback_field (default: completion) +- extra_info: a shallow copy of the original example plus bookkeeping fields +- reward_model: {"style": "rule", "ground_truth": ground_truth} + +Saves one parquet per split into --local_dir and a small JSON preview. +""" + +import argparse +import json +import os +import uuid +from typing import Any, Optional + +from datasets import Dataset, concatenate_datasets, load_dataset + +SYSTEM_PROMPT = """The assistant is designed to be helpful, proactive, and highly interactive. + +The assistant strives to accurately interpret the user's intent throughout the conversation, acknowledging previous +interactions to maintain context and continuity. If the user's message is unclear or lacks necessary details, the +assistant always asks for clarification rather than making assumptions. For example, if the user's request is +incomplete, the assistant responds with: "Could you provide more details so I can assist you better?" + +The assistant asks specific follow-up questions and offers suggestions based on the user's needs, avoiding vague or +generic prompts. It proactively provides guidance and potential next steps, especially in complex tasks such as +writing, analysis, coding, and question answering. + +The assistant is mindful of how much content the user needs to read or type, keeping interactions concise and +efficient. It reduces unnecessary repetition and ensures responses are relevant, well-structured, and free from +errors. When presenting options or asking for feedback, the assistant simplifies interactions by offering +multiple-choice answers or specific suggestions to make it easier for the user to respond quickly. + +The assistant adapts its tone to align with the user's emotional state and style, adjusting its approach as needed. +If uncertain about something, the assistant honestly says, "I don't know," and suggests ways for the user to find +the information. + +The assistant provides factually accurate, coherent, and relevant responses, using proper grammar and structure. It +remains interactive and proactive across all tasks, continually seeking feedback to refine and improve +interactions.""" + + +# Required fields: "prompt", "ground_truth", "extra_info" +# In "extra_info" dict: +# (1) Rquired: "single_turn_prompt", which is the specific problem used to inform the user simulator, +# (2) Optional: "task_desc" (a short task description), +# (3) Optional: other fields for customized reward computation +def collapse_example(example: dict[str, Any]) -> dict[str, Any]: + if "prompt" not in example: + raise ValueError("Missing required 'prompt' field.") + + ground_truth = ( + example.get("ground_truth") or example.get("single_turn_completion") or example.get("completion") or "" + ) + + extra_info = {} + for k, v in example.items(): + if k in ("prompt", "ground_truth", "extra_info"): + continue + extra_info.setdefault(k, v) # keep extra_info values if keys overlap + + # make sure extra_info has the required fields + assert "single_turn_prompt" in extra_info, "Missing 'single_turn_prompt' in extra_info." + + # add system prompt as the beginning of the list + example["prompt"] = [{"role": "system", "content": SYSTEM_PROMPT}] + example["prompt"] + + extra_info.setdefault("prompt", example["prompt"]) # save the original prompt + extra_info.setdefault( + "interaction_kwargs", + { + "name": "collabllm", + "single_turn_prompt": extra_info.pop("single_turn_prompt"), + "task_desc": extra_info.pop("task_desc", "general ask-for-assistance task"), + }, + ) + return { + "prompt": example["prompt"], + "ground_truth": ground_truth, + "raw_prompt": example["prompt"], # save the original prompt + "extra_info": extra_info, + "reward_model": {"style": "rule", "ground_truth": ground_truth}, + "data_source": "collabllm", + "agent_name": "collabllm_agent", + "index": str(uuid.uuid4()), + } + + +# ---------- IO helpers ---------- +def save_parquet(ds_split: Dataset, filename: str, out_dir: str) -> None: + os.makedirs(out_dir, exist_ok=True) + path = os.path.join(out_dir, f"{filename}.parquet") + ds_split.to_parquet(path) + print(f"[OK] Wrote {filename}.parquet → {path} ({len(ds_split)} rows)") + + +def maybe_copy_to_hdfs(local_dir: str, hdfs_dir: Optional[str]) -> None: + if not hdfs_dir: + return + try: + from verl.utils.hdfs_io import copy, makedirs # type: ignore + except Exception as e: + print(f"[WARN] Skipping HDFS copy (verl not available): {e}") + return + makedirs(hdfs_dir) + copy(src=local_dir, dst=hdfs_dir) + print(f"[OK] Copied {local_dir} → {hdfs_dir}") + + +# ---------- Main ---------- +def main(): + ap = argparse.ArgumentParser() + ap.add_argument( + "--dataset", default="collabllm/collabllm-multiturn-math-hard", help="HF dataset path or local dir/file." + ) + ap.add_argument("--task_desc", default="solving math problems", help="Task description for the dataset.") + ap.add_argument("--local_dir", default="~/data/collabllm-math-hard", help="Output directory.") + ap.add_argument("--hdfs_dir", default=None, help="Optional HDFS destination (requires verl).") + ap.add_argument( + "--validation_size", type=float, default=0.1, help="Validation split size (fraction or absolute int)." + ) + ap.add_argument("--seed", type=int, default=42, help="Random seed for splitting.") + ap.add_argument("--num_proc", type=int, default=1, help="Parallel workers for map().") + ap.add_argument("--dataset_type", default="rl", choices=["rl", "sft"], help="Type of dataset (e.g., 'rl', 'sft').") + args = ap.parse_args() + + out_dir = os.path.expanduser(args.local_dir) + os.makedirs(out_dir, exist_ok=True) + + print(f"[INFO] Loading dataset: {args.dataset}") + ds_dict = load_dataset(args.dataset) + parts = list(ds_dict.values()) + ds_all: Dataset = parts[0] if len(parts) == 1 else concatenate_datasets(parts) + # Dataset({ + # features: ['prompt', 'completion', 'conv_id', 'score', 'single_turn_prompt', + # 'single_turn_completion', 'single_turn_metadata', 'turn_id', 'sessions', 'rewards'], + # num_rows: xxx + # }) + + if args.dataset_type == "rl": + # If multiple splits exist, merge them before collapsing/splitting. + ds_all = ds_all.map(lambda x: {"task_desc": args.task_desc}, num_proc=args.num_proc) + + print(f"[INFO] Collapsing to formatted fields on {len(ds_all)} rows…") + ds_all = ds_all.map( + function=collapse_example, + remove_columns=ds_all.column_names, + num_proc=args.num_proc, + ) + + def dedup_by_prompt(dataset): + seen = set() + unique_rows = [] + for ex in dataset: + prompt_key = json.dumps(ex["prompt"], sort_keys=True, ensure_ascii=False) + if prompt_key not in seen: + seen.add(prompt_key) + unique_rows.append(ex) + return Dataset.from_list(unique_rows) + + ds_all = dedup_by_prompt(ds_all) + + elif args.dataset_type == "sft": + df = ds_all.to_pandas() + + # Sort so that within each conv_id the highest turn_id is first, + # and if multiple rows share the same turn_id, the highest score comes first + df = df.sort_values(["conv_id", "turn_id", "score"], ascending=[True, False, False]) + + # Keep only the top row per conv_id + df = df.drop_duplicates(subset="conv_id", keep="first") + + # Back to HF Dataset + ds_all = Dataset.from_pandas(df, preserve_index=False) + + # Append assistant response into prompt list + def append_completion(example): + example["prompt"] = ( + [{"role": "system", "content": SYSTEM_PROMPT}] + + example["prompt"] + + [{"role": "assistant", "content": example["completion"]}] + ) + return example + + ds_all = ds_all.map(append_completion) + + # Keep only prompt column + cols_to_remove = [col for col in ds_all.column_names if col != "prompt"] + ds_all = ds_all.remove_columns(cols_to_remove) + + print(f"[INFO] Splitting with validation_size={args.validation_size}, seed={args.seed}") + split = ds_all.train_test_split(test_size=args.validation_size, seed=args.seed, shuffle=True) + train_ds, val_ds = split["train"], split["test"] + print(train_ds, val_ds) + + save_parquet(train_ds, f"{args.dataset_type}_train", out_dir) + save_parquet(val_ds, f"{args.dataset_type}_validation", out_dir) + + maybe_copy_to_hdfs(local_dir=out_dir, hdfs_dir=args.hdfs_dir) + print(f"[DONE] {args.dataset_type}_train.parquet and {args.dataset_type}_validation.parquet written.") + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/collabllm/reward_function.py b/ICL/DAPO/verl-recipe/collabllm/reward_function.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ec29ef884e0818f3033fe9709184afcbc5cafa --- /dev/null +++ b/ICL/DAPO/verl-recipe/collabllm/reward_function.py @@ -0,0 +1,227 @@ +# Copyright 2025 CollabLLM team and/or its affiliates +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import importlib.util +import os +import sys +from typing import Any, Callable, Optional + +import litellm +import torch +from transformers import PreTrainedTokenizer + +from verl import DataProto +from verl.utils.reward_score import default_compute_score +from verl.workers.reward_manager import register +from verl.workers.reward_manager.abstract import AbstractRewardManager + +TERMINATION_SIGNAL = "[[TERMINATE CHAT]]" + + +async def conversation_level_reward_func( + data_source, messages, ground_truth, extra_info, metrics, **kwargs +) -> torch.Tensor: + """ + Async version of conversation-level reward function. + + Apply conversation-level reward function to the future interactions between the user simulator + and policy model, which are generated from `verl/interactions/collabllm_interation.py` + """ + num_retries = kwargs.get("num_retries", 6) + + rewards = {} + for metric in metrics: + current_dir = os.path.dirname(os.path.abspath(__file__)) + metric_file_path = os.path.join(current_dir, f"metrics/{metric}.py") + + if not os.path.exists(metric_file_path): + print(f"Error: Metric file '{metric_file_path}' not found. Assigning 0 to metric '{metric}'.") + rewards[metric] = 0.0 + continue + + spec = importlib.util.spec_from_file_location(f"metric_{metric}", metric_file_path) + if spec is None: + print(f"Error: Could not create spec for metric '{metric}'. Assigning 0 to metric '{metric}'.") + rewards[metric] = 0.0 + continue + + module = importlib.util.module_from_spec(spec) + + try: + sys.modules[f"metric_{metric}"] = module + assert spec.loader is not None + spec.loader.exec_module(module) + except Exception as e: + print(f"Error loading metric module from '{metric_file_path}': {e}. Assigning 0 to metric '{metric}'.") + rewards[metric] = 0.0 + continue + + # Assume each metric file has a compute_score function + if not hasattr(module, "compute_score"): + print( + f"Error: Function 'compute_score' not found in '{metric_file_path}'. Assigning 0 to metric '{metric}'." + ) + rewards[metric] = 0.0 + continue + + compute_score_fn = module.compute_score + + # Retry mechanism for calling the metric function + for attempt in range(num_retries): + try: + # Call the metric function (await if it's async) + if asyncio.iscoroutinefunction(compute_score_fn): + rewards[metric] = await compute_score_fn(data_source, messages, ground_truth, extra_info, **kwargs) + else: + rewards[metric] = compute_score_fn(data_source, messages, ground_truth, extra_info, **kwargs) + break # Success, exit retry loop + except Exception as e: + if attempt == num_retries - 1: # Last attempt + print( + f"Error: Failed to compute metric '{metric}' after {num_retries} attempts. " + f"Last error: {e}. Assigning 0 to metric '{metric}'." + ) + rewards[metric] = 0.0 + else: + print(f"Attempt {attempt + 1} failed for metric '{metric}': {e}. Retrying...") + if isinstance(e, litellm.RateLimitError): + await asyncio.sleep(max(2**attempt, 60)) # Exponential backoff + + # Return dict with metric names as keys + return {metric: torch.tensor(reward, dtype=torch.float32) for metric, reward in rewards.items()} + + +@register("collabllm") +class CollabLLMRewardManager(AbstractRewardManager): + """ + The Reward Manager used in https://github.com/Wuyxin/collabllm/ + """ + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + num_examine: int, + metric_weights: dict, + llm_judge_kwargs: dict, + reward_fn_key: str = "data_source", + compute_score: Optional[Callable] = None, + normalize_by_data_source=False, + ) -> None: + self.tokenizer = tokenizer + self.num_examine = num_examine # the number of batches of decoded responses to print to the console + self.compute_score = compute_score or default_compute_score + self.reward_fn_key = reward_fn_key + + self.metric_weights = metric_weights + self.llm_judge_kwargs = llm_judge_kwargs + self.normalize_by_data_source = normalize_by_data_source + + self.metrics = list(self.metric_weights.keys()) + + def __call__(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]: + # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn + if "rm_scores" in data.batch.keys(): + if return_dict: + return {"reward_tensor": data.batch["rm_scores"]} + else: + return data.batch["rm_scores"] + # Use thread-compatible async loop management instead of asyncio.run() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(self._compute_rewards_async(data, return_dict)) + finally: + loop.close() + + async def _compute_rewards_async(self, data: DataProto, return_dict: bool = False) -> torch.Tensor | dict[str, Any]: + # batched scoring + prompt_ids = data.batch["prompts"] + prompt_length = prompt_ids.shape[-1] + valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=-1) + + data_source = data.non_tensor_batch["data_source"] + ground_truth = data.non_tensor_batch["ground_truth"] + extra_info = data.non_tensor_batch["extra_info"] + message_lst = data.non_tensor_batch["messages"] + + # batch the messages into multiple + num_repeat_rollouts = len(message_lst[0]["messages"]) + batch_size = len(data_source) + + grouped_messages = [ + [message_lst[i]["messages"][j] for i in range(len(message_lst))] for j in range(num_repeat_rollouts) + ] + + # Flatten lists for all batch items across all rollouts + flattened_data_sources = [data_source[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)] + flattened_ground_truths = [ground_truth[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)] + flattened_extra_infos = [extra_info[i] for _ in range(num_repeat_rollouts) for i in range(batch_size)] + flattened_messages = [grouped_messages[j][i] for j in range(num_repeat_rollouts) for i in range(batch_size)] + + if num_repeat_rollouts > 0: + tasks = [ + self.compute_score( + flattened_data_sources[i], + flattened_messages[i], + flattened_ground_truths[i], + flattened_extra_infos[i], + self.metrics, + **self.llm_judge_kwargs, + ) + for i in range(len(flattened_data_sources)) + ] + score_dicts = await asyncio.gather(*tasks) + + # Aggregate scores for each metric across repeated rollouts + scores_by_metrics = { + metric: torch.stack([score_dict[metric] for score_dict in score_dicts]) + .view(num_repeat_rollouts, -1) + .sum(dim=0) + for metric in self.metrics + } + + # Apply metric-specific weights + weighted_scores_by_metrics = { + metric: torch.clamp( + scores_by_metrics[metric] * self.metric_weights[metric] / num_repeat_rollouts, + min=-1.0, + max=1.0, + ) + for metric in self.metrics + } + # Compute mean of weighted scores for each metric + mean_weighted_scores_by_metrics = { + metric: weighted_scores_by_metrics[metric].mean(dim=0) for metric in self.metrics + } + + # Combine weighted scores from all metrics into a single tensor + scores = torch.stack([weighted_scores_by_metrics[metric] for metric in self.metrics]).sum(dim=0) + else: + score_dicts = [] + scores = torch.full((batch_size,), 0.0, dtype=torch.float32, device=prompt_ids.device) + mean_weighted_scores_by_metrics = {metric: 0.0 for metric in self.metrics} + + print("Scores:", scores, mean_weighted_scores_by_metrics) + + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) + + for i in range(len(data)): + reward_tensor[i, valid_response_length[i].item() - 1] = scores[i] + + if return_dict: + return {"reward_tensor": reward_tensor} + else: + return reward_tensor diff --git a/ICL/DAPO/verl-recipe/collabllm/train_rl_collabllm.sh b/ICL/DAPO/verl-recipe/collabllm/train_rl_collabllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..f0595296c926b2dec178ae0114cd4dd73159ce3e --- /dev/null +++ b/ICL/DAPO/verl-recipe/collabllm/train_rl_collabllm.sh @@ -0,0 +1,76 @@ +# Usage: sh recipe/collabllm/train_rl_collabllm.sh + +set -x + +PROJECT_DIR="$(pwd)" +export VLLM_USE_V1=1 + +RESUME_PATH="${1:-}" + +if [ -z "$RESUME_PATH" ]; then + RESUME_PATH=null +fi + +DATASET=math-hard-large +PROJECT_DIR="$(pwd)" +AGENTLOOP_CONFIG_PATH="$PROJECT_DIR/recipe/collabllm/config/agent.yaml" + + +python3 -m verl.trainer.main_ppo \ + trainer.val_before_train=False \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/collabllm-$DATASET/rl_train.parquet \ + data.val_files=$HOME/data/collabllm-$DATASET/rl_validation.parquet \ + reward_model.reward_manager=collabllm \ + +reward_model.reward_kwargs.metric_weights.accuracy=1 \ + +reward_model.reward_kwargs.metric_weights.interactivity=1 \ + +reward_model.reward_kwargs.metric_weights.token_amount=-0.0001 \ + +reward_model.reward_kwargs.llm_judge_kwargs.model=gpt-4o-mini \ + +reward_model.reward_kwargs.llm_judge_kwargs.max_tokens=2048 \ + +reward_model.reward_kwargs.llm_judge_kwargs.temperature=0 \ + data.train_batch_size=16 \ + data.max_prompt_length=8196 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path="Qwen/Qwen2.5-7B-Instruct" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=8 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.multi_turn.enable=true \ + actor_rollout_ref.rollout.multi_turn.format=hermes \ + actor_rollout_ref.rollout.multi_turn.max_user_turns=2 \ + actor_rollout_ref.rollout.multi_turn.max_assistant_turns=3 \ + actor_rollout_ref.rollout.multi_turn.num_repeat_rollouts=3 \ + actor_rollout_ref.rollout.agent.agent_loop_config_path=$AGENTLOOP_CONFIG_PATH \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console", "wandb"]' \ + trainer.project_name=verlxcollabllm \ + trainer.experiment_name=collabllm-qwen2.5-7B-$DATASET \ + trainer.nnodes=1 \ + trainer.n_gpus_per_node=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + trainer.save_freq=100 \ + trainer.test_freq=10 \ + trainer.total_epochs=20 \ + custom_reward_function.path=recipe/collabllm/reward_function.py \ + custom_reward_function.name=conversation_level_reward_func \ + actor_rollout_ref.rollout.multi_turn.interaction_config_path="$PROJECT_DIR/recipe/collabllm/config/collabllm_interaction_config.yaml" \ + trainer.resume_from_path=$RESUME_PATH diff --git a/ICL/DAPO/verl-recipe/collabllm/train_sft_collabllm.sh b/ICL/DAPO/verl-recipe/collabllm/train_sft_collabllm.sh new file mode 100644 index 0000000000000000000000000000000000000000..f2328687a1185b5b7b69ccae13f2708f9eed9461 --- /dev/null +++ b/ICL/DAPO/verl-recipe/collabllm/train_sft_collabllm.sh @@ -0,0 +1,32 @@ +#!/bin/bash +set -x + +if [ "$#" -lt 1 ]; then + echo "Usage: sft_train_collabllm.sh [ other_configs...]" + exit 1 +fi + +nproc_per_node=$1 + +# Shift the arguments so $@ refers to the rest +shift 1 + +DATASET=math-hard-large + +torchrun --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/collabllm-$DATASET/sft_train.parquet \ + data.val_files=$HOME/data/collabllm-$DATASET/sft_validation.parquet \ + data.multiturn.enable=true \ + data.multiturn.messages_key=prompt \ + optim.lr=1e-6 \ + data.train_batch_size=64 \ + data.micro_batch_size_per_gpu=2 \ + data.max_length=8196 \ + model.partial_pretrain=Qwen/Qwen2.5-7B-Instruct \ + trainer.project_name=collabllm-sft-$DATASET \ + trainer.experiment_name=collabllm-sft-qwen2.5-7B-$DATASET \ + trainer.logger=console \ + trainer.total_epochs=3 $@ \ + ulysses_sequence_parallel_size=1 \ + use_remove_padding=true $@ \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/dapo/README.md b/ICL/DAPO/verl-recipe/dapo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5545be1acd14ac0c2a7d3a44ac11fa37e9138e94 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/README.md @@ -0,0 +1,192 @@ +# Recipe: Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO) + +> Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211) + +> [!IMPORTANT] +> +> **🔥 News!!!** +> +> - [2025/04] We reproduced the results of two versions of DAPO ([Full](./run_dapo_qwen2.5_32b.sh) & [w/o Dynamic Sampling](./run_dapo_wo_ds_qwen2.5_32b.sh)), achieving 52% and 50% on AIME 2024 respectively, based on [the latest codebase on `recipe/dapo`](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo). Please check the details in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n). +> - [2025/03] We published the training record of [an early version of DAPO (w/o Token-level PG Loss & Dynamic Sampling)](./run_dapo_early_qwen2.5_32b.sh), achieving 44% on AIME 2024, in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n). + +🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper@arXiv](https://arxiv.org/abs/2503.14476) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO) + +> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Our system is based on the awesome [verl](https://github.com/volcengine/verl) framework. Thanks for their great work! Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps. +> +> ![dapo-main-result](https://dapo-sia.github.io/static/images/score.png) + +## Quickstart + +1. Prepare the datasets **on the Ray cluster**: + +```bash +bash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default +``` + +2. Submit the job to the Ray cluster **from any machine**: + +```bash +cd verl # Repo root +export RAY_ADDRESS="http://${RAY_IP:-localhost}:8265" # The Ray cluster address to connect to +export WORKING_DIR="${PWD}" # The local directory to package to the Ray cluster +# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml +export RUNTIME_ENV="./recipe/dapo/runtime_env.yaml" # This sets environment variables for the Ray cluster +bash recipe/dapo/run_dapo_qwen2.5_32b.sh # or other scripts +``` + +## Reproduction Runs + +| Setup | AIME 2024 Acc. | Hardware | Image | Commit | Environment Variables | Training Script | Training Record | +| -------------------------------------------- | -------------- | --------- | -------------------------------------------------------------------- | -------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | +| DAPO | 52% | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | +| DAPO w/o Dynamic Sampling | 50% | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_wo_ds_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | +| DAPO w/o Token-level Loss & Dynamic Sampling | 44% | 16x8xH20 | `hiyouga/verl:ngc-th2.5.1-cu120-vllm0.7.4-hotfix` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_early_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | + +> [!IMPORTANT] +> +> **📢 Call for Contribution!** +> +> Welcome to submit your reproduction runs and setups! + +## Configuration + +### Separated Clip Epsilons (-> Clip-Higher) + +An example configuration: + +```yaml +actor_rollout_ref: + actor: + clip_ratio_low: 0.2 + clip_ratio_high: 0.28 +``` + +`clip_ratio_low` and `clip_ratio_high` specify the $\varepsilon_{\text {low }}$ and $\varepsilon_{\text {high }}$ in the DAPO objective. + +Core relevant code: + +```python +pg_losses1 = -advantages * ratio +pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) +pg_losses = torch.maximum(pg_losses1, pg_losses2) +``` + +### Dynamic Sampling (with Group Filtering) + +An example configuration: + +```yaml +data: + gen_batch_size: 1536 + train_batch_size: 512 +algorithm: + filter_groups: + enable: True + metric: acc # score / seq_reward / seq_final_reward / ... + max_num_gen_batches: 10 # Non-positive values mean no upper limit +``` + +Setting `filter_groups.enable` to `True` will filter out groups whose outputs' `metric` are all the same, e.g., for `acc`, groups whose outputs' accuracies are all 1 or 0. + +The trainer will repeat sampling with `gen_batch_size` until there are enough qualified groups for `train_batch_size` or reaching the upper limit specified by `max_num_gen_batches`. + +Core relevant code: + +```python +prompt_bsz = self.config.data.train_batch_size +if num_prompt_in_batch < prompt_bsz: + print(f'{num_prompt_in_batch=} < {prompt_bsz=}') + num_gen_batches += 1 + max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: + print(f'{num_gen_batches=} < {max_num_gen_batches=}. Keep generating...') + continue + else: + raise ValueError( + f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.' + ) +else: + # Align the batch + traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + batch = batch[:traj_bsz] +``` + +### Flexible Loss Aggregation Mode (-> Token-level Loss) + +An example configuration: + +```yaml +actor_rollout_ref: + actor: + loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" + # NOTE: "token-mean" is the default behavior +``` + +Setting `loss_agg_mode` to `token-mean` will mean the (policy gradient) loss across all the tokens in all the sequences in a mini-batch. + +Core relevant code: + +```python +if loss_agg_mode == "token-mean": + loss = verl_F.masked_mean(loss_mat, loss_mask) +elif loss_agg_mode == "seq-mean-token-sum": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum + loss = torch.mean(seq_losses) # seq-mean +elif loss_agg_mode == "seq-mean-token-mean": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean + loss = torch.mean(seq_losses) # seq-mean +else: + raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") +``` + +### Overlong Reward Shaping + +An example configuration: + +```yaml +data: + max_response_length: 20480 # 16384 + 4096 +reward_model: + overlong_buffer: + enable: True + len: 4096 + penalty_factor: 1.0 +``` + +Setting `overlong_buffer.enable` to `True` will penalize the outputs whose lengths are overlong but still within the hard context limit. + +Specifically, the penalty increases linearly from `0` to `overlong_buffer.penalty_factor` when the length of the output exceeds the `max_response_length - overlong_buffer.len` by `0` to `overlong_buffer.len` tokens. + +Core relevant code: + +```python +if self.overlong_buffer_cfg.enable: + overlong_buffer_len = self.overlong_buffer_cfg.len + expected_len = self.max_resp_len - overlong_buffer_len + exceed_len = valid_response_length - expected_len + overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor + overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) + reward += overlong_reward +``` + +## FAQ + +### Where is the "Overlong Filtering" in the paper? + +Most experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here. + +### What's the difference between [the `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl-recipe/tree/main/dapo) and the [`recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo)? + +[The `recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) is for **as-is reproduction** and thus won't be updated with new features. + +[The `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl-recipe/tree/main/dapo) works as an example of how to extend the latest `verl` to implement an algorithm recipe, which will be maintained with new features. + +### Why can't I produce similar results after modifications? + +RL infrastructures nowadays still have inherent unrobustness, on which we are still working hard to improve. + +We strongly recommend to only modify one thing at a time. + +We also list some known problems here: + +1. Enabling CUDA graph (`enforce_eager=False`) might cause model performance degradation, whose cause is still under investigation. diff --git a/ICL/DAPO/verl-recipe/dapo/dapo_ray_trainer.py b/ICL/DAPO/verl-recipe/dapo/dapo_ray_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..b6e52851caf9ad091af3aee870cb2193b3917795 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/dapo_ray_trainer.py @@ -0,0 +1,418 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FSDP PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import os +import uuid +from collections import defaultdict +from copy import deepcopy +from pprint import pprint + +import numpy as np +import torch +from tqdm import tqdm + +from verl import DataProto +from verl.trainer.ppo.core_algos import agg_loss +from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics +from verl.trainer.ppo.ray_trainer import ( + AdvantageEstimator, + RayPPOTrainer, + apply_kl_penalty, + compute_advantage, + compute_response_mask, +) +from verl.trainer.ppo.reward import compute_reward +from verl.utils.metric import reduce_metrics +from verl.utils.profiler import marked_timer +from verl.utils.rollout_skip import RolloutSkip + + +class RayDAPOTrainer(RayPPOTrainer): + """ + Note that this trainer runs on the driver process on a single CPU/GPU node. + """ + + def compute_kl_related_metrics(self, batch: DataProto, metrics: dict, timing_raw: dict): + batch.batch["response_mask"] = compute_response_mask(batch) + + # recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, "blue"): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer("ref", timing_raw, "olive"): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + return batch + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + self.gen_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip.wrap_generate_sequences() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + self.gen_steps += 1 + last_val_metrics = None + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + timing_raw = defaultdict(float) + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False) + metrics = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + + new_batch: DataProto = DataProto.from_single_dict(batch_dict) + num_gen_batches += 1 + gen_batch = self._get_gen_batch(new_batch) + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + is_last_step = self.global_steps >= self.total_training_steps + + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, "red"): + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + with marked_timer("gen_max", timing_raw, "red"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + + new_batch = new_batch.union(gen_baseline_output) + # compute reward model score on new_batch + rm_scores = None + if self.use_rm and "rm_scores" not in new_batch.batch.keys(): + rm_scores = self.rm_wg.compute_rm_score(new_batch) + new_batch = new_batch.union(rm_scores) + reward_baseline_tensor, _ = compute_reward(new_batch, self.reward_fn) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + keys_to_pop = set(gen_baseline_output.batch.keys()) + if rm_scores is not None: + keys_to_pop.update(rm_scores.batch.keys()) + new_batch.pop(batch_keys=list(keys_to_pop)) + + new_batch.batch["reward_baselines"] = reward_baseline_tensor + + del rm_scores, gen_baseline_batch, gen_baseline_output + + new_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object + ) + # repeat to align with repeated responses in rollout + new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + new_batch = new_batch.union(gen_batch_output) + + if self.config.algorithm.use_kl_in_reward: + # We need these metrics for apply_kl_penalty if using kl in reward + new_batch = self.compute_kl_related_metrics(new_batch, metrics, timing_raw) + # otherwise, we will compute those after dynamic sampling + + with marked_timer("reward", timing_raw, "yellow"): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm and "rm_scores" not in new_batch.batch.keys(): + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(new_batch) + new_batch = new_batch.union(reward_tensor) + + # we combine with rule-based rm + reward_tensor, reward_extra_infos_dict = compute_reward(new_batch, self.reward_fn) + + new_batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + new_batch.non_tensor_batch.update( + {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + ) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + new_batch, kl_metrics = apply_kl_penalty( + new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update( + kl_metrics + ) # TODO: This will be cleared if we use multiple genenration batches + else: + new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] + + if not self.config.algorithm.filter_groups.enable: + batch = new_batch + else: # NOTE: When prompts after filtering is less than train batch size, + # we skip to the next generation batch + metric_name = self.config.algorithm.filter_groups.metric + if metric_name == "seq_final_reward": + # Turn to numpy for easier filtering + new_batch.non_tensor_batch["seq_final_reward"] = ( + new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() + ) + elif metric_name == "seq_reward": + new_batch.non_tensor_batch["seq_reward"] = ( + new_batch.batch["token_level_scores"].sum(dim=-1).numpy() + ) + + # Collect the sequence reward for each trajectory + prompt_uid2metric_vals = defaultdict(list) + for uid, metric_val in zip( + new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True + ): + prompt_uid2metric_vals[uid].append(metric_val) + + prompt_uid2metric_std = {} + for prompt_uid, metric_vals in prompt_uid2metric_vals.items(): + prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) + + kept_prompt_uids = [ + uid + for uid, std in prompt_uid2metric_std.items() + if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 + ] + num_prompt_in_batch += len(kept_prompt_uids) + + kept_traj_idxs = [] + for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]): + if traj_from_prompt_uid in kept_prompt_uids: + kept_traj_idxs.append(idx) + + new_batch = new_batch[kept_traj_idxs] + batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) + + prompt_bsz = self.config.data.train_batch_size + if num_prompt_in_batch < prompt_bsz: + print(f"{num_prompt_in_batch=} < {prompt_bsz=}") + max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: + print(f"{num_gen_batches=}. Keep generating...") + self.gen_steps += 1 + is_last_step = self.global_steps >= self.total_training_steps + continue + else: + raise ValueError( + f"{num_gen_batches=} >= {max_num_gen_batches=}." + + " Generated too many. Please check if your data are too difficult." + + " You could also try set max_num_gen_batches=0 to enable endless trials." + ) + else: + # Align the batch + traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + batch = batch[:traj_bsz] + + # === Updating === + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + if not self.config.algorithm.use_kl_in_reward: + batch = self.compute_kl_related_metrics(batch, metrics, timing_raw) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, "cyan"): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + # Compute rollout correction weights and off-policy metrics (inherited from RayPPOTrainer) + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) + + with marked_timer("adv", timing_raw, "brown"): + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, "pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, "red"): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, "green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with marked_timer("save_checkpoint", timing_raw, "green"): + self._save_checkpoint() + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + timing_raw = defaultdict(float) # clear timing + + metrics["train/num_gen_batches"] = num_gen_batches + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + if is_last_step: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True) + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + progress_bar.update(1) + self.global_steps += 1 + self.gen_steps += 1 + # check if last step checkpint exists + checkpoint_dir = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}") + if not os.path.exists(checkpoint_dir): + # save last step checkpoint + timing_raw = defaultdict(float) + with marked_timer("save_checkpoint", timing_raw, "green"): + self._save_checkpoint() + metrics = {f"timing/{k}": v for k, v in timing_raw.items()} + logger.log(data=metrics, step=self.global_steps) diff --git a/ICL/DAPO/verl-recipe/dapo/main_dapo.py b/ICL/DAPO/verl-recipe/dapo/main_dapo.py new file mode 100644 index 0000000000000000000000000000000000000000..47d8cf3edd652c788f1d03b58a50b7c6c3d744a0 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/main_dapo.py @@ -0,0 +1,185 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.ppo.reward import load_reward_manager +from verl.utils.device import auto_set_device, is_cuda_available + +from .dapo_ray_trainer import RayDAPOTrainer + + +@hydra.main(config_path="config", config_name="dapo_trainer", version_base=None) +def main(config): + # Automatically set `config.trainer.device = npu` when running on Ascend NPU. + auto_set_device(config) + + run_ppo(config) + + +def run_ppo(config) -> None: + if not ray.is_initialized(): + # this is for local ray cluster + default_runtime_env = get_ppo_ray_runtime_env() + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + try: + if ( + is_cuda_available + and config.global_profiler.tool == "nsys" + and OmegaConf.select(config.global_profiler, "steps") is not None + and len(OmegaConf.select(config.global_profiler, "steps")) > 0 + ): + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + finally: + if ray.is_initialized(): + ray.shutdown() + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + def run(self, config): + # print initial config + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_to_local(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # used for multimodal LLM, could be none + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + from verl.single_controller.ray import RayWorkerGroup + + # define worker classes + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} + + from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.workers.megatron_workers import AsyncActorRolloutRefWorker, CriticWorker + + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker), + Role.Critic: ray.remote(CriticWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + } + + # we should adopt a multi-source reward function here + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # - finally, we combine all the rewards together + # - The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # reference model + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(AsyncActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + reward_fn = load_reward_manager( + config, + tokenizer, + 0, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer, + ) + + # Note that we always use function-based RM for validation + val_reward_fn = load_reward_manager( + config, + tokenizer, + 1, + max_resp_len=config.data.max_response_length, + overlong_buffer_cfg=config.reward_model.overlong_buffer, + ) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayDAPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + ) + trainer.init_workers() + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/dapo/prepare_dapo_data.sh b/ICL/DAPO/verl-recipe/dapo/prepare_dapo_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..b5dbb25a7dd3f0826eb435bb32ee317bff029322 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/prepare_dapo_data.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -uxo pipefail + +export VERL_HOME=${VERL_HOME:-"${HOME}/verl"} +export TRAIN_FILE=${TRAIN_FILE:-"${VERL_HOME}/data/dapo-math-17k.parquet"} +export TEST_FILE=${TEST_FILE:-"${VERL_HOME}/data/aime-2024.parquet"} +export OVERWRITE=${OVERWRITE:-0} + +mkdir -p "${VERL_HOME}/data" + +if [ ! -f "${TRAIN_FILE}" ] || [ "${OVERWRITE}" -eq 1 ]; then + wget -O "${TRAIN_FILE}" "https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k/resolve/main/data/dapo-math-17k.parquet?download=true" +fi + +if [ ! -f "${TEST_FILE}" ] || [ "${OVERWRITE}" -eq 1 ]; then + wget -O "${TEST_FILE}" "https://huggingface.co/datasets/BytedTsinghua-SIA/AIME-2024/resolve/main/data/aime-2024.parquet?download=true" +fi diff --git a/ICL/DAPO/verl-recipe/dapo/run dapo_qwen2.5_vl_32b_fsdp2_npu.sh b/ICL/DAPO/verl-recipe/dapo/run dapo_qwen2.5_vl_32b_fsdp2_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..9fcfdc15a0cd633626cc4c03fbbae23d420a42cc --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run dapo_qwen2.5_vl_32b_fsdp2_npu.sh @@ -0,0 +1,151 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export VLLM_USE_V1=1 +export HCCL_CONNECT_TIMEOUT=5400 +export VLLM_ASCEND_ENABLE_NZ=0 +export LD_PRELOAD=/usr/local/lib/libjemalloc.so.2 +# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, +# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. +export USE_OPTIMIZED_MODEL=0 + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-vl-32B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=1024 +max_response_length=2048 +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=4 +train_prompt_bsz=64 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=8 +train_prompt_mini_bsz=16 + +# Ray +PWD=./ +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-VL-32B-Instruct"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/geo3k/train.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/geo3k/test.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=4 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +gen_tp=4 +fsdp_size=-1 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + --address "${RAY_ADDRESS}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.ref.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=console \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes=1 \ + trainer.val_before_train=True \ + trainer.test_freq=1 \ + trainer.save_freq=20 \ + trainer.resume_mode=auto \ + trainer.device=npu \ + trainer.total_epochs=30 \ + trainer.total_training_steps=100 \ + trainer.default_local_dir="${CKPTS_DIR}" diff --git a/ICL/DAPO/verl-recipe/dapo/run dapo_qwen2.5_vl_3b_fsdp2_npu.sh b/ICL/DAPO/verl-recipe/dapo/run dapo_qwen2.5_vl_3b_fsdp2_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..0abad8075da246ca44829d78250671b1a20905f5 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run dapo_qwen2.5_vl_3b_fsdp2_npu.sh @@ -0,0 +1,154 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export VLLM_USE_V1=1 +export HCCL_CONNECT_TIMEOUT=5400 +export VLLM_ASCEND_ENABLE_NZ=0 +export LD_PRELOAD=/usr/local/lib/libjemalloc.so.2 +# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, +# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. +export USE_OPTIMIZED_MODEL=0 + + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-vl-3B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=1024 +max_response_length=2048 +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=4 +train_prompt_bsz=64 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=8 +train_prompt_mini_bsz=16 + +# Ray +PWD=./ +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-VL-3B-Instruct"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/geo3k/train.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/geo3k/test.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=1 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +offload=True +gen_tp=1 +fsdp_size=-1 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + --address "${RAY_ADDRESS}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.ref.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=console \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.val_before_train=True \ + trainer.test_freq=1 \ + trainer.save_freq=20 \ + trainer.device=npu \ + trainer.resume_mode=auto \ + trainer.total_epochs=30 \ + trainer.total_training_steps=100 \ + trainer.default_local_dir="${CKPTS_DIR}" + diff --git a/ICL/DAPO/verl-recipe/dapo/run dapo_qwen2.5_vl_7b_fsdp2_npu.sh b/ICL/DAPO/verl-recipe/dapo/run dapo_qwen2.5_vl_7b_fsdp2_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..840b919922a82838792cc784d5cf20771b1c5c60 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run dapo_qwen2.5_vl_7b_fsdp2_npu.sh @@ -0,0 +1,153 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export VLLM_USE_V1=1 +export HCCL_CONNECT_TIMEOUT=5400 +export VLLM_ASCEND_ENABLE_NZ=0 +export LD_PRELOAD=/usr/local/lib/libjemalloc.so.2 + +# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, +# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. +export USE_OPTIMIZED_MODEL=0 + + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-vl-7B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=1024 +max_response_length=2048 +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=4 +train_prompt_bsz=128 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=8 +train_prompt_mini_bsz=16 + +# Ray +PWD=./ +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-VL-7B-Instruct"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/geo3k/train.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/geo3k/test.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=1 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +gen_tp=1 +fsdp_size=-1 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + --address "${RAY_ADDRESS}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.ref.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=console \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.val_before_train=True \ + trainer.test_freq=1 \ + trainer.save_freq=20 \ + trainer.resume_mode=auto \ + trainer.device=npu \ + trainer.total_epochs=30 \ + trainer.total_training_steps=100 \ + trainer.default_local_dir="${CKPTS_DIR}" \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/dapo/run dapo_qwen3_vl_30b_fsdp2_npu.sh b/ICL/DAPO/verl-recipe/dapo/run dapo_qwen3_vl_30b_fsdp2_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..d209797c10506d39ecb5f5f9b20a743f87473265 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run dapo_qwen3_vl_30b_fsdp2_npu.sh @@ -0,0 +1,152 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export VLLM_USE_V1=1 +export HCCL_CONNECT_TIMEOUT=5400 +export VLLM_ASCEND_ENABLE_NZ=0 +export LD_PRELOAD=/usr/local/lib/libjemalloc.so.2 +# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, +# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. +export USE_OPTIMIZED_MODEL=0 + +project_name='DAPO' +exp_name='DAPO-Qwen3-vl-30B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=1024 +max_response_length=2048 +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=4 +train_prompt_bsz=64 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=8 +train_prompt_mini_bsz=16 + +# Ray +PWD=./ +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-VL-30B-A3B-Instruct"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/geo3k/train.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/geo3k/test.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +gen_tp=8 +fsdp_size=16 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + --address "${RAY_ADDRESS}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.70 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.expert_parallel_size=8 \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + +actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ + actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.ref.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=console \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=2 \ + trainer.val_before_train=True \ + trainer.test_freq=1 \ + trainer.save_freq=20 \ + trainer.resume_mode=auto \ + trainer.device=npu \ + trainer.total_epochs=30 \ + trainer.total_training_steps=100 \ + trainer.default_local_dir="${CKPTS_DIR}" diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_early_qwen2.5_32b.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_early_qwen2.5_32b.sh new file mode 100644 index 0000000000000000000000000000000000000000..517e5cefcacc993551fb3bdaabecfe8947e8366a --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_early_qwen2.5_32b.sh @@ -0,0 +1,129 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Early-Qwen2.5-32B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# An early version for DAPO +loss_agg_mode="seq-mean-token-mean" + +enable_filter_groups=False +gen_prompt_bsz=512 # NOTE: no filtering here +train_prompt_bsz=512 +train_prompt_mini_bsz=32 +n_resp_per_prompt=16 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-16} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=True +gen_tp=4 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=5 \ + trainer.save_freq=5 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b.sh new file mode 100644 index 0000000000000000000000000000000000000000..0ec1047a1712b7d09ab59618e8d3371e113513ac --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-32B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-16} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=True +gen_tp=4 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=5 \ + trainer.save_freq=5 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b_fsdp2_20k_npu.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b_fsdp2_20k_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..f0c442ae41c63d08ed9841819dbe5b5228db31f5 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b_fsdp2_20k_npu.sh @@ -0,0 +1,151 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export VLLM_USE_V1=1 +export HCCL_OP_EXPANSION_MODE="AIV" +export VLLM_ASCEND_ENABLE_FLASHCOMM=1 +export HCCL_EXEC_TIMEOUT=3600 +export HCCL_CONNECT_TIMEOUT=3600 + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-32B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=32 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-1} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +offload=True +gen_tp=4 +gen_dp=1 +enable_chunked_prefill=True + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.data_parallel_size=${gen_dp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=${enable_chunked_prefill} \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.rollout.expert_parallel_size=1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=100 \ + trainer.save_freq=100 \ + trainer.total_epochs=100 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.device='npu' \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False $@ + diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b_fsdp2_4k_npu.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b_fsdp2_4k_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..ae4968ffbe5280fb9dc4b7ac51d89fce9f35c7bb --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b_fsdp2_4k_npu.sh @@ -0,0 +1,155 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export VLLM_USE_V1=1 +export HCCL_OP_EXPANSION_MODE="AIV" +export VLLM_ASCEND_ENABLE_FLASHCOMM=1 +export HCCL_EXEC_TIMEOUT=3600 +export HCCL_CONNECT_TIMEOUT=3600 + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-32B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 4)) +min_response_length=$((1024 * 4)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=32 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-1} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=True +gen_tp=4 +gen_dp=1 +enable_chunked_prefill=True + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + --address "${RAY_ADDRESS}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + +data.min_response_length=${min_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.data_parallel_size=${gen_dp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=${enable_chunked_prefill} \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.ignore_eos=True \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.rollout.expert_parallel_size=1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=100 \ + trainer.save_freq=100 \ + trainer.total_epochs=100 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.device='npu' \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False $@ + diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b_npu.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..bce3ab8eca6c5b5392f2376a673c620d1253d4c6 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b_npu.sh @@ -0,0 +1,140 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO-Qwen2.5-32B' +exp_name='Qwen2.5-32B-npu-32rank-gbs128' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 +clip_ratio_low=0.2 +clip_ratio_high=0.28 +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 +loss_agg_mode="token-mean" +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 + +NNODES=2 + +train_prompt_bsz=128 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +PWD=./ +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +offload=True +gen_tp=4 +enable_chunked_prefill=True + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + --address "${RAY_ADDRESS}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.90 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=${enable_chunked_prefill} \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger="['console','wandb']" \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=5 \ + trainer.save_freq=20 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh new file mode 100644 index 0000000000000000000000000000000000000000..b46feb9ba254f3cb2f9f41339399fb4a89dde4bd --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh @@ -0,0 +1,176 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +# Rollout Correction Example +# References: +# - Rollout Correction Docs: https://github.com/volcengine/verl/blob/main/docs/algo/rollout_corr.md +# - Rollout Correction Math: https://github.com/volcengine/verl/blob/main/docs/algo/rollout_corr_math.md +# - When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch: https://richardli.xyz/rl-collapse +# - Off-policy RL: https://fengyao.notion.site/off-policy-rl + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-32B-RolloutCorr' # Rollout Correction + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +# Rollout Correction parameters (sequence-level TIS + geometric RS) +rollout_is=sequence +rollout_is_threshold=2.0 +rollout_is_batch_normalize=true +rollout_rs=geometric +rollout_rs_threshold=1.01 +rollout_rs_threshold_lower=0.99 +rollout_token_veto_threshold=1e-4 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-16} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=True +gen_tp=4 + + +# Rollout Correction (corrects distribution mismatch between rollout and training) +# +# Configuration: DAPO with Rollout Correction: +# - Self-normalized sequence-level TIS (Truncated Importance Sampling) +# - Geometric rejection sampling for outlier filtering +# - Token veto for catastrophic distribution shifts +# +# Please note that server mode (agent loop) hasn't returned rollout_log_probs for now, +# so currently server mode is not supported for Rollout Correction. +# +# Rollout Correction parameters (configured at top of script): +# algorithm.rollout_correction.rollout_is=sequence +# algorithm.rollout_correction.rollout_is_threshold=2.0 +# algorithm.rollout_correction.rollout_is_batch_normalize=true +# algorithm.rollout_correction.rollout_rs=geometric +# algorithm.rollout_correction.rollout_rs_threshold=1.01 +# algorithm.rollout_correction.rollout_rs_threshold_lower=0.99 +# algorithm.rollout_correction.rollout_token_veto_threshold=1e-4 +# actor_rollout_ref.rollout.calculate_log_probs=True # Required! + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + algorithm.rollout_correction.rollout_is=${rollout_is} \ + algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \ + algorithm.rollout_correction.rollout_is_batch_normalize=${rollout_is_batch_normalize} \ + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ + algorithm.rollout_correction.rollout_rs_threshold_lower=${rollout_rs_threshold_lower} \ + algorithm.rollout_correction.rollout_token_veto_threshold=${rollout_token_veto_threshold} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=5 \ + trainer.save_freq=5 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_7b_npu.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_7b_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..834ab21fa6de055365d1932e51dbcc2869b7e266 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen2.5_7b_npu.sh @@ -0,0 +1,142 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO-Qwen2.5-7B-Instruct' +exp_name='DAPO-Qwen2.5-7B-Instruct' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 +clip_ratio_low=0.2 +clip_ratio_high=0.28 +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 +loss_agg_mode="token-mean" +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 + +NNODES=1 + +train_prompt_bsz=16 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=1 + +# Ray +PWD=./ +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-7B-Instruct"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Performance Related Parameter +sp_size=4 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +offload=True +gen_tp=1 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + --address "${RAY_ADDRESS}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.50 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger="['console']" \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=5 \ + trainer.save_freq=20 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.entropy_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_14b_base_npu.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_14b_base_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..9e0fdae374c0979110a49f5e000bfeb12e229dd6 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_14b_base_npu.sh @@ -0,0 +1,139 @@ +#!/bin/bash +project_name='DAPO' +exp_name='DAPO-Qwen3-14B-Base' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=16 +gen_prompt_bsz=$((train_prompt_bsz * 2)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=1 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-2} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-14B-Base"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Performance Related Parameter +sp_size=2 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +offload=True +gen_tp=2 + +ray job submit --runtime-env="${RUNTIME_ENV}" \ + --address "${RAY_ADDRESS}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=8 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=20 \ + trainer.total_epochs=1 \ + trainer.total_training_steps=100 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + data.shuffle=False \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.entropy_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=True diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_30b_fsdp_6k_npu.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_30b_fsdp_6k_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..4bce41a0199dc5de74f1f4b1c8c9b1f9a534d987 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_30b_fsdp_6k_npu.sh @@ -0,0 +1,161 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export VLLM_USE_V1=1 +export HCCL_OP_EXPANSION_MODE="AIV" +export VLLM_ASCEND_ENABLE_FLASHCOMM=1 +export HCCL_EXEC_TIMEOUT=3600 +export HCCL_CONNECT_TIMEOUT=3600 + +project_name='DAPO' +exp_name='DAPO-Qwen3-30B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 6)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=32 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=16 +max_num_seqs=1024 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-2} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=1 +use_dynamic_bsz=True +log_prob_micro_batch_size_per_gpu=1 +ppo_micro_batch_size_per_gpu=1 +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +max_num_batched_tokens=$(((max_prompt_length + max_response_length) * 4)) +offload=True +gen_tp=2 +gen_dp=1 +enable_chunked_prefill=True + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + --address "${RAY_ADDRESS}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.max_num_seqs=${max_num_seqs} \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.data_parallel_size=${gen_dp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=${enable_chunked_prefill} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.strategy=fsdp \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.expert_parallel_size=1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes="${NNODES}" \ + trainer.device='npu' \ + trainer.val_before_train=False \ + trainer.test_freq=200 \ + trainer.save_freq=50 \ + trainer.total_epochs=100 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${log_prob_micro_batch_size_per_gpu} \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${log_prob_micro_batch_size_per_gpu} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${ppo_micro_batch_size_per_gpu} \ + ++actor_rollout_ref.nccl_timeout=7200 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False $@ \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_moe_30b_base_fsdp_npu.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_moe_30b_base_fsdp_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..52fb0b4e6a5ebb04d8bccc437115e91e8b4f3561 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_moe_30b_base_fsdp_npu.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash +set -euxo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen3-MOE-30B-FSDP-128rank-gbs512' + +NNODES=8 +NPUS_PER_NODE=16 + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 +loss_agg_mode="token-mean" +ppo_mini_batch_size=32 + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=16 + +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=16 # For load-balance. For smaller cluster this can be set to as less as 2. +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / 2)) +offload=True +recompute=True +max_num_seqs=128 +gen_tp=2 + + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.max_num_seqs=${max_num_seqs} \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=${recompute} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=False \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.fsdp_config.forward_prefetch=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${NPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.save_freq=-1 \ + trainer.total_epochs=1 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False + diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_moe_30b_megatron_npu.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_moe_30b_megatron_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..24624275929ed5eccd604ae633b6cddce090ca21 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_moe_30b_megatron_npu.sh @@ -0,0 +1,170 @@ +#!/bin/bash + +project_name='DAPO' +exp_name='DAPO-Qwen3-30B-megatron' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=16 +gen_prompt_bsz=$((train_prompt_bsz * 2)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=2 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-1} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B"} +# MCORE_MODEL_PATH points to the converted checkpoint. +# To avoid loading these weights, set actor_rollout_ref.actor.megatron.use_dist_checkpointing=False. +MCORE_MODEL_PATH=${MCORE_MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-dist_ckpt"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length))) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length))) +offload=True + +max_num_batched_tokens=$((max_prompt_length + max_response_length)) + +# vllm +gen_tp=4 + +# Megatron backen +train_tp=4 +train_ep=2 +train_pp=2 +train_cp=1 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --address "${RAY_ADDRESS}" \ + -- python3 -m recipe.dapo.main_dapo \ + --config-name="dapo_megatron_trainer" \ + data.filter_overlong_prompts=False \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.shuffle=False \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_epochs=1 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} \ + actor_rollout_ref.actor.megatron.context_parallel_size=${train_cp} \ + actor_rollout_ref.actor.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} \ + actor_rollout_ref.ref.megatron.context_parallel_size=${train_cp} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + actor_rollout_ref.ref.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.enable_prefix_caching=False \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} \ + actor_rollout_ref.rollout.max_model_len=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=-1 \ + trainer.save_freq=-1 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + actor_rollout_ref.nccl_timeout=14400 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False \ + +actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 + diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh new file mode 100644 index 0000000000000000000000000000000000000000..b9b96aa753deca0c0e4cd7ee2bad67a3be145d5d --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_qwen3_moe_30b_vllm_fp8_rollout.sh @@ -0,0 +1,171 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO-FP8-ROLLOUT' +exp_name='DAPO-Qwen3-MOE-30B-VLLM-FP8-ROLLOUT' + + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Rollout Correction parameters for FP8 rollout +rollout_is=token +rollout_is_threshold=2.0 +rollout_rs=null +rollout_rs_threshold=null +rollout_rs_threshold_lower=null +rollout_token_veto_threshold=null + +max_prompt_length=$((1024)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=512 +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=32 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +gen_prompt_bsz=96 + +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +echo "WORKING_DIR: ${WORKING_DIR}" +# For vllm 0.11.x, DEEP_GEMM is enabled by default. +# For vllm 0.10.x, please set VLLM_USE_DEEP_GEMM=1 in runtime_env.yaml +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +echo "RUNTIME_ENV: ${RUNTIME_ENV}" +NNODES=${NNODES:-2} +echo "NNODES: ${NNODES}" + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH="Qwen/Qwen3-30B-A3B-Base" +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=1.0 + +# Performance Related Parameter +sp_size=4 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=true +gen_tp=1 +train_tp=1 +train_pp=1 + +# Set Flash-RL environment variables +export VERL_LOGGING_LEVEL=DEBUG +export VLLM_LOGGING_LEVEL=DEBUG +export VLLM_CONFIGURE_LOGGING=1 +export VLLM_USE_V1=1 +export VLLM_USE_DEEP_GEMM=1 +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 + +RAY_ADDRESS='http://127.0.0.1:8265' ray job submit --runtime-env=${RUNTIME_ENV} \ +-- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=1800 \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.rollout_correction.rollout_is=${rollout_is} \ + algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \ + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ + algorithm.rollout_correction.rollout_rs_threshold_lower=${rollout_rs_threshold_lower} \ + algorithm.rollout_correction.rollout_token_veto_threshold=${rollout_token_veto_threshold} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$(( 1024 * 32 )) \ + actor_rollout_ref.rollout.max_num_seqs=256 \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + +actor_rollout_ref.rollout.quantization=fp8 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + reward_model.overlong_buffer.log=False \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.save_freq=5 \ + trainer.total_epochs=100 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=1 \ + trainer.total_training_steps=500 \ + trainer.max_actor_ckpt_to_keep=5 \ + actor_rollout_ref.rollout.enforce_eager=False diff --git a/ICL/DAPO/verl-recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh b/ICL/DAPO/verl-recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh new file mode 100644 index 0000000000000000000000000000000000000000..50c18eadb12ad24fdab9ad68efb50be2fa450341 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh @@ -0,0 +1,126 @@ +#!/usr/bin/env bash +set -euxo pipefail +# DAPO (w/o Dynamic Sampling) + +project_name='DAPO-verl' +exp_name='DAPO-wo-DS-Qwen2.5-32B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +train_prompt_bsz=512 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-16} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=True +gen_tp=4 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=5 \ + trainer.save_freq=5 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto diff --git a/ICL/DAPO/verl-recipe/dapo/runtime_env.yaml b/ICL/DAPO/verl-recipe/dapo/runtime_env.yaml new file mode 100644 index 0000000000000000000000000000000000000000..13f4b2ba230b892a277026d53a98cb42afc4ae4d --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/runtime_env.yaml @@ -0,0 +1,5 @@ +working_dir: ./ +excludes: ["/.git/"] +env_vars: + TORCH_NCCL_AVOID_RECORD_STREAMS: "1" + VLLM_USE_V1: "1" diff --git a/ICL/DAPO/verl-recipe/dapo/test_dapo_7b.sh b/ICL/DAPO/verl-recipe/dapo/test_dapo_7b.sh new file mode 100644 index 0000000000000000000000000000000000000000..a1201dc32f7afeaa4d645ef083c6291440f54753 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/test_dapo_7b.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7B-Math-Test' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 2)) +enable_overlong_buffer=True +overlong_buffer_len=512 +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=512 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +train_prompt_mini_bsz=32 +n_resp_per_prompt=16 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Mathematically equivalent +use_dynamic_bsz=True +infer_micro_batch_size=null +train_micro_batch_size=null +offload=False + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m recipe.dapo.main_dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=2 \ + trainer.save_freq=2 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable diff --git a/ICL/DAPO/verl-recipe/dapo/test_dapo_7b_math.sh b/ICL/DAPO/verl-recipe/dapo/test_dapo_7b_math.sh new file mode 100644 index 0000000000000000000000000000000000000000..e7fa99268689f636f7988127ae4627195116a8b0 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/test_dapo_7b_math.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7b-MATH-0527a1' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-8} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=4 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=4 +fsdp_size=32 + +# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361 + +python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=10 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=200 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/ICL/DAPO/verl-recipe/dapo/test_dapo_7b_math_lora.sh b/ICL/DAPO/verl-recipe/dapo/test_dapo_7b_math_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..06c66baa42f005de27d43651632d08646a8df8e5 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/test_dapo_7b_math_lora.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7b-MATH-0527a1' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-8} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=4 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=4 +fsdp_size=32 + +# remember to set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 for this model + +python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.lora_rank=8 \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=10 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=200 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/ICL/DAPO/verl-recipe/dapo/test_dapo_7b_math_megatron.sh b/ICL/DAPO/verl-recipe/dapo/test_dapo_7b_math_megatron.sh new file mode 100644 index 0000000000000000000000000000000000000000..a866e968508d3184150549228a1b6fc746f20728 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/test_dapo_7b_math_megatron.sh @@ -0,0 +1,132 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7b-MATH-megatron-0519a1' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=4 +train_tp=4 +train_pp=2 + +# TODO: support dynamic_bsz for megatron +# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ +# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ +# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ +# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ +# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ +# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=10 \ + trainer.total_epochs=10 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/ICL/DAPO/verl-recipe/dapo/test_dapo_8b_megatron_fp16.sh b/ICL/DAPO/verl-recipe/dapo/test_dapo_8b_megatron_fp16.sh new file mode 100644 index 0000000000000000000000000000000000000000..0dfd77854cb3f547fd8311301da9b9d140db29c3 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/test_dapo_8b_megatron_fp16.sh @@ -0,0 +1,142 @@ +#!/usr/bin/env bash +set -xeuo pipefail + + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi +dtype="float16" # ["bfloat16", "float16"] + +project_name='DAPO-fp16' +exp_name='fp16' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=32 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-1} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-8B-Base"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +offload=True +gen_tp=1 +train_tp=2 +train_pp=1 + +# TODO: support dynamic_bsz for megatron + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.return_raw_chat=$return_raw_chat \ + data.truncation='left' \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.dtype=${dtype} \ + actor_rollout_ref.actor.megatron.dtype=${dtype} \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.actor.megatron.use_mbridge=True \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + reward_model.reward_manager=dapo \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 + diff --git a/ICL/DAPO/verl-recipe/dapo/test_dapo_8b_megatron_fp8train.sh b/ICL/DAPO/verl-recipe/dapo/test_dapo_8b_megatron_fp8train.sh new file mode 100644 index 0000000000000000000000000000000000000000..5827abdd8793ccc7a4634c62f609aed0b9f9140f --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/test_dapo_8b_megatron_fp8train.sh @@ -0,0 +1,201 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +# need cuda12.9 or higher +# use docker://verlai/verl:dev.vllm_nightly-243ed7d32e94f00a9a32fbbc51be932f6277a55d or self build + + +# this env var is required for TE fp8 training +# if you are running multiple nodes, you need to set this env var in RUNTIME_ENV +export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 + +################################################### quick config ################################################### + + +rollout_mode="sync" +rollout_name="vllm" # sglang or vllm +return_raw_chat="False" +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi +dtype="bfloat16" # ["bfloat16", "float16"] + +project_name='DAPO' +exp_name='fp8train' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=32 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-1} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-8B-Base"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +offload=True +gen_tp=1 +train_tp=2 +train_pp=1 + +################################################### start of config ################################################### + +FP8=( + +actor_rollout_ref.actor.megatron.override_transformer_config.fp8="e4m3" # e4m3 or hybrid + +actor_rollout_ref.actor.megatron.override_transformer_config.fp8_recipe="blockwise" + +actor_rollout_ref.actor.optim.override_optimizer_config.fp8_recipe="blockwise" +) + +DATA=( + # dddd + data.train_files="${TRAIN_FILE}" + data.val_files="${TEST_FILE}" + data.prompt_key=prompt + data.return_raw_chat=$return_raw_chat + data.truncation='left' + data.max_prompt_length=${max_prompt_length} + data.max_response_length=${max_response_length} + data.train_batch_size=${train_prompt_bsz} +) + +REWARD_MODEL=( + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False + +reward_model.reward_kwargs.max_resp_len=${max_response_length} + reward_model.reward_manager=dapo +) + +PERF_OPT=( + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True + actor_rollout_ref.model.use_fused_kernels=False +) + +ACTOR=( + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} + actor_rollout_ref.actor.clip_ratio_c=10.0 + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} + actor_rollout_ref.actor.optim.lr=1e-6 + actor_rollout_ref.actor.optim.lr_warmup_steps=10 + actor_rollout_ref.actor.optim.weight_decay=0.1 + actor_rollout_ref.actor.optim.clip_grad=1.0 + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} + actor_rollout_ref.actor.megatron.param_offload=${offload} + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} + actor_rollout_ref.actor.megatron.grad_offload=${offload} + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} + actor_rollout_ref.actor.entropy_coeff=0 + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} + actor_rollout_ref.actor.megatron.use_mbridge=True +) + +ROLLOUT=( + actor_rollout_ref.rollout.name=${rollout_name} + actor_rollout_ref.rollout.mode=${rollout_mode} + actor_rollout_ref.rollout.dtype=${dtype} + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} + actor_rollout_ref.rollout.enable_chunked_prefill=True + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) + actor_rollout_ref.rollout.temperature=${temperature} + actor_rollout_ref.rollout.top_p=${top_p} + actor_rollout_ref.rollout.top_k=${top_k} + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} + actor_rollout_ref.rollout.val_kwargs.do_sample=True + actor_rollout_ref.rollout.val_kwargs.n=1 + actor_rollout_ref.rollout.calculate_log_probs=True + actor_rollout_ref.rollout.n=${n_resp_per_prompt} +) + +TRAINER=( + trainer.logger=['console','wandb'] + trainer.project_name="${project_name}" + trainer.experiment_name="${exp_name}" + trainer.n_gpus_per_node=8 + trainer.nnodes="${NNODES}" + trainer.val_before_train=False + trainer.test_freq=10 + trainer.save_freq=-1 + trainer.total_epochs=10 + trainer.default_local_dir="${CKPTS_DIR}" + trainer.resume_mode=auto + trainer.log_val_generations=10 +) + +FORWARD_ONLY_SETS=( + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} +) + +MODEL=( + actor_rollout_ref.model.path="${MODEL_PATH}" +) + +ALGORITHM=( + algorithm.adv_estimator=${adv_estimator} + algorithm.use_kl_in_reward=${use_kl_in_reward} + algorithm.kl_ctrl.kl_coef=${kl_coef} +) +################################################### start script ################################################### +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + -- python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + "${DATA[@]}" \ + "${ALGORITHM[@]}" \ + "${MODEL[@]}" \ + "${ROLLOUT[@]}" \ + "${ACTOR[@]}" \ + "${REWARD_MODEL[@]}" \ + "${FP8[@]}" \ + "${PERF_OPT[@]}" \ + "${TRAINER[@]}" \ + "${FORWARD_ONLY_SETS[@]}" \ \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/dapo/test_dapo_dspk_671b_megatron_96gb.sh b/ICL/DAPO/verl-recipe/dapo/test_dapo_dspk_671b_megatron_96gb.sh new file mode 100644 index 0000000000000000000000000000000000000000..a62b68c66a539a5aa74ac7d6641368a728ebc2c2 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/test_dapo_dspk_671b_megatron_96gb.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +# 0. download the config +# only need to download the configuration_deepseek.py and config.json +# remove the `quantization_config` in the `config.json` +# set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported +huggingface-cli download deepseek-ai/DeepSeek-V3-0324 configuration_deepseek.py config.json + +project_name='DAPO' +exp_name='DAPO-DeepSeek-671b-megatron' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=0.1 + +loss_agg_mode="token-mean" + +train_prompt_bsz=256 # must be > n_gpus. need to fix +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # mini_bsz * n >= micro_bsz * pp * dp + +NNODES=${NNODES:-64} + +# 1. download the dist_ckpt format model from https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main +# change the MODEL_PATH and MCORE_MODEL_PATH to your own path +# Paths +MODEL_PATH="" +MCORE_MODEL_PATH="" +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +aime24_test_path=${RAY_DATA_HOME}/data/aime-2024.parquet +# TEST_FILE="['$math500_test_path', '$aime24_test_path']" + +TEST_FILE="['$aime24_test_path']" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=32 +train_tp=1 +train_ep=32 +train_pp=16 + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} \ + actor_rollout_ref.actor.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=3 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=2 \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + actor_rollout_ref.ref.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \ + actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.save_freq=5 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=10 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/ICL/DAPO/verl-recipe/dapo/test_dapo_glm_air_megatron.sh b/ICL/DAPO/verl-recipe/dapo/test_dapo_glm_air_megatron.sh new file mode 100644 index 0000000000000000000000000000000000000000..2e7d91c07a5746d11c1be60f60c7c177d7af0c9b --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/test_dapo_glm_air_megatron.sh @@ -0,0 +1,197 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +NNODES=${NNODES:-8} +NGPUS_PER_NODES=${NGPUS_PER_NODES:-8} + +project_name='DAPO' +exp_name='DAPO-GLM-AIR-MATH-megatron' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=16 +train_prompt_mini_bsz=128 +train_ppo_micro_batch_size_per_gpu=2 +infer_ppo_micro_batch_size_per_gpu=2 +# Paths +MODEL_PATH=/models/zai-org/GLM-4.5-Air-Base +# GLM Base model can use chat_template.jinja from instruct models +cp /models/zai-org/GLM-4.5-Air/chat_template.jinja ${MODEL_PATH}/chat_template.jinja + +TRAIN_FILE=/data/dapo/dapo-math-17k.parquet +aime24_test_path=/data/dapo/aime-2024.parquet +# math500_test_path=/data/rlhf/math500/test.parquet + +# TEST_FILE="['$math500_test_path', '$aime24_test_path']" + +TEST_FILE="['$aime24_test_path']" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length))) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length))) +offload=True + +COMMON_PP=${COMMON_PP:-2} +COMMON_VPP=${COMMON_VPP:-null} +COMMON_CP=${COMMON_CP:-4} +COMMON_TP=${COMMON_TP:-2} +COMMON_EP=${COMMON_EP:-8} +COMMON_ETP=${COMMON_ETP:-1} + +TRAIN_TP=${TRAIN_TP:-$COMMON_TP} +INFER_TP=${INFER_TP:-8} + +ACTOR_PP=${ACTOR_PP:-$COMMON_PP} +ACTOR_VPP=${ACTOR_VPP:-$COMMON_VPP} +ACTOR_CP=${ACTOR_CP:-$COMMON_CP} +ACTOR_TP=${ACTOR_TP:-$TRAIN_TP} +ACTOR_EP=${ACTOR_EP:-$COMMON_EP} +ACTOR_ETP=${ACTOR_ETP:-$COMMON_ETP} +ROLLOUT_TP=${ROLLOUT_TP:-$INFER_TP} +REF_PP=${REF_PP:-$COMMON_PP} +REF_VPP=${REF_VPP:-$COMMON_VPP} +REF_CP=${REF_CP:-$COMMON_CP} +REF_TP=${REF_TP:-$TRAIN_TP} +REF_EP=${REF_EP:-$COMMON_EP} +REF_ETP=${REF_ETP:-$COMMON_ETP} +CRITIC_PP=${CRITIC_PP:-$COMMON_PP} +CRITIC_VPP=${CRITIC_VPP:-$COMMON_VPP} +CRITIC_CP=${CRITIC_CP:-$COMMON_CP} +CRITIC_TP=${CRITIC_TP:-$TRAIN_TP} +CRITIC_EP=${CRITIC_EP:-$COMMON_EP} +CRITIC_ETP=${CRITIC_ETP:-$COMMON_ETP} +RM_PP=${RM_PP:-$COMMON_PP} +RM_VPP=${RM_VPP:-$COMMON_VPP} +RM_CP=${RM_CP:-$COMMON_CP} +RM_TP=${RM_TP:-$TRAIN_TP} +RM_EP=${RM_EP:-$COMMON_EP} +RM_ETP=${RM_ETP:-$COMMON_ETP} + +USE_MBRIDGE=True +USE_DIST_CKPT=False + +# Install the latest mbridge +# pip install --no-cache-dir git+https://github.com/ISEEKYAN/mbridge.git + +python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + +actor_rollout_ref.model.override_config.model_config.max_position_embeddings=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.lr_decay_style='constant' \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.megatron.use_mbridge=$USE_MBRIDGE \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=$USE_DIST_CKPT \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${ACTOR_TP} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${ACTOR_PP} \ + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=${ACTOR_VPP} \ + actor_rollout_ref.actor.megatron.context_parallel_size=${ACTOR_CP} \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${ACTOR_EP} \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ACTOR_ETP} \ + actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity="selective" \ + actor_rollout_ref.actor.megatron.override_transformer_config.recompute_modules=["core_attn","moe_act","layernorm","mlp","moe"] \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.masked_softmax_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.bias_activation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.bias_dropout_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.deallocate_pipeline_outputs=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.persist_layer_norm=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_grouped_gemm=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_shared_expert_overlap=False \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type="flex" \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=False \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.name='vllm' \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${INFER_TP} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_ppo_micro_batch_size_per_gpu} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${REF_TP} \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${REF_PP} \ + actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=${REF_VPP} \ + actor_rollout_ref.ref.megatron.context_parallel_size=${REF_CP} \ + actor_rollout_ref.ref.megatron.expert_model_parallel_size=${REF_EP} \ + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${REF_ETP} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODES}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=10 \ + trainer.save_freq=100 \ + trainer.total_epochs=10 \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/dapo/test_dapo_gptoss_20b_megatron.sh b/ICL/DAPO/verl-recipe/dapo/test_dapo_gptoss_20b_megatron.sh new file mode 100644 index 0000000000000000000000000000000000000000..6ca432aed51c4ca44f2fc94f0314536263693690 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/test_dapo_gptoss_20b_megatron.sh @@ -0,0 +1,248 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +################################################### document for gptoss ################################################### + +####################### running environment: ####################### +# option 1: use pre-built images verlai/verl:vll012.exp or verlai/verl:sgl056.exp +# +# option 2: self build TE>=2.8 with CUDNN>=9.13.1, megatron with branch `core_dev_r0.15.0`, latest vllm or sglang +# you can modify the dockerfile to build the image, see Dockerfile at https://github.com/volcengine/verl/blob/main/docker/Dockerfile.stable.vllm or https://github.com/volcengine/verl/blob/main/docker/Dockerfile.stable.sglang + + +####################### before training: ####################### +# # install matched mbridge version +# pip uninstall -y mbridge && pip install git+https://github.com/ISEEKYAN/mbridge@gpt-oss + +# # convert gptoss to bf16 +cat > get_model.py << EOF +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, Mxfp4Config + +model_id = "openai/gpt-oss-20b" +output_dir = "$HOME/models/gpt-oss-20b-bf16" + +quantization_config = Mxfp4Config(dequantize=True) +model_kwargs = dict( + attn_implementation="eager", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config, + use_cache=False, + device_map="auto", +) + +model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + +# Patch config with custom attribute before saving +model.config.attn_implementation = "eager" + +model.save_pretrained(output_dir) +tokenizer = AutoTokenizer.from_pretrained(model_id) +tokenizer.save_pretrained(output_dir) +EOF + +python get_model.py + +####################### specific training config: ####################### + +GPT_OSS_CONFIG=( + # only support mbridge for gptoss + actor_rollout_ref.actor.megatron.use_mbridge=True + # for now (latest TE=2.10), gptoss's optimized attn kernel is not supported for thd format, so we use bshd format here + # when bshd format is used, we need to pad the input_ids to the longest sequence length + # so we recommend to disable dynamic batch size and set micro batch size to 1 to avoid paddings + # but it is ok to try with micro_batch_size>1 + actor_rollout_ref.actor.megatron.use_remove_padding=False +) +use_dynamic_bsz=False # recommended but not necessary + +################################################### quick config ################################################### + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +export VLLM_USE_V1=1 +return_raw_chat="True" +dtype="bfloat16" # ["bfloat16", "float16"] + +project_name='DAPO' +exp_name='gptoss' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=32 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-1} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/gpt-oss-20b"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +offload=True +gen_tp=4 +train_tp=4 +EP=8 +ETP=1 +train_pp=1 + +################################################### start of config ################################################### + + +DATA=( + data.train_files="${TRAIN_FILE}" + data.val_files="${TEST_FILE}" + data.prompt_key=prompt + data.return_raw_chat=$return_raw_chat + data.truncation='left' + data.max_prompt_length=${max_prompt_length} + data.max_response_length=${max_response_length} + data.train_batch_size=${train_prompt_bsz} +) + +REWARD_MODEL=( + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False + +reward_model.reward_kwargs.max_resp_len=${max_response_length} + reward_model.reward_manager=dapo +) + +PERF_OPT=( + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True + actor_rollout_ref.model.use_fused_kernels=False + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 + actor_rollout_ref.actor.megatron.override_transformer_config.attention_backend=auto + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 + +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True + +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True +) + +ACTOR=( + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} + actor_rollout_ref.actor.clip_ratio_c=10.0 + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} + actor_rollout_ref.actor.optim.lr=1e-6 + actor_rollout_ref.actor.optim.lr_warmup_steps=10 + actor_rollout_ref.actor.optim.weight_decay=0.1 + actor_rollout_ref.actor.optim.clip_grad=1.0 + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} + actor_rollout_ref.actor.megatron.param_offload=${offload} + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} + actor_rollout_ref.actor.megatron.grad_offload=${offload} + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP} + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP} + actor_rollout_ref.actor.entropy_coeff=0 + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} +) + +ROLLOUT=( + actor_rollout_ref.rollout.name=${rollout_name} + actor_rollout_ref.rollout.mode=${rollout_mode} + actor_rollout_ref.rollout.dtype=${dtype} + actor_rollout_ref.rollout.gpu_memory_utilization=0.70 + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} + actor_rollout_ref.rollout.enable_chunked_prefill=True + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) + actor_rollout_ref.rollout.temperature=${temperature} + actor_rollout_ref.rollout.top_p=${top_p} + actor_rollout_ref.rollout.top_k=${top_k} + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} + actor_rollout_ref.rollout.val_kwargs.do_sample=True + actor_rollout_ref.rollout.val_kwargs.n=1 + actor_rollout_ref.rollout.calculate_log_probs=True + actor_rollout_ref.rollout.n=${n_resp_per_prompt} +) + +TRAINER=( + trainer.logger=['console','wandb'] + trainer.project_name="${project_name}" + trainer.experiment_name="${exp_name}" + trainer.n_gpus_per_node=8 + trainer.nnodes="${NNODES}" + trainer.val_before_train=False + trainer.test_freq=10 + trainer.save_freq=-1 + trainer.total_epochs=10 + trainer.default_local_dir="${CKPTS_DIR}" + trainer.resume_mode=auto + trainer.log_val_generations=10 +) + +FORWARD_ONLY_SETS=( + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} +) + +MODEL=( + actor_rollout_ref.model.path="${MODEL_PATH}" +) + +ALGORITHM=( + algorithm.adv_estimator=${adv_estimator} + algorithm.use_kl_in_reward=${use_kl_in_reward} + algorithm.kl_ctrl.kl_coef=${kl_coef} +) +################################################### start script ################################################### +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + -- python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + "${DATA[@]}" \ + "${ALGORITHM[@]}" \ + "${MODEL[@]}" \ + "${ROLLOUT[@]}" \ + "${ACTOR[@]}" \ + "${REWARD_MODEL[@]}" \ + "${PERF_OPT[@]}" \ + "${TRAINER[@]}" \ + "${GPT_OSS_CONFIG[@]}" \ + "${FORWARD_ONLY_SETS[@]}" \ \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/dapo/test_dapo_qwen3_30b_math.sh b/ICL/DAPO/verl-recipe/dapo/test_dapo_qwen3_30b_math.sh new file mode 100644 index 0000000000000000000000000000000000000000..c6956635f2e9705744eb3c4918b86468f89e494d --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/test_dapo_qwen3_30b_math.sh @@ -0,0 +1,127 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen3-30B-A3B-Base-MATH-0527a1' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-8} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=4 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=4 +fsdp_size=32 + +python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=10 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=300 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/ICL/DAPO/verl-recipe/dapo/test_dapo_qwen3_30b_math_single_node.sh b/ICL/DAPO/verl-recipe/dapo/test_dapo_qwen3_30b_math_single_node.sh new file mode 100644 index 0000000000000000000000000000000000000000..5af2822ea267a601f2d7bde7dd6dd40e67dab876 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/test_dapo_qwen3_30b_math_single_node.sh @@ -0,0 +1,127 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen3-30B-A3B-Base-MATH-0719a1' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 4)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=0.1 + +loss_agg_mode="token-mean" + +train_prompt_bsz=64 +n_resp_per_prompt=16 +train_prompt_mini_bsz=16 + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-1} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=4 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=4 +fsdp_size=8 + +python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=300 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/ICL/DAPO/verl-recipe/dapo/test_dapo_qwen3_moe_30b_megatron_fp16.sh b/ICL/DAPO/verl-recipe/dapo/test_dapo_qwen3_moe_30b_megatron_fp16.sh new file mode 100644 index 0000000000000000000000000000000000000000..f5c85ca22b28d34dc5c2a4785dd566970e95deeb --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/test_dapo_qwen3_moe_30b_megatron_fp16.sh @@ -0,0 +1,148 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +else + return_raw_chat="False" +fi + +dtype="float16" # ["bfloat16", "float16"] + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +train_prompt_bsz=32 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +NNODES=4 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:6379"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-1} + +project_name='DAPO-moe-fp16' +exp_name='qwen3moe_30b_a3b_fp16' + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/checkpoints/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +# Performance Related Parameter +use_dynamic_bsz=False +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +offload=True +gen_tp=8 +train_tp=4 +train_pp=4 +train_ep=8 + + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + reward_model.reward_manager=dapo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.return_raw_chat=$return_raw_chat \ + data.truncation='left' \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.dtype=${dtype} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.actor.megatron.dtype=${dtype} \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.optim.clip_grad=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.actor.megatron.use_mbridge=True \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=5 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/ICL/DAPO/verl-recipe/dapo/test_dapo_qwen3next_80b_megatron.sh b/ICL/DAPO/verl-recipe/dapo/test_dapo_qwen3next_80b_megatron.sh new file mode 100644 index 0000000000000000000000000000000000000000..d74d98f3a685966f1d99d3932b07bdb59ae96083 --- /dev/null +++ b/ICL/DAPO/verl-recipe/dapo/test_dapo_qwen3next_80b_megatron.sh @@ -0,0 +1,232 @@ +#!/usr/bin/env bash +set -xeuo pipefail + + +################################################### document for qwen3next ################################################### + +####################### running environment: ####################### + +# option 1: use pre-built docker images verlai/verl:vll012.exp or verlai/verl:sgl056.exp + +# option 2: self build TE>=2.8, megatron with dev branch and megatron-bridge with main branch + +####################### how we support qwen3next? ####################### +# we support qwen3next with megatron-bridge, which is enabled by set `vanilla_mbridge=False` + +####################### limitations: ####################### +# 1. context parallel(CP) is not supported until this PR is merged: https://github.com/NVIDIA/Megatron-LM/pull/2614 +# 2. sequence packing(aka thd) is not supported, we must set `actor_rollout_ref.actor.megatron.use_remove_padding=False`, until this PR is merged: https://github.com/NVIDIA/Megatron-LM/pull/2644 + +## if sequence packing is disabled, we recommend to set `use_dynamic_bsz=False` and set micro batchsize to 1, +## otherwise the data will be padded to the max length of the batch, which is not efficient. But it's not mandatory + + + + +################################################### quick config ################################################### + +# pip install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@dev # install megatron from dev branch +# pip install --no-deps git+https://github.com/NVIDIA-Nemo/Megatron-Bridge.git # install megatron-bridge from main branch + + +rollout_mode="async" +return_raw_chat="True" +export VLLM_USE_V1=1 +rollout_name="vllm" # sglang or vllm +dtype="bfloat16" + + +project_name='DAPO-test' +exp_name='qwen3next' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=32 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-Next-80B-A3B-Instruct"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=False +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 10 / 10)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1)) +offload=True +gen_tp=16 +train_tp=2 +EP=32 +ETP=1 +train_pp=1 + +################################################### start of config ################################################### + +FP8=( + # # train + # +actor_rollout_ref.actor.megatron.override_transformer_config.fp8="e4m3" # e4m3 or hybrid + # +actor_rollout_ref.actor.megatron.override_transformer_config.fp8_recipe="blockwise" + # +actor_rollout_ref.actor.optim.override_optimizer_config.fp8_recipe="blockwise" + # # rollout + # +actor_rollout_ref.rollout.quantization="fp8" +) + +DATA=( + # dddd + data.train_files="${TRAIN_FILE}" + data.val_files="${TEST_FILE}" + data.prompt_key=prompt + data.return_raw_chat=$return_raw_chat + data.truncation='left' + data.max_prompt_length=${max_prompt_length} + data.max_response_length=${max_response_length} + data.train_batch_size=${train_prompt_bsz} +) + +REWARD_MODEL=( + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False + +reward_model.reward_kwargs.max_resp_len=${max_response_length} + reward_model.reward_manager=dapo +) + +PERF_OPT=( + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True + actor_rollout_ref.actor.megatron.use_remove_padding=False + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full + +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 + actor_rollout_ref.actor.megatron.override_transformer_config.attention_backend=auto + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 + +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True + +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True +) + +ACTOR=( + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} + actor_rollout_ref.actor.clip_ratio_c=10.0 + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} + actor_rollout_ref.actor.optim.lr=1e-6 + actor_rollout_ref.actor.optim.lr_warmup_steps=10 + actor_rollout_ref.actor.optim.weight_decay=0.1 + actor_rollout_ref.actor.optim.clip_grad=1.0 + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} + actor_rollout_ref.actor.megatron.param_offload=${offload} + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} + actor_rollout_ref.actor.megatron.grad_offload=${offload} + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP} + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP} + actor_rollout_ref.actor.entropy_coeff=0 + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} + actor_rollout_ref.actor.megatron.use_mbridge=True + actor_rollout_ref.actor.megatron.vanilla_mbridge=False + actor_rollout_ref.model.use_remove_padding=False +) + +ROLLOUT=( + actor_rollout_ref.rollout.name=${rollout_name} + actor_rollout_ref.rollout.mode=${rollout_mode} + actor_rollout_ref.rollout.dtype=${dtype} + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} + actor_rollout_ref.rollout.enable_chunked_prefill=True + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) + actor_rollout_ref.rollout.temperature=${temperature} + actor_rollout_ref.rollout.top_p=${top_p} + actor_rollout_ref.rollout.top_k=${top_k} + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} + actor_rollout_ref.rollout.val_kwargs.do_sample=True + actor_rollout_ref.rollout.val_kwargs.n=1 + actor_rollout_ref.rollout.calculate_log_probs=True + actor_rollout_ref.rollout.n=${n_resp_per_prompt} +) + +TRAINER=( + trainer.logger=['console','wandb'] + trainer.project_name="${project_name}" + trainer.experiment_name="${exp_name}" + trainer.n_gpus_per_node=8 + trainer.nnodes="${NNODES}" + trainer.val_before_train=False + trainer.test_freq=5 + trainer.save_freq=-1 + trainer.total_epochs=10 + trainer.default_local_dir="${CKPTS_DIR}" + trainer.resume_mode=auto + trainer.log_val_generations=10 +) + +FORWARD_ONLY_SETS=( + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} +) + +MODEL=( + actor_rollout_ref.model.path="${MODEL_PATH}" +) + +ALGORITHM=( + algorithm.adv_estimator=${adv_estimator} + algorithm.use_kl_in_reward=${use_kl_in_reward} + algorithm.kl_ctrl.kl_coef=${kl_coef} +) +################################################### start script ################################################### + +python3 -m verl.trainer.main_ppo \ + --config-path=config \ + --config-name='ppo_megatron_trainer.yaml' \ + "${DATA[@]}" \ + "${ALGORITHM[@]}" \ + "${MODEL[@]}" \ + "${ROLLOUT[@]}" \ + "${ACTOR[@]}" \ + "${REWARD_MODEL[@]}" \ + "${FP8[@]}" \ + "${PERF_OPT[@]}" \ + "${TRAINER[@]}" \ + "${FORWARD_ONLY_SETS[@]}" \ \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/deepeyes/README.md b/ICL/DAPO/verl-recipe/deepeyes/README.md new file mode 100644 index 0000000000000000000000000000000000000000..64f420fc88e2cfea9c0426fd20e77226c8faa6b4 --- /dev/null +++ b/ICL/DAPO/verl-recipe/deepeyes/README.md @@ -0,0 +1,49 @@ +# DeepEyes: Incentivizing "Thinking with Images" via Reinforcement Learning + +This directory contains the implementation for reproducing the DeepEyes paper within the verl framework, supporting multi-turn visual tool calls. This implementation is based on the original [DeepEyes paper](https://arxiv.org/abs/2505.14362) and its [official implementation](https://github.com/Visual-Agent/DeepEyes), integrated with the multi-modal and multi-turn capabilities of the verl framework. + +## Reproducing the Experiment + +> **Note on the 'Chart' Dataset:** +> +> The provided preprocessing script intentionally excludes `data_v0.8_visual_toolbox_v2.parquet`, which contains the 'Chart' data. This subset consists of very high-resolution images, often resembling large figures composed of multiple sub-plots, much like those found in academic papers. +> +> Consequently, even after using the zoom-in tool, the resulting cropped images remain large. This poses a significant risk of causing Out-of-Memory (OOM) errors, which can abruptly terminate the training process. +> +> **We strongly recommend against training on the 'Chart' dataset on a single node.** + +> **Note on the 'thinklite' Dataset:** +> Many images in the `thinklite` dataset have a very low resolution, with either a height or width below 28 pixels. This fails to meet the minimum input size required by the Qwen-2.5VL image processor and would cause errors during data loading. +> +> To mitigate this, we upscale these low-resolution images to satisfy the processor's requirements. However, please be aware that because the original resolution is low, subsequent `crop` operations by the zoom-in tool might frequently trigger exceptions, which could in turn affect the model's tool-use performance. + +First, launch an inference service to act as a judge for reward calculation. You can use the following script as a reference: + +```bash +python -m sglang.launch_server --model-path /path/to/Qwen2.5-72B-Instruct \ + --port 18901 \ + --tp-size 8 \ + --context-length 32768 \ + --trust-remote-code \ + --log-requests false +``` + +Next, you can start the training: + +```bash +bash recipe/deepeyes/run_deepeyes_grpo.sh +``` + +## Performance + +See [Comment](https://github.com/volcengine/verl/pull/2398#issuecomment-3157142856) for more details. + +Note: AgentLoop does not directly record num_tool_calls, but records num_turns. In our scenario, you can calculate the number of tool calls by num_tool_calls = num_turns / 2 - 1. + +## References and Acknowledgements + +- [DeepEyes Paper](https://arxiv.org/abs/2505.14362) +- [DeepEyes Official Implementation](https://github.com/Visual-Agent/DeepEyes) + +--- +If you need further details for reproduction or encounter any issues, feel free to open an issue or contact the maintainers. diff --git a/ICL/DAPO/verl-recipe/deepeyes/configs/image_zoom_in_tool_config.yaml b/ICL/DAPO/verl-recipe/deepeyes/configs/image_zoom_in_tool_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e2802f094aefbebb736f5624c0ea5704e9cc36b7 --- /dev/null +++ b/ICL/DAPO/verl-recipe/deepeyes/configs/image_zoom_in_tool_config.yaml @@ -0,0 +1,26 @@ +tools: + - class_name: "verl.tools.image_zoom_in_tool.ImageZoomInTool" + config: + num_workers: 256 + rate_limit: 256 + timeout: 60 + type: native + tool_schema: + type: "function" + function: + name: "image_zoom_in_tool" + description: "Zoom in on a specific region of an image by cropping it based on a bounding box (bbox) and an optional object label." + parameters: + type: "object" + properties: + bbox_2d: + type: "array" + items: + type: "number" + minItems: 4 + maxItems: 4 + description: "The bounding box of the region to zoom in, as [x1, y1, x2, y2], where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner." + label: + type: "string" + description: "The name or label of the object in the specified bounding box (optional)." + required: ["bbox_2d"] \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/deepeyes/run_deepeyes_grpo.sh b/ICL/DAPO/verl-recipe/deepeyes/run_deepeyes_grpo.sh new file mode 100644 index 0000000000000000000000000000000000000000..3a25332817a7513536e73a68b2afd79a67599a76 --- /dev/null +++ b/ICL/DAPO/verl-recipe/deepeyes/run_deepeyes_grpo.sh @@ -0,0 +1,73 @@ +#!/bin/bash + +set -x + +export LLM_AS_A_JUDGE_BASE="your llm-as-a-judge server/v1" +export WANDB_API_KEY="your wandb key" + +PROJECT_NAME="your_project_name" +EXPERIMENT_NAME="your_experiment_name" + +BASEDIR=base_dir +SAVE_CHECKPOINT_DIR=${BASEDIR}/verl_checkpoints +DATASET_TRAIN=${BASEDIR}/dataset/train.parquet +DATASET_VAL=${BASEDIR}/dataset/val.parquet + +REF_MODEL_PATH=ref_model_path + +PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ + --config-path=${BASEDIR}/recipe/deepeyes/configs \ + --config-name='deepeyes_multiturn_grpo' \ + data.train_files=${DATASET_TRAIN} \ + data.val_files=[${DATASET_VAL}] \ + data.train_batch_size=128 \ + data.max_prompt_length=8192 \ + data.max_response_length=16384 \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + algorithm.adv_estimator=grpo \ + algorithm.kl_ctrl.kl_coef=0.0 \ + actor_rollout_ref.model.path=${REF_MODEL_PATH} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.use_fused_kernels=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0.0 \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','hf_model','optimizer','extra'] \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.rollout.max_num_batched_tokens=32768 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.multi_turn.enable=True \ + actor_rollout_ref.rollout.multi_turn.max_assistant_turns=5 \ + actor_rollout_ref.rollout.multi_turn.max_user_turns=5 \ + actor_rollout_ref.rollout.multi_turn.max_parallel_calls=1 \ + actor_rollout_ref.rollout.multi_turn.tool_config_path=recipe/deepeyes/configs/image_zoom_in_tool_config.yaml \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb','tensorboard'] \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=8 \ + trainer.test_freq=80 \ + trainer.project_name=${PROJECT_NAME} \ + trainer.experiment_name=${EXPERIMENT_NAME} \ + trainer.default_local_dir=${SAVE_CHECKPOINT_DIR}/${PROJECT_NAME}/${EXPERIMENT_NAME} \ + +trainer.tensorboard_dir=${SAVE_CHECKPOINT_DIR}/logs/tensorboard \ + +trainer.rl_logging_board_dir=${SAVE_CHECKPOINT_DIR}/logs/rl_logging_board \ + trainer.total_epochs=1 2>&1 | tee ./logs/${EXPERIMENT_NAME}.log diff --git a/ICL/DAPO/verl-recipe/entropy/32b_clip_cov.sh b/ICL/DAPO/verl-recipe/entropy/32b_clip_cov.sh new file mode 100644 index 0000000000000000000000000000000000000000..addbb9128c55ce5123ef2c9c18fd3bda9ee988bc --- /dev/null +++ b/ICL/DAPO/verl-recipe/entropy/32b_clip_cov.sh @@ -0,0 +1,149 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export WANDB_API_KEY=YOUR_WANDB_API_KEY +# export VLLM_USE_V1=1 + +project_name='Qwen2.5-32B' +exp_name='clipcov' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=1 +clip_ratio_high=1 +clip_cov_ratio=0.0002 +clip_cov_lb=1.0 +clip_cov_ub=5.0 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +loss_mode="clip_cov" +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +train_prompt_mini_bsz=32 +n_resp_per_prompt=8 +max_token=20480 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} +CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} +TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} +TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +ppo_kl_coef=1 +kl_cov_ratio=0.02 + +# Mathematically equivalent +use_dynamic_bsz=True +infer_micro_batch_size=null +train_micro_batch_size=null +offload=False + +HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.filter_overlong_prompts=False \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \ + actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \ + actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.mode=sync \ + actor_rollout_ref.rollout.name=vllm \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.weight_decay=0 \ + actor_rollout_ref.actor.optim.lr_scheduler_type=constant \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.clip_cov_ratio=${clip_cov_ratio} \ + actor_rollout_ref.actor.clip_cov_lb=${clip_cov_lb} \ + actor_rollout_ref.actor.clip_cov_ub=${clip_cov_ub} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=False \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=4 \ + trainer.save_freq=32 \ + trainer.total_epochs=1000 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable diff --git a/ICL/DAPO/verl-recipe/entropy/32b_kl_cov.sh b/ICL/DAPO/verl-recipe/entropy/32b_kl_cov.sh new file mode 100644 index 0000000000000000000000000000000000000000..ad640d64561b3697b467a9d42056ba083a0b3979 --- /dev/null +++ b/ICL/DAPO/verl-recipe/entropy/32b_kl_cov.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export WANDB_API_KEY=YOUR_WANDB_API_KEY +# export VLLM_USE_V1=1 + +project_name='Qwen2.5-32B' +exp_name='klcov' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +loss_mode="kl_cov" +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +train_prompt_mini_bsz=32 +n_resp_per_prompt=8 +max_token=20480 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} +CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} +TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} +TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +ppo_kl_coef=1 +kl_cov_ratio=0.0002 + +# Mathematically equivalent +use_dynamic_bsz=True +infer_micro_batch_size=null +train_micro_batch_size=null +offload=False + +HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.filter_overlong_prompts=False \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \ + actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.mode=sync \ + actor_rollout_ref.rollout.name=vllm \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.weight_decay=0 \ + actor_rollout_ref.actor.optim.lr_scheduler_type=constant \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=False \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=4 \ + trainer.save_freq=32 \ + trainer.total_epochs=1000 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable diff --git a/ICL/DAPO/verl-recipe/entropy/32b_kl_cov_mininbsz.sh b/ICL/DAPO/verl-recipe/entropy/32b_kl_cov_mininbsz.sh new file mode 100644 index 0000000000000000000000000000000000000000..10dd223bf1792122270e55a857de3d771afe6f86 --- /dev/null +++ b/ICL/DAPO/verl-recipe/entropy/32b_kl_cov_mininbsz.sh @@ -0,0 +1,142 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export WANDB_API_KEY=YOUR_WANDB_API_KEY +# export VLLM_USE_V1=1 + +project_name='Qwen2.5-32B' +exp_name='klcov' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +loss_mode="kl_cov" +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +train_prompt_mini_bsz=16 +n_resp_per_prompt=8 +max_token=20480 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} +CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} +TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} +TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +ppo_kl_coef=1 +kl_cov_ratio=0.0002 + +# Mathematically equivalent +use_dynamic_bsz=True +infer_micro_batch_size=null +train_micro_batch_size=null +offload=False + +HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.filter_overlong_prompts=False \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \ + actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.mode=sync \ + actor_rollout_ref.rollout.name=vllm \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.weight_decay=0 \ + actor_rollout_ref.actor.optim.lr_scheduler_type=constant \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=False \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=4 \ + trainer.save_freq=32 \ + trainer.total_epochs=1000 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable diff --git a/ICL/DAPO/verl-recipe/entropy/7b_clip_cov.sh b/ICL/DAPO/verl-recipe/entropy/7b_clip_cov.sh new file mode 100644 index 0000000000000000000000000000000000000000..8fd7d7f6cd52287df2a2ad61d45917ffa4399d5c --- /dev/null +++ b/ICL/DAPO/verl-recipe/entropy/7b_clip_cov.sh @@ -0,0 +1,146 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export WANDB_API_KEY=YOUR_WANDB_API_KEY +# export VLLM_USE_V1=1 + +project_name='Qwen2.5-7B' +exp_name='clipcov' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=1 +clip_ratio_high=1 +clip_cov_ratio=0.0002 +clip_cov_lb=1.0 +clip_cov_ub=5.0 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +loss_mode="clip_cov" +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +train_prompt_mini_bsz=32 +n_resp_per_prompt=8 +max_token=30720 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} +CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} +TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} +TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +ppo_kl_coef=1 +kl_cov_ratio=0.2 + +# Mathematically equivalent +use_dynamic_bsz=True +infer_micro_batch_size=null +train_micro_batch_size=null +offload=False + +HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.filter_overlong_prompts=False \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \ + actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \ + actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.mode=sync \ + actor_rollout_ref.rollout.name=vllm \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.weight_decay=0 \ + actor_rollout_ref.actor.optim.lr_scheduler_type=constant \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=False \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=4 \ + trainer.save_freq=32 \ + trainer.total_epochs=1000 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable diff --git a/ICL/DAPO/verl-recipe/entropy/7b_kl_cov.sh b/ICL/DAPO/verl-recipe/entropy/7b_kl_cov.sh new file mode 100644 index 0000000000000000000000000000000000000000..1ac215f2b58f51434b321a4ee95953b5422e683e --- /dev/null +++ b/ICL/DAPO/verl-recipe/entropy/7b_kl_cov.sh @@ -0,0 +1,142 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +export WANDB_API_KEY=YOUR_WANDB_API_KEY +# export VLLM_USE_V1=1 + +project_name='Qwen2.5-7B' +exp_name='klcov' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 2)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +loss_mode="kl_cov" +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +train_prompt_mini_bsz=32 +n_resp_per_prompt=8 +max_token=30720 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} +CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} +TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} +TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +ppo_kl_coef=1 +kl_cov_ratio=0.002 + +# Mathematically equivalent +use_dynamic_bsz=True +infer_micro_batch_size=null +train_micro_batch_size=null +offload=False + +HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.filter_overlong_prompts=False \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \ + actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.mode=sync \ + actor_rollout_ref.rollout.name=vllm \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.weight_decay=0 \ + actor_rollout_ref.actor.optim.lr_scheduler_type=constant \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=False \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=4 \ + trainer.save_freq=32 \ + trainer.total_epochs=1000 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=disable diff --git a/ICL/DAPO/verl-recipe/entropy/README.md b/ICL/DAPO/verl-recipe/entropy/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5238cec84bbe9b28fac75e97a100341bfa2e1267 --- /dev/null +++ b/ICL/DAPO/verl-recipe/entropy/README.md @@ -0,0 +1,110 @@ +
+ +# The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning. + +[![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2505.22617) [![Github](https://img.shields.io/badge/PRIME-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL) [![alphaXiv](https://img.shields.io/badge/discussion-A42C25?style=for-the-badge&logo=arxiv&logoColor=white&color=blue +)](https://www.alphaxiv.org/abs/2505.22617) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/stingning/status/1928088554166505667) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/charlesfornlp/status/1928089451080585283) [![Twitter-ak](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/_akhaliq/status/1928077929105268861) + + + + +
+ + +# 🎉News + +- **[2025/05/29]** 🎉 Ranked **#1** of the day on [Huggingface Daily Papers](https://huggingface.co/papers?date=2025-05-29). +- **[2025/05/29]** Released our Paper on arXiv. See [here](https://arxiv.org/pdf/2505.22617). We provide insights into the entropy mechanism of RL for LLMs and propose two simple yet effective strategies to alleviate the entropy collapse. + + + +# ✨Getting started + +After preparing the training data, for training Qwen2.5-7B on a single node, taking the KL-Cov approach as an example, you can simply run: + +``` +cd verl +conda activate your_env +bash recipe/dapo/7b_kl_cov.sh +``` + +While for training Qwen2.5-32B on multi nodes, you can run the following commands: + +``` +cd verl +conda activate your_env +bash recipe/dapo/32b_kl_cov.sh +``` + +# 📖Introduction + +
+ issue +
+ +This paper addresses the entropy collapse issue in scaling reinforcement learning (RL) for large language models (LLMs), where policy entropy drops sharply during training, leading to overconfidence and performance saturation. We empirically establish a relationship between entropy ($H$) and performance ($R$): $R=−aexp(H)+b$, showing performance is bottlenecked by entropy exhaustion. + +
+ issue +
+ +Theoretically, we find entropy changes are driven by the covariance between action probability and logit updates, which correlates with advantage in Policy Gradient methods. High-probability, high-advantage actions reduce entropy, while rare, high-advantage actions increase it. Empirically, the covariance term remains positive, explaining entropy’s monotonic decline. To mitigate this, we propose ​​Clip-Cov​​ and ​​KL-Cov​​, which restrict updates for high-covariance tokens. These methods effectively prevent entropy collapse, and improve performance. + +# 📃Evaluation + +
+ issue +
+ + +Our method is able to maintain a considerably higher level of entropy throughout training. For example, when the baseline's entropy reaches a plateau and can no longer be consumed, the KL-Cov method still sustains an entropy level over 10 times higher. Meanwhile, the response length of the policy model steadily increases, and its performance on the test set consistently surpasses that of the baseline. This indicates that our model is able to explore more freely during training, learning better policy through RL. +| **Method** | **AIME24** | **AIME25** | **AMC** | **MATH-500** | **OMNI-MATH** | **OlympiadBench** | **Minerva** | **Avg.** | +| ----------------- | ---------: | ---------: | -------: | -----------: | ------------: | ----------------: | ----------: | -------: | +| *Qwen2.5-7B* | | | | | | | | | +| GRPO | 21.2 | 9.6 | 58.7 | 78.8 | 27.9 | 40.7 | 36.7 | 38.6 | +| w. Clip-higher | 18.1 | 11.5 | 56.6 | 79.2 | 29.8 | 43.3 | 40.4 | 38.8 | +| w. **`CLIP-Cov`** | 22.1 | **15.8** | 58.2 | 80.4 | **30.5** | **44.1** | **41.1** | 40.4 | +| w. **`KL-Cov`** | **22.6** | 12.9 | **61.4** | **80.8** | 29.1 | 42.6 | 38.2 | **40.6** | +| *Qwen2.5-32B* | | | | | | | | | +| GRPO | 21.8 | 16.2 | 69.7 | 84.2 | 35.2 | 43.6 | 45.5 | 45.8 | +| w. Clip-higher | 35.6 | 22.3 | 69.5 | 77.2 | 35.1 | 42.5 | 43.0 | 47.2 | +| w. **`CLIP-Cov`** | 32.3 | 22.7 | 67.2 | **87.0** | **42.0** | **57.2** | 46.0 | 50.3 | +| w. **`KL-Cov`** | **36.8** | **30.8** | **74.5** | 84.6 | 39.1 | 49.0 | **46.3** | **52.2** | + +Our two approaches both achieve non-trivial improvements across all benchmarks. Compared to GRPO, our method outperforms it by 2.0% on average for the 7B model and by 6.4% for the 32B model. Moreover, we observe that our method yields more substantial gains on the larger Qwen2.5-32B. Specifically, our method achieves improvements of 15.0% and 14.6% compared to GRPO on the most challenging benchmarks, AIME24 and AIME25, respectively. + + +# 🎈Citation +If you find this paper or repo helpful, please cite us. + +```bibtex +@article{cui2025entropy, + title={The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models}, + author={Cui, Ganqu and Zhang, Yuchen and Chen, Jiacheng and Yuan, Lifan and Wang, Zhi and Zuo, Yuxin and Li, Haozhan and Fan, Yuchen and Chen, Huayu and Chen, Weize and others}, + journal={arXiv preprint arXiv:2505.22617}, + year={2025} +} +``` +# 🌻Acknowledgement +We implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) for inference. Our models are trained primarily on [Qwen2.5 family](https://github.com/QwenLM/Qwen2.5). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions! + +# 📬 Contact + +For questions, discussion, or collaboration opportunities, feel free to contact: +- Ganqu Cui: cuiganqu@pjlab.org.cn +- Yuchen Zhang: yuchen.zhang2003@gmail.com +- Jiacheng Chen: jackchan9345@gmail.com +- Ning Ding: ningding.cs@gmail.com + diff --git a/ICL/DAPO/verl-recipe/entropy/entropy_ray_trainer.py b/ICL/DAPO/verl-recipe/entropy/entropy_ray_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa712794e4ae0ef994b51362abd4d4f28df578d --- /dev/null +++ b/ICL/DAPO/verl-recipe/entropy/entropy_ray_trainer.py @@ -0,0 +1,357 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FSDP PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import uuid +from collections import defaultdict +from copy import deepcopy +from pprint import pprint + +import numpy as np +import torch +from tqdm import tqdm + +from verl import DataProto +from verl.trainer.ppo.metric_utils import compute_data_metrics, compute_throughout_metrics, compute_timing_metrics +from verl.trainer.ppo.ray_trainer import ( + AdvantageEstimator, + RayPPOTrainer, + apply_kl_penalty, + compute_advantage, + compute_response_mask, +) +from verl.trainer.ppo.reward import compute_reward +from verl.utils.metric import reduce_metrics +from verl.utils.profiler import simple_timer + + +class RayEntropyTrainer(RayPPOTrainer): + """ + Note that this trainer runs on the driver process on a single CPU/GPU node. + """ + + def compute_kl_related_metrics(self, batch: DataProto, timing_raw: dict): + batch.batch["response_mask"] = compute_response_mask(batch) + + # recompute old_log_probs + with simple_timer("old_log_prob", timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with simple_timer("ref", timing_raw): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + return batch + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + + timing_raw = defaultdict(float) + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + + new_batch: DataProto = DataProto.from_single_dict(batch_dict) + num_gen_batches += 1 + # pop those keys for generation + if "multi_modal_inputs" in new_batch.non_tensor_batch.keys(): + gen_batch = new_batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"], + ) + else: + gen_batch = new_batch.pop( + batch_keys=["input_ids", "attention_mask", "position_ids"], + non_tensor_batch_keys=["raw_prompt_ids"], + ) + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + is_last_step = self.global_steps >= self.total_training_steps + + with simple_timer("step", timing_raw): + # generate a batch + with simple_timer("gen", timing_raw): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output) + else: + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + with simple_timer("gen_max", timing_raw): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + + new_batch = new_batch.union(gen_baseline_output) + # compute reward model score on new_batch + rm_scores = None + if self.use_rm and "rm_scores" not in new_batch.batch.keys(): + rm_scores = self.rm_wg.compute_rm_score(new_batch) + new_batch = new_batch.union(rm_scores) + reward_baseline_tensor, _ = compute_reward(new_batch, self.reward_fn) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + keys_to_pop = set(gen_baseline_output.batch.keys()) + if rm_scores is not None: + keys_to_pop.update(rm_scores.batch.keys()) + new_batch.pop(batch_keys=list(keys_to_pop)) + + new_batch.batch["reward_baselines"] = reward_baseline_tensor + + del rm_scores, gen_baseline_batch, gen_baseline_output + + new_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object + ) + # repeat to align with repeated responses in rollout + new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + new_batch = new_batch.union(gen_batch_output) + + if self.config.algorithm.use_kl_in_reward: + # We need these metrics for apply_kl_penalty if using kl in reward + new_batch = self.compute_kl_related_metrics(new_batch, timing_raw) + # otherwise, we will compute those after dynamic sampling + + with simple_timer("reward", timing_raw): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm and "rm_scores" not in new_batch.batch.keys(): + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(new_batch) + new_batch = new_batch.union(reward_tensor) + + # we combine with rule-based rm + reward_tensor, reward_extra_infos_dict = compute_reward(new_batch, self.reward_fn) + + new_batch.batch["token_level_scores"] = reward_tensor + + print(f"{list(reward_extra_infos_dict.keys())=}") + if reward_extra_infos_dict: + new_batch.non_tensor_batch.update( + {k: np.array(v) for k, v in reward_extra_infos_dict.items()} + ) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + new_batch, kl_metrics = apply_kl_penalty( + new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update( + kl_metrics + ) # TODO: This will be cleared if we use multiple genenration batches + else: + new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] + + if not self.config.algorithm.filter_groups.enable: + batch = new_batch + else: # NOTE: When prompts after filtering is less than train batch size, + # we skip to the next generation batch + metric_name = self.config.algorithm.filter_groups.metric + if metric_name == "seq_final_reward": + # Turn to numpy for easier filtering + new_batch.non_tensor_batch["seq_final_reward"] = ( + new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() + ) + elif metric_name == "seq_reward": + new_batch.non_tensor_batch["seq_reward"] = ( + new_batch.batch["token_level_scores"].sum(dim=-1).numpy() + ) + + # Collect the sequence reward for each trajectory + prompt_uid2metric_vals = defaultdict(list) + for uid, metric_val in zip( + new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True + ): + prompt_uid2metric_vals[uid].append(metric_val) + + prompt_uid2metric_std = {} + for prompt_uid, metric_vals in prompt_uid2metric_vals.items(): + prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) + + kept_prompt_uids = [ + uid + for uid, std in prompt_uid2metric_std.items() + if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 + ] + num_prompt_in_batch += len(kept_prompt_uids) + + kept_traj_idxs = [] + for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]): + if traj_from_prompt_uid in kept_prompt_uids: + kept_traj_idxs.append(idx) + + new_batch = new_batch[kept_traj_idxs] + batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) + + prompt_bsz = self.config.data.train_batch_size + if num_prompt_in_batch < prompt_bsz: + print(f"{num_prompt_in_batch=} < {prompt_bsz=}") + max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: + print(f"{num_gen_batches=}. Keep generating...") + continue + else: + raise ValueError( + f"{num_gen_batches=} >= {max_num_gen_batches=}." + + " Generated too many. Please check if your data are too difficult." + + " You could also try set max_num_gen_batches=0 to enable endless trials." + ) + else: + # Align the batch + traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + print( + f"Collected {num_prompt_in_batch} / {self.config.data.train_batch_size} prompt. " + f"Collecting finished." + ) + batch = batch[:traj_bsz] + + # === Updating === + # balance the number of valid tokens on each dp rank. + # Note that this breaks the order of data inside the batch. + # Please take care when you implement group based adv computation such as GRPO and rloo + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + if not self.config.algorithm.use_kl_in_reward: + batch = self.compute_kl_related_metrics(batch, timing_raw) + + # compute values + if self.use_critic: + with simple_timer("values", timing_raw): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with simple_timer("adv", timing_raw): + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + ) + + # update critic + if self.use_critic: + with simple_timer("update_critic", timing_raw): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with simple_timer("update_actor", timing_raw): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with simple_timer("testing", timing_raw): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with simple_timer("save_checkpoint", timing_raw): + self._save_checkpoint() + + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + timing_raw = defaultdict(float) # clear timing + + metrics["train/num_gen_batches"] = num_gen_batches + batch = None + num_prompt_in_batch = 0 + num_gen_batches = 0 + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + progress_bar.update(1) + self.global_steps += 1 diff --git a/ICL/DAPO/verl-recipe/entropy/main_entropy.py b/ICL/DAPO/verl-recipe/entropy/main_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..f47ef70f0722a9133f82185962473e0a8f09ebeb --- /dev/null +++ b/ICL/DAPO/verl-recipe/entropy/main_entropy.py @@ -0,0 +1,259 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import hydra +import ray +from omegaconf import OmegaConf + +from .entropy_ray_trainer import RayEntropyTrainer +from .reward import load_reward_manager + + +@hydra.main(config_path="config", config_name="entropy_trainer", version_base=None) +def main(config): + run_ppo(config) + + +def run_ppo(config) -> None: + if not ray.is_initialized(): + # this is for local ray cluster + default_runtime_env = { + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "WARN", + } + } + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + +def merge_dict(a: dict, b: dict) -> dict: + """Return a new dict that has `a` updated with `b` (b wins on conflicts). + + Example:: + + >>> d1 = {"x": 1, "y": 2} + >>> d2 = {"y": 20, "z": 3} + >>> new_dict = merge_dict(d1, d2) + >>> print(new_dict) # {'x': 1, 'y': 20, 'z': 3} + >>> print(d1) # {"x": 1, "y": 2} (unchanged) + >>> print(d2) # {"y": 20, "z": 3} (unchanged) + """ + return a | b + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + def run(self, config): + # print initial config + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_to_local(config.actor_rollout_ref.model.path) + print(f"{config.actor_rollout_ref.model.path}") + # instantiate tokenizer + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none + + # define worker classes + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + + actor_rollout_cls = ActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(actor_rollout_cls), + Role.Critic: ray.remote(CriticWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + Role.Critic: global_pool_id, + } + + # we should adopt a multi-source reward function here + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # - finally, we combine all the rewards together + # - The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # use reference model + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + reward_kwargs = { + "max_resp_len": config.data.max_response_length, + "overlong_buffer_cfg": config.reward_model.overlong_buffer, + } + cfg_reward_kwargs = config.reward_model.get("reward_kwargs", {}) + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **OmegaConf.merge(OmegaConf.create(reward_kwargs), cfg_reward_kwargs) + ) + val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **reward_kwargs) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + from verl.utils.dataset.rl_dataset import collate_fn + + train_dataset = create_rl_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + max_samples=config.data.get("train_max_samples", -1), + ) + val_dataset = create_rl_dataset( + config.data.val_files, config.data, tokenizer, processor, max_samples=config.data.get("val_max_samples", -1) + ) + train_sampler = create_rl_sampler(config.data, train_dataset) + trainer = RayEntropyTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + trainer.init_workers() + trainer.fit() + + +def create_rl_dataset(data_paths, data_config, tokenizer, processor, max_samples: int = -1): + """Create a dataset. + + Arguments: + data_config: The data config. + tokenizer (Tokenizer): The tokenizer. + processor (Processor): The processor. + + Returns: + dataset (Dataset): The dataset. + """ + from torch.utils.data import Dataset + + from verl.utils.dataset.rl_dataset import RLHFDataset + + if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: + from verl.utils.import_utils import load_extern_object + + dataset_cls = load_extern_object(data_config.custom_cls.path, data_config.custom_cls.name) + if not issubclass(dataset_cls, Dataset): + raise TypeError( + f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' " + f"must inherit from torch.utils.data.Dataset" + ) + else: + dataset_cls = RLHFDataset + print(f"Using dataset class: {dataset_cls.__name__}") + + dataset = dataset_cls( + data_files=data_paths, + tokenizer=tokenizer, + processor=processor, + config=data_config, + max_samples=max_samples, + ) + + return dataset + + +def create_rl_sampler(data_config, dataset): + """Create a sampler for the dataset. + + Arguments: + data_config: The data config. + dataset (Dataset): The dataset. + + Returns: + sampler (Sampler): The sampler. + """ + import torch + from torch.utils.data import RandomSampler, SequentialSampler + + # use sampler for better ckpt resume + if data_config.shuffle: + train_dataloader_generator = torch.Generator() + seed = data_config.get("seed") + if seed is not None: + train_dataloader_generator.manual_seed(seed) + sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + else: + sampler = SequentialSampler(data_source=dataset) + + return sampler + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/entropy/reward.py b/ICL/DAPO/verl-recipe/entropy/reward.py new file mode 100644 index 0000000000000000000000000000000000000000..36b8b65a4d2d7aa2e5977a1214e3ef5c4f9e4b4a --- /dev/null +++ b/ICL/DAPO/verl-recipe/entropy/reward.py @@ -0,0 +1,86 @@ +# Copyright 2025 Individual Contributor: Thibaut Barroyer +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +from functools import partial + +import ray + +from verl import DataProto +from verl.trainer.ppo.reward import compute_reward, get_custom_reward_fn + +from .reward_score import _default_compute_score + + +def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): + """ + Load and initialize a reward manager based on the configuration. + + Args: + config: PPO trainer configuration object containing reward_model fields. + tokenizer: Tokenizer object used for processing text. + num_examine: Number of samples to examine. + **reward_kwargs: Additional keyword arguments for the reward manager. + + Returns: + An instance of the specified reward manager class. + """ + from verl.workers.reward_manager import get_reward_manager_cls + + # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`: + # naive: NaiveRewardManager + # prime: PrimeRewardManager + # batch: BatchRewardManager + # dapo: DAPORewardManager + # Note(haibin.lin): For custom reward managers, please make sure they are imported and + # registered via `verl.workers.reward_manager.register` + # By default reward_manager is set to naive (NaiveRewardManager) + reward_manager_name = config.reward_model.get("reward_manager", "naive") + reward_manager_cls = get_reward_manager_cls(reward_manager_name) + + # Try to get a custom reward function based on the configuration + compute_score = get_custom_reward_fn(config) + final_compute_score = compute_score + + if compute_score is None: + sandbox_config = config.reward_model.get("sandbox_fusion") + sandbox_url = sandbox_config.get("url") if sandbox_config else None + if sandbox_url: + sandbox_manager = multiprocessing.Manager() + # Create a semaphore to control concurrent access to the sandbox + _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) + final_compute_score = partial( + _default_compute_score, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore + ) + else: + final_compute_score = _default_compute_score + + # Instantiate and return the reward manager with the specified parameters + return reward_manager_cls( + tokenizer=tokenizer, + num_examine=num_examine, + compute_score=final_compute_score, + reward_fn_key=config.data.reward_fn_key, + **reward_kwargs, + ) + + +@ray.remote(num_cpus=1) +def compute_reward_async(data: DataProto, config, tokenizer): + """ + Load the reward manager and compute the reward for a batch of data. + This is meant to be run in a separate Ray worker. + """ + reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})) + return compute_reward(data, reward_fn) diff --git a/ICL/DAPO/verl-recipe/fapo/README.md b/ICL/DAPO/verl-recipe/fapo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4401bbc4f7a573e608ca54a920aecc6f0e9a1607 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fapo/README.md @@ -0,0 +1,89 @@ +

+

FAPO: Flawed-Aware Policy Optimization for Efficient and Reliable Reasoning

+ +

+ Project Page + Infra Design + Resources + Paper + Code +

+ +- **Algorithm Insights:** Visit our [Project Page](https://fapo-rl.github.io/) for an overview; comprehensive details are available in the [Paper](). +- **Infrastructure Design:** Refer to the [Reward Loop](https://verl.readthedocs.io/en/latest/advance/reward_loop.html) document for architectural insights. +- **Open-Source Software:** Explore the [Huggingface Collections](https://huggingface.co/collections/dyyyyyyyy/fapo) for datasets and models. + + +![fapo-result](https://fapo-rl.github.io/_astro/intro_main.DKe72RHX_1Us2HB.webp) + +## Step 1: Train FAPO-GenRM-4B (Generative Reward Model) + +We provide our training and evaluation datasets [here](https://huggingface.co/datasets/dyyyyyyyy/FAPO-Critic). +Directly download them to `${RAY_DATA_HOME}/data/`. + +Then, submit the training job to the ray cluster: + +```bash +cd verl # Repo root +export RAY_ADDRESS="..." # The Ray cluster address to connect to +export RAY_DATA_HOME="..." # The directory to store the data +export WORKING_DIR="${PWD}" # The local directory to package to the Ray cluster +# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml +export RUNTIME_ENV="./recipe/fapo/runtime_env.yaml" # This sets environment variables for the Ray cluster +bash recipe/fapo/run_fapo_genrm_train.sh +``` + +You can skip this step if you want to use the pre-trained FAPO-GenRM-4B model available [here](https://huggingface.co/dyyyyyyyy/FAPO-GenRM-4B). + +## Step 2: Integrate the GRM into the Final Training + +Our training data is identical to that of DAPO-Math-17K, except that we replace the instruction with "Put the final answer in \boxed{}", which is a common practice for current instruct models. + +You can construct the training and evaluation datasets by: +```bash +python recipe/fapo/prepare_fapo_data.py --local_dir ${RAY_DATA_HOME}/data/ +``` + +Or you can directly use the data available [here](https://huggingface.co/datasets/dyyyyyyyy/FAPO-Reasoning-Dataset). + +To integrate the GRM into the final training, we provide two options: + +1. **Launch GRM as an external service:** Launch multiple model servers and a router in advance to handle and dispatch incoming requests. Refer to `verl/recipe/genrm_remote` for more details. The scripts is `verl/recipe/fapo/run_fapo_{7b/32b}_remote.sh`. +2. **Launch GRM in verl single controller:** Start the GRM model directly inside the verl single controller with an integrated router. (Note: this feature is still unstable for large-scale training scenarios.) + +```bash +cd verl # Repo root +export RAY_ADDRESS="..." # The Ray cluster address to connect to +export WORKING_DIR="${PWD}" # The local directory to package to the Ray cluster +# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml +export RUNTIME_ENV="./recipe/fapo/runtime_env.yaml" # This sets environment variables for the Ray cluster + +# run Baseline Models +bash recipe/fapo/run_baseline_7b.sh # 7b baseline model +bash recipe/fapo/run_baseline_32b.sh # 32b baseline model + +# run FAPO Models (with external GRM service) +# Note that you should launch the external GRM service first, +# and specify the router address in the compute_score function +bash recipe/fapo/run_fapo_7b_remote.sh # 7b fapo model +bash recipe/fapo/run_fapo_32b_remote.sh # 32b fapo model + +# run FAPO Models (single controller mode) +bash recipe/fapo/run_fapo_7b.sh # 7b fapo model +bash recipe/fapo/run_fapo_32b.sh # 32b fapo model +``` + +## Infrastructure Design + +We implement RewardLoop to enable efficient and flexible reward computation. +The core implementation can be found in `verl/experimental/reward/`. +Refer to [this official document](https://verl.readthedocs.io/en/latest/advance/reward_loop.html) for more implementation details. + +```bibtex +@article{ding2025fapo, + title={FAPO: Flawed-Aware Policy Optimization for Efficient and Reliable Reasoning}, + author={Ding, Yuyang and Zhang, Chi and Li, Juntao and Lin, Haibin and Liu, Xin and Zhang, Min}, + journal={arXiv preprint arXiv:2510.22543}, + year={2025} +} +``` \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/fapo/prepare_fapo_data.py b/ICL/DAPO/verl-recipe/fapo/prepare_fapo_data.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae9b5152785dadf119225e6e7a2b84898a34256 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fapo/prepare_fapo_data.py @@ -0,0 +1,153 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the dataset to parquet format +""" + +import argparse +import os +from functools import partial + +from datasets import concatenate_datasets, load_dataset + +from verl.utils.hdfs_io import copy, makedirs + + +def example_map_fn(example, idx, process_fn, data_source, ability, split): + question, prompt, ground_truth = process_fn(example) + data = { + "data_source": data_source, + "prompt": [{"role": "user", "content": prompt}], + "ability": ability, + "reward_model": {"style": "rule", "ground_truth": ground_truth}, + "extra_info": {"split": split, "index": idx, "question": question}, + } + return data + + +def build_aime2024_dataset(): + def process_aime2024(example): + question, ground_truth = example["Problem"], str(example["Answer"]) + prompt = question.strip() + "\n\n" + "Please reason step by step, and put your final answer within \\boxed{}." + return question, prompt, ground_truth + + data_source = "Maxwell-Jia/AIME_2024" + print(f"Loading the {data_source} dataset from huggingface...", flush=True) + dataset = load_dataset(data_source, split="train") + map_fn = partial(example_map_fn, process_fn=process_aime2024, data_source="aime24", ability="Math", split="test") + dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) + return dataset + + +def build_aime2025_dataset(): + def process_aime2025(example): + question, ground_truth = example["problem"], str(example["solution"]) + prompt = question.strip() + "\n\n" + "Please reason step by step, and put your final answer within \\boxed{}." + return question, prompt, ground_truth + + data_source = "yentinglin/aime_2025" + print(f"Loading the {data_source} dataset from huggingface...", flush=True) + dataset = load_dataset(data_source, split="train") + map_fn = partial(example_map_fn, process_fn=process_aime2025, data_source="aime25", ability="Math", split="test") + dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) + return dataset + + +def build_gpqa_diamond_dataset(): + import random + + GPQA_QUERY_TEMPLATE = ( + "{Question}\n" + "A. {A}\nB. {B}\nC. {C}\nD. {D}\n\n" + "Please reason step by step, and put your final answer (only the choice letter) within \\boxed{{}}." + ) + + def process_gpqa_diamond(example): + choices = [ + example["Incorrect Answer 1"].strip(), + example["Incorrect Answer 2"].strip(), + example["Incorrect Answer 3"].strip(), + ] + random.shuffle(choices) + gold_index = random.randint(0, 3) + choices.insert(gold_index, example["Correct Answer"].strip()) + question = example["Question"] + query_prompt = GPQA_QUERY_TEMPLATE.format( + A=choices[0], + B=choices[1], + C=choices[2], + D=choices[3], + Question=question, + ) + gold_choice = "ABCD"[gold_index] + return question, query_prompt, gold_choice + + data_source = "Idavidrein/gpqa" + print(f"Loading the {data_source} dataset from huggingface...", flush=True) + + dataset = load_dataset(data_source, "gpqa_diamond", split="train") + map_fn = partial( + example_map_fn, process_fn=process_gpqa_diamond, data_source="gpqa-diamond", ability="General", split="test" + ) + dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) + return dataset + + +def build_dapo_train_dataset(): + def process_dapo(example): + question, ground_truth = example["prompt"], example["solution"] + prompt = question.strip() + "\n\n" + "Please reason step by step, and put your final answer within \\boxed{}." + return question, prompt, ground_truth + + data_source = "open-r1/DAPO-Math-17k-Processed" + print(f"Loading the {data_source} dataset from huggingface...", flush=True) + dataset = load_dataset(data_source, "all", split="train") + map_fn = partial(example_map_fn, process_fn=process_dapo, data_source="math-dapo", ability="Math", split="train") + dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) + return dataset + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default="~/data/genrm") + parser.add_argument("--hdfs_dir", default=None) + parser.add_argument("--tasks", default="all") + + args = parser.parse_args() + + train_dataset = build_dapo_train_dataset() + train_dataset = concatenate_datasets([train_dataset for _ in range(20)]) + + test_datasets = [] + # AIME 2024 + aime24_dataset = build_aime2024_dataset() + test_datasets.extend([aime24_dataset for _ in range(32)]) + # AIME 2025 + aime25_dataset = build_aime2025_dataset() + test_datasets.extend([aime25_dataset for _ in range(32)]) + # GPQA Diamond + gpqa_dataset = build_gpqa_diamond_dataset() + test_datasets.extend([gpqa_dataset for _ in range(4)]) + test_dataset = concatenate_datasets(test_datasets) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + train_dataset.to_parquet(os.path.join(local_dir, "fapo-train-boxed.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "fapo-test-full-boxed.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + + copy(src=local_dir, dst=hdfs_dir) diff --git a/ICL/DAPO/verl-recipe/fapo/reward_fn_genrm.py b/ICL/DAPO/verl-recipe/fapo/reward_fn_genrm.py new file mode 100644 index 0000000000000000000000000000000000000000..cd2c095ae650ba871c6d82c5633a027fc6962776 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fapo/reward_fn_genrm.py @@ -0,0 +1,68 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py + + +from verl.utils.reward_score.math_dapo import last_boxed_only_string, remove_boxed + + +def parse_ans( + solution_str: str, + total_steps: int, +) -> tuple[bool, str]: + try: + boxed_answer = last_boxed_only_string(solution_str[-300:]) + extracted_answer = int(remove_boxed(boxed_answer)) + if extracted_answer == -1 or 0 <= extracted_answer < total_steps: + return extracted_answer + else: + return None + except Exception: + return None + + +def compute_score_fapo_genrm( + solution_str: str, + ground_truth: int, + extra_info: dict, + **kwargs, +) -> float: + # Verify the solution + total_steps = extra_info["total_steps"] + extracted_answer = parse_ans(solution_str, total_steps) + gt = "correct" if ground_truth == -1 else "incorrect" + pred = "correct" if extracted_answer == -1 else "incorrect" + if extracted_answer is None: + pred = "[INVALID]" + acc = gt == pred + # reward = 1.0 if acc else -1.0 + if extracted_answer is None: + reward = -1.0 + elif ground_truth == -1: + reward = 1.0 if extracted_answer == -1 else -1.0 + else: + # ground truth != -1 + if extracted_answer == -1: + reward = -1.0 + else: + # gt != -1, pred != -1 + reward = 1.0 + reward -= abs(extracted_answer - ground_truth) / total_steps + + return { + "score": reward, + "acc": acc, + "pred": extracted_answer, + "gt": ground_truth, + } diff --git a/ICL/DAPO/verl-recipe/fapo/reward_fn_reasoning.py b/ICL/DAPO/verl-recipe/fapo/reward_fn_reasoning.py new file mode 100644 index 0000000000000000000000000000000000000000..ad20a00e26e83994afcce4b9e9f0d812b01455bf --- /dev/null +++ b/ICL/DAPO/verl-recipe/fapo/reward_fn_reasoning.py @@ -0,0 +1,149 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os + +import aiohttp +from transformers import PreTrainedTokenizer + +from verl.utils.ray_utils import get_event_loop +from verl.utils.reward_score.math_dapo import last_boxed_only_string, normalize_final_answer, remove_boxed + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def verify( + solution_str: str, + gt: str, +) -> tuple[bool, str]: + solution_str = solution_str[-300:] + boxed_answer = last_boxed_only_string(solution_str) + if boxed_answer is not None: + extracted_answer = remove_boxed(boxed_answer) + else: + extracted_answer = "[INVALID]" + + pred = normalize_final_answer(extracted_answer) + gt = normalize_final_answer(gt) + return (pred == gt), pred + + +async def compute_score_baseline( + solution_str: str, + ground_truth: str, + **kwargs, +): + loop = get_event_loop() + """Compute the reward score for Baseline.""" + correct, pred = await loop.run_in_executor(None, lambda: verify(solution_str, ground_truth)) + reward_score = 1.0 if correct else -1.0 + return {"score": reward_score, "acc": correct, "pred": pred} + + +# FAPO Hyper-parameters +FAPO_GENRM_TEMPLATE = ( + "The following is a math problem with its ground truth answer, along with an AI solution (split into steps):\n\n" + "[Math Problem]\n\n" + "{problem}\n\n" + "[Ground Truth]\n\n" + "{ground_truth}\n\n" + "[AI Solution]\n\n" + "{solution}\n\n" + "Your task is to review and critique the solution step by step. " + "Once you identify an error in a step, return the index of the step where the earliest error occurs. " + "Otherwise, return the index of -1 (which typically denotes 'not found').\n\n" + "Please reason step by step, put your final answer (i.e., the index) in \\boxed{{}}." +) +GRM_SAMPLING_PARAMS = { + "max_new_tokens": 16384, +} +FLAWED_REWARD_PENALTY = 1.0 + + +async def generate_aiohttp(router_address: str, prompt_ids: list[int], sampling_params: dict): + payload = { + "input_ids": prompt_ids, + "sampling_params": sampling_params, + } + url = f"http://{router_address}/generate" + try: + session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) + async with session.post(url, json=payload) as resp: + output = await resp.text() + try: + output = json.loads(output) + return output + except Exception: + logger.error(f"Failed to parse JSON response: {output}") + return {} + finally: + await session.close() + + +async def compute_score_fapo( + data_source: str, + solution_str: str, + ground_truth: str, + extra_info: dict, + reward_router_address: str, + reward_model_tokenizer: PreTrainedTokenizer, +): + """Compute the reward score for FAPO.""" + loop = get_event_loop() + + question, split = extra_info["question"], extra_info["split"] + correct, pred = await loop.run_in_executor(None, lambda: verify(solution_str, ground_truth)) + reward_score = 1.0 if correct else -1.0 + is_flawed_positive = False + + # for test set or incorrect solution, directly return the reward score + if split == "test" or not correct: + return {"score": reward_score, "acc": correct, "pred": pred, "is_flawed_positive": is_flawed_positive} + + grm_prompt = FAPO_GENRM_TEMPLATE.format( + problem=question, + ground_truth=ground_truth, + solution=solution_str, + ) + grm_prompt_ids = await loop.run_in_executor( + None, + lambda: reward_model_tokenizer.apply_chat_template( + [{"role": "user", "content": grm_prompt}], + tokenize=True, + add_generation_prompt=True, + ), + ) + grm_outputs = await generate_aiohttp( + router_address=reward_router_address, + prompt_ids=grm_prompt_ids, + sampling_params=GRM_SAMPLING_PARAMS, + ) + grm_response_ids = grm_outputs.get("output_ids", None) + if grm_response_ids is not None: + grm_response = await loop.run_in_executor( + None, lambda: reward_model_tokenizer.decode(grm_response_ids, skip_special_tokens=True) + ) + try: + err_location = remove_boxed(last_boxed_only_string(grm_response)) + is_flawed_positive = int(err_location) != -1 + except Exception: + is_flawed_positive = False + + if is_flawed_positive: + reward_score -= FLAWED_REWARD_PENALTY + + return {"score": reward_score, "acc": correct, "pred": pred, "is_flawed_positive": is_flawed_positive} diff --git a/ICL/DAPO/verl-recipe/fapo/reward_fn_reasoning_remote.py b/ICL/DAPO/verl-recipe/fapo/reward_fn_reasoning_remote.py new file mode 100644 index 0000000000000000000000000000000000000000..153cd1bee94fff326477a6a836ca377dc1bc04d8 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fapo/reward_fn_reasoning_remote.py @@ -0,0 +1,134 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import aiohttp + +from verl.utils.reward_score.math_dapo import last_boxed_only_string, normalize_final_answer, remove_boxed + + +def verify( + solution_str: str, + gt: str, +) -> tuple[bool, str]: + boxed_answer = last_boxed_only_string(solution_str) + if boxed_answer is not None: + extracted_answer = remove_boxed(boxed_answer) + else: + extracted_answer = "[INVALID]" + + pred = normalize_final_answer(extracted_answer) + gt = normalize_final_answer(gt) + return (pred == gt), pred + + +def compute_score_baseline( + solution_str: str, + ground_truth: str, + **kwargs, +) -> float: + # Limit solution length for efficiency + solution_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters + + # Verify the solution + correct, pred = verify(solution_str, ground_truth) + + reward = 1.0 if correct else -1.0 + acc = correct + + return { + "score": reward, + "acc": acc, + "pred": pred, + } + + +ADDRESS = "xx.xx.xx.xx:xxxx" +MODEL_NAME = "FAPO-4B-GenRM" +FAPO_GENRM_TEMPLATE = ( + "The following is a math problem with its ground truth answer, along with an AI solution (split into steps):\n\n" + "[Math Problem]\n\n" + "{problem}\n\n" + "[Ground Truth]\n\n" + "{ground_truth}\n\n" + "[AI Solution]\n\n" + "{solution}\n\n" + "Your task is to review and critique the solution step by step. " + "Once you identify an error in a step, return the index of the step where the earliest error occurs. " + "Otherwise, return the index of -1 (which typically denotes 'not found').\n\n" + "Please reason step by step, put your final answer (i.e., the index) in \\boxed{{}}." +) + + +async def chat_completions_aiohttp(address, **chat_complete_request): + try: + request_url = f"http://{address}/v1/chat/completions" + timeout = aiohttp.ClientTimeout(total=None) + session = aiohttp.ClientSession(timeout=timeout) + async with session.post( + url=request_url, + json=chat_complete_request, + ) as resp: + output = await resp.text() + try: + output = json.loads(output) + return output["choices"][0]["message"]["content"] + except Exception as e: + print(f"Error: {e}. Output: {output}") + return "" + finally: + await session.close() + + +def judge_fp_process(response, return_err_step=False): + try: + boxed_result = last_boxed_only_string(response) + result = remove_boxed(boxed_result) + reward_score = int(eval(result)) != -1 + if return_err_step: + return reward_score, int(result) + return reward_score + except Exception: + if return_err_step: + return None, None + return None + + +async def compute_score_fapo(data_source, solution_str, ground_truth, extra_info, keep_genrm_critics=False, **kwargs): + question, split = extra_info["question"], extra_info["split"] + result = compute_score_baseline(solution_str, ground_truth) + result["flawed_positive"] = False + + if split == "test" or result["acc"] == 0: + if keep_genrm_critics: + result["genrm_critics"] = "" + return result + else: + prompt = FAPO_GENRM_TEMPLATE.format(problem=question, ground_truth=ground_truth, solution=solution_str) + messages = [{"role": "user", "content": prompt}] + response = await chat_completions_aiohttp( + ADDRESS, + messages=messages, + model=MODEL_NAME, + max_tokens=16384, + ) + if response is not None and judge_fp_process(response): # flawed positive + result["score"] = 0.0 + result["flawed_positive"] = True + + if keep_genrm_critics and response is not None: + result["genrm_critics"] = response + + return result diff --git a/ICL/DAPO/verl-recipe/fapo/run_baseline_32b.sh b/ICL/DAPO/verl-recipe/fapo/run_baseline_32b.sh new file mode 100644 index 0000000000000000000000000000000000000000..f788066b5c5f6949bbf1a1fcd44b7f9764298780 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fapo/run_baseline_32b.sh @@ -0,0 +1,135 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='FAPO-Reproduce' +exp_name='Baseline-32B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=8 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-8} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/fapo-train-boxed.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/fapo-test-full-boxed.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=True +gen_tp=4 +fsdp_size=32 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --address "${RAY_ADDRESS}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=True \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + custom_reward_function.path=recipe/fapo/reward_fn_reasoning.py \ + custom_reward_function.name=compute_score_baseline \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=600 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/ICL/DAPO/verl-recipe/fapo/run_baseline_7b.sh b/ICL/DAPO/verl-recipe/fapo/run_baseline_7b.sh new file mode 100644 index 0000000000000000000000000000000000000000..77605b1bbacad09829617947c5ce4232383ca614 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fapo/run_baseline_7b.sh @@ -0,0 +1,137 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='FAPO-Reproduce' +exp_name='Baseline-7B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=8 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/fapo-train-boxed.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/fapo-test-full-boxed.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=1 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=1 +fsdp_size=8 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --address "${RAY_ADDRESS}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=True \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + custom_reward_function.path=recipe/fapo/reward_fn_reasoning.py \ + custom_reward_function.name=compute_score_baseline \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=200 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/ICL/DAPO/verl-recipe/fapo/run_fapo_32b.sh b/ICL/DAPO/verl-recipe/fapo/run_fapo_32b.sh new file mode 100644 index 0000000000000000000000000000000000000000..f458070c4a7e0e7dbde685394e0a1e6a4a351e50 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fapo/run_fapo_32b.sh @@ -0,0 +1,146 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='FAPO-Reproduce' +exp_name='FAPO-32B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=8 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-2} +RM_NODES=${RM_NODES:-2} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} +GRM_PATH=${GRM_PATH:-"${RAY_DATA_HOME}/models/FAPO-GenRM-4B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/fapo-train-boxed.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/fapo-test-full-boxed.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=True +gen_tp=4 +fsdp_size=32 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --address "${RAY_ADDRESS}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.enable=True \ + reward_model.enable_resource_pool=True \ + reward_model.n_gpus_per_node=8 \ + reward_model.nnodes="${RM_NODES}" \ + reward_model.model.path=${GRM_PATH} \ + reward_model.rollout.name=sglang \ + reward_model.rollout.gpu_memory_utilization=0.95 \ + reward_model.rollout.tensor_model_parallel_size=1 \ + reward_model.rollout.free_cache_engine=False \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=True \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + custom_reward_function.path=recipe/fapo/reward_fn_reasoning.py \ + custom_reward_function.name=compute_score_fapo \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=600 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/ICL/DAPO/verl-recipe/fapo/run_fapo_32b_remote.sh b/ICL/DAPO/verl-recipe/fapo/run_fapo_32b_remote.sh new file mode 100644 index 0000000000000000000000000000000000000000..8833f109138a47bbfa10f2d762913b175b6e9edb --- /dev/null +++ b/ICL/DAPO/verl-recipe/fapo/run_fapo_32b_remote.sh @@ -0,0 +1,135 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='FAPO-Reproduce' +exp_name='FAPO-32B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=8 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-4} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/fapo-train-boxed.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/fapo-test-full-boxed.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=8 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=True +gen_tp=4 +fsdp_size=32 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --address "${RAY_ADDRESS}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=True \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + custom_reward_function.path=recipe/fapo/reward_fn_reasoning_remote.py \ + custom_reward_function.name=compute_score_fapo \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=600 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/ICL/DAPO/verl-recipe/fapo/run_fapo_7b.sh b/ICL/DAPO/verl-recipe/fapo/run_fapo_7b.sh new file mode 100644 index 0000000000000000000000000000000000000000..96884d94e9be4711f89a0e1734e1606d1c921d83 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fapo/run_fapo_7b.sh @@ -0,0 +1,148 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='FAPO-Reproduce' +exp_name='FAPO-7B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=8 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-2} +RM_NODES=${RM_NODES:-2} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +GRM_PATH=${GRM_PATH:-"${RAY_DATA_HOME}/models/FAPO-GenRM-4B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/fapo-train-boxed.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/fapo-test-full-boxed.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=1 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=1 +fsdp_size=8 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --address "${RAY_ADDRESS}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.enable=True \ + reward_model.enable_resource_pool=True \ + reward_model.n_gpus_per_node=8 \ + reward_model.nnodes="${RM_NODES}" \ + reward_model.model.path=${GRM_PATH} \ + reward_model.rollout.name=sglang \ + reward_model.rollout.gpu_memory_utilization=0.95 \ + reward_model.rollout.tensor_model_parallel_size=1 \ + reward_model.rollout.free_cache_engine=False \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=True \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + custom_reward_function.path=recipe/fapo/reward_fn_reasoning.py \ + custom_reward_function.name=compute_score_fapo \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=200 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/ICL/DAPO/verl-recipe/fapo/run_fapo_7b_remote.sh b/ICL/DAPO/verl-recipe/fapo/run_fapo_7b_remote.sh new file mode 100644 index 0000000000000000000000000000000000000000..663e10c385bcb893bca5251750c4a3e819009eb9 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fapo/run_fapo_7b_remote.sh @@ -0,0 +1,137 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='FAPO-Reproduce' +exp_name='FAPO-7B' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=8 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-2} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/fapo-train-boxed.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/fapo-test-full-boxed.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=1 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=1 +fsdp_size=8 + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --address "${RAY_ADDRESS}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=True \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + custom_reward_function.path=recipe/fapo/reward_fn_reasoning_remote.py \ + custom_reward_function.name=compute_score_fapo \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=-1 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=200 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 diff --git a/ICL/DAPO/verl-recipe/fault_recover/README.md b/ICL/DAPO/verl-recipe/fault_recover/README.md new file mode 100644 index 0000000000000000000000000000000000000000..73646c77c2e101ab476cdf363db4682cb3a8ca66 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fault_recover/README.md @@ -0,0 +1,69 @@ +# Recipe: Tokens saving and Auto Fault Recover for Rollout + +## Design + +[RFC](https://github.com/volcengine/verl/discussions/4355) + +## Solution + +![Req Resume Performance](https://github.com/user-attachments/assets/58127f8f-c3f8-43f2-9e54-198cfd22d705) + +## Support + +algorithm + +- [x] grpo + +rollout + +- [x] vllm +- [ ] sglang + +train + +- [x] megatron +- [ ] fsdp + +## Version + +```bash +# dev version +pip install verl@git+https://github.com/volcengine/verl.git@b97ebfd5062223337ae065c2250f8ab5c0e08e5e +``` + +## Quickstart + +```bash +# refer to this example: recipe/fault_recover/run_qwen2_5_0.5b_megatron.sh +python3 -m recipe.fault_recover.main_ppo --config-path=config \ + --config-name='fault_recover_ppo_megatron_trainer.yaml' \ + fault_manager.enable=True \ + actor_rollout_ref.rollout.agent.default_agent_loop=fault_recover_single_turn_agent \ + # refer to other detail config in the fault_manager part of + # recipe/fault_recover/config/fault_recover_ppo_megatron_trainer.yaml +``` + +## Configuration + +```yaml +fault_manager: + enable: False + # max retry times for other training phases except rollout (restart ray) + max_reschedule_times: 1 + # max retry times for rollout phase (rebuild worker group) + max_rebuild_times: 1 + # timeout of waiting cluster to be ready + timeout_rebuild: 300 + # check chips usage interval during rollout, set -1 to disable timeout check + timeout_task_check_interval: 10 + # timeout of chips usage being free, set -1 to disable chip check and + # 'timeout_task_check_interval' will be the whole time limit of rollout + # which means you should increase it + timeout_chip_free: 30 + # file path for token saving + tokens_save_file: ./tokens_ckpt/tokens.pt + # interval of saving tokens to disk, remember to clear if training config is changed + tokens_save_interval: 10 +``` + +## FAQ diff --git a/ICL/DAPO/verl-recipe/fault_recover/fault_manager.py b/ICL/DAPO/verl-recipe/fault_recover/fault_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..02efcc21dc585017162672070651146d0529792b --- /dev/null +++ b/ICL/DAPO/verl-recipe/fault_recover/fault_manager.py @@ -0,0 +1,682 @@ +import datetime +import os +import re +import shutil +import signal +import subprocess +import threading +import time +from collections import defaultdict +from functools import wraps + +import numpy as np +import ray +import torch +from omegaconf import OmegaConf +from ray.exceptions import RayActorError, RayTaskError +from ray.util.queue import Queue +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + +from verl.trainer.ppo.utils import Role + +QUEUE_NAME = "fault_manager_queue" + + +@ray.remote +class TokensDict: + def __init__(self, auto_save_path=None, save_interval=5, batch_size=0): + self.iteration = 0 + self.index_prompt_tokens = {} + self._lock = threading.Lock() + self.auto_save_path = auto_save_path + self.save_interval = save_interval + self.saving_thread = None + self.saved_step_ckpt = {} + self.batch_size = batch_size + self.is_rollout_finished_step = False + + def start_save(self): + if self.auto_save_path: + if self.saving_thread is None: + self.saving_thread = threading.Thread(target=self._auto_save, daemon=True) + self.saving_thread.start() + + def update_data(self, global_id, new_token_info): + with self._lock: + if global_id not in self.index_prompt_tokens: + self.index_prompt_tokens[global_id] = {} + for k, v in new_token_info.items(): + if k not in self.index_prompt_tokens[global_id]: + self.index_prompt_tokens[global_id][k] = type(v)() + if k == "new_token_ids": + self.index_prompt_tokens[global_id][k].extend(v) + else: + self.index_prompt_tokens[global_id][k] = v + + def update_datas(self, global_id_map, req_info): + with self._lock: + for req_id, new_token_info in req_info.items(): + if req_id in global_id_map: + global_id = global_id_map[req_id] + if global_id not in self.index_prompt_tokens: + self.index_prompt_tokens[global_id] = {} + for k, v in new_token_info.items(): + if k not in self.index_prompt_tokens[global_id]: + self.index_prompt_tokens[global_id][k] = type(v)() + if k == "new_token_ids": + self.index_prompt_tokens[global_id][k].extend(v) + else: + self.index_prompt_tokens[global_id][k] = v + + def set_data(self, global_id, key, value): + with self._lock: + if global_id not in self.index_prompt_tokens: + self.index_prompt_tokens[global_id] = {} + self.index_prompt_tokens[global_id][key] = value + + def extend(self, global_id, key, value): + with self._lock: + if global_id not in self.index_prompt_tokens: + self.index_prompt_tokens[global_id] = {} + if key not in self.index_prompt_tokens[global_id]: + self.index_prompt_tokens[global_id][key] = [] + self.index_prompt_tokens[global_id][key].extend(value) + + def get(self): + with self._lock: + return self.index_prompt_tokens + + def clear(self, latest_model_ckpt_step): + save_dir, _ = os.path.split(self.auto_save_path) + global_step_path = os.path.join(save_dir, f"global_step_{self.iteration}.pt") + while True: + with self._lock: + finished = [req_info.get("finished", False) for _, req_info in self.index_prompt_tokens.items()] + if not finished: + break + if all(finished) and os.path.exists(global_step_path): + break + print(f"[fault_manager][{datetime.datetime.now()}] waiting all reqs to be finished and saved") + time.sleep(1) + + self.index_prompt_tokens.clear() + # clear expired tokens ckpt + for iteration in list(self.saved_step_ckpt.keys()): + if iteration <= latest_model_ckpt_step: + if os.path.exists(self.saved_step_ckpt[iteration]): + os.remove(self.saved_step_ckpt[iteration]) + self.saved_step_ckpt.pop(iteration) + + def try_load(self): + with self._lock: + save_dir = os.path.dirname(self.auto_save_path) + finished_save_path = os.path.join(save_dir, f"global_step_{self.iteration}.pt") + if os.path.exists(finished_save_path): + load_data = torch.load(finished_save_path) + self.index_prompt_tokens = load_data["tokens"] + self.is_rollout_finished_step = True + return True + self.is_rollout_finished_step = False + if os.path.exists(self.auto_save_path): + load_data = torch.load(self.auto_save_path) + if load_data["iter"] == self.iteration: + self.index_prompt_tokens = load_data["tokens"] + return True + return False + + def update_iter(self, iteration): + self.iteration = iteration + + def _auto_save(self): + save_dir, save_file = os.path.split(self.auto_save_path) + save_dir_tmp = os.path.join(os.path.dirname(self.auto_save_path), "tmp") + tmp_path = os.path.join(save_dir_tmp, save_file) + os.makedirs(save_dir_tmp, exist_ok=True) + os.makedirs(save_dir, exist_ok=True) + while True: + if not self.is_rollout_finished_step: + with self._lock: + torch.save({"iter": self.iteration, "tokens": self.index_prompt_tokens}, tmp_path) + os.replace(tmp_path, self.auto_save_path) + finished = sum( + [1 for _, req_info in self.index_prompt_tokens.items() if req_info.get("finished", False)] + ) + print(f"[fault_manager][{datetime.datetime.now()}] finished requests num: {finished}") + if ( + all([req_info.get("finished", False) for _, req_info in self.index_prompt_tokens.items()]) + and finished == self.batch_size + ): + global_step_path = os.path.join(save_dir, f"global_step_{self.iteration}.pt") + shutil.copy(self.auto_save_path, global_step_path) + self.saved_step_ckpt[self.iteration] = global_step_path + time.sleep(self.save_interval) + + +@ray.remote +class NodeWorker: + def __init__(self, actor_pids, device_name): + self.actor_pids = actor_pids + self.get_usage_fn = self._get_npu_usage if device_name == "npu" else self._get_gpu_usage + self.get_chip_info_cmd = ["npu-smi", "info"] if device_name == "npu" else ["nvidia-smi"] + + def is_chip_free(self): + devices_info = set() + chip_info = self._exec_shell(self.get_chip_info_cmd) + for pid in self.actor_pids: + device_info = self._get_middle_str("\n", chip_info, str(pid)) + if device_info: + devices_info.add(tuple(device_info.split())) + + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=len(devices_info)) as executor: + usages = list(executor.map(self.get_usage_fn, devices_info)) + print(f"[fault_manager][{datetime.datetime.now()}] chips core utilization: {usages}") + return all([usage == 0 for usage in usages]) + + def _get_npu_usage(self, device_info): + try: + _, npu_id, chip_id, _ = device_info + chip_info = self._exec_shell(["npu-smi", "info", "-i", npu_id, "-c", chip_id, "-t", "usages"]) + if not chip_info: + return 0 + *_, usage = self._get_middle_str("Aicore", chip_info, "\n").split() + return int(usage) + except Exception as e: + print(f"[fault_manager][{datetime.datetime.now()}] get npu usage error: {str(e)}") + return 0 + + def _get_gpu_usage(self, device_info): + try: + gpu_id, _, _ = device_info + chip_info = self._exec_shell( + ["nvidia-smi", "dmon", "-c", "1", "-i", gpu_id, "-s", "u", "--format", "noheader,nounit"] + ) + if not chip_info: + return 0 + _, usage, *_ = chip_info.split() + return int(usage) + except Exception as e: + print(f"[fault_manager][{datetime.datetime.now()}] get gpu usage error: {str(e)}") + return 0 + + @staticmethod + def _exec_shell(cmd: list): + try: + result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, check=True) + return result.stdout + except subprocess.CalledProcessError: + return False + + @staticmethod + def _get_middle_str(left, text, right): + middle = "(.*?)" if left and right else "(.*)" + match = re.search(rf"{left}{middle}{right}", text) + if match: + return match.group(1) + return "" + + +class FaultMgr: + trainer = None + tokens_queue = None + tokens_dict = None + request_global_id_map = {} + node_workers = [] + timeout_chip_check = False + device_type = "cpu" + node_pids = defaultdict(list) + + @classmethod + def init_tokens_queue(cls): + cls.tokens_queue = Queue( + actor_options={ + "name": QUEUE_NAME, + "scheduling_strategy": cls._get_head_node_strategy(), + # "max_concurrency": 4, # better be the num of vllm servers + } + ) + + @classmethod + def bind_trainer(cls, trainer): + from recipe.fault_recover.agent_loop.fault_recover_agent_loop import ( + FaultRecoverAgentLoopManager as AgentLoopManager, + ) + + print(f"[fault_manager][{datetime.datetime.now()}] start bind trainer") + cls.trainer = trainer + cls.tokens_dict = TokensDict.options(scheduling_strategy=cls._get_head_node_strategy()).remote( + auto_save_path=cls.trainer.config.fault_manager.tokens_save_file, + save_interval=cls.trainer.config.fault_manager.tokens_save_interval, + batch_size=cls.trainer.config.data.train_batch_size * cls.trainer.config.actor_rollout_ref.rollout.n, + ) + cls.catch_rollout_tokens() + cls.device_type = cls.trainer.actor_rollout_wg.get_device_name()[0] + cls.timeout_chip_check = (cls.trainer.config.fault_manager.timeout_chip_free > 0) and cls.device_type != "cpu" + + AgentLoopManager.generate_sequences = cls.catch_rollout_fault( + cls.timeout(AgentLoopManager.generate_sequences), roles=[Role.ActorRollout, Role.RefPolicy] + ) + + if cls.timeout_chip_check: + cls._init_node_workers() + + @classmethod + def reschedule(cls, func): + @wraps(func) + def wrapper(config, task_runner_class=None): + try: + func(config, task_runner_class) + except Exception as reschedule_error: + print(f"[fault_manager][{datetime.datetime.now()}] catch reschedule fault: {reschedule_error}") + if config.fault_manager.enable: + max_reschedule_times = config.fault_manager.max_reschedule_times + reschedule_times = 0 + while (max_reschedule_times > 0) and (reschedule_times < max_reschedule_times): + try: + ray.shutdown() + func(config, task_runner_class, is_rescheduling=True) + except Exception as e: + print( + f"[fault_manager][{datetime.datetime.now()}] catch reschedule fault: " + f"{e} during recover, reschedule_times: {reschedule_times}/{max_reschedule_times}" + ) + reschedule_error = e + reschedule_times += 1 + else: + break + else: + raise reschedule_error + else: + raise reschedule_error + + return wrapper + + @classmethod + def rebuild_wg(cls, roles: list): + if not cls.trainer: + raise ValueError("[fault_manager] Have not bound trainer!") + print(f"[fault_manager][{datetime.datetime.now()}] start rebuild wg") + from verl.single_controller.ray import RayClassWithInitArgs + from verl.single_controller.ray.base import create_colocated_worker_cls + + actor_rollout_resource_pool = None + for role in roles: + resource_pool = cls.trainer.resource_pool_manager.get_resource_pool(role) + if role == Role.ActorRollout: + actor_rollout_resource_pool = resource_pool + role_cls = RayClassWithInitArgs( + cls=cls.trainer.role_worker_mapping[role], + config=cls._get_role_config(role), + role=str(role), + ) + cls.trainer.resource_pool_to_cls[resource_pool][str(role)] = role_cls + + wg_kwargs = cls._get_wg_kwargs() + for resource_pool, class_dict in cls.trainer.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = cls.trainer.ray_worker_group_cls( + resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, **wg_kwargs + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + + for role in class_dict.keys(): + role_wg = spawn_wg[role] + setattr(cls.trainer, cls._get_wg_name(role), role_wg) + getattr(cls.trainer, cls._get_wg_name(role)).init_model() + if cls.timeout_chip_check: + cls._init_node_workers() + return actor_rollout_resource_pool + + @classmethod + def catch_rollout_tokens(cls): + print(f"[fault_manager][{datetime.datetime.now()}] start catch rollout tokens") + + @ray.remote(num_cpus=1) + def run(q, td): + while True: + req_info = q.get() + # print(f"[fault manager] qsize {q.qsize()}") + if isinstance(req_info, tuple): + request_id, global_id = req_info + cls.request_global_id_map[request_id] = global_id + elif isinstance(req_info, dict): + ray.get(td.update_datas.remote(cls.request_global_id_map, req_info)) + + run.remote(cls.tokens_queue, cls.tokens_dict) + + @classmethod + def catch_rollout_fault(cls, func, roles): + @wraps(func) + def wrapper(_self, gen_batch_output): + if gen_batch_output.meta_info.get("validate"): + gen_batch_output = func(_self, gen_batch_output) + return gen_batch_output + try: + gen_batch_output = cls._update_gen_batch(gen_batch_output, ray.get(cls.tokens_dict.get.remote())) + gen_batch_output = func(_self, gen_batch_output) + return gen_batch_output + except Exception as rebuild_error: + print(f"[fault_manager][{datetime.datetime.now()}] catch rollout fault: {rebuild_error}") + max_rebuild_times = cls.trainer.config.fault_manager.max_rebuild_times + rebuild_times = 0 + while (max_rebuild_times < 0) or (rebuild_times < max_rebuild_times): + try: + pre_rebuild_result = cls._pre_rebuild() + if pre_rebuild_result is not True: + rebuild_error = pre_rebuild_result + break + + print(f"[fault_manager][{datetime.datetime.now()}] start rebuild") + actor_rollout_resource_pool = cls.rebuild_wg(roles=roles) + cls.rebuild_manager(actor_rollout_resource_pool) + gen_batch_output = cls._update_gen_batch( + gen_batch_output, ray.get(cls.tokens_dict.get.remote()) + ) + print(f"[fault_manager][{datetime.datetime.now()}] retry rollout") + gen_batch_output = func(cls.trainer.async_rollout_manager, gen_batch_output) + return gen_batch_output + except Exception as e: + print( + f"[fault_manager][{datetime.datetime.now()}] catch rebuild fault: " + f"{e} during recover retry, rebuild_times: {rebuild_times}/{max_rebuild_times}" + ) + rebuild_error = e + rebuild_times += 1 + raise rebuild_error + + return wrapper + + @classmethod + def timeout(cls, func): + @wraps(func) + def wrapper(_self, prompts): + timeout_task_check_interval = _self.config.fault_manager.timeout_task_check_interval + timeout_chip_free = _self.config.fault_manager.timeout_chip_free + if ( + not _self.config.fault_manager.enable + or timeout_task_check_interval < 0 + or prompts.meta_info.get("validate") + ): + return func(_self, prompts) + if cls.timeout_chip_check: + free_flag = threading.Event() + stop_flag = threading.Event() + + def monitor(): + start_time = time.time() + while not stop_flag.is_set(): + chips_free = all(ray.get([w.is_chip_free.remote() for w in cls.node_workers])) + if chips_free: + if not start_time: + start_time = time.time() + elif time.time() - start_time > timeout_chip_free and not free_flag.is_set(): + free_flag.set() + else: + start_time = None + if free_flag.is_set(): + free_flag.clear() + time.sleep(1) + + t = threading.Thread(target=monitor, daemon=True) + t.start() + + def _handle_timeout(signum, frame): + if cls.timeout_chip_check: + if free_flag.is_set(): + [ray.kill(w) for w in cls.trainer.async_rollout_manager.agent_loop_workers] + [ + ray.get(rr.server_handle.clear_engine.remote()) + for rr in cls.trainer.async_rollout_manager.rollout_replicas + ] + [ray.kill(rr.server_handle) for rr in cls.trainer.async_rollout_manager.rollout_replicas] + else: + signal.alarm(timeout_task_check_interval) + else: + raise TimeoutError(f"[fault_manager][{datetime.datetime.now()}] {func} timeout") + + signal.signal(signal.SIGALRM, _handle_timeout) + try: + signal.alarm(timeout_task_check_interval) + return func(_self, prompts) + except (RayTaskError, RayActorError) as e: + raise TimeoutError(f"[fault_manager][{datetime.datetime.now()}] {func} timeout") from e + finally: + if cls.timeout_chip_check: + stop_flag.set() + signal.alarm(0) + t.join(timeout=2) + + return wrapper + + @classmethod + def init_index_prompt_tokens(cls, gen_batch_output): + ray.get(cls.tokens_dict.clear.remote(latest_model_ckpt_step=cls._get_latest_global_steps())) + ray.get(cls.tokens_dict.update_iter.remote(cls.trainer.global_steps)) + ray.get(cls.tokens_dict.try_load.remote()) + gen_batch_output.non_tensor_batch["global_id"] = np.array( + [str(i) for i in range(len(gen_batch_output.non_tensor_batch["prompt"]))], dtype=object + ) + ray.get(cls.tokens_dict.start_save.remote()) + cls.request_global_id_map.clear() + + @classmethod + def _get_wg_name(cls, role): + return { + str(Role.ActorRollout): "actor_rollout_wg", + str(Role.RefPolicy): "ref_policy_wg", + }.get(role) + + @classmethod + def _get_wg_kwargs(cls): + wg_kwargs = {} + if OmegaConf.select(cls.trainer.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = cls.trainer.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(cls.trainer.config.global_profiler, "steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(cls.trainer.config.global_profiler, "steps") + # Only require nsight worker options when tool is nsys + if OmegaConf.select(cls.trainer.config.global_profiler, "tool") == "nsys": + assert ( + OmegaConf.select( + cls.trainer.config.global_profiler.global_tool_config.nsys, "worker_nsight_options" + ) + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select( + cls.trainer.config.global_profiler.global_tool_config.nsys, "worker_nsight_options" + ) + ) + wg_kwargs["device_name"] = cls.trainer.device_name + return wg_kwargs + + @classmethod + def _get_role_config(cls, role): + return { + Role.ActorRollout: cls.trainer.config.actor_rollout_ref, + Role.RefPolicy: cls.trainer.config.actor_rollout_ref, + }.get(role) + + @classmethod + def _parse_req_tokens(cls, req_info, td): + for req_id, new_token_info in req_info.items(): + if req_id in cls.request_global_id_map: + # fused for better performance + ray.get(td.update_data.remote(cls.request_global_id_map[req_id], new_token_info)) + + @classmethod + def _update_gen_batch(cls, gen_batch_output, tokens_dict): + all_tokens = tokens_dict + global_ids = gen_batch_output.non_tensor_batch["global_id"] + all_new_token_ids = [] + all_new_token_length = [] + all_token_finished = [] + all_log_probs = [] + all_routed_experts = [] + all_num_preempted = [] + + for global_id in global_ids: + token_info = all_tokens.get(global_id, {"new_token_ids": [], "finished": False}) + new_token_ids = token_info.get("new_token_ids", []) + finished = token_info.get("finished", False) + log_probs = token_info.get("log_probs", None) + routed_experts = token_info.get("routed_experts", None) + num_preempted = token_info.get("num_preempted", -1) + all_new_token_ids.append(new_token_ids) + all_new_token_length.append(len(new_token_ids)) + all_token_finished.append(finished) + all_log_probs.append(log_probs) + all_routed_experts.append(routed_experts) + all_num_preempted.append(num_preempted) + + if all([length == 0 for length in all_new_token_length]): + return gen_batch_output + + gen_batch_output.non_tensor_batch["new_token_ids"] = np.array(all_new_token_ids, dtype=object) + gen_batch_output.non_tensor_batch["finished"] = np.array(all_token_finished, dtype=bool) + gen_batch_output.non_tensor_batch["log_probs"] = np.array(all_log_probs, dtype=object) + gen_batch_output.non_tensor_batch["routed_experts"] = np.array(all_routed_experts, dtype=object) + gen_batch_output.non_tensor_batch["num_preempted"] = np.array(all_num_preempted, dtype=object) + return gen_batch_output + + @classmethod + def _get_latest_global_steps(cls): + from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path + + checkpoint_folder = cls.trainer.config.trainer.default_local_dir + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) + if global_step_folder: + return int(global_step_folder.split("global_step_")[-1]) + return 0 + + @classmethod + def rebuild_manager(cls, actor_rollout_resource_pool): + from recipe.fault_recover.agent_loop.fault_recover_agent_loop import ( + FaultRecoverAgentLoopManager as AgentLoopManager, + ) + + from verl.checkpoint_engine import CheckpointEngineManager + + if cls.trainer.use_reward_loop and cls.trainer.use_rm: + raise NotImplementedError("[fault_manager] fault_recover does not support use_rm yet") + + [ray.kill(w) for w in cls.trainer.async_rollout_manager.agent_loop_workers] + [ray.get(rr.server_handle.clear_engine.remote()) for rr in cls.trainer.async_rollout_manager.rollout_replicas] + [ray.kill(rr.server_handle) for rr in cls.trainer.async_rollout_manager.rollout_replicas] + + cls.trainer.async_rollout_manager = AgentLoopManager( + config=cls.trainer.config, + worker_group=cls.trainer.actor_rollout_wg, + rollout_resource_pool=actor_rollout_resource_pool, + reward_loop_worker_handles=None, + ) + + cls.trainer.checkpoint_manager = CheckpointEngineManager( + backend=cls.trainer.config.actor_rollout_ref.rollout.checkpoint_engine.backend, + trainer=cls.trainer.actor_rollout_wg, + replicas=cls.trainer.async_rollout_manager.rollout_replicas, + ) + + # sleep all replicas to load checkpoint + cls.trainer.checkpoint_manager.sleep_replicas() + + if cls.trainer._load_checkpoint() != 0: + cls.trainer.global_steps += 1 + cls.trainer.checkpoint_manager.update_weights() + + @classmethod + def _init_node_workers(cls): + [ray.kill(w) for w in cls.node_workers] + cls.node_pids.clear() + cls.node_workers.clear() + for node_id, actor_pid in cls.trainer.actor_rollout_wg.get_node_pids(): + cls.node_pids[node_id].append(actor_pid) + for node_id, actor_pids in cls.node_pids.items(): + node_worker = NodeWorker.options( + scheduling_strategy=NodeAffinitySchedulingStrategy(node_id=node_id, soft=False) + ).remote(actor_pids, device_name=cls.device_type) + cls.node_workers.append(node_worker) + + @classmethod + def _pre_rebuild(cls): + if cls.trainer.global_steps != cls._get_latest_global_steps() + 1: + return Exception( + f"[fault_manager][{datetime.datetime.now()}] ckpt of fault step {cls.trainer.global_steps - 1} lost" + ) + + for pool in cls.trainer.resource_pool_to_cls.keys(): + for pg in pool.pgs: + ray.util.remove_placement_group(pg) + pool.pgs = None + + while not cls.tokens_queue.empty(): + print(f"[fault_manager][{datetime.datetime.now()}] waiting for tokens queue to be empty...") + time.sleep(1) + + rebuild_time = time.time() + timeout_rebuild = cls.trainer.config.fault_manager.timeout_rebuild + while time.time() - rebuild_time < timeout_rebuild: + try: + check_resource_available(cls.trainer.resource_pool_manager.resource_pool_spec) + return True + except ValueError as e: + print(f"[fault_manager][{datetime.datetime.now()}] {str(e)}\nwaiting for resource to be ready...") + time.sleep(5) + return Exception( + f"[fault_manager][{datetime.datetime.now()}] " + f"timeout waiting for resource to be ready for {timeout_rebuild}s" + ) + + @classmethod + def _get_head_node_strategy(cls): + return NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=False, + ) + + +def get_tokens_queue(): + try: + tokens_queue = ray.get_actor(QUEUE_NAME) + except ValueError: + tokens_queue = None + return tokens_queue + + +def check_resource_available(resource_pool_spec): + """Check if the resource pool can be satisfied in this ray cluster.""" + node_available_resources = ray._private.state.available_resources_per_node() + node_available_gpus = { + node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) + for node, node_info in node_available_resources.items() + } + + # check total required gpus can be satisfied + total_available_gpus = sum(node_available_gpus.values()) + total_required_gpus = sum( + [n_gpus for process_on_nodes in resource_pool_spec.values() for n_gpus in process_on_nodes] + ) + if total_available_gpus < total_required_gpus: + raise ValueError( + f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" + ) + + +def get_resource_pool_spec(config): + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + # TODO Here you can use the new registration method to support dynamic registration of roles + if config.reward_model.enable_resource_pool: + if config.reward_model.n_gpus_per_node <= 0: + raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0") + if config.reward_model.nnodes <= 0: + raise ValueError("config.reward_model.nnodes must be greater than 0") + + reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes + resource_pool_spec["reward_pool"] = reward_pool + return resource_pool_spec diff --git a/ICL/DAPO/verl-recipe/fault_recover/main_ppo.py b/ICL/DAPO/verl-recipe/fault_recover/main_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..eb93aa669a3cd6b21444430802cef7397e0f185e --- /dev/null +++ b/ICL/DAPO/verl-recipe/fault_recover/main_ppo.py @@ -0,0 +1,311 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other mpain. +""" + +import datetime +import os +import socket +import time + +import hydra +import ray +from omegaconf import OmegaConf, open_dict +from recipe.fault_recover.fault_manager import FaultMgr, check_resource_available, get_resource_pool_spec +from recipe.fault_recover.ray_trainer import FaultRecoverRayPPOTrainer + +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.main_ppo import TaskRunner, create_rl_dataset, create_rl_sampler +from verl.trainer.ppo.reward import load_reward_manager +from verl.trainer.ppo.utils import need_critic, need_reference_policy +from verl.utils.config import validate_config +from verl.utils.device import auto_set_device, is_cuda_available + + +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) +def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config: Hydra configuration dictionary containing training parameters. + """ + # Automatically set `config.trainer.device = npu` when running on Ascend NPU. + auto_set_device(config) + + run_ppo(config) + + +# Define a function to run the PPO-like training process +@FaultMgr.reschedule +def run_ppo(config, task_runner_class=None, is_rescheduling=False) -> None: + """Initialize Ray cluster and run distributed PPO training process. + + Args: + config: Training configuration object containing all necessary parameters + for distributed PPO training including Ray initialization settings, + model paths, and training hyperparameters. + task_runner_class: For recipe to change TaskRunner. + """ + # Check if Ray is not initialized + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + default_runtime_env = get_ppo_ray_runtime_env() + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + + if config.transfer_queue.enable: + # Add runtime environment variables for transfer queue + runtime_env_vars = runtime_env_kwargs.get("env_vars", {}) + runtime_env_vars["TRANSFER_QUEUE_ENABLE"] = "1" + runtime_env_kwargs["env_vars"] = runtime_env_vars + + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + if is_rescheduling: + reschedule_time = time.time() + timeout_reschedule = config.fault_manager.timeout_reschedule + resource_pool_spec = get_resource_pool_spec(config) + while True: + if time.time() - reschedule_time >= timeout_reschedule: + raise Exception( + f"[fault_manager][{datetime.datetime.now()}] " + f"timeout waiting for resource to be ready for {timeout_reschedule}s" + ) + try: + check_resource_available(resource_pool_spec) + break + except ValueError as e: + print(f"[fault_manager][{datetime.datetime.now()}] {str(e)}\nwaiting for resource to be ready...") + time.sleep(5) + + print(f"[fault_manager][{datetime.datetime.now()}] start reschedule") + + if task_runner_class is None: + task_runner_class = ray.remote(num_cpus=1)( + FaultRecoverTaskRunner + ) # please make sure main_task is not scheduled on head + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if ( + is_cuda_available + and config.global_profiler.tool == "nsys" + and config.global_profiler.get("steps") is not None + and len(config.global_profiler.get("steps", [])) > 0 + ): + from verl.utils.import_utils import is_nvtx_available + + assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = task_runner_class.remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +class FaultRecoverTaskRunner(TaskRunner): + """Ray remote class for executing distributed PPO training tasks. + + This class encapsulates the main training logic and runs as a Ray remote actor + to enable distributed execution across multiple nodes and GPUs. + + Attributes: + role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes + mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation + """ + + def add_actor_rollout_worker(self, config): + """Add actor rollout worker based on the actor strategy.""" + from verl.single_controller.ray import RayWorkerGroup + from verl.trainer.ppo.ray_trainer import Role + + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + + # use new model engine implementation + if use_legacy_worker_impl == "disable": + from verl.workers.engine_workers import ActorRolloutRefWorker + + actor_rollout_cls = ActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0) + if lora_rank <= 0: + lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0) + ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None + # NOTE: In new model engine, ref policy and actor rollout are in same ActorRolloutRefWorker, + # while in legacy model engine, ref policy is in a separate ActorRolloutRefWorker. + if need_reference_policy(config) and not ref_in_actor: + role = Role.ActorRolloutRef + else: + role = Role.ActorRollout + self.role_worker_mapping[role] = ray.remote(actor_rollout_cls) + self.mapping[role] = "global_pool" + return actor_rollout_cls, ray_worker_group_cls + + # Note: sync mode validation is now handled in RolloutConfig.__post_init__ + # Always use async worker since sync mode is deprecated and rejected + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker + + actor_rollout_cls = AsyncActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + from recipe.fault_recover.megatron_workers import AsyncFaultRecoverActorRolloutRefWorker + + actor_rollout_cls = AsyncFaultRecoverActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls) + self.mapping[Role.ActorRollout] = "global_pool" + return actor_rollout_cls, ray_worker_group_cls + + def run(self, config): + """Execute the main PPO training workflow. + + This method sets up the distributed training environment, initializes + workers, datasets, and reward functions, then starts the training process. + + Args: + config: Training configuration object containing all parameters needed + for setting up and running the PPO training process. + """ + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) + self.add_critic_worker(config) + + # We should adopt a multi-source reward function here: + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # finally, we combine all the rewards together + # The reward type depends on the tag of the data + self.add_reward_model_worker(config) + + # Add a reference policy worker if KL loss or KL reward is used. + self.add_ref_policy_worker(config, actor_rollout_cls) + + # validate config + validate_config( + config=config, + use_reference_policy=need_reference_policy(config), + use_critic=need_critic(config), + ) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + # Load the reward manager for training and validation. + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + val_reward_fn = load_reward_manager( + config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + ) + + resource_pool_manager = self.init_resource_pool_mgr(config) + + from verl.utils.dataset.rl_dataset import collate_fn + + # Create training and validation datasets. + train_dataset = create_rl_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + is_train=True, + max_samples=config.data.get("train_max_samples", -1), + ) + val_dataset = create_rl_dataset( + config.data.val_files, + config.data, + tokenizer, + processor, + is_train=False, + max_samples=config.data.get("val_max_samples", -1), + ) + train_sampler = create_rl_sampler(config.data, train_dataset) + + if config.fault_manager.enable: + with open_dict(config.actor_rollout_ref): + config.actor_rollout_ref.fault_manager = config.fault_manager + FaultMgr.init_tokens_queue() + + # Initialize the PPO trainer. + trainer = FaultRecoverRayPPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=self.role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + # Initialize the workers of the trainer. + trainer.init_workers() + + try: + # Start the training process. + trainer.fit() + except Exception as e: + [ray.get(rr.server_handle.clear_engine.remote()) for rr in trainer.async_rollout_manager.rollout_replicas] + raise e + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/fault_recover/megatron_workers.py b/ICL/DAPO/verl-recipe/fault_recover/megatron_workers.py new file mode 100644 index 0000000000000000000000000000000000000000..996ce3c6aa99994f9d8b7b6e2c5eb8d36ce43c76 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fault_recover/megatron_workers.py @@ -0,0 +1,21 @@ +import os + +import ray + +from verl.single_controller.base.decorator import Dispatch, register +from verl.utils.device import get_device_name +from verl.workers.megatron_workers import AsyncActorRolloutRefWorker + + +class AsyncFaultRecoverActorRolloutRefWorker(AsyncActorRolloutRefWorker): + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def get_device_name(self): + return get_device_name() + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def get_pid(self): + return os.getpid() + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def get_node_pids(self): + return ray.get_runtime_context().get_node_id(), os.getpid() diff --git a/ICL/DAPO/verl-recipe/fault_recover/ray_trainer.py b/ICL/DAPO/verl-recipe/fault_recover/ray_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..4dca451ae2d585fc55c991e90a81b9b2b6d092a9 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fault_recover/ray_trainer.py @@ -0,0 +1,673 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import uuid +from copy import deepcopy +from pprint import pprint + +import numpy as np +import ray +import torch +from omegaconf import OmegaConf +from tqdm import tqdm + +from verl import DataProto +from verl.checkpoint_engine import CheckpointEngineManager +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.single_controller.ray import RayClassWithInitArgs +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + compute_variance_proxy_metrics, +) +from verl.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask +from verl.trainer.ppo.reward import compute_reward_async +from verl.trainer.ppo.utils import Role +from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.debug import marked_timer +from verl.utils.import_utils import load_class_from_fqn +from verl.utils.metric import reduce_metrics +from verl.utils.rollout_skip import RolloutSkip +from verl.workers.config import FSDPEngineConfig + + +class FaultRecoverRayPPOTrainer(RayPPOTrainer): + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + from recipe.fault_recover.fault_manager import FaultMgr + + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + actor_role = Role.ActorRolloutRef if Role.ActorRolloutRef in self.role_worker_mapping else Role.ActorRollout + if self.hybrid_engine: + actor_rollout_resource_pool = self.resource_pool_manager.get_resource_pool(actor_role) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[actor_role], + config=self.config.actor_rollout_ref, + role=str(actor_role), + ) + self.resource_pool_to_cls[actor_rollout_resource_pool][str(actor_role)] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + + from verl.workers.config import CriticConfig + + critic_cfg: CriticConfig = omega_conf_to_dataclass(self.config.critic) + + if self.use_legacy_worker_impl == "disable": + # convert critic_cfg into TrainingWorkerConfig + from verl.workers.engine_workers import TrainingWorkerConfig + + orig_critic_cfg = critic_cfg + if orig_critic_cfg.strategy == "fsdp": + engine_config: FSDPEngineConfig = orig_critic_cfg.model.fsdp_config + engine_config.infer_max_token_len_per_gpu = critic_cfg.ppo_infer_max_token_len_per_gpu + engine_config.max_token_len_per_gpu = critic_cfg.ppo_max_token_len_per_gpu + else: + raise NotImplementedError(f"Unknown strategy {orig_critic_cfg.strategy=}") + + critic_cfg = TrainingWorkerConfig( + model_type="value_model", + model_config=orig_critic_cfg.model_config, + engine_config=engine_config, + optimizer_config=orig_critic_cfg.optim, + checkpoint_config=orig_critic_cfg.checkpoint, + ) + + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) + self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls + + # create reference policy if needed + if self.use_reference_policy and Role.RefPolicy in self.role_worker_mapping: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role=str(Role.RefPolicy), + ) + self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls + + # create a reward model if reward_fn is None + # for legacy discriminative reward model, we create a reward model worker here + # for reward loop discriminative reward model, we create a reward loop manager here + if self.use_rm and not self.use_reward_loop: + raise RuntimeError("Reward model worker group is not supported, please set use_reward_loop=True") + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.global_profiler, "steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") + # Only require nsight worker options when tool is nsys + if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": + assert ( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + ) + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg[str(Role.Critic)] + if self.use_legacy_worker_impl == "disable": + self.critic_wg.reset() + # assign critic loss + from functools import partial + + from verl.workers.utils.losses import value_loss + + value_loss_ = partial(value_loss, config=orig_critic_cfg) + self.critic_wg.set_loss_fn(value_loss_) + else: + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + if str(Role.RefPolicy) in all_wg: + self.ref_policy_wg = all_wg[str(Role.RefPolicy)] + self.ref_policy_wg.init_model() + else: + # Model engine: ActorRolloutRefWorker + assert str(Role.ActorRolloutRef) in all_wg, f"{all_wg.keys()=}" + self.ref_policy_wg = all_wg[str(Role.ActorRolloutRef)] + + self.rm_wg = None + # initalization of rm_wg will be deprecated in the future + if self.use_rm and not self.use_reward_loop: + self.rm_wg = all_wg[str(Role.RewardModel)] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg[str(actor_role)] + self.actor_rollout_wg.init_model() + + if self.ref_in_actor: + self.ref_policy_wg = self.actor_rollout_wg + + # create reward loop manager + if self.use_reward_loop: + from verl.experimental.reward_loop import RewardLoopManager + + # initalize reward loop manager + # reward model (colocate or standalone): get resource_pool + # no reward model: resource_pool = None + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) if self.use_rm else None + self.reward_loop_manager = RewardLoopManager( + config=self.config, + rm_resource_pool=resource_pool, + ) + + # create async rollout manager and request scheduler + # Note: mode is always "async" since sync mode is deprecated + self.async_rollout_mode = True + + # Support custom AgentLoopManager via config + manager_class_fqn = self.config.actor_rollout_ref.rollout.get("agent", {}).get("agent_loop_manager_class") + if manager_class_fqn: + AgentLoopManager = load_class_from_fqn(manager_class_fqn, "AgentLoopManager") + else: + from recipe.fault_recover.agent_loop.fault_recover_agent_loop import ( + FaultRecoverAgentLoopManager as AgentLoopManager, + ) + + # infrastructure overview: https://verl.readthedocs.io/en/latest/advance/reward_loop.html#architecture-design + # agent_reward_loop: streaming reward computation with actor rollout + # two conditions satisfied: (1) no reward model, or (2) reward model with extra resource pool + enable_agent_reward_loop = self.use_reward_loop and ( + not self.use_rm or self.config.reward_model.enable_resource_pool + ) + # if enable_agent_reward_loop, we directly pass reward_loop_workers to agent loop manager + # to stream reward computation with actor rollout + + reward_loop_worker_handles = self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None + self.async_rollout_manager = AgentLoopManager( + config=self.config, + worker_group=self.actor_rollout_wg, + rollout_resource_pool=actor_rollout_resource_pool, + reward_loop_worker_handles=reward_loop_worker_handles, + ) + + self.checkpoint_manager = CheckpointEngineManager( + backend=self.config.actor_rollout_ref.rollout.checkpoint_engine.backend, + trainer=self.actor_rollout_wg, + replicas=self.async_rollout_manager.rollout_replicas, + ) + + # sleep all replicas to load checkpoint + self.checkpoint_manager.sleep_replicas() + + if self.config.fault_manager.enable: + FaultMgr.bind_trainer(self) + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + from recipe.fault_recover.fault_manager import FaultMgr + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint and update weights before doing anything + self._load_checkpoint() + self.checkpoint_manager.update_weights() + + current_epoch = self.global_steps // len(self.train_dataloader) + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip.wrap_generate_sequences() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + self.max_steps_duration = 0 + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + for epoch in range(current_epoch, self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False) + metrics = {} + timing_raw = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + batch: DataProto = DataProto.from_single_dict(batch_dict) + batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature + + # add uid to batch + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + + gen_batch = self._get_gen_batch(batch) + + # pass global_steps to trace + gen_batch.meta_info["global_steps"] = self.global_steps + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + if self.config.fault_manager.enable: + FaultMgr.init_index_prompt_tokens(gen_batch_output=gen_batch_output) + + is_last_step = self.global_steps >= self.total_training_steps + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, color="red"): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output) + else: + if curr_step_profile: + self.async_rollout_manager.start_profile(global_step=self.global_steps) + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + self.checkpoint_manager.sleep_replicas() + if curr_step_profile: + self.async_rollout_manager.stop_profile() + + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + if self.reward_fn is None: + raise ValueError("A reward_fn is required for REMAX advantage estimation.") + + with marked_timer("gen_max", timing_raw, color="purple"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + if not self.async_rollout_mode: + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + else: + if curr_step_profile: + self.async_rollout_manager.start_profile() + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + self.checkpoint_manager.sleep_replicas() + if curr_step_profile: + self.async_rollout_manager.stop_profile() + batch = batch.union(gen_baseline_output) + # compute reward model score on batch + rm_scores = None + if self.use_rm and "rm_scores" not in batch.batch.keys(): + if not self.use_reward_loop: + rm_scores = self.rm_wg.compute_rm_score(batch) + else: + assert self.reward_loop_manager is not None, "RewardLoopManager is None" + rm_scores = self.reward_loop_manager.compute_rm_score(batch) + batch = batch.union(rm_scores) + + # Compute or extract reward for REMAX baseline + reward_baseline_tensor = self._compute_or_extract_reward( + batch, reward_fn=self.reward_fn, sum_reward=True + ) + + keys_to_pop = set(gen_baseline_output.batch.keys()) + if rm_scores is not None: + keys_to_pop.update(rm_scores.batch.keys()) + batch.pop(batch_keys=list(keys_to_pop)) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del rm_scores, gen_baseline_batch, gen_baseline_output + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + if "response_mask" not in batch.batch.keys(): + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + # get images_seqlens + images_seqlens_all = [] + for multi_modal_input in batch.non_tensor_batch["multi_modal_inputs"]: + if "image_grid_thw" not in multi_modal_input.keys(): + continue + images_seqlens_all.extend(multi_modal_input["images_seqlens"].tolist()) + batch.meta_info["images_seqlens"] = images_seqlens_all + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm and "rm_scores" not in batch.batch.keys(): + if not self.use_reward_loop: + reward_tensor = self.rm_wg.compute_rm_score(batch) + else: + assert self.reward_loop_manager is not None, "RewardLoopManager is None" + reward_tensor = self.reward_loop_manager.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + # Compute or extract reward for training + if self.config.reward_model.launch_reward_fn_async: + future_reward = compute_reward_async.remote( + data=batch, config=self.config, tokenizer=self.tokenizer + ) + else: + reward_tensor, reward_extra_infos_dict = self._compute_or_extract_reward( + batch, reward_fn=self.reward_fn, reward_for_val=False + ) + + # Operating Mode Selection: + # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ) + # - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ) + # Note: π_old computed once per data batch, serves as stable reference during mini-batch updates + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False) + if bypass_recomputing_logprobs: # Use `rollout_log_probs` + from verl.trainer.ppo.rollout_corr_helper import apply_bypass_mode + + apply_bypass_mode( + batch=batch, + rollout_corr_config=rollout_corr_config, + policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss, + ) + else: # Recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + actor_config = self.config.actor_rollout_ref.actor + entropy_agg = agg_loss( + loss_mat=entropys, + loss_mask=response_masks, + loss_agg_mode=actor_config.loss_agg_mode, + loss_scale_factor=actor_config.loss_scale_factor, + ) + old_log_prob_metrics = { + "actor/entropy": entropy_agg.detach().item(), + "perf/mfu/actor_infer": old_log_prob_mfu, + } + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + if "routed_experts" in batch.batch and "routed_experts" in old_log_prob.batch: + router_mode = getattr( + self.config.actor_rollout_ref.actor.router_replay, "mode", "disabled" + ) + if router_mode == "R2": + batch.batch.pop("routed_experts") + else: + old_log_prob.batch.pop("routed_experts") + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) + + assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}' + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): + ref_log_prob = self._compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values = self._compute_values(batch) + batch = batch.union(values) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + if self.config.reward_model.launch_reward_fn_async: + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # Compute rollout correction: IS weights, rejection sampling, and metrics + # Only runs in decoupled mode (computes once per batch using stable π_old) + # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout + if ( + rollout_corr_config is not None + and "rollout_log_probs" in batch.batch + and not bypass_recomputing_logprobs # Only in decoupled mode + ): + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + # Compute IS weights, apply rejection sampling, compute metrics + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) + + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self._update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + actor_output = self._update_actor(batch) + + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step + or self.global_steps % self.config.trainer.save_freq == 0 + or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + # update weights from trainer to rollout + with marked_timer("update_weights", timing_raw, color="red"): + self.checkpoint_manager.update_weights() + + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + # compute variance proxy metrics + gradient_norm = metrics.get("actor/grad_norm", None) + metrics.update(compute_variance_proxy_metrics(batch=batch, gradient_norm=gradient_norm)) + # Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation + + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if ( + hasattr(self.config.actor_rollout_ref.actor, "profiler") + and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory" + ): + self.actor_rollout_wg.dump_memory_snapshot( + tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}" + ) + + if is_last_step: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True) + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + # this is experimental and may be changed/removed in the future + # in favor of a general-purpose data buffer pool + if hasattr(self.train_dataset, "on_batch_end"): + # The dataset may be changed after each training batch + self.train_dataset.on_batch_end(batch=batch) diff --git a/ICL/DAPO/verl-recipe/fault_recover/run_qwen2_5_0.5b_megatron.sh b/ICL/DAPO/verl-recipe/fault_recover/run_qwen2_5_0.5b_megatron.sh new file mode 100644 index 0000000000000000000000000000000000000000..4ba99e4440f98c326145a09ba69546a200878457 --- /dev/null +++ b/ICL/DAPO/verl-recipe/fault_recover/run_qwen2_5_0.5b_megatron.sh @@ -0,0 +1,83 @@ +ulimit -n 65535 +set -x + +# avoid delayed console log, may impact performance +export PYTHONUNBUFFERED=1 + +export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping +export RAY_DEDUP_LOGS=0 +export HYDRA_FULL_ERROR=1 +export RAY_DEBUG=1 +export VLLM_ASCEND_ENABLE_NZ=0 +# export VLLM_USE_V1=1 + +project_name='GRPO' +exp_name='Qwen2.5-0.5B-Instruct-megatron-vllm-fault-recover' + +MODEL_PATH="${HOME}/model/Qwen2.5-0.5B-Instruct" +CKPTS_DIR="${HOME}/model/Qwen2.5-0.5B-Instruct-save" +TRAIN_FILE="${HOME}/data/gsm8k/train.parquet" +TEST_FILE="${HOME}/data/gsm8k/test.parquet" + +offload=False +train_tp=2 +train_pp=1 +gen_tp=2 + +python3 -m recipe.fault_recover.main_ppo --config-path=config \ + --config-name='fault_recover_ppo_megatron_trainer.yaml'\ + algorithm.adv_estimator=grpo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.train_batch_size=8 \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=${MODEL_PATH} \ + actor_rollout_ref.actor.checkpoint.async_save=False \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=8 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.actor.megatron.use_mbridge=False \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.rollout.agent.default_agent_loop=fault_recover_single_turn_agent \ + actor_rollout_ref.ref.use_torch_compile=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + ++actor_rollout_ref.ref.megatron.override_transformer_config.use_flash_attn=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.use_fused_ring_attention_update=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.use_distributed_optimizer=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console"]' \ + trainer.project_name=${project_name} \ + trainer.experiment_name=${exp_name} \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.default_local_dir=${CKPTS_DIR} \ + trainer.resume_mode=auto \ + trainer.val_before_train=False \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + trainer.total_epochs=15 \ diff --git a/ICL/DAPO/verl-recipe/flash_rl_ascend/run.sh b/ICL/DAPO/verl-recipe/flash_rl_ascend/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..b55872bc5c31d2e3965ed8997d4380b6a344a147 --- /dev/null +++ b/ICL/DAPO/verl-recipe/flash_rl_ascend/run.sh @@ -0,0 +1,82 @@ +pkill -9 python +ray stop --force +rm -rf /tmp/ray/* + +# HCCL 相关配置 +export HCCL_EXEC_TIMEOUT=7200 +export HCCL_EVENT_TIMEOUT=7200 +export HCCL_CONNECT_TIMEOUT=7200 +export ACL_DEVICE_SYNC_TIMEOUT=7200 +export HCCL_ASYNC_ERROR_HANDLING=0 +export P2P_HCCL_BUFFSIZE=30 +export HCCL_BUFFSIZE=300 +export HCCL_HOST_SOCKET_PORT_RANGE=60000-60050 +export HCCL_NPU_SOCKET_PORT_RANGE=61000-61050 + +export RAY_DEDUP_LOGS=1 +export HYDRA_FULL_ERROR=1 + +export TASK_QUEUE_ENABLE=2 # 下发优化,图模式设置为1,非图模式设置为2 +export PYTORCH_NPU_ALLOC_CONF="max_split_size_mb:2048" +export PYTHONUNBUFFERED=1 + +# 修改为当前需要跑的用例路径 +DEFAULT_SH="examples/flash_rl/test_qwen3-30b_int8_npu.sh" +echo "Use $DEFAULT_SH" + +ulimit -n 32768 +mkdir -p logs + +NNODES=1 +NPUS_PER_NODE=16 +# 修改为对应主节点IP +MASTER_ADDR="MASTER_ADDR" +# 修改为当前节点的通信网卡 +SOCKET_IFNAME="SOCKET_IFNAME" +export HCCL_SOCKET_IFNAME=$SOCKET_IFNAME +export TP_SOCKET_IFNAME=$SOCKET_IFNAME +export GLOO_SOCKET_IFNAME=$SOCKET_IFNAME +export GLOO_SOCKET_TIMEOUT=7200 + +# 获取当前节点IP +CURRENT_IP=$(ifconfig $SOCKET_IFNAME | grep -Eo 'inet (addr:)?([0-9]{1,3}\.){3}[0-9]{1,3}' | awk '{print $NF}') +if [ "$MASTER_ADDR" = "$CURRENT_IP" ]; then + # 主节点启动 + ray start --head --port 6766 --dashboard-host=$MASTER_ADDR --node-ip-address=$CURRENT_IP --dashboard-port=8260 --resources='{"NPU": '$NPUS_PER_NODE'}' + + while true; do + ray_status_output=$(ray status) + npu_count=$(echo "$ray_status_output" | grep -oP '(?<=/)\d+\.\d+(?=\s*NPU)' | head -n 1) + npu_count_int=$(echo "$npu_count" | awk '{print int($1)}') + device_count=$((npu_count_int / $NPUS_PER_NODE)) + + # 判断 device_count 是否与 NNODES 相等 + if [ "$device_count" -eq "$NNODES" ]; then + echo "Ray cluster is ready with $device_count devices (from $npu_count NPU resources), starting Python script." + ray status + bash $DEFAULT_SH 2>&1 | tee logs/qwen3-30b_int8_npu.log + break + else + echo "Waiting for Ray to allocate $NNODES devices. Current device count: $device_count" + sleep 5 + fi + done +else + # 子节点尝试往主节点注册ray直到成功 + while true; do + # 尝试连接 Ray 集群 + ray start --address="$MASTER_ADDR:6766" --resources='{"NPU": '$NPUS_PER_NODE'}' --node-ip-address=$CURRENT_IP + + # 检查连接是否成功 + ray status + if [ $? -eq 0 ]; then + echo "Successfully connected to the Ray cluster!" + break + else + echo "Failed to connect to the Ray cluster. Retrying in 5 seconds..." + sleep 5 + fi + done +fi + +sleep 600 \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/flash_rl_ascend/test_qwen3-30b_int8_npu.sh b/ICL/DAPO/verl-recipe/flash_rl_ascend/test_qwen3-30b_int8_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..d07c734fe413873513d5fa719942f4aac3ec8f8b --- /dev/null +++ b/ICL/DAPO/verl-recipe/flash_rl_ascend/test_qwen3-30b_int8_npu.sh @@ -0,0 +1,195 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +RAY_DATA_PATH=$(dirname "$(dirname "$(dirname "$(realpath "$0")")")") +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_PATH}/models/Qwen3-30B-A3B"} +QUANT_PATH=${QUANT_PATH:-"${RAY_DATA_PATH}/models/Qwen3-30B-A3B-w8a8"} +PROFILE_PATH=${PROFILE_PATH:-"${RAY_DATA_PATH}/profile.30b.pt"} +CONFIG_PATH=${CONFIG_PATH:-"${RAY_DATA_PATH}/.flashrl_config.30b.yaml"} + +if ! command -v flashrl &> /dev/null +then + pip install flash-llm-rl # need to be installed in all nodes in multi-node training +fi + +# manually add 'import flash_rl' in 'verl/verl/__init__.py' +if ! grep -q "import flash_rl" "${RAY_DATA_PATH}/verl/__init__.py"; then + echo "Adding 'import flash_rl' to verl/verl/__init__.py" + sed -i '1i import flash_rl' "${RAY_DATA_PATH}/verl/__init__.py" +fi + +if [ ! -f "${CONFIG_PATH}" ]; then + echo "Profile file not found at ${PROFILE_PATH}. Running profiling and setup..." + flashrl profile -m ${MODEL_PATH} -q ${QUANT_PATH} -o ${PROFILE_PATH} --fn int8 + flashrl setup -m ${QUANT_PATH} -p ${PROFILE_PATH} --fn int8 -o ${CONFIG_PATH} +else + echo "Profile file found at ${PROFILE_PATH}. Skipping profiling and setup." +fi +# (Optional) conduct rollout generation in 16bits and 8bits in a hybrid manner across DP workers +# flashrl setup -a --fn bf16 -o ${CONFIG_PATH} + +flashrl cleanup + +export VERL_LOGGING_LEVEL=DEBUG +export VLLM_LOGGING_LEVEL=DEBUG +export VLLM_CONFIGURE_LOGGING=1 +export FLASHRL_LOGGING_LEVEL=DEBUG +export FLASHRL_CONFIG=${CONFIG_PATH} +export FLASHRL_LMHEAD_FP32=1 + +project_name='GRPO' +exp_name='Qwen3-30B-INT8-ROLLOUT' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=True +kl_loss_coef=0.001 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 30)) +train_prompt_bsz=32 +gen_prompt_bsz=$((train_prompt_bsz * 3)) +n_resp_per_prompt=8 +max_num_seqs=1024 +train_prompt_mini_bsz=32 +loss_agg_mode="token-mean" + +# Ray +NNODES=1 +# Paths +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_PATH}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_PATH}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_PATH}/data/gsm8k/test.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=16 +use_dynamic_bsz=True +log_prob_micro_batch_size_per_gpu=8 +ppo_micro_batch_size_per_gpu=8 +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) / sp_size)) +offload=True +gen_tp=4 +enable_chunked_prefill=True + +# Importance Sampling (IS) weights configuration +rollout_is="sequence" # Self-normalized sequence-level IS +rollout_is_threshold=2.0 # Upper threshold for IS weights +rollout_is_batch_normalize="true" # Self-normalization (mean=1.0) + +# Rejection Sampling (RS) configuration +rollout_rs="null" # No rejection sampling for basic RLOO +rollout_rs_threshold="null" # RS upper threshold +rollout_rs_threshold_lower="null" # RS lower threshold + +# Veto mechanism (optional, independent of IS/RS) +rollout_token_veto_threshold="null" # Per-token veto threshold (null to disable) + +# Policy Gradient loss mode (bypass mode with policy gradient loss, no PPO clipping) +bypass_mode="true" # Required for policy gradient mode +use_policy_gradient="true" # Use policy gradient loss (works with IS/RS/both) + +python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + +data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.max_num_seqs=${max_num_seqs} \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.60 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=${enable_chunked_prefill} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + algorithm.rollout_correction.rollout_is=${rollout_is} \ + algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \ + algorithm.rollout_correction.rollout_is_batch_normalize=${rollout_is_batch_normalize} \ + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ + +algorithm.rollout_correction.rollout_rs_threshold_lower=${rollout_rs_threshold_lower} \ + +algorithm.rollout_correction.rollout_token_veto_threshold=${rollout_token_veto_threshold} \ + algorithm.rollout_correction.bypass_mode=${bypass_mode} \ + +algorithm.rollout_correction.use_policy_gradient=${use_policy_gradient} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.strategy=fsdp \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${log_prob_micro_batch_size_per_gpu} \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${log_prob_micro_batch_size_per_gpu} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${ppo_micro_batch_size_per_gpu} \ + ++actor_rollout_ref.nccl_timeout=7200 \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.ref.use_torch_compile=False \ + reward_model.reward_manager=naive \ + trainer.logger='["console"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=16 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=False \ + trainer.test_freq=-1 \ + trainer.save_freq=-1 \ + trainer.total_epochs=1 \ + trainer.total_training_steps=100 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.device='npu' $@ + diff --git a/ICL/DAPO/verl-recipe/flowrl/FLOWRL_SIMPLE_GUIDE.md b/ICL/DAPO/verl-recipe/flowrl/FLOWRL_SIMPLE_GUIDE.md new file mode 100644 index 0000000000000000000000000000000000000000..0dbffb319b8b838501225338319663396cbe3444 --- /dev/null +++ b/ICL/DAPO/verl-recipe/flowrl/FLOWRL_SIMPLE_GUIDE.md @@ -0,0 +1,156 @@ +# FlowRL Implementation + +## 4 Simple Steps to Add FlowRL + +### Step 1: Add Partition Function Z + +**File**: `verl/workers/fsdp_workers.py` + +[Add this class at line 100](https://github.com/Xuekai-Zhu/FlowRL/blob/4b0b3bee0e85258b7be46481f9a46ffe9e6b5508/verl_FlowRL/verl/workers/fsdp_workers.py#L100): + +```python +class ProjZModule(torch.nn.Module): + def __init__(self, hidden_size: int, num_layers: int = 3, dropout: float = 0.1): + super().__init__() + layers = [] + + for i in range(num_layers - 1): + layers.extend([ + torch.nn.Linear(hidden_size, hidden_size), + torch.nn.GELU(), + torch.nn.LayerNorm(hidden_size), + torch.nn.Dropout(dropout) + ]) + + layers.append(torch.nn.Linear(hidden_size, 1)) + self.net = torch.nn.Sequential(*layers) + + def forward(self, x): + return self.net(x) +``` + +[Add this to model building at line 267](https://github.com/Xuekai-Zhu/FlowRL/blob/4b0b3bee0e85258b7be46481f9a46ffe9e6b5508/verl_FlowRL/verl/workers/fsdp_workers.py#L265): + +```python +n_dim = actor_module.config.hidden_size +actor_module.proj_z = ProjZModule(n_dim, num_layers=self.config.actor.porj_layer) +``` + +### Step 2: Modify Forward Pass + +**File**: `verl/workers/actor/dp_actor.py` + +[Change method signature at line 75](https://github.com/Xuekai-Zhu/FlowRL/blob/4b0b3bee0e85258b7be46481f9a46ffe9e6b5508/verl_FlowRL/verl/workers/actor/dp_actor.py#L75): + +```python +def _forward_micro_batch(self, micro_batch, temperature, calculate_entropy=False,return_log_z=False) -> Tuple[torch.Tensor, torch.Tensor]: +``` + +[Add before return at line 232](https://github.com/Xuekai-Zhu/FlowRL/blob/4b0b3bee0e85258b7be46481f9a46ffe9e6b5508/verl_FlowRL/verl/workers/actor/dp_actor.py#L232): + +```python +if return_log_z: + last_hidden = output.hidden_states[-1].squeeze(0) # (total_nnz, hidden size) + if self.use_ulysses_sp: + last_hidden = gather_outputs_and_unpad( + last_hidden, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + full_last_hidden = pad_input(hidden_states=last_hidden, + indices=indices, + batch=batch_size, + seqlen=seqlen) + # extract pormpt hiddenstate for log z + prompts_last_hidden = full_last_hidden[:, : -response_length - 1] + prompt_attention_mask = attention_mask[:, : -response_length - 1] + avg_hidden = verl_F.masked_mean(prompts_last_hidden, prompt_attention_mask.unsqueeze(-1), axis=1) + + # avg_hidden = avg_hidden.detach() # use detach() to stop gradient of proj_z to policy + log_z = self.actor_module.proj_z(avg_hidden) + + return entropy, log_probs, log_z + +else: + return entropy, log_probs +``` + +### Step 3: Replace PPO Loss with FlowRL Loss + +**File**: `verl/workers/actor/dp_actor.py` + +[Replace PPO loss computation around line 412](https://github.com/Xuekai-Zhu/FlowRL/blob/4b0b3bee0e85258b7be46481f9a46ffe9e6b5508/verl_FlowRL/verl/workers/actor/dp_actor.py#L412): + +```python +# OLD PPO CODE - REMOVE: +# entropy, log_prob = self._forward_micro_batch(...) +# pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(...) + +# NEW FLOWRL CODE: +entropy, log_prob, log_z = self._forward_micro_batch(micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy, return_log_z=True) + +policy_loss, data = self.compute_flowrl_objective(logpf=log_prob, + logf_ref=data['ref_log_prob'], + logpf_old=old_log_prob, + log_z=log_z, + reward=advantages, + response_mask=response_mask, + clip_ratio=self.config.clip_ratio) +``` + +[Add FlowRL objective function at line 555](https://github.com/Xuekai-Zhu/FlowRL/blob/4b0b3bee0e85258b7be46481f9a46ffe9e6b5508/verl_FlowRL/verl/workers/actor/dp_actor.py#L555): + +```python +def compute_flowrl_objective(self, logpf=None, logf_ref=None, logpf_old=None, log_z=None, reward=None, response_mask=None, clip_ratio=None): + # we set 𝛽 and 𝛾 to 0.1 and 1.0, + # squeeze log_z to (B,) + log_z = log_z.squeeze(-1) + B = log_z.shape[0] + + # mean of log p_f / log p_ref over valid tokens + avg_logpf = verl_F.masked_mean(logpf, response_mask, axis=1) + avg_logp_ref = verl_F.masked_mean(logf_ref, response_mask, axis=1) + + # mean of token-level reward → log + # we set R = exp(advantage); then log_reward = advantage + seq_log_reward = verl_F.masked_mean(reward, response_mask, axis=1) + + # TB loss residual + delta = log_z + avg_logpf - 15 * seq_log_reward - avg_logp_ref + + # important sampling + log_w = verl_F.masked_sum(logpf - logpf_old, response_mask, axis=1) # sum over valid tokens per trajectory + importance_weight = torch.exp(log_w).detach() + clip_importance_weight = torch.clamp(importance_weight, 1 - clip_ratio, 1 + clip_ratio) + + weighted_losses = importance_weight * (delta ** 2) + avg_loss = torch.mean(weighted_losses) + + # Loss statistics + negative_approx_kl = logpf - logf_ref + ratio = torch.exp(negative_approx_kl) + loss_term_dict = { + "actor/logpf": verl_F.masked_mean(logpf, response_mask).detach().item(), + "actor/logp_ref": verl_F.masked_mean(logf_ref, response_mask).detach().item(), + "actor/log_z": log_z.mean().detach().item(), + "actor/log_reward": verl_F.masked_mean(reward, response_mask).detach().item(), + "actor/tb_loss": avg_loss.detach().item(), + } + + return avg_loss, loss_term_dict +``` + +### Step 4: Fix Model Loading + +**File**: `verl/workers/sharding_manager/fsdp_vllm.py` + +[Change line 290-293](https://github.com/Xuekai-Zhu/FlowRL/blob/4b0b3bee0e85258b7be46481f9a46ffe9e6b5508/verl_FlowRL/verl/workers/sharding_manager/fsdp_vllm.py#L290): + +```python +# Skip proj_z parameters when loading to vLLM +loaded_params = model.load_weights(((name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) + for name, param in updated_params.items() + if not name.startswith("proj_z")) + ) +``` diff --git a/ICL/DAPO/verl-recipe/flowrl/flowrl_actor.py b/ICL/DAPO/verl-recipe/flowrl/flowrl_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..62ad4ee626ef47e44b641f01290493253112a452 --- /dev/null +++ b/ICL/DAPO/verl-recipe/flowrl/flowrl_actor.py @@ -0,0 +1,486 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import torch + +import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input +from verl.utils.device import get_device_id +from verl.utils.profiler import GPUMemoryLogger +from verl.utils.py_functional import append_to_dict +from verl.utils.seqlen_balancing import prepare_dynamic_batch +from verl.utils.torch_functional import logprobs_from_logits +from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs +from verl.workers.actor.dp_actor import DataParallelPPOActor + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class ProjZModule(torch.nn.Module): + """Projection network for estimating log partition function Z in FlowRL.""" + + def __init__(self, hidden_size: int, num_layers: int = 3, dropout: float = 0.1): + super().__init__() + layers = [] + + for i in range(num_layers - 1): + layers.extend( + [ + torch.nn.Linear(hidden_size, hidden_size), + torch.nn.GELU(), + torch.nn.LayerNorm(hidden_size), + torch.nn.Dropout(dropout), + ] + ) + + layers.append(torch.nn.Linear(hidden_size, 1)) + self.net = torch.nn.Sequential(*layers) + + def forward(self, x): + return self.net(x) + + +class FlowRLActor(DataParallelPPOActor): + """FlowRL Actor that extends DataParallelPPOActor with partition function estimation.""" + + def __init__(self, config, *args, **kwargs): + super().__init__(config, *args, **kwargs) + # FlowRL hyperparameters (hardcoded as per paper) + self.flowrl_beta_coef = 15.0 # β coefficient for reward scaling in flowrl loss + + def _forward_micro_batch( + self, micro_batch, temperature, calculate_entropy=False, return_log_z=False + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Returns: + entropy: # (bs, response_len) + log_probs: # (bs, response_len) + """ + response_length = micro_batch["responses"].size(-1) + multi_modal_inputs = {} + if "multi_modal_inputs" in micro_batch.keys(): + from verl.utils.model import extract_multi_modal_inputs + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) + + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + entropy = None + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen) + + if self.use_remove_padding: + input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + if position_ids.dim() == 3: + position_ids_rmpad = ( + index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + .transpose(0, 1) + .unsqueeze(1) + ) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + if "image_bound" in multi_modal_inputs: + from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo + + multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( + input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs + ) + + # for compute the log_prob + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + + # pad and slice the inputs if sp > 1 + if self.use_ulysses_sp: + is_vlm_model = hasattr( + getattr(self.actor_module, "module", self.actor_module).config, "vision_config" + ) + if is_vlm_model: + # vlm model's inputs will be sliced after embedding + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + else: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, + position_ids_rmpad=None, + sp_size=self.ulysses_sequence_parallel_size, + ) + + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + + # only pass input_ids and position_ids to enable flash_attn_varlen + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + output_hidden_states=True if return_log_z else False, # FlowRL: for log_z estimation + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs.squeeze(0) # (total_nnz,) + entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) + + else: + logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) + logits_rmpad.div_(temperature) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + inplace_backward = True + if calculate_entropy: + inplace_backward = False + log_probs = logprobs_from_logits( + logits=logits_rmpad, + labels=input_ids_rmpad_rolled, + inplace_backward=inplace_backward, + ) + + # compute entropy + if calculate_entropy: + if not self.config.entropy_checkpointing: + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + else: + entropy_rmpad = torch.utils.checkpoint.checkpoint( + self.compute_entropy_from_logits, logits_rmpad + ) + + # gather log_prob if sp > 1 + if self.use_ulysses_sp: + # gather and unpad for the ulysses sp + log_probs = gather_outputs_and_unpad( + log_probs, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + if calculate_entropy: + entropy_rmpad = gather_outputs_and_unpad( + entropy_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + # pad back to (bsz, seqlen) + if calculate_entropy: + full_entropy = pad_input( + hidden_states=entropy_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + full_log_probs = pad_input( + hidden_states=log_probs.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + + # only return response part: + if calculate_entropy: + entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + + else: # not using rmpad and no ulysses sp + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=True if return_log_z else False, # FlowRL: for log_z estimation + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs[:, -response_length - 1 : -1] + entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) + + else: + logits = output.logits + + logits.div_(temperature) + logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size) + log_probs = logprobs_from_logits(logits, micro_batch["responses"]) + if calculate_entropy: + if not self.config.entropy_checkpointing: + entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + else: + entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits) + + # ==== FlowRL: use proj_z to estimate log Z ==== + if return_log_z: + last_hidden = output.hidden_states[-1].squeeze(0) # (total_nnz, hidden size) + if self.use_ulysses_sp: + last_hidden = gather_outputs_and_unpad( + last_hidden, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + full_last_hidden = pad_input( + hidden_states=last_hidden, indices=indices, batch=batch_size, seqlen=seqlen + ) + # extract pormpt hiddenstate for log z + prompts_last_hidden = full_last_hidden[:, : -response_length - 1] + prompt_attention_mask = attention_mask[:, : -response_length - 1] + avg_hidden = verl_F.masked_mean(prompts_last_hidden, prompt_attention_mask.unsqueeze(-1), axis=1) + + log_z = self.actor_module.proj_z(avg_hidden) + + return entropy, log_probs, log_z + else: + return entropy, log_probs + + @GPUMemoryLogger(role="dp actor", logger=logger) + def update_policy(self, data: DataProto): + # make sure we are in training mode + self.actor_module.train() + + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + + select_keys = [ + "responses", + "response_mask", + "input_ids", + "attention_mask", + "position_ids", + "old_log_probs", + "advantages", + ] + if self.config.use_kl_loss: + select_keys.append("ref_log_prob") + if getattr(self.config, "tis_imp_ratio_cap", 0) > 0: + assert "rollout_log_probs" in data.batch.keys(), ( + "Truncated Importance Sampling (TIS) requires to configure " + "`actor_rollout_ref.rollout.calculate_log_probs=True` " + "and is not currently supported in Server mode (agent loop)." + ) + select_keys.append("rollout_log_probs") + + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + mini_batches = data.split(self.config.ppo_mini_batch_size) + + on_policy = len(mini_batches) == 1 and self.config.ppo_epochs == 1 + + metrics = {} + for _ in range(self.config.ppo_epochs): + for batch_idx, mini_batch in enumerate(mini_batches): + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len) + else: + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) + micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + + self.actor_optimizer.zero_grad() + + for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) + micro_batch_metrics = {} + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} + response_mask = model_inputs["response_mask"] + old_log_prob = model_inputs["old_log_probs"] + # Get rollout log probs if TIS is enabled + tis_enabled = getattr(self.config, "tis_imp_ratio_cap", 0) > 0 + rollout_log_probs = model_inputs["rollout_log_probs"] if tis_enabled else None + advantages = model_inputs["advantages"] + ref_log_prob = model_inputs["ref_log_prob"] + + if self.config.use_dynamic_bsz: + loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size + else: + loss_scale_factor = 1 / self.gradient_accumulation + + # FlowRL: compute log probs and log Z + entropy, log_prob, log_z = self._forward_micro_batch( + model_inputs, temperature=temperature, calculate_entropy=False, return_log_z=True + ) + + if on_policy: + old_log_prob = log_prob.detach() + else: + old_log_prob = model_inputs["old_log_probs"] + + # loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") + # vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla + # gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg + # clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov + # policy_loss_fn = get_policy_loss_fn(loss_mode) + # pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn( + # old_log_prob=old_log_prob, + # log_prob=log_prob, + # advantages=advantages, + # response_mask=response_mask, + # loss_agg_mode=loss_agg_mode, + # config=self.config, + # rollout_log_probs=rollout_log_probs, + # ) + # Compute FlowRL trajectory balance loss + policy_loss, flowrl_metrics = self.compute_flowrl( + log_prob=log_prob, + ref_log_prob=ref_log_prob, + old_log_prob=old_log_prob, + log_z=log_z, + reward=advantages, + response_mask=response_mask, + clip_ratio=self.config.clip_ratio, + rollout_log_probs=rollout_log_probs, + ) + + # if entropy_coeff != 0: + # entropy_loss = agg_loss( + # loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode + # ) + # # compute policy loss + # policy_loss = pg_loss - entropy_loss * entropy_coeff + # else: + # policy_loss = pg_loss + + # if self.config.use_kl_loss: + # ref_log_prob = model_inputs["ref_log_prob"] + # # compute kl loss + # kld = kl_penalty( + # logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type + # ) + # kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + # policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef + # micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() * loss_scale_factor + # micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef + + if self.config.use_dynamic_bsz: + # relative to the dynamic bsz + loss = policy_loss * loss_scale_factor + else: + loss = policy_loss * loss_scale_factor + + # Use gradient scaler for FP16 training + if self.scaler is not None: + self.scaler.scale(loss).backward() + else: + loss.backward() + + micro_batch_metrics.update(flowrl_metrics) + # micro_batch_metrics.update( + # { + # "actor/pg_loss": pg_loss.detach().item() * loss_scale_factor, + # "actor/pg_clipfrac": pg_clipfrac.detach().item(), + # "actor/ppo_kl": ppo_kl.detach().item(), + # "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + # } + # ) + append_to_dict(metrics, micro_batch_metrics) + + grad_norm = self._optimizer_step() + mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, mini_batch_metrics) + self.actor_optimizer.zero_grad() + return metrics + + def compute_flowrl( + self, + log_prob=None, + ref_log_prob=None, + old_log_prob=None, + log_z=None, + reward=None, + response_mask=None, + clip_ratio=None, + rollout_log_probs=None, + ): + # squeeze log_z to (B,) + log_z = log_z.squeeze(-1) + + # Average token log-probs & rewards over valid positions + avg_log_prob = verl_F.masked_mean(log_prob, response_mask, axis=1) + avg_ref_log_prob = verl_F.masked_mean(ref_log_prob, response_mask, axis=1) + seq_log_reward = verl_F.masked_mean(reward, response_mask, axis=1) + + # FlowRL residual: logZ + logpf - β*R - logpref + delta = log_z + avg_log_prob - self.flowrl_beta_coef * seq_log_reward - avg_ref_log_prob + + # Importance ratio from current vs old policy (product of token ratios) + log_w = verl_F.masked_sum(log_prob - old_log_prob, response_mask, axis=1) + imp_w_raw = torch.exp(log_w).detach() + imp_w = torch.clamp(imp_w_raw, max=10) + + # Loss: weighted squared residual with importance weights + weighted_losses = imp_w * (delta**2) + avg_loss = torch.mean(weighted_losses) + + # PPO KL: negative_approx_kl = log_prob - old_log_prob + negative_approx_kl = log_prob - old_log_prob + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + # Reference KL: approx_kl_ref = log_prob - ref_log_prob + approx_kl_ref = log_prob - ref_log_prob + ref_kl = verl_F.masked_mean(-approx_kl_ref, response_mask) + + # Metrics + loss_term_dict = { + "actor/log_prob": verl_F.masked_mean(log_prob, response_mask).detach().item(), + "actor/old_log_prob": verl_F.masked_mean(old_log_prob, response_mask).detach().item(), + "actor/ref_log_prob": verl_F.masked_mean(ref_log_prob, response_mask).detach().item(), + "actor/log_z": log_z.mean().detach().item(), + "actor/log_reward": verl_F.masked_mean(reward, response_mask).detach().item(), + "actor/final_loss": avg_loss.detach().item(), + "actor/importance_weight": imp_w.mean().detach().item(), + "actor/ppo_kl": ppo_kl.detach().item(), # PPO-style KL (current vs old policy) + "actor/ref_kl": ref_kl.detach().item(), # KL with reference policy + } + + return avg_loss, loss_term_dict diff --git a/ICL/DAPO/verl-recipe/flowrl/flowrl_ray_trainer.py b/ICL/DAPO/verl-recipe/flowrl/flowrl_ray_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e30d1e86f8859bbbc0112a07408ad74959d90cc7 --- /dev/null +++ b/ICL/DAPO/verl-recipe/flowrl/flowrl_ray_trainer.py @@ -0,0 +1,29 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FlowRL Ray Trainer that extends RayPPOTrainer with FlowRL-specific components. +""" + +from verl.trainer.ppo.ray_trainer import RayPPOTrainer + + +class RayFlowRLTrainer(RayPPOTrainer): + """ + FlowRL trainer that uses the FlowRL advantage estimator. + The main difference is in the advantage estimation which is registered + as 'flowrl' in flowrl_adv_estimator.py + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/ICL/DAPO/verl-recipe/genrm_remote/README.md b/ICL/DAPO/verl-recipe/genrm_remote/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1a800fd882c60d20d1211828362d9f2acccec579 --- /dev/null +++ b/ICL/DAPO/verl-recipe/genrm_remote/README.md @@ -0,0 +1,39 @@ +# Generative Reward Model + +## Scripts + +### Step 1: Launch a vLLM Server (Optional) + +Deploy the pretrained GenRM model using vLLM. Skip this step if you want to use an external api service. + +```bash +vllm serve verl-team/GenRM-CI-Test-1.5B --served-model-name genrm-demo +``` + +### Step 2: Perform RL using GenRM + +```bash +bash recipe/api-genrm/run_genrm_remote.sh +``` + +The implementation works by passing a customized reward function (see `reward_function.py`) + +For convenience, we run both the RL training and server on the same machine. To use an external server, configure the `BASE_URL` and `API_KEY` in `reward_function.py` first. + +## Advanced: Customizing Your GenRM + +You can use sglang server with data parallel for faster inference: + +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 python -m sglang_router.launch_server --model-path verl-team/GenRM-CI-Test-1.5B --dp-size 4 +``` + +Note that you should modify the `BASE_URL` in `reward_function.py` to match your SGLang Server address. + +You can also create your own customized GenRM by implementing a custom reward function. Here are some tips for customizing your own GenRM based on `reward_function.py`: + +- Design appropriate prompts for your GenRM +- Convert GenRM responses into RL rewards +- ... + +Since these aspects are highly flexible, we only provide a demo implementation. The actual design and implementation of GenRM is left to the user's discretion. diff --git a/ICL/DAPO/verl-recipe/genrm_remote/reward_function.py b/ICL/DAPO/verl-recipe/genrm_remote/reward_function.py new file mode 100644 index 0000000000000000000000000000000000000000..09fe0881781880677e03f225c47138c1169bd5aa --- /dev/null +++ b/ICL/DAPO/verl-recipe/genrm_remote/reward_function.py @@ -0,0 +1,109 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from time import sleep + +import aiohttp + +from verl.utils.reward_score.math_reward import last_boxed_only_string, remove_boxed + +BASE_URL = "http://localhost:30000" +API_KEY = "EMPTY" +MAX_RETRIES = 3 +BASE_DELAY = 2 +MAX_WORKERS = 32 +MODEL_NAME = "genrm-demo" +GENRM_PROMPT_TEMPLATE = """ +The following is a math problem and an AI solution: + +[Math Problem] + +{problem} + +[AI Solution] + +{solution} + +Your task is to review and critique the solution step by step, and output whether the AI solution is correct. + +Please put your final answer (i.e., 'True' or 'False') in \\boxed{{}}. +""".strip() + + +async def post_request(payload, endpoint): + url = f"{BASE_URL}/{endpoint}" + try: + timeout = aiohttp.ClientTimeout(total=None) + session = aiohttp.ClientSession(timeout=timeout) + async with session.post(url, json=payload) as resp: + output = await resp.text() + output = json.loads(output) + return output + except Exception as e: + raise e + finally: + await session.close() + + +async def get_response(problem, solution_str, ground_truth): + prompt = GENRM_PROMPT_TEMPLATE.format(problem=problem, solution=solution_str) + messages = [{"role": "user", "content": prompt}] + for attempt in range(MAX_RETRIES): + try: + data = {"model": MODEL_NAME, "messages": messages} + output = await post_request(data, "v1/chat/completions") + response = output["choices"][0]["message"]["content"] + return response + except Exception as e: + if attempt < MAX_RETRIES - 1: + print("Exception: ", repr(e)) + delay = BASE_DELAY * (2**attempt) + print(f"Retrying in {delay} seconds...") + sleep(delay) + else: + print(f"Failed after {MAX_RETRIES} attempts. Error: {e}") + + raise ConnectionRefusedError(f"Failed to run the model for {prompt}!") + + +def compute_reward(response): + reward_score = 0.0 + try: + boxed_result = last_boxed_only_string(response) + if boxed_result is not None: + result = remove_boxed(boxed_result) + reward_score = float(result == "True") + except Exception as e: + print(e) + return reward_score + + +async def compute_score(data_source, solution_str, ground_truth, extra_info): + split = extra_info["split"] + from verl.utils.reward_score import default_compute_score + + func_rm_score = default_compute_score(data_source, solution_str, ground_truth, extra_info) + + if split == "test": + return func_rm_score + else: + problem = extra_info["question"] + response = await get_response(problem, solution_str, ground_truth) + if response is not None: + reward_score = compute_reward(response) + else: + reward_score = 0.0 + + return reward_score diff --git a/ICL/DAPO/verl-recipe/genrm_remote/run_genrm_remote.sh b/ICL/DAPO/verl-recipe/genrm_remote/run_genrm_remote.sh new file mode 100644 index 0000000000000000000000000000000000000000..59a49a5990af63fc42a990033377506ec8d7ecb2 --- /dev/null +++ b/ICL/DAPO/verl-recipe/genrm_remote/run_genrm_remote.sh @@ -0,0 +1,45 @@ +# vllm server +# CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve verl-team/GenRM-CI-Test-1.5B --served_model_name genrm-demo + +# sglang server +# CUDA_VISIBLE_DEVICES=0,1,2,3 python -m sglang_router.launch_server --model-path verl-team/GenRM-CI-Test-1.5B --dp-size 4 + +set -x + +CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=${HOME}/data/gsm8k/train.parquet \ + data.val_files=${HOME}/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.n=8 \ + algorithm.use_kl_in_reward=False \ + reward_model.reward_manager=naive \ + custom_reward_function.path=recipe/genrm_remote/reward_function.py \ + custom_reward_function.name=compute_score \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_func_rm_example_gsm8k' \ + trainer.experiment_name='qwen2_5_3b_gen_rm' \ + trainer.n_gpus_per_node=4 \ + trainer.val_before_train=True \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=5 \ + trainer.total_epochs=10 \ + trainer.resume_mode='disable' diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/README.md b/ICL/DAPO/verl-recipe/gkd/megatron/README.md new file mode 100644 index 0000000000000000000000000000000000000000..55b8d392206c94968d6ade5a29ce82eb8d267c8f --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/README.md @@ -0,0 +1,242 @@ +# Recipe: Async On-Policy Knowledge Distillation Trainer + +**Authors:** Brilliant Hanabi, furunding + +**Last updated:** 2025-11-08 + +## 1. Background + +On-policy knowledge distillation (KD) trains a student policy to imitate a stronger teacher using samples drawn from the student's current policy. For each on-policy rollout the teacher returns soft, top-k token distributions and the student is optimized with a token-wise sparse KL objective that focuses learning on the teacher's high-probability modes. Because training examples come from the student's own state distribution, KD reduces distributional mismatch relative to off-policy distillation or supervised fine-tuning (SFT), improving stability and sample efficiency. Compared with reinforcement learning, KD avoids high-variance reward-based optimization and complex reward design by providing dense, informative per-token targets, which typically yields faster convergence and simpler scaling. Recent empirical and implementation-focused writeups (e.g., [ThinkingMachines' blog on on-policy distillation](https://thinkingmachines.ai/blog/on-policy-distillation/)) also demonstrate that on-policy distillation can deliver high-quality behavior with substantially lower compute and data requirements than many alternative approaches. + +Built on verl’s Ray-based single-controller components, we initially assembled a strictly on-policy KD pipeline where rollout generation, teacher knowledge acquisition, and policy optimization ran in lockstep. In practice, this synchronous design proved highly inefficient: the three stages had to wait for one another, creating pipeline bubbles and underutilized GPUs. To address this, we extend the asynchronous schedulers introduced by the One-Step-Off Policy pipeline to overlap these phases. This overlap preserves the same distillation objective while trading some strict on-policy guarantees for substantial gains in end-to-end throughput and hardware utilization. + +## 2. Distillation Overview and Objective + +This recipe centers on on-policy knowledge distillation: the student policy learns from a stronger teacher on samples generated by the current policy (on-policy). For each input prompt, the student (actor) generates responses; the teacher provides top-k token distributions, and the student is trained to match them token-wise. + +Core components: + +1. Teacher signal: top-k log-probabilities and token indices per valid token position. +2. Student objective: sparse, token-level KL divergence between student logits and teacher top-k distribution. + +Objective: encourage student probabilities $Q$ to cover teacher modes $P$ using token-wise $\mathrm{KL}(P\,\|\,Q)$ computed on the teacher's top-k support. + +## 3. Efficient System Design + +### 3.1 Schedulers (One-Step / Two-Step Off-Policy) + +The native (serial) on-policy distillation process is shown in the figure below. + +![Zero-Step-Off Scheduler](https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/zero-step-off-distill.png) + +This recipe supports optional schedulers that overlap generation, teacher querying, and updates to improve throughput without changing the distillation objective. + +#### 3.1.1 One-Step-Off-Policy + +![One-Step-Off Scheduler](https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/one-step-off-distill.png) + +- Warm-up: 2 steps. +- Overlap pattern: rollout while actor update; weight sync while teacher retrieving. +- Timing keys: `sync_rollout_weights`, `wait_prev_gen`, `wait_prev_teacher`. + +#### 3.1.2 Two-Step-Off-Policy + +![Two-Step-Off Scheduler](https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/docs/two-step-off-distill.png) + +- Warm-up: 3 steps. +- Overlap pattern: rollout, actor update while teacher retrieving; interleave weight sync. +- Timing keys: `sync_rollout_weights`, `max(wait_prev_gen, wait_prev_prev_teacher)`. + +Tip: Use `two_step_off` when teacher takes much more time than sync; `one_step_off` for simpler overlapping. + +Practical details: + +- Inputs per batch: `teacher_topk_logps`, `teacher_topk_indices`, `attention_mask` (to select valid token positions). +- Loss injection: last pipeline stage computes KL via a logits processor; earlier stages remain unchanged. +- Optional dynamic micro-batching groups sequences by density to reduce padding overhead. + +The pipeline: + +1. Actor parameters are synchronized to a rollout worker group (nccl broadcast) with a little bit latency. +2. Rollout workers (vLLM-backed) generate sequences asynchronously (`async_generate_sequences`). +3. Teacher client service (ZeroMQ based) returns top-k log-probabilities + token indices for each sequence (batched micro-requests), enabling KL-based guidance. +4. Megatron actor performs a KL divergence computation between student logits and teacher top-k distributions (custom TP-aware kernel in `megatron_kl_loss.py`). +5. Scheduling strategies (`one_step_off_scheduler`, `two_step_off_scheduler`) can overlap phases (optional for throughput): + +### 3.2 Weights sync between actor and rollout + +We initially followed the weight synchronization path from the One-Step-Off-Policy recipe (Ray collective broadcast across all actor and rollout ranks, plus Megatron-side allgather of parameter shards). In practice this became the dominant bottleneck, so we made three changes: + +1. Batch-and-bulk load on the rollout side: instead of streaming tensors one-by-one (in one-step-off-policy recipe), we stage a bundle of parameter tensors and issue a single batched load into the rollout engine. In our setup this reduced the weight-loading time by roughly 3×. +2. Batch-and-bulk broadcast between the actor and rollout: instead of streaming tensors one-by-one (in one-step-off-policy recipe), we stage a bundle of parameter tensors and issue a single batched broadcast between the actor and rollout workers. +3. Replace allgather with gather-to-root in Megatron: parameter shards are gathered to actor rank 0 (rather than allgathered to everyone), and that root then serves as the single source for broadcasting to rollout ranks. On top of the previous change, 2 and 3 changes delivered an additional ~4× speedup in the synchronization phase. + +## 4. High-Level Data & Control Flow + +``` +Driver (TaskRunner) + ├─ Initialize Ray, tokenizer, datasets, worker groups + ├─ Build ResourcePoolManager (actor vs rollout GPU layouts) + ├─ Trainer.fit() + ├─ init_workers(): build actor + rollout groups, broadcast weight metadata, create nccl collective group + ├─ continuous_iterator(): epochs → batches + ├─ scheduler (see Section 6) + • _async_gen_next_batch(): optional weight sync + non-blocking rollout + • _async_get_teacher_knowledge(): submit teacher requests, store future + ├─ For each step: + • Sync rollout weights + • Retrieve (batch, gen_output, teacher_output) from futures + • Merge gen + teacher outputs → DataProto + • Compute metrics (response length stats, timing, throughput) + • Update actor (forward_backward_batch + KL loss + optimizer step) + • (Optional) save checkpoint +``` + +> Note: Schedulers are optional and explained later; the distillation objective is independent of how phases are overlapped. + +## 5. Key Components + +### 5.1 `OnPolicyDistillTrainer` (`ray_trainer.py`) +- Creates `GenerationBatchFuture` objects holding rollout and (later) teacher futures. +- Adds scheduling + teacher integration + modified metric emission (KL, timing, MFU). + +### 5.2 Actor Worker (Megatron) +- `OnPolicyDistillActor.update_policy()` orchestrates micro-batch forward/backward. +- KL Loss injection via `logits_processor` during forward on pipeline last stage. + +### 5.3 Rollout Worker (vLLM / SGLang) +- Pure inference mode (`init_model` builds model; no optimizer). +- `async_generate_sequences` returns a Ray future for overlapping. + +### 5.4 Teacher Service (`teacher/`) +- Proxy + worker architecture (ZMQ REQ/REP) for batched top-k retrieval. +- `TeacherClient.submit()` returns a `Future`; aggregator composes micro-batches. +- Configurable temperature, max tokens, only-response mode. + +### 5.5 KL Loss (`megatron_kl_loss.py`) +- Performs normalization & stable per-token probability construction across TP shards. +- Gradient is (student_probs - teacher_sparse_probs) scaled by upstream grad. + +## 6. Configuration Highlights (`on_policy_distill_trainer.yaml`) + +| Section | Purpose | Notable Keys | +|---------|---------|-------------| +| actor_rollout_ref.teacher | Teacher server | server_ip, server_port, n_server_workers | +| trainer | Global training control | total_epochs, save_freq, scheduler (one_step_off | two_step_off), n_gpus_per_node, nnodes | +| rollout | Resource split for rollout | n_gpus_per_node, nnodes | + +**Remember to set `trainer.n_gpus_per_node`, `trainer.nnodes`, `rollout.n_gpus_per_node` and `rollout.nnodes` to allocate GPU resources.** + +### Dynamic Batch Size + +Enable by: + +``` +actor_rollout_ref.actor.use_dynamic_bsz=True +actor_rollout_ref.actor.max_token_len=6000 # cap post-group token length +``` + +Improves utilization under variable sequence lengths. + +### Resource Guidelines + +- Actor pool: `trainer.nnodes * trainer.n_gpus_per_node` GPUs. +- Rollout pool: `rollout.nnodes * rollout.n_gpus_per_node` GPUs. +- Ensure teacher server capacity ≈ `n_server_workers` to avoid stalls (monitor `wait_prev_teacher`). + +## 7. Usage Examples + +### 7.1 Launch Teacher Server + +Before training process, you should have a teacher server to provide logp information. + +We provide a toy teacher server example with vLLM. It needs `telnet` to check proxy status, and `python` command to run. So if you have not installed `telnet`, you can just delete these code in `start_server.sh`. And some OS use `python3` rather than `python`, so you also need to modify it. Also you can change the port of teacher if you meet port conflict. + +There are 3 arguments can be set for vllm backend `--tp-size`, `--n-logprobs` and `--ckpt-path` in `start_server.sh` / `worker.py`. You should set before you start server. + +We also provide a toy multi-node teacher server. You can start the main node using `start_server.sh` and start the slave nodes using `join_server.sh`. Still remember to set args in `join_server.sh`, especially the `$PROXY_IP` and `$PROXY_BACKEND_PORT` of main node. + +When training, student will automatically use the teacher's topk (n-logprobs) to set its own topk argument at line 83 of `recipe/gkd/megatron_kl_loss.py`, so you don't need to set student's topk argument. + +```bash +cd recipe/gkd/teacher +bash start_server.sh +# Exports ports and launches proxy + worker (default vLLM backend) +``` + +Verify with: + +```bash +telnet localhost 15555 +``` + +### 7.2 Minimal Local (Megatron + vLLM) Run + +```bash +python3 -m recipe.gkd.main_gkd \ + --config-path=recipe/gkd/config \ + --config-name=on_policy_distill_trainer \ + actor_rollout_ref.model.path=/path/to/MODEL \ + data.train_files=/path/to/train.parquet \ + trainer.total_epochs=2 \ + trainer.n_gpus_per_node=4 rollout.n_gpus_per_node=2 \ + actor_rollout_ref.teacher.server_ip=127.0.0.1 \ + actor_rollout_ref.teacher.server_port=15555 \ + trainer.scheduler=one_step_off +``` + +(Requires a running teacher server). + +### 7.3 Ray Job Submission (Distilled 16B Example) + +See `run_moonlight_dsv3_training.sh` for a full script including: + +- Dist ckpt path setup (`dist_checkpointing_path`) +- Expert parallel sizing (EP / ETP) +- Dynamic batch sizing +- Two-step-off scheduling for deeper overlap. + +Submit (after adjusting paths): + +```bash +bash recipe/gkd/run_moonlight_dsv3_training.sh +``` + +## 8. Metrics & Monitoring + +Emitted metrics include (prefixes may vary): + +- Timing: `timing/wait_prev_gen`, `timing/sync_rollout_weights`, `timing/get_teacher_knowledge`, `timing/update_actor`. +- Sequence stats: `response_seq_len/*` (avg, max, min, counts). +- Performance: `perf/mfu/actor`, `perf/max_memory_allocated_gb`, `perf/cpu_memory_used_gb`. +- Distillation: `actor/kl_loss`, `actor/grad_norm`, `actor/lr`. + +Interpretation Tips: + +- High `wait_prev_teacher` → scale `n_server_workers` and allocate more teacher GPUs or reduce per-request batch size, or just use `two_step_off`. +- High `wait_prev_gen` with uniform lengths → allocate more rollout GPUs. +- High `sync_rollout_weights` → check NCCL env / network congestion and try to modify `actor_rollout_ref.rollout.update_weights_bucket_megabytes`. + +## 9. Extensibility Notes + +- Add new schedulers by following interface returning `(epoch, batch, gen_output, teacher_output, timing_dict)`. +- Integrate different distillation signals (e.g., hidden states, intermediate reasoning tokens) by extending `teacher_utils.get_teacher_knowledge` and modifying `logits_processor`. + +## 10. Functional Support Summary + +| Category | Supported | +|----------|-----------| +| Train engine | Megatron | +| Rollout engine | vLLM | +| Distillation signal | Teacher top-k logprobs & indices | +| Scheduling | one_step_off, two_step_off | + +## 11. Quick Checklist Before Running + +- Teacher server reachable (`telnet `). +- `actor_rollout_ref.model.path` contains the correct Megatron/HF config artifacts. +- `train_files` points to a parquet dataset compatible with this recipe's dataset loader. +- NCCL environment vars set (see `config/runtime_env.yaml`). + +--- +Feel free to open issues or PRs to extend scheduler variants, add new distillation objectives, or broaden engine support, and more improvement. diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/config/on_policy_distill_trainer.yaml b/ICL/DAPO/verl-recipe/gkd/megatron/config/on_policy_distill_trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5227785b1e47622dfb66c3a7aa5303604c98cbb7 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/config/on_policy_distill_trainer.yaml @@ -0,0 +1,397 @@ +# specify the default per-component configs +# defaults: + +# # @.: +# # actor_rollout_ref.actor: trainer/config/actor/megatron_actor.yaml +# - actor@actor_rollout_ref.actor: megatron_actor +# # load the reference default config, then apply the fields in the current yaml +# - _self_ + +data: + tokenizer: null + train_files: /path/to/train.parquet + val_files: null + prompt_key: question + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + return_full_prompt: False + shuffle: True + filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up. + filter_overlong_prompts_workers: 1 + truncation: error + trust_remote_code: False # main_ppo will check this config to determine whether to use remote code for tokenizer + custom_cls: + path: null + name: null + sampler: + class_path: null + class_name: null + dataloader_num_workers: 8 + return_multi_modal_inputs: True + +actor_rollout_ref: + hybrid_engine: False + nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron + model: + path: /path/to/MODEL + custom_chat_template: null + external_lib: null + override_config: + model_config: {"num_nextn_predict_layers": 0} + moe_config: + freeze_moe_router: False + enable_gradient_checkpointing: False + use_remove_padding: False + # gradient_checkpointing_kwargs: + # ## Activation Checkpointing + # activations_checkpoint_method: null # 'uniform', 'block'; not used with 'selective' + # # 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk + # # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + # activations_checkpoint_granularity: null # 'selective' or 'full' + # # 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention + # activations_checkpoint_num_layers: null # not used with 'selective' + trust_remote_code: False + actor: + # Whether to automatically adjust batch size at runtime + strategy: megatron + micro_batch_size: 2 + megatron: + param_offload: False + grad_offload: False + optimizer_offload: False + tensor_model_parallel_size: 1 + expert_model_parallel_size: 1 + expert_tensor_parallel_size: null + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: True + use_distributed_optimizer: True + use_dist_checkpointing: False + dist_checkpointing_path: null + seed: 42 + # additional transformer config like: num_layers_in_first(/last)_pipeline_stage + override_transformer_config: {} + use_mbridge: False + optim: + # Learning rate + lr: 1e-6 + # Warmup steps ratio (used if lr_warmup_steps is negative) + lr_warmup_steps_ratio: 0.0 + # Total training steps (must be overridden at runtime) + total_training_steps: -1 + # Weight decay + weight_decay: 0.01 + optimizer: adam + clip_grad: 1.0 + # initial learning rate for warmup, default to 0.0 + lr_warmup_init: 0.0 + # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps: null + lr_decay_steps: null + # select from constant/linear/cosine/inverse_square_root + lr_decay_style: constant + # minimum learning rate, default to 0.0 + min_lr: 0.0 + # select from constant/linear/cosine + weight_decay_incr_style: constant + # select from constant/exponential/cosine + lr_wsd_decay_style: exponential + lr_wsd_decay_steps: null + # use checkpoint optimizer parameter scheduler + use_checkpoint_opt_param_scheduler: False + data_loader_seed: null + load_weight: True + checkpoint: + async_save: False # save checkpoint asynchronously + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents} + use_dynamic_bsz: False + max_token_len: 1024 + use_torch_compile: False + shuffle: False + # profile the actor model in `update_policy` + profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # profiler tool, default same as profiler.tool in global config + # choices: nsys, npu, torch + tool: ${oc.select:global_profiler.tool,null} + + # whether enable profile on Actor + enable: False + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + + # profile results saving path + save_path: ${oc.select:global_profiler.save_path,null} + + # specific tool config which only related to the role + tool_config: + + # nsys tool config + nsys: + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + + # npu config + npu: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NPUToolConfig + + # Contents to profile, can be empty + # options: npu, cpu, memory, shapes, module, stack + contents: [] + + # Collection level, optional values: level_none, level0, level1, level2. + level: "level0" + + # Whether to automatically parse the data. + analysis: True + + # torch profiler config + torch: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + + # start profile mini-batch in training + # NOTICE: different with global steps config which refers to iteration + # This field only related with mini-batch + step_start: 0 + + # stop profile mini-batch in training + step_end: null + + + rollout: + _target_: verl.workers.config.RolloutConfig + name: vllm + mode: sync # sync: LLM, async: AsyncLLM + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + data_parallel_size: 1 + expert_parallel_size: 1 + pipeline_model_parallel_size: 1 + # max_batch_size: 8 + prompt_length: ${data.max_prompt_length} # for xperf_gpt + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_megatron + tensor_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + disable_log_stats: True + enable_chunked_prefill: False # could get higher throughput + # for hf rollout + do_sample: True + layer_name_map: + qkv_layer_name: qkv + gate_proj_layer_name: gate_up + # number of responses (i.e. num sample times) + engine_kwargs: # inference engine parameters + vllm: + swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB + disable_mm_preprocessor_cache: False # whether to disable the preprocessor cache for multimodel models. + sglang: + attention_backend: null # null means use the engine default value, available options: flashinfer, triton, flashmla + val_kwargs: + # sampling parameters for validation + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1.0 + temperature: 0 + n: 1 + do_sample: False # default eager for validation + + # Multi-turn interaction config for tools or chat. + multi_turn: + _target_: verl.workers.config.MultiTurnConfig + # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well + enable: False + # null for no limit (default max_length // 3) + max_assistant_turns: null + # null for no tool + tool_config_path: null + # null for no limit (default max_length // 3) + max_user_turns: null + # max parallel call for tools in single turn + max_parallel_calls: 1 + # max length of tool response + max_tool_response_length: 256 + # truncate side of tool response: left, middle, right + tool_response_truncate_side: middle + # null for no interaction + interaction_config_path: null + # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior. + # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output, + # which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts. + use_inference_chat_template: False + # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation. + # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids. + # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them. + # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template: + # Qwen/QwQ-32B, Qwen/Qwen3-xxB + # - disable: disable tokenization sanity check + # - strict: enable strict tokenization sanity check (default) + # - ignore_strippable: ignore strippable tokens when checking tokenization sanity + tokenization_sanity_check_mode: strict + # Format of the multi-turn interaction. Options: hermes, llama3_json, ... + format: hermes + + # [Experimental] agent loop based rollout configs + agent: + # Number of agent loop workers + num_workers: 8 + custom_async_server: + path: null + name: null + update_weights_bucket_megabytes: 512 + # support logging rollout prob for debugging purpose + calculate_log_probs: False + # # Nsight system profiler configs + # profiler: + # # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + # _target_: verl.utils.profiler.ProfilerConfig + # tool: ${oc.select:global_profiler.tool,null} + # all_ranks: False + # ranks: [] + + teacher: + server_ip: localhost + server_port: 15555 + overlap_rollout: False + n_server_workers: 1 + +trainer: + balance_batch: True + total_epochs: 30 + total_training_steps: null + profile_steps: null # [1,2,5] or [] or null + project_name: verl_examples + experiment_name: gsm8k + logger: ['console'] + log_val_generations: 0 + nnodes: 1 + n_gpus_per_node: 4 + save_freq: -1 + esi_redundant_time: 0 + + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or disable or resume_path if resume_from_path is set + resume_from_path: null + del_local_ckpt_after_load: False + val_before_train: True + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + # The timeout for ray worker group to wait for the register center to be ready + ray_wait_register_center_timeout: 300 + device: cuda + scheduler: one_step_off + +rollout: + nnodes: 1 + n_gpus_per_node: 4 + +ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. + timeline_json_file: null + +# profiler configs +global_profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # Profiling tool: choose between nsys, npu, torch + tool: null + + # profile steps + steps: null + + # Whether to combine continuous steps into one database. + ## If True, worker.profiler.discrete must be False, [1,2] in one, [5] in another. + ## If False, [1] in one, [2] in another, [5] in another. + profile_continuous_steps: False + + # Path to save profiling contents + save_path: "outputs/profile" + + # Specific tool configs, can use +profiler.tool_config.[tool].xxx to config + global_tool_config: + + # nsys config + nsys: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # Whether to profile all ranks. + all_ranks: False + ranks: [] + # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None. + ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html + ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html + controller_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None. + worker_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config. + capture-range: "cudaProfilerApi" + + # Specify the desired behavior when a capture range ends. + # In verl we need the torch.cuda.profiler.start/stop pair to repeats n times. + # valid values are "repeat-shutdown:n" or null. + # For normal whole step profiling, n = len(profile_steps); + # but for discrete profiling, n = len(profile_steps) * Number(subtasks). + # Or you can just leave it null and the program will use n = len(profile_steps) * 6; + capture-range-end: null + + # Send signal to the target application's process group. We let the program to exit by itself. + kill: none diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/config/runtime_env.yaml b/ICL/DAPO/verl-recipe/gkd/megatron/config/runtime_env.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c3311466fab79c6e527eeb14d6d4de3bfdb105b5 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/config/runtime_env.yaml @@ -0,0 +1,23 @@ +working_dir: ./ +excludes: ["/.git/"] +env_vars: + TORCH_NCCL_AVOID_RECORD_STREAMS: "1" + CUDA_LAUNCH_BLOCKING: "0" + NVTE_DEBUG: "1" + NVTE_DEBUG_LEVEL: "2" + NVTE_FLASH_ATTN: "1" + NVTE_FUSED_ATTN: "0" + NVTE_UNFUSED_ATTN: "0" + RAY_DEBUG: "legacy" + NCCL_DEBUG: "WARN" + # NCCL_IB_DISABLE: "1" + # NCCL_P2P_DISABLE: "1" + NCCL_DEBUG_FILE: "/workspace/nccl_debug.log" + # NCCL_SOCKET_IFNAME: "xgbe0" + VLLM_USE_V1: "1" + VERL_VLLM_DISTRIBUTED_BACKEND: "ray" + + + # NVTE_UNFUSED_ATTN: "1" + # If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: + # VLLM_ATTENTION_BACKEND: "XFORMERS" \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/main_gkd.py b/ICL/DAPO/verl-recipe/gkd/megatron/main_gkd.py new file mode 100644 index 0000000000000000000000000000000000000000..b94e13e08b31414d35a5f98b7750c36963f9125f --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/main_gkd.py @@ -0,0 +1,227 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Individual Contributor: Brilliant Hanabi, furunding +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf +from recipe.gkd.ray_trainer import OnPolicyDistillTrainer + +RAY_RUNTIME_ENV = { + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "VLLM_LOGGING_LEVEL": "WARN", + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "false", + "CUDA_DEVICE_MAX_CONNECTIONS": "1", + # To prevent hanging or crash during synchronization of weights between actor and rollout + # in disaggregated mode. See: + # https://docs.vllm.ai/en/latest/usage/troubleshooting.html?h=nccl_cumem_enable#known-issues + # https://github.com/vllm-project/vllm/blob/c6b0a7d3ba03ca414be1174e9bd86a97191b7090/vllm/worker/worker_base.py#L445 + "NCCL_CUMEM_ENABLE": "0", + }, +} + + +@hydra.main(config_path="config", config_name="on_policy_distill_trainer", version_base=None) +def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config_dict: Hydra configuration dictionary containing training parameters. + """ + run_on_policy_distill(config) + + +# Define a function to run the PPO-like training process +def run_on_policy_distill(config) -> None: + """Initialize Ray cluster and run distributed PPO training process. + + Args: + config: Training configuration object containing all necessary parameters + for distributed PPO training including Ray initialization settings, + model paths, and training hyperparameters. + """ + # Check if Ray is not initialized + + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + # PPO_RAY_RUNTIME_ENV["env_vars"]["NCCL_DEBUG"] = "INFO" + ray.init( + runtime_env=RAY_RUNTIME_ENV, + num_cpus=config.ray_init.num_cpus, + ) + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if ( + config.global_profiler.tool == "nsys" + and OmegaConf.select(config.global_profiler, "steps") is not None + and len(OmegaConf.select(config.global_profiler, "steps")) > 0 + ): + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_init.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + """Ray remote class for executing distributed PPO training tasks. + + This class encapsulates the main training logic and runs as a Ray remote actor + to enable distributed execution across multiple nodes and GPUs. + """ + + def run(self, config): + """Execute the main PPO training workflow. + + This method sets up the distributed training environment, initializes + workers, datasets, and reward functions, then starts the training process. + + Args: + config: Training configuration object containing all parameters needed + for setting up and running the PPO training process. + """ + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + + pprint(OmegaConf.to_container(config, resolve=True)) + + OmegaConf.resolve(config) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + + # Version validation for vllm. + if config.actor_rollout_ref.rollout.name in ["vllm"]: + from verl.utils.vllm import is_version_ge + + if config.actor_rollout_ref.model.get("lora_rank", 0) > 0: + if not is_version_ge(pkg="vllm", minver="0.7.3"): + raise NotImplementedError("PPO LoRA is not supported before vllm 0.7.3") + + # Megatron-only workers, split into rollout and actor + if config.actor_rollout_ref.actor.strategy == "megatron": + from verl.single_controller.ray import RayWorkerGroup + + from .megatron_workers import ( + MegatronOnPolicyDistillActorWorker, + MegatronOnPolicyDistillRolloutWorker, + ) + + rollout_cls = MegatronOnPolicyDistillRolloutWorker + actor_cls = MegatronOnPolicyDistillActorWorker + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + # Worker mapping and resource pools + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + # Map roles to their corresponding remote worker classes. + role_worker_mapping = { + Role.Rollout: ray.remote(rollout_cls), + Role.Actor: ray.remote(actor_cls), + } + + # Define the resource pool specification. + # Map roles to the resource pool. + assert config.trainer.n_gpus_per_node > 0, "config.trainer.n_gpus_per_node must be greater than 0" + assert config.trainer.nnodes > 0, "config.trainer.nnodes must be greater than 0" + assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0" + assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0" + + actor_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes + rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes + + resource_pool_spec = { + "rollout_pool": rollout_pool, + "actor_pool": actor_pool, + } + mapping = { + Role.Rollout: "rollout_pool", + Role.Actor: "actor_pool", + } + print(f"resource_pool_spec: {resource_pool_spec}") + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + from verl.trainer.main_ppo import create_rl_sampler + from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn + + # Create training and validation datasets. + train_dataset = RLHFDataset(config.data.train_files, tokenizer, config.data, None) + + if config.data.val_files: + val_dataset = RLHFDataset(config.data.val_files, tokenizer, config.data, None) + else: + val_dataset = None + + train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the PPO trainer. + trainer = OnPolicyDistillTrainer( + config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + device_name=config.trainer.device, + ) + # Initialize the workers of the trainer. + trainer.init_workers() + # Start the training process. + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/megatron_distill_losses.py b/ICL/DAPO/verl-recipe/gkd/megatron/megatron_distill_losses.py new file mode 100644 index 0000000000000000000000000000000000000000..d361922047c4eb7911355b3342f69f2abcec5e28 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/megatron_distill_losses.py @@ -0,0 +1,681 @@ +# megatron_distill_losses.py +# A unified file that provides 4 selectable vocab-parallel distillation losses: +# 1) KL : KL(P_topk || Q_full) (teacher top-k truncated forward KL) +# 2) RKL : KL(Q_hat_topk || P_hat_topk) (pure reverse KL on renormalized top-k) +# 3) KL_RKL : (1-r)*KL + r*RKL +# 4) JSD : JSD_beta(P_topk, Q_full) with analytic rest term for Q||M outside top-k +# +# Usage: +# op = build_vocab_parallel_distill_loss(cfg).cuda() +# loss_per_token = op(vocab_parallel_logits, teacher_topk_logps, teacher_topk_indices) + +import math +from typing import Any, Optional + +import torch +from megatron.core.fusions.fused_cross_entropy import calculate_logits_max +from megatron.core.parallel_state import ( + get_data_parallel_rank, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from megatron.core.tensor_parallel.utils import VocabUtility + + +# ----------------------------- +# utils +# ----------------------------- +def _clamp01_open(x: float, eps: float = 1e-6) -> float: + # clamp into (0, 1) open interval for logs + if x < eps: + return eps + if x > 1.0 - eps: + return 1.0 - eps + return x + + +def mylog(message: str, filename: str = "distill_loss.log"): + # optional debug + with open(filename, "a") as f: + f.write(f"({get_data_parallel_rank()}, {get_tensor_model_parallel_rank()}): {message}\n") + + +# ============================================================ +# 1) Forward KL (teacher top-k truncated): KL(P_topk || Q_full) +# ============================================================ +class _VocabParallelKLDivergence(torch.autograd.Function): + @staticmethod + def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices): + eps = 1e-20 + + # Student Q_full = softmax(logits) (vocab-parallel) + vocab_parallel_logits, logits_max = calculate_logits_max(vocab_parallel_logits) + partition_vocab_size = vocab_parallel_logits.size(-1) + + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group() + ) + + vocab_parallel_logits -= logits_max.unsqueeze(dim=-1) + vocab_parallel_logits.exp_() + exp_logits = vocab_parallel_logits + sum_exp_logits = exp_logits.sum(dim=-1) + + torch.distributed.all_reduce( + sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() + ) + + Q_full = exp_logits + Q_full.div_(sum_exp_logits.unsqueeze(-1)) # [*, V_part] + + # Local vocab range and map global top-k -> local indices + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start, vocab_end = VocabUtility.vocab_range_from_per_partition_vocab_size( + partition_vocab_size, rank, world_size + ) + + topk_in_vocab = (target_topk_indices >= vocab_start) & (target_topk_indices < vocab_end) + + topk_idx_local = (target_topk_indices - vocab_start).clone() + topk_idx_local[~topk_in_vocab] = 0 # placeholder + + # Teacher P_topk (local pieces only) + P_topk_part = torch.exp(target_topk_logps).clone() + P_topk_part[~topk_in_vocab] = 0.0 + + logP_topk_part = target_topk_logps.clone() + logP_topk_part[~topk_in_vocab] = 0.0 + + # Gather student's Q on teacher top-k indices + origin_shape = target_topk_indices.shape + topk = target_topk_indices.size(-1) + + Q_full_2d = Q_full.view(-1, partition_vocab_size) + row = torch.arange(Q_full_2d.size(0), device=Q_full_2d.device) + + Q_topk_2d = Q_full_2d[row.unsqueeze(-1), topk_idx_local.view(-1, topk)] + Q_topk = Q_topk_2d.view(origin_shape).clone() + Q_topk[~topk_in_vocab] = 0.0 + + logQ_topk = torch.log(Q_topk + eps) + logQ_topk[~topk_in_vocab] = 0.0 + + # KL(P_topk || Q_full) ≈ sum_k P_k (logP_k - logQ_k) + per_token_kl_local = torch.sum(P_topk_part * (logP_topk_part - logQ_topk), dim=-1) # [*] + per_token_kl = per_token_kl_local.clone() + torch.distributed.all_reduce( + per_token_kl, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() + ) + + ctx.save_for_backward(Q_full, P_topk_part, topk_idx_local) + return per_token_kl + + @staticmethod + def backward(ctx, grad_output): + Q_full, P_topk_part, topk_idx_local = ctx.saved_tensors + partition_vocab_size = Q_full.size(-1) + topk = topk_idx_local.size(-1) + + # d/dz KL(P||Q) = Q - P_sparse(topk) + grad_input = Q_full.clone() + grad_2d = grad_input.view(-1, partition_vocab_size) + row = torch.arange(grad_2d.size(0), device=grad_2d.device).unsqueeze(-1) + idx_2d = topk_idx_local.view(-1, topk) + grad_2d[row, idx_2d] -= P_topk_part.view(-1, topk) + + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + return grad_input, None, None + + +def vocab_parallel_kl_divergence(vocab_parallel_logits, target_topk_logps, target_topk_indices): + return _VocabParallelKLDivergence.apply(vocab_parallel_logits, target_topk_logps, target_topk_indices) + + +# ============================================================ +# 2) Pure Reverse KL on top-k (renormalized): KL(Q_hat || P_hat) +# ============================================================ +class _VocabParallelRKLDivergence(torch.autograd.Function): + @staticmethod + def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices): + eps = 1e-20 + + # Student Q_full + vocab_parallel_logits, logits_max = calculate_logits_max(vocab_parallel_logits) + partition_vocab_size = vocab_parallel_logits.size(-1) + + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group() + ) + + vocab_parallel_logits -= logits_max.unsqueeze(dim=-1) + vocab_parallel_logits.exp_() + exp_logits = vocab_parallel_logits + sum_exp_logits = exp_logits.sum(dim=-1) + + torch.distributed.all_reduce( + sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() + ) + + Q_full = exp_logits + Q_full.div_(sum_exp_logits.unsqueeze(-1)) # [*, V_part] + + # Local vocab range + local indices + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start, vocab_end = VocabUtility.vocab_range_from_per_partition_vocab_size( + partition_vocab_size, rank, world_size + ) + + topk_in_vocab = (target_topk_indices >= vocab_start) & (target_topk_indices < vocab_end) + + topk_idx_local = (target_topk_indices - vocab_start).clone() + topk_idx_local[~topk_in_vocab] = 0 # placeholder + + # Teacher P_topk (local pieces only) + P_topk_part = torch.exp(target_topk_logps).clone() + P_topk_part[~topk_in_vocab] = 0.0 + + # Gather Q_topk + origin_shape = target_topk_indices.shape + topk = target_topk_indices.size(-1) + + Q_full_2d = Q_full.view(-1, partition_vocab_size) + row = torch.arange(Q_full_2d.size(0), device=Q_full_2d.device) + + Q_topk_2d = Q_full_2d[row.unsqueeze(-1), topk_idx_local.view(-1, topk)] + Q_topk = Q_topk_2d.view(origin_shape).clone() + Q_topk[~topk_in_vocab] = 0.0 + + # Global sums for renorm + P_sum_local = P_topk_part.sum(dim=-1) + P_sum = P_sum_local.clone() + torch.distributed.all_reduce(P_sum, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) + + Q_sum_local = Q_topk.sum(dim=-1) + Q_sum = Q_sum_local.clone() + torch.distributed.all_reduce(Q_sum, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) + + Q_hat = Q_topk / (Q_sum.unsqueeze(-1) + eps) + P_hat = P_topk_part / (P_sum.unsqueeze(-1) + eps) + + logQ_hat = torch.log(Q_hat + eps) + logP_hat = torch.log(P_hat + eps) + + per_token_rkl_local = torch.sum(Q_hat * (logQ_hat - logP_hat), dim=-1) + per_token_rkl = per_token_rkl_local.clone() + torch.distributed.all_reduce( + per_token_rkl, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() + ) + + ctx.save_for_backward(Q_full, P_topk_part, topk_idx_local, Q_sum, P_sum) + return per_token_rkl + + @staticmethod + def backward(ctx, grad_output): + eps = 1e-20 + Q_full, P_topk_part, topk_idx_local, Q_sum, P_sum = ctx.saved_tensors + + partition_vocab_size = Q_full.size(-1) + topk = topk_idx_local.size(-1) + + # Re-gather Q_topk + Q_full_2d = Q_full.view(-1, partition_vocab_size) + row1 = torch.arange(Q_full_2d.size(0), device=Q_full_2d.device) + + Q_topk_2d = Q_full_2d[row1.unsqueeze(-1), topk_idx_local.view(-1, topk)] + Q_topk = Q_topk_2d.view_as(P_topk_part) + + # Only real local entries + topk_mask = P_topk_part > 0 + Q_topk = torch.where(topk_mask, Q_topk, torch.zeros_like(Q_topk)) + + Z = Q_sum.unsqueeze(-1) + eps + T = P_sum.unsqueeze(-1) + eps + + Q_hat = Q_topk / Z + P_hat = P_topk_part / T + logQ_hat = torch.log(Q_hat + eps) + logP_hat = torch.log(P_hat + eps) + + a = logQ_hat + 1.0 - logP_hat + + # mean_a = sum(Q_hat * a) over global topk + mean_a_local = torch.sum(Q_hat * a, dim=-1) + mean_a = mean_a_local.clone() + torch.distributed.all_reduce(mean_a, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) + + # grad on topk then scatter into vocab partition + grad_topk = (Q_topk / Z) * (a - mean_a.unsqueeze(-1)) # [*, K] + + grad_input = torch.zeros_like(Q_full) + grad_2d = grad_input.view(-1, partition_vocab_size) + + grad_topk_2d = grad_topk.view(-1, topk) + grad_topk_2d = torch.where(topk_mask.view(-1, topk), grad_topk_2d, torch.zeros_like(grad_topk_2d)) + + idx_2d = topk_idx_local.view(-1, topk) + row2 = torch.arange(grad_2d.size(0), device=grad_2d.device).unsqueeze(-1) + grad_2d[row2, idx_2d] += grad_topk_2d + + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + return grad_input, None, None + + +def vocab_parallel_rkl_divergence(vocab_parallel_logits, target_topk_logps, target_topk_indices): + return _VocabParallelRKLDivergence.apply(vocab_parallel_logits, target_topk_logps, target_topk_indices) + + +# ============================================================ +# 3) KL + RKL weighted: (1-r)*KL + r*RKL +# ============================================================ +class _VocabParallelWeightedKLRKLDivergence(torch.autograd.Function): + @staticmethod + def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, rkl_ratio: float = 0.1): + eps = 1e-20 + rkl_ratio = float(rkl_ratio) + if rkl_ratio < 0.0: + rkl_ratio = 0.0 + if rkl_ratio > 1.0: + rkl_ratio = 1.0 + kl_ratio = 1.0 - rkl_ratio + + # Student Q_full + vocab_parallel_logits, logits_max = calculate_logits_max(vocab_parallel_logits) + partition_vocab_size = vocab_parallel_logits.size(-1) + + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group() + ) + + vocab_parallel_logits -= logits_max.unsqueeze(dim=-1) + vocab_parallel_logits.exp_() + exp_logits = vocab_parallel_logits + sum_exp_logits = exp_logits.sum(dim=-1) + + torch.distributed.all_reduce( + sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() + ) + + Q_full = exp_logits + Q_full.div_(sum_exp_logits.unsqueeze(-1)) # [*, V_part] + + # Local vocab range + local indices + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start, vocab_end = VocabUtility.vocab_range_from_per_partition_vocab_size( + partition_vocab_size, rank, world_size + ) + + topk_in_vocab = (target_topk_indices >= vocab_start) & (target_topk_indices < vocab_end) + + topk_idx_local = (target_topk_indices - vocab_start).clone() + topk_idx_local[~topk_in_vocab] = 0 + + # Teacher P_topk (local) + P_topk_part = torch.exp(target_topk_logps).clone() + P_topk_part[~topk_in_vocab] = 0.0 + + logP_topk_part = target_topk_logps.clone() + logP_topk_part[~topk_in_vocab] = 0.0 + + # Gather Q_topk + origin_shape = target_topk_indices.shape + topk = target_topk_indices.size(-1) + + Q_full_2d = Q_full.view(-1, partition_vocab_size) + row = torch.arange(Q_full_2d.size(0), device=Q_full_2d.device) + + Q_topk_2d = Q_full_2d[row.unsqueeze(-1), topk_idx_local.view(-1, topk)] + Q_topk = Q_topk_2d.view(origin_shape).clone() + Q_topk[~topk_in_vocab] = 0.0 + + logQ_topk = torch.log(Q_topk + eps) + logQ_topk[~topk_in_vocab] = 0.0 + + # Forward KL (not renorm): sum P_k (logP_k - logQ_k) + per_token_kl_local = torch.sum(P_topk_part * (logP_topk_part - logQ_topk), dim=-1) + per_token_kl = per_token_kl_local.clone() + torch.distributed.all_reduce( + per_token_kl, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() + ) + + # Reverse KL on topk with renorm + P_sum_local = P_topk_part.sum(dim=-1) + P_sum = P_sum_local.clone() + torch.distributed.all_reduce(P_sum, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) + + Q_sum_local = Q_topk.sum(dim=-1) + Q_sum = Q_sum_local.clone() + torch.distributed.all_reduce(Q_sum, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) + + Q_hat = Q_topk / (Q_sum.unsqueeze(-1) + eps) + P_hat = P_topk_part / (P_sum.unsqueeze(-1) + eps) + logQ_hat = torch.log(Q_hat + eps) + logP_hat = torch.log(P_hat + eps) + + per_token_rkl_local = torch.sum(Q_hat * (logQ_hat - logP_hat), dim=-1) + per_token_rkl = per_token_rkl_local.clone() + torch.distributed.all_reduce( + per_token_rkl, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() + ) + + per_token_loss = kl_ratio * per_token_kl + rkl_ratio * per_token_rkl + + ctx.save_for_backward(Q_full, P_topk_part, topk_idx_local, Q_sum, P_sum) + ctx.rkl_ratio = rkl_ratio + ctx.kl_ratio = kl_ratio + return per_token_loss + + @staticmethod + def backward(ctx, grad_output): + eps = 1e-20 + Q_full, P_topk_part, topk_idx_local, Q_sum, P_sum = ctx.saved_tensors + rkl_ratio = ctx.rkl_ratio + kl_ratio = ctx.kl_ratio + + partition_vocab_size = Q_full.size(-1) + topk = topk_idx_local.size(-1) + + # A) Forward KL grad = Q_full - P_sparse(topk) + grad_kl = Q_full.clone() + grad_kl_2d = grad_kl.view(-1, partition_vocab_size) + row = torch.arange(grad_kl_2d.size(0), device=grad_kl_2d.device).unsqueeze(-1) + idx_2d = topk_idx_local.view(-1, topk) + grad_kl_2d[row, idx_2d] -= P_topk_part.view(-1, topk) + + # B) Reverse KL grad (scatter topk only) + Q_full_2d = Q_full.view(-1, partition_vocab_size) + row1 = torch.arange(Q_full_2d.size(0), device=Q_full_2d.device) + + Q_topk_2d = Q_full_2d[row1.unsqueeze(-1), topk_idx_local.view(-1, topk)] + Q_topk = Q_topk_2d.view_as(P_topk_part) + + topk_mask = P_topk_part > 0 + Q_topk = torch.where(topk_mask, Q_topk, torch.zeros_like(Q_topk)) + + Z = Q_sum.unsqueeze(-1) + eps + T = P_sum.unsqueeze(-1) + eps + + Q_hat = Q_topk / Z + P_hat = P_topk_part / T + logQ_hat = torch.log(Q_hat + eps) + logP_hat = torch.log(P_hat + eps) + + a = logQ_hat + 1.0 - logP_hat + mean_a_local = torch.sum(Q_hat * a, dim=-1) + mean_a = mean_a_local.clone() + torch.distributed.all_reduce(mean_a, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) + + grad_topk = (Q_topk / Z) * (a - mean_a.unsqueeze(-1)) # [*, K] + + grad_rkl = torch.zeros_like(Q_full) + grad_rkl_2d = grad_rkl.view(-1, partition_vocab_size) + + grad_topk_2d = grad_topk.view(-1, topk) + grad_topk_2d = torch.where(topk_mask.view(-1, topk), grad_topk_2d, torch.zeros_like(grad_topk_2d)) + + row2 = torch.arange(grad_rkl_2d.size(0), device=grad_rkl_2d.device).unsqueeze(-1) + idx_2d = topk_idx_local.view(-1, topk) + grad_rkl_2d[row2, idx_2d] += grad_topk_2d + + grad_input = kl_ratio * grad_kl + rkl_ratio * grad_rkl + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + return grad_input, None, None, None + + +def vocab_parallel_kl_rkl_divergence( + vocab_parallel_logits, target_topk_logps, target_topk_indices, rkl_ratio: float = 0.1 +): + return _VocabParallelWeightedKLRKLDivergence.apply( + vocab_parallel_logits, target_topk_logps, target_topk_indices, rkl_ratio + ) + + +# ============================================================ +# 4) JSD(beta) with analytic rest term for Q||M outside top-k +# ============================================================ +class _VocabParallelJSDivergence(torch.autograd.Function): + @staticmethod + def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices, beta: float): + beta = min(max(float(beta), 1e-6), 1.0 - 1e-6) + one_minus_beta = 1.0 - beta + eps = 1e-20 + + # Student Q_full + vocab_parallel_logits, logits_max = calculate_logits_max(vocab_parallel_logits) + partition_vocab_size = vocab_parallel_logits.size(-1) + + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group() + ) + + vocab_parallel_logits -= logits_max.unsqueeze(dim=-1) + vocab_parallel_logits.exp_() + exp_logits = vocab_parallel_logits + sum_exp_logits = exp_logits.sum(dim=-1) + + torch.distributed.all_reduce( + sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() + ) + + Q_full = exp_logits + Q_full.div_(sum_exp_logits.unsqueeze(-1)) # Q + + # Local vocab range + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start, vocab_end = VocabUtility.vocab_range_from_per_partition_vocab_size( + partition_vocab_size, rank, world_size + ) + + topk_in_vocab = (target_topk_indices >= vocab_start) & (target_topk_indices < vocab_end) + + topk_idx_local = (target_topk_indices - vocab_start).clone() + topk_idx_local[~topk_in_vocab] = 0 + + # Teacher P_topk + P_topk = torch.exp(target_topk_logps).clone() + P_topk[~topk_in_vocab] = 0.0 + + logP_topk = target_topk_logps.clone() + logP_topk[~topk_in_vocab] = 0.0 + + # Gather Q_topk + origin_shape = target_topk_indices.shape + topk = target_topk_indices.size(-1) + + Q_full_2d = Q_full.view(-1, partition_vocab_size) + row = torch.arange(Q_full_2d.size(0), device=Q_full_2d.device) + + Q_topk_2d = Q_full_2d[row.unsqueeze(-1), topk_idx_local.view(-1, topk)] + Q_topk = Q_topk_2d.view(origin_shape).clone() + Q_topk[~topk_in_vocab] = 0.0 + + logQ_topk = torch.log(Q_topk + eps) + + # Mix on topk: M_k = beta P_k + (1-beta) Q_k + M_topk = beta * P_topk + one_minus_beta * Q_topk + logM_topk = torch.log(M_topk + eps) + + # KL(P||M) topk + kl_P_M_local = torch.sum(P_topk * (logP_topk - logM_topk), dim=-1) + + # KL(Q||M) topk + kl_Q_M_topk_local = torch.sum(Q_topk * (logQ_topk - logM_topk), dim=-1) + + # KL(Q||M) rest analytic: for non-topk, M_j=(1-beta)Q_j => Q_j log(1/(1-beta)) + Q_topk_sum_local = Q_topk.sum(dim=-1) + Q_topk_sum = Q_topk_sum_local.clone() + torch.distributed.all_reduce( + Q_topk_sum, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() + ) + + log_one_minus_beta = math.log(one_minus_beta) + Q_rest_sum = 1.0 - Q_topk_sum + kl_Q_M_rest = Q_rest_sum * (-log_one_minus_beta) + + tmp = beta * kl_P_M_local + one_minus_beta * kl_Q_M_topk_local + torch.distributed.all_reduce(tmp, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) + + per_token_jsd = tmp + one_minus_beta * kl_Q_M_rest + + ctx.save_for_backward(Q_full, P_topk, topk_idx_local) + ctx.beta = beta + return per_token_jsd + + @staticmethod + def backward(ctx, grad_output): + Q_full, P_topk, topk_idx_local = ctx.saved_tensors + beta = ctx.beta + one_minus_beta = 1.0 - beta + eps = 1e-20 + + partition_vocab_size = Q_full.size(-1) + topk = topk_idx_local.size(-1) + + # Re-gather Q_topk + Q_full_2d = Q_full.view(-1, partition_vocab_size) + row = torch.arange(Q_full_2d.size(0), device=Q_full_2d.device) + + Q_topk_2d = Q_full_2d[row.unsqueeze(-1), topk_idx_local.view(-1, topk)] + Q_topk = Q_topk_2d.view_as(P_topk) + + topk_mask = P_topk > 0 + Q_topk = torch.where(topk_mask, Q_topk, torch.zeros_like(Q_topk)) + + M_topk = beta * P_topk + one_minus_beta * Q_topk + logQ_topk = torch.log(Q_topk + eps) + logM_topk = torch.log(M_topk + eps) + + KL_Q_M_topk_local = torch.sum(Q_topk * (logQ_topk - logM_topk), dim=-1) + + Q_topk_sum_local = Q_topk.sum(dim=-1) + Q_topk_sum = Q_topk_sum_local.clone() + torch.distributed.all_reduce( + Q_topk_sum, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() + ) + + log_one_minus_beta = math.log(one_minus_beta) + Q_rest_sum = 1.0 - Q_topk_sum + KL_Q_M_rest = Q_rest_sum * (-log_one_minus_beta) + + KL_Q_M_topk_global = KL_Q_M_topk_local.clone() + torch.distributed.all_reduce( + KL_Q_M_topk_global, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() + ) + + KL_Q_M = KL_Q_M_topk_global + KL_Q_M_rest # [*] + + # A_j = log(Q_j / M_j); non-topk: -log(1-beta) + A = torch.full_like(Q_full, -log_one_minus_beta) + + A_topk = logQ_topk - logM_topk + A_2d = A.view(-1, partition_vocab_size) + A_topk_2d = A_topk.view(-1, topk) + + idx_2d = topk_idx_local.view(-1, topk) + row2 = torch.arange(A_2d.size(0), device=A_2d.device).unsqueeze(-1) + A_2d[row2, idx_2d] = A_topk_2d + + # d/dz JSD = (1-beta) * Q * (A - KL(Q||M)) + grad_input = one_minus_beta * Q_full * (A - KL_Q_M.unsqueeze(-1)) + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + return grad_input, None, None, None + + +def vocab_parallel_jsd_divergence(vocab_parallel_logits, target_topk_logps, target_topk_indices, beta: float = 0.5): + return _VocabParallelJSDivergence.apply(vocab_parallel_logits, target_topk_logps, target_topk_indices, beta) + + +# ============================================================ +# Unified operator wrapper + factory +# ============================================================ +class VocabParallelDistillLoss(torch.nn.Module): + """ + Unified operator: + forward(vocab_parallel_logits, teacher_topk_logps, teacher_topk_indices) -> per_token_loss + + Supported names (case-insensitive): + - "kl" + - "rkl" + - "kl_rkl" + - "jsd" + + Params: + - rkl_ratio: only used when name == "kl_rkl" + - beta: only used when name == "jsd" + """ + + def __init__(self, name: str = "kl", rkl_ratio: float = 0.1, beta: float = 0.5): + super().__init__() + self.name = str(name).lower() + self.rkl_ratio = float(rkl_ratio) + self.beta = float(beta) + + def forward(self, vocab_parallel_logits, teacher_topk_logps, teacher_topk_indices): + n = self.name + + if n in ["kl", "forward_kl", "forward-kl"]: + return vocab_parallel_kl_divergence(vocab_parallel_logits, teacher_topk_logps, teacher_topk_indices) + + if n in ["rkl", "reverse_kl", "reverse-kl"]: + return vocab_parallel_rkl_divergence(vocab_parallel_logits, teacher_topk_logps, teacher_topk_indices) + + if n in ["kl_rkl", "kl+rkl", "kl_rkl_weighted", "weighted_kl_rkl", "klrkl"]: + return vocab_parallel_kl_rkl_divergence( + vocab_parallel_logits, teacher_topk_logps, teacher_topk_indices, rkl_ratio=self.rkl_ratio + ) + + if n in ["jsd", "jensen_shannon", "jensen-shannon", "jensen_shannon_divergence"]: + return vocab_parallel_jsd_divergence( + vocab_parallel_logits, teacher_topk_logps, teacher_topk_indices, beta=self.beta + ) + + raise ValueError(f"Unknown distill loss name: {self.name}") + + +def build_vocab_parallel_distill_loss(loss_cfg: Optional[Any]) -> VocabParallelDistillLoss: + """ + loss_cfg can be: + - None + - dict + - OmegaConf DictConfig + + Expected fields: + - name: "kl" | "rkl" | "kl_rkl" | "jsd" + - rkl_ratio: float (only for kl_rkl) + - beta: float (only for jsd) + """ + cfg: dict[str, Any] = {} + if loss_cfg is None: + cfg = {} + else: + try: + from omegaconf import DictConfig, OmegaConf # type: ignore + + if isinstance(loss_cfg, DictConfig): + cfg = OmegaConf.to_container(loss_cfg, resolve=True) # type: ignore + elif isinstance(loss_cfg, dict): + cfg = dict(loss_cfg) + else: + cfg = {} + except Exception: + cfg = dict(loss_cfg) if isinstance(loss_cfg, dict) else {} + + name = str(cfg.get("name", "kl")).lower() + rkl_ratio = float(cfg.get("rkl_ratio", 0.1)) + beta = float(cfg.get("beta", 0.5)) + + return VocabParallelDistillLoss(name=name, rkl_ratio=rkl_ratio, beta=beta) + + +__all__ = [ + "vocab_parallel_kl_divergence", + "vocab_parallel_rkl_divergence", + "vocab_parallel_kl_rkl_divergence", + "vocab_parallel_jsd_divergence", + "VocabParallelDistillLoss", + "build_vocab_parallel_distill_loss", +] diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/megatron_kl_loss.py b/ICL/DAPO/verl-recipe/gkd/megatron/megatron_kl_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a6da80654dd2821d882ab4b600148bc94ff0ab8e --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/megatron_kl_loss.py @@ -0,0 +1,161 @@ +# Copyright 2025 Individual Contributor: furunding +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core.fusions.fused_cross_entropy import calculate_logits_max +from megatron.core.parallel_state import ( + get_data_parallel_rank, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from megatron.core.tensor_parallel.utils import VocabUtility + + +def normalize(logps): + probs = torch.exp(logps) + probs = probs / probs.sum(dim=-1, keepdim=True) + normalized_logps = torch.log(probs) + return normalized_logps + + +def mylog(message): + with open("kl_loss.log", "a") as f: + f.write(f"({get_data_parallel_rank()}, {get_tensor_model_parallel_rank()}): {message}\n") + + +class _VocabParallelKLDivergence(torch.autograd.Function): + @staticmethod + def forward(ctx, vocab_parallel_logits, target_topk_logps, target_topk_indices): + # seq_len, batch_size, top_k = target_topk_logps.size() + # target_topk_logps = normalize(target_topk_logps) + vocab_parallel_logits, logits_max = calculate_logits_max(vocab_parallel_logits) + partition_vocab_size = vocab_parallel_logits.size(-1) + + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group() + ) + + vocab_parallel_logits -= logits_max.unsqueeze(dim=-1) + vocab_parallel_logits.exp_() + exp_logits = vocab_parallel_logits + sum_exp_logits = exp_logits.sum(dim=-1) + + torch.distributed.all_reduce( + sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_tensor_model_parallel_group(), + ) + + # Get the partition's vocab indices + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_per_partition_vocab_size( + partition_vocab_size, rank, world_size + ) + + topk_indices_in_vocab_mask = (target_topk_indices >= vocab_start_index) & ( + target_topk_indices < vocab_end_index + ) + + vocab_parallel_target_topk_indices = target_topk_indices - vocab_start_index + vocab_parallel_target_topk_indices[~topk_indices_in_vocab_mask] = 0 + vocab_parallel_target_topk_probs = torch.exp(target_topk_logps) + vocab_parallel_target_topk_probs[~topk_indices_in_vocab_mask] = 0 + vocab_parallel_target_topk_logps = torch.empty_like(target_topk_logps) + vocab_parallel_target_topk_logps[...] = target_topk_logps[...] + vocab_parallel_target_topk_logps[~topk_indices_in_vocab_mask] = 0 + # assert ((0 <= target_topk_indices) & (target_topk_indices < partition_vocab_size)).all() + + # bs, sl, topk = target_topk_indices.shape + target_topk_logps_origin_shape = target_topk_indices.shape + topk = target_topk_indices.size(-1) + + vocab_parallel_source_probs = exp_logits + vocab_parallel_source_probs.div_(sum_exp_logits.unsqueeze(-1)) + vocab_parallel_source_probs_2d = vocab_parallel_source_probs.view(-1, partition_vocab_size) # (b*s, h/tp) + + arange_1d = torch.arange( + start=0, end=vocab_parallel_source_probs_2d.size(0), device=vocab_parallel_source_probs_2d.device + ) # (b*s, ) + vocab_parallel_source_topk_probs_2d = vocab_parallel_source_probs_2d[ + arange_1d.unsqueeze(-1), vocab_parallel_target_topk_indices.view(-1, topk) + ] # (b*s, topk) + vocab_parallel_source_topk_probs = vocab_parallel_source_topk_probs_2d.view( + target_topk_logps_origin_shape + ) # (b, s, topk) + vocab_parallel_source_topk_logps = torch.log(1e-20 + vocab_parallel_source_topk_probs) + vocab_parallel_source_topk_logps[~topk_indices_in_vocab_mask] = 0 + + # KL(P||Q)会强制 Q 覆盖 P 的所有模式(避免漏峰) + # KL(Q||P)会鼓励 Q 聚焦于 P 的一个模式(避免多峰) + # 这里使用 KL(P||Q),其中P为target,Q为source,鼓励source学习target的所有模式 + + per_token_kl_loss = torch.sum( + vocab_parallel_target_topk_probs * (vocab_parallel_target_topk_logps - vocab_parallel_source_topk_logps), + dim=-1, + ) # (b, s) + + # if torch.isinf(per_token_kl_loss).any() or torch.isnan(per_token_kl_loss).any(): + # breakpoint() + + torch.distributed.all_reduce( + per_token_kl_loss, + op=torch.distributed.ReduceOp.SUM, + group=get_tensor_model_parallel_group(), + ) + + ctx.save_for_backward( + vocab_parallel_source_probs, vocab_parallel_target_topk_probs, vocab_parallel_target_topk_indices + ) + # if get_data_parallel_rank() == 0 and get_tensor_model_parallel_rank() == 1: + # import ipdb; ipdb.set_trace() + # torch.distributed.barrier() + return per_token_kl_loss + + @staticmethod + def backward(ctx, grad_output): + vocab_parallel_source_probs, vocab_parallel_target_topk_probs, vocab_parallel_target_topk_indices = ( + ctx.saved_tensors + ) + # source_probs, target_probs = ctx.saved_tensors + # KL 散度对 vocab_parallel_logits 的梯度为: (student_softmax_logits - valid_target_logits) + grad_input = vocab_parallel_source_probs # shape: [seq_len, batch_size, vocab_parition_size] + + topk = vocab_parallel_target_topk_indices.size(-1) + grad_input_2d = grad_input.view(-1, grad_input.size(-1)) + arange_1d = torch.arange(start=0, end=grad_input_2d.size(0), device=grad_input_2d.device) # (b*s, ) + grad_input_2d[arange_1d.unsqueeze(-1), vocab_parallel_target_topk_indices.view(-1, topk)] -= ( + vocab_parallel_target_topk_probs.view(-1, topk) + ) + + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None, None # 返回给第一个输入 vocab_parallel_logits 的梯度 + + +def vocab_parallel_kl_divergence(vocab_parallel_logits, target_topk_logps, target_topk_indices): + """ + Performs cross entropy loss when logits are split across tensor parallel ranks. + + Args: + vocab_parallel_logits: logits split across tensor parallel ranks + dimension is [sequence_length, batch_size, vocab_size_per_partition] + target_topk_logits: logits split across tensor parallel ranks + dimension is [sequence_length, batch_size, top_k] + + Returns: + loss: scalar tensor + """ + return _VocabParallelKLDivergence.apply(vocab_parallel_logits, target_topk_logps, target_topk_indices) diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/megatron_workers.py b/ICL/DAPO/verl-recipe/gkd/megatron/megatron_workers.py new file mode 100644 index 0000000000000000000000000000000000000000..e3339ffb270bbf1c9cb548887deb6f4097d63a9d --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/megatron_workers.py @@ -0,0 +1,826 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# Copyright 2025 Individual Contributor: Brilliant Hanabi, funrunding +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import os +import time + +import numpy as np +import psutil +import torch +from codetiming import Timer +from megatron.core import parallel_state as mpu +from megatron.core.distributed import finalize_model_grads +from megatron.core.optimizer import DistributedOptimizer +from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron_distill_losses import build_vocab_parallel_distill_loss +from omegaconf import DictConfig, OmegaConf +from torch import nn + +from verl import DataProto +from verl.single_controller.base.decorator import ( + Dispatch, + make_nd_compute_dataproto_dispatch_fn, + register, +) +from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager +from verl.utils.device import get_device_id, get_device_name, get_torch_device +from verl.utils.flops_counter import FlopsCounter +from verl.utils.megatron.pipeline_parallel import make_batch_generator +from verl.utils.megatron_utils import get_model_config +from verl.utils.profiler import ( + DistProfiler, + GPUMemoryLogger, + log_gpu_memory_usage, + simple_timer, +) +from verl.utils.profiler.performance import gather_timing +from verl.utils.profiler.profile import Profiler +from verl.utils.py_functional import append_to_dict +from verl.utils.seqlen_balancing import rearrange_micro_batches +from verl.workers.megatron_workers import ActorRolloutRefWorker + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class TensorBuffer: + def __init__(self, memory_alloc, dtype): + self.device = get_device_id() + dtype_size = torch.tensor([], dtype=dtype).element_size() + self.capacity = memory_alloc // dtype_size + self.dtype = dtype + self.tensor = torch.empty(self.capacity, dtype=self.dtype, device=self.device) + self.keys = [] + self.shapes = [] + + @property + def size(self): + return sum(shape.numel() for shape in self.shapes) + + def clear(self): + self.keys.clear() + self.shapes.clear() + self.tensor = torch.empty(self.capacity, dtype=self.dtype, device=self.device) + + def append(self, key, shape, weight=None): + if weight is not None: + self.tensor[self.size : self.size + shape.numel()] = weight.view(-1) + self.keys.append(key) + self.shapes.append(shape) + + def to_tensors(self): + tensors = [] + start = 0 + for key_, shape_ in zip(self.keys, self.shapes, strict=False): + tensors.append((key_, self.tensor[start : start + shape_.numel()].view(shape_))) + start += shape_.numel() + return tensors + + +def record_time(func): + def wrapper(*args, **kwargs): + tik = time.time() + func(*args, **kwargs) + tok = time.time() + return tok - tik + + return wrapper + + +class OnPolicyDistillActor: + """ + Responsible purely for the training step (forward-backward + optimizer). + """ + + def __init__( + self, + config, + model_config, + hf_config, + tf_config, + actor_module: nn.ModuleList, + actor_optimizer: DistributedOptimizer, + ): + """MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron. + + Args: + config (OmegaConf): the basic config that contains the hyper-parameters of PPO Actor. It must contain + + ``shuffle``: whether to shuffle the data after each ppo epoch. + + ``clip_ratio``: clip ratio of the ppo algorithm. See https://arxiv.org/abs/1707.06347. + + ``entropy_coeff``: entropy coefficient of the PPO loss. See https://arxiv.org/abs/1707.06347. + model_config (OmegaConf): model configuration. It must contains ``model_config.vocab_size`` and + ``model_config.hidden_size`` + hf_config (PretrainedConfig): huggingface config + tf_config (TransformerConfig): mcore transformer config + actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this + pp stage. + each nn.Module in this rank holds a vpp module chunk. See https://arxiv.org/pdf/2104.04473.pdf for + more details. + The actor module has some constraints to follow in order to use the updating logics implemented here + + 1. It must implement unpad_input before any computation and pad_input after all the computation. + Remove padding is an + optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn + (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py). + + 2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size], + where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size + of the hidden state is [total_nnz // tp, 1, hidden_size]. + actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron. + It implements + zero1 optimizer that shards the optimizer state across dp ranks. + + >>> from megatron.training import get_model + >>> from megatron.optimizer import get_megatron_optimizer + >>> actor_module = get_model(megatron_actor_model_provider, wrap_with_ddp=True) + >>> actor_module = nn.ModuleList(actor_module) + >>> actor_optimizer = get_megatron_optimizer(actor_module) + >>> actor = MegatronPPOActor(config=config, + >>> model_config=actor_model_config, + >>> hf_config=hf_config, + >>> tf_config=tf_config, + >>> actor_module=actor_module, + >>> actor_optimizer=actor_optimizer) + """ + self.config = config + self._validate_config(config) + self.model_config = model_config + self.hf_config = hf_config + self.tf_config = tf_config + self.actor_module = actor_module + self.actor_optimizer: DistributedOptimizer = actor_optimizer + self.prof = Profiler(self.config.profiler) + self.optimizer_step_args = OmegaConf.create( + { + "skip_grad": None, + "overlap_dp_param_comm": False, + "overlap_dp_grad_comm": False, + "gradient_accumulation_steps": 1, + "sequence_parallel": self.tf_config.sequence_parallel, + "DDP_impl": "local", + "layernorm_allreduce_bucket_threshold": 0, + "pipeline_model_parallel_split_rank": None, + "reduce_grads_use_alltoall": False, + } + ) + + config = get_model_config(self.actor_module[0]) + print(config) + config.finalize_model_grads_func = finalize_model_grads + + # Build distill loss operator (selectable by config) + self.distill_loss_op = build_vocab_parallel_distill_loss(self.config.get("distill_loss", None)).cuda() + + def _validate_config(self, config) -> None: + """Validate config options not implemented for Megatron backend""" + assert config.get("ulysses_sequence_parallel_size", 1) == 1 + if config.get("shuffle", False): + assert config.data_loader_seed is not None, "If shuffle dataloader, seed must be manually set" + if config.megatron.tensor_model_parallel_size == 1: + print("[Warining] Because actor tp size == 1, set sp to False") + config.megatron.sequence_parallel = False + self.config = config + + def forward_backward_batch( + self, + data: DataProto, + use_dynamic_bsz=False, + micro_batch_size=None, + max_token_len=None, + ): + """ + We assume: + - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input + - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled + """ + # broadcast from last pp rank to all other pp ranks + # TODO: actually, we just need to control the sampling order. + # broadcast_dict_tensor( + # data.batch, + # src=mpu.get_pipeline_model_parallel_last_rank(), + # group=mpu.get_pipeline_model_parallel_group(), + # ) + # split into micro-batches + data.batch["attention_mask"] = data.batch["attention_mask"].to(bool) + + indices = None + if use_dynamic_bsz: + assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" + vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + if vpp_size is not None and vpp_size > 1: + microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage + micro_batches, indices = rearrange_micro_batches( + batch=data.batch, + num_batches_divided_by=microbatch_group_size_per_vp_stage, + max_token_len=max_token_len, + ) + assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, ( + f"micro_batches {len(micro_batches)} must be divisible by microbatch_group_size_per_vp_stage " + f"{microbatch_group_size_per_vp_stage} for megatron backend" + ) + else: + micro_batches, indices = rearrange_micro_batches(batch=data.batch, max_token_len=max_token_len) + # total_seqlen = max_token_len + if mpu.is_pipeline_last_stage(): + teacher_topk_logps_tensor = torch.tensor(data.non_tensor_batch["teacher_topk_logps"]) + teacher_topk_indices_tensor = torch.tensor(data.non_tensor_batch["teacher_topk_indices"]) + teacher_topk_logps, teacher_topk_indices = [], [] + for partition in indices: + curr_logp_micro_batch, curr_idx_micro_batch = [], [] + for idx in partition: + curr_logp_micro_batch.append(teacher_topk_logps_tensor[idx : idx + 1]) + curr_idx_micro_batch.append(teacher_topk_indices_tensor[idx : idx + 1]) + curr_logp_micro_batch = torch.cat(curr_logp_micro_batch) + curr_idx_micro_batch = torch.cat(curr_idx_micro_batch) + + teacher_topk_logps.append(curr_logp_micro_batch) + teacher_topk_indices.append(curr_idx_micro_batch) + + for i, mb in enumerate(micro_batches): + responses = mb["responses"] + response_length = responses.size(1) + calc_kl_mask = mb["attention_mask"].clone() + calc_kl_mask[:, : (-response_length - 1)] = False + mb["calc_kl_mask"] = calc_kl_mask + mb["kl_losses"] = torch.zeros_like(calc_kl_mask, dtype=torch.float32) + mb["teacher_topk_logps"] = teacher_topk_logps[i].pin_memory() + mb["teacher_topk_indices"] = teacher_topk_indices[i].pin_memory() + else: + assert micro_batch_size is not None, ( + "micro_batch_size is needed to be passed in when not using dynamic batch size" + ) + micro_batches = data.batch.split(micro_batch_size) + # seq_len = micro_batches[0]["input_ids"].shape[1] + # total_seqlen = micro_batch_size * seq_len + if mpu.is_pipeline_last_stage(): + teacher_topk_logps = np.array_split(data.non_tensor_batch["teacher_topk_logps"], len(micro_batches)) + teacher_topk_indices = np.array_split(data.non_tensor_batch["teacher_topk_indices"], len(micro_batches)) + for i, mb in enumerate(micro_batches): + responses = mb["responses"] + response_length = responses.size(1) + calc_kl_mask = mb["attention_mask"].clone() + calc_kl_mask[:, : (-response_length - 1)] = False + mb["calc_kl_mask"] = calc_kl_mask + mb["kl_losses"] = torch.zeros_like(calc_kl_mask, dtype=torch.float32) + mb["teacher_topk_logps"] = torch.tensor(teacher_topk_logps[i]).pin_memory() + mb["teacher_topk_indices"] = torch.tensor(teacher_topk_indices[i]).pin_memory() + + # compute input shapes for pp stages + n_micro_batch = len(micro_batches) + + forward_backward_func = get_forward_backward_func() + + def loss_func(output): + # For memory efficiency + # We move calculation of entropy to compute_log_probs, forward_only == True + metrics = {} + + ret_entropy = None + stats = {} + kl_losses = output["kl_losses"] + calc_kl_mask = output["calc_kl_mask"] + # inf_cnt = masked_kl_lossed.isinf().sum().item() + # nan_cnt = masked_kl_lossed.isnan().sum().item() + # total_cnt = masked_kl_lossed.nelement() + # print(f"rank: {rank}, kl_loss inf_cnt/nan_cnt/total_cnt: {inf_cnt} / {nan_cnt} /{total_cnt}") + masked_kl_lossed = kl_losses[calc_kl_mask] + mean_kl_loss = masked_kl_lossed.mean() + stats.update({"actor/kl_loss": mean_kl_loss.detach().item()}) + + append_to_dict(metrics, stats) + return mean_kl_loss, [metrics, ret_entropy] + + def forward_step(batch_iter, model): + batch = next(batch_iter) + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + position_ids = batch["position_ids"] + + def logits_processor(logits, teacher_topk_logps, teacher_topk_indices, calc_kl_mask, kl_losses): + assert logits.shape[:2] == calc_kl_mask.shape[:2] + assert logits.shape[:2] == teacher_topk_indices.shape[:2] + assert logits.shape[:2] == teacher_topk_logps.shape[:2] + + masked_logits = logits[calc_kl_mask] + masked_teacher_topk_logps = teacher_topk_logps[calc_kl_mask] + masked_teacher_topk_indices = teacher_topk_indices[calc_kl_mask] + + kl_losses[calc_kl_mask] = self.distill_loss_op( + masked_logits, masked_teacher_topk_logps, masked_teacher_topk_indices + ) + return {"kl_losses": kl_losses, "calc_kl_mask": calc_kl_mask} + + if mpu.is_pipeline_last_stage(): + device = get_device_id() + teacher_topk_logps = batch["teacher_topk_logps"].to(device, non_blocking=True) + teacher_topk_indices = batch["teacher_topk_indices"].to(device, non_blocking=True) + logits_processor_args = { + "calc_kl_mask": batch["calc_kl_mask"], + "kl_losses": batch["kl_losses"], + "teacher_topk_logps": teacher_topk_logps, + "teacher_topk_indices": teacher_topk_indices, + } + else: + logits_processor_args = None + + multi_modal_inputs = {} + if "multi_modal_inputs" in batch: + from verl.utils.model import extract_multi_modal_inputs + + indices = batch.get("multi_modal_inputs_idx", None) + multi_modal_inputs = extract_multi_modal_inputs(batch["multi_modal_inputs"], indices) + + from verl.models.mcore import get_mcore_forward_fn + + forward_fn = get_mcore_forward_fn(self.hf_config) + + output = forward_fn( + model, + input_ids, + attention_mask, + position_ids, + multi_modal_inputs, + logits_processor=logits_processor, + logits_processor_args=logits_processor_args, + ) + + return output, loss_func + + # batch should be a list of batches inside micro-batches + batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.actor_module)) + + # TODO: we may use the new schedule instead + # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) + losses_reduced = forward_backward_func( + forward_step_func=forward_step, + data_iterator=batch_generator, + model=self.actor_module, + num_microbatches=n_micro_batch, + seq_length=-1, # no use when variable_seq_lengths was set + micro_batch_size=-1, # no use when variable_seq_lengths was set + forward_only=False, + ) + + # loss_reduces contains the stats returned from loss_func + + losses_reduced = {"output": losses_reduced} + if use_dynamic_bsz: + losses_reduced["indices"] = indices + return losses_reduced + + @GPUMemoryLogger(role="megatron actor", logger=logger) + def update_policy(self, data: DataProto) -> dict: + """Update the policy with an iterator of DataProto + + Args: + dataloader (Iterable[DataProto]): an iterator over the DataProto that returns by ``make_minibatch_iterator`` + The keys of each data batch is described in the make_minibatch_iterator. + + Returns: + Dict: a dictionary containing the statistics. Note that the statistics are only valid in the last pp stage + and users have to combine the output in each dp rank manually. + + """ + metrics = {} + # self.prof.start() + data.to(get_device_id()) + + self.actor_optimizer.zero_grad() + # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm + for chunk in self.actor_module: + # if use distributed optimizer, zero grad buffer will be handled by optimizer + chunk.zero_grad_buffer() + + micro_batch_size = self.config.micro_batch_size + max_token_len = None + if self.config.use_dynamic_bsz: + max_token_len = self.config.max_token_len * self.config.megatron.context_parallel_size + + metric_micro_batch = self.forward_backward_batch( + data, + use_dynamic_bsz=self.config.use_dynamic_bsz, + micro_batch_size=micro_batch_size, + max_token_len=max_token_len, + ) + + metric_micro_batch = metric_micro_batch["output"] + for metric in metric_micro_batch: + # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask + append_to_dict(metrics, metric[0]) # append the metric from this micro-batch to global metrics. + + update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step() + data = {"actor/grad_norm": grad_norm} + append_to_dict(metrics, data) + + if update_successful: + # allgather already execute in optimizer.step in new megatron + pass + else: + raise NotImplementedError + # self.prof.step() + # add empty cache after each compute + # self.prof.stop_and_save() + # self.prof.stop_trace() + get_torch_device().empty_cache() + return metrics + + +class MegatronOnPolicyDistillActorWorker(ActorRolloutRefWorker): + """ + Actor-only worker: owns the trainable Megatron model and optimizer, performs update_actor. + """ + + def __init__(self, config: DictConfig, role: str): + # Ensure we run as actor-only worker + is_struct = OmegaConf.is_struct(config) or False + OmegaConf.set_struct(config, False) + OmegaConf.set_struct(config, is_struct) + + super().__init__(config, role) + assert self._is_actor and not self._is_rollout, "Actor worker must be actor-only." + + def _get_actor_params_generator(self): + assert self._is_actor + if self.bridge is not None: + generator = self.bridge.export_weights(self.actor.actor_module) + else: + # from verl.utils.megatron_utils import per_tensor_generator + from megatron_utils import per_tensor_generator + + from verl.models.mcore import get_mcore_weight_converter + + layer_name_mapping = { + "qkv_layer_name": "self_attention.linear_qkv.", + "gate_proj_layer_name": "linear_fc1.", + } + weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) + generator = per_tensor_generator( + self.actor.actor_module, + self.actor_model_config, + weight_converter, + self.tf_config, + layer_name_mapping, + ) + return generator + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + from verl.utils.torch_dtypes import PrecisionType + + override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + override_transformer_config = OmegaConf.to_container( + self.config.actor.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True + ) + + self.param_dtype = torch.bfloat16 + log_gpu_memory_usage("Before init actor model and optimizer", logger=logger) + self.dtype = PrecisionType.to_dtype(self.param_dtype) + # we need the model for actor and rollout + optim_config = self.config.actor.optim + ( + self.actor_module, + self.actor_optimizer, + self.actor_optimizer_scheduler, + self.actor_model_config, + self.actor_optim_config, + ) = self._build_model_optimizer( + model_path=self.config.model.path, + optim_config=optim_config, + override_model_config=override_model_config, + override_transformer_config=override_transformer_config, + ) + + self.actor = OnPolicyDistillActor( + config=self.config.actor, + model_config=self.actor_model_config, + hf_config=self.hf_config, + tf_config=self.tf_config, + actor_module=self.actor_module, + actor_optimizer=self.actor_optimizer, + ) + log_gpu_memory_usage("After OnPolicyDistillActor init", logger=logger) + + self.flops_counter = FlopsCounter(self.actor_model_config) + self.checkpoint_mananager = MegatronCheckpointManager( + config=self.config, + checkpoint_config=self.config.actor.checkpoint, + model_config=self.actor_model_config, + transformer_config=self.tf_config, + role="actor", + model=self.actor_module, + arch=self.architectures[0], + hf_config=self.hf_config, + param_dtype=self.param_dtype, + share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, + processing_class=self.processor if self.processor is not None else self.tokenizer, + optimizer=self.actor_optimizer, + optimizer_scheduler=self.actor_optimizer_scheduler, + use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, + use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler, + bridge=self.bridge, + use_dist_checkpointing=self.config.actor.megatron.use_dist_checkpointing, + ) + get_torch_device().empty_cache() + log_gpu_memory_usage("Actor init_model finished", logger=logger) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @GPUMemoryLogger(role="update_actor", logger=logger) + @DistProfiler.annotate(color="red") + def update_actor(self, data: DataProto): + assert self._is_actor and not self._is_rollout + + with Timer(name="update_policy", logger=None) as timer: + metrics = self.actor.update_policy(data=data) + + delta_time = timer.last + global_num_tokens = data.meta_info["global_token_num"] + estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) + metrics["perf/mfu/actor"] = estimated_flops / promised_flops / self.world_size + metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) + metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) + metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) + from verl.utils.megatron.optimizer import get_megatron_last_lr + + metrics["actor/lr"] = get_megatron_last_lr(self.actor_optimizer) + self.actor_optimizer_scheduler.step(1) + + # TODO: here, we should return all metrics + output = DataProto(meta_info={"metrics": metrics}) + output = output.to("cpu") + + get_torch_device().empty_cache() + return output + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def sync_rollout_weights(self): + assert self._is_actor and not self.config.hybrid_engine + assert hasattr(self, "_weights_info") and self._weights_info is not None + + params_generator = self._get_actor_params_generator() + + from ray.util.collective import collective + + update_weights_bucket_bytes = int(self.config.rollout.update_weights_bucket_megabytes) << 20 + tensor_buffer = TensorBuffer(update_weights_bucket_bytes, self.param_dtype) + + for key, shape, dtype in self._weights_info: + weight_key, weight = next(params_generator) + assert key == weight_key + assert shape == weight.size() + try: + assert dtype == weight.dtype + except AssertionError: + if not key.endswith("e_score_correction_bias"): + raise + # weight = weight.to(dtype) + + if shape.numel() > tensor_buffer.capacity: + collective.broadcast(weight, src_rank=0, group_name="actor_rollout") + else: + if tensor_buffer.size + shape.numel() > tensor_buffer.capacity: + collective.broadcast(tensor_buffer.tensor, src_rank=0, group_name="actor_rollout") + tensor_buffer.clear() + tensor_buffer.append(key, shape, weight) + if tensor_buffer.size > 0: + collective.broadcast(tensor_buffer.tensor, src_rank=0, group_name="actor_rollout") + tensor_buffer.clear() + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def get_actor_weights_info(self): + assert self._is_actor + if hasattr(self, "_weights_info"): + return self._weights_info + + params_generator = self._get_actor_params_generator() + ret = [] + for key, tensor in params_generator: + ret.append((key, tensor.size(), tensor.dtype)) + + self._weights_info = ret + return ret + + +class MegatronOnPolicyDistillRolloutWorker(ActorRolloutRefWorker): + """ + Rollout-only worker: owns the inference engine (vLLM/SGlang, or Megatron forward) and generates sequences. + """ + + def __init__(self, config: DictConfig, role: str): + # Ensure we run as rollout-only worker + # is_struct = OmegaConf.is_struct(config) or False + # OmegaConf.set_struct(config, False) + # # Set a safe minimal rollout micro-batch size if not provided by config + # if OmegaConf.select(config, "actor.ppo_mini_batch_size") is None: + # config.actor.ppo_mini_batch_size = 2 + # if OmegaConf.select(config, "rollout.n") is None: + # config.rollout.n = 1 + # OmegaConf.set_struct(config, is_struct) + import datetime + + from verl.utils.config import omega_conf_to_dataclass + from verl.utils.device import ( + get_nccl_backend, + get_torch_device, + ) + from verl.utils.distributed import set_numa_affinity + from verl.utils.fs import copy_to_local + from verl.utils.model import get_generation_config + from verl.utils.profiler import DistProfilerExtension, ProfilerConfig + from verl.workers.megatron_workers import MegatronWorker + + MegatronWorker.__init__(self) + self.config = config + self.local_path = copy_to_local(self.config.model.path) + + # NOTE(sgm): We utilize colocate WorkerGroup by default. + # As a result, Workers for different model share the same process. + # Therefore, we only require one distribute initialization. + # To utilize different parallel strategy in different models: + # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, + # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 + if not torch.distributed.is_initialized(): + set_numa_affinity() + rank = int(os.environ["LOCAL_RANK"]) + torch.distributed.init_process_group( + backend=get_nccl_backend(), + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + get_torch_device().set_device(rank) + + self.role = role + assert self.role == "rollout" + + self._is_actor = False + self._is_rollout = True + self._is_ref = False + + # NOTE: In colocation mode, rollout config may not take effect (follow the actor config) + # This is for extendability in AsyncRL cases + omega_profiler_config = config.rollout.get("profiler", {}) + + # omega_profiler_config is DictConfig + # profiler_config is a ProfilerConfig dataclass + profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) + ) + + self._is_offload_param = False + self._is_offload_grad = False + self._is_offload_optimizer = False + + # self._build_rollout will use this variable + self.bridge = "none" + self.generation_config = get_generation_config(self.local_path) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + """ + Build the actor module only for inference + rollout engine; no optimizer/updates. + """ + from verl.utils.torch_dtypes import PrecisionType + + self.param_dtype = torch.bfloat16 + log_gpu_memory_usage("Before init rollout model", logger=logger) + self.dtype = PrecisionType.to_dtype(self.param_dtype) + + self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) + self.rollout_device_mesh = self.rollout.device_mesh + log_gpu_memory_usage("After rollout init", logger=logger) + get_torch_device().empty_cache() + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) + @GPUMemoryLogger(role="generate_sequences", logger=logger) + @DistProfiler.annotate(color="red") + def generate_sequences(self, prompts: DataProto): + """ + Asynchronous-friendly rollout. When called via Ray with blocking=False, + returns immediately with a future. The actual method execution generates + sequences and optionally fetches teacher knowledge, and returns DataProto. + """ + assert self._is_rollout and not self._is_actor + prompts.batch = prompts.batch.to(get_device_name()) + meta_info = { + "eos_token_id": self.generation_config.eos_token_id + if self.generation_config is not None + else self.tokenizer.eos_token_id, + "pad_token_id": self.generation_config.pad_token_id + if self.generation_config is not None + else self.tokenizer.pad_token_id, + } + prompts.meta_info.update(meta_info) + + timing_generate = {} + # No context switching here; rollout-only worker always in rollout mode. + + with simple_timer("generate_sequences", timing_generate): + output = self.rollout.generate_sequences(prompts=prompts) + + # We calculate the average timing across all ranks + # to make sure meta_info["timing"] is the same + timing_generate = gather_timing(timing_generate) + output.meta_info["timing"] = timing_generate + output = output.to("cpu") + # clear kv cache + get_torch_device().empty_cache() + + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout"), blocking=False) + def async_generate_sequences(self, *args, **kwargs): + return self.generate_sequences(*args, **kwargs) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def sync_rollout_weights(self): + from ray.util.collective import collective + + assert self._is_rollout and not self.config.hybrid_engine + assert hasattr(self, "_weights_info") and self._weights_info is not None + rollout_name = self.config.rollout.name + if rollout_name == "vllm": + inference_model = ( + self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model + ) + from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader + + patch_vllm_moe_model_weight_loader(inference_model) + elif rollout_name == "sglang": + from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights + + inference_model = self.rollout._engine + + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def update_weights(inference_engine, params): + await sgl_update_weights( + engine=inference_engine, + params_batch=params, + device_mesh_key="infer_tp", + device_mesh=self.rollout_device_mesh, + ) + + if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0: + await inference_engine.flush_cache() + else: + raise NotImplementedError(f"Unknown rollout name: {rollout_name}") + + update_weights_bucket_bytes = int(self.config.rollout.update_weights_bucket_megabytes) << 20 + tensor_buffer = TensorBuffer(update_weights_bucket_bytes, self.param_dtype) + + def group_tensor_generator(): + for key, shape, dtype in self._weights_info: + assert dtype == self.param_dtype, key + if shape.numel() > tensor_buffer.capacity: + tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) + collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") + yield [(key, tensor)] + else: + if tensor_buffer.size + shape.numel() > tensor_buffer.capacity: + collective.broadcast(tensor_buffer.tensor, src_rank=0, group_name="actor_rollout") + yield tensor_buffer.to_tensors() + tensor_buffer.clear() + tensor_buffer.append(key, shape) + if tensor_buffer.size > 0: + collective.broadcast(tensor_buffer.tensor, src_rank=0, group_name="actor_rollout") + yield tensor_buffer.to_tensors() + tensor_buffer.clear() + + for tensors in group_tensor_generator(): + if rollout_name == "vllm": + inference_model.load_weights(tensors) + elif rollout_name == "sglang": + loop.run_until_complete(update_weights(inference_model, tensors)) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def set_actor_weights_info(self, weights_info): + assert self._is_rollout + self._weights_info = weights_info diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/ray_trainer.py b/ICL/DAPO/verl-recipe/gkd/megatron/ray_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..533979dcb506d3b1020a18d2b98ceb3b8a7037e3 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/ray_trainer.py @@ -0,0 +1,696 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FSDP PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import time +from typing import Optional + +import numpy as np +import ray +import torch +from omegaconf import OmegaConf, open_dict +from recipe.gkd.teacher import TeacherClient +from recipe.gkd.teacher_utils import get_teacher_knowledge +from torch.utils.data import Dataset, Sampler +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from verl import DataProto +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.single_controller.base import Worker +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo.metric_utils import ( + compute_throughout_metrics, + compute_timing_metrics, +) +from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role +from verl.utils.debug import marked_timer +from verl.utils.metric import ( + reduce_metrics, +) +from verl.utils.torch_dtypes import PrecisionType +from verl.utils.tracking import ValidationGenerationsLogger + +WorkerType = type[Worker] + + +class GenerationBatchFuture: + """ + Wrapper class for encapsulating batch generation results + """ + + def __init__(self, epoch, batch, gen_batch_output): + """ + :param epoch: current epoch + :param batch: Input batch data + :param gen_batch_output: Generated sequences from the main model (DataProtoFuture) + """ + self.epoch = epoch + self.batch = batch + self.gen_batch_output = gen_batch_output + self.teacher_batch_output = None + + def set_teacher_batch_output(self, teacher_batch_output): + """Set the teacher batch output for this generation batch. + + Args: + teacher_batch_output: The teacher model's output (DataProtoFuture or raw output) + to be associated with this generation batch. This will be used for + distillation or guidance during training. + """ + self.teacher_batch_output = teacher_batch_output + + def get(self): + """ + Get the actual results by calling get() method on gen_batch_output + + Returns: + tuple: (batch, gen_batch_result) + - batch: Original input batch data + - gen_batch_result: Result from gen_batch_output.get() or gen_batch_output itself + """ + # Call get() method on gen_batch_output if available + if hasattr(self.gen_batch_output, "get"): + gen_batch_result = self.gen_batch_output.get() + self.gen_batch_output = gen_batch_result + + if self.teacher_batch_output is None: + return self.epoch, self.batch, self.gen_batch_output + + if hasattr(self.teacher_batch_output, "get"): + try: + teacher_batch_result = self.teacher_batch_output.get() + except Exception as e: + # set result to empty + teacher_batch_result = None + print(f"{e}") + else: + teacher_batch_result = self.teacher_batch_output + + return self.epoch, self.batch, self.gen_batch_output, teacher_batch_result + + +class OnPolicyDistillTrainer(RayPPOTrainer): + """Distributed PPO trainer using Ray for scalable reinforcement learning. + + This trainer orchestrates distributed PPO training across multiple nodes and GPUs, + managing actor rollouts, critic training, and reward computation with Ray backend. + Supports various model architectures including FSDP, Megatron, and vLLM integration. + """ + + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + collate_fn=None, + train_sampler: Optional[Sampler] = None, + device_name="cuda", + ): + """ + Initialize distributed PPO trainer with Ray backend. + Note that this trainer runs on the driver process on a single CPU/GPU node. + + Args: + config: Configuration object containing training parameters. + tokenizer: Tokenizer used for encoding and decoding text. + role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes. + resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools. + ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup. + processor: Optional data processor, used for multimodal data + reward_fn: Function for computing rewards during training. + val_reward_fn: Function for computing rewards during validation. + train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. + val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. + collate_fn: Function to collate data samples into batches. + train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. + device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to "cuda". + """ + + # Store the tokenizer for text processing + self.tokenizer = tokenizer + self.config = config + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert not self.hybrid_engine + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name + self.validation_generations_logger = ValidationGenerationsLogger() + self.use_critic = False + + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + self.teacher_config = self.config.actor_rollout_ref.teacher + self.n_server_workers = self.teacher_config.n_server_workers + self.teacher_client = TeacherClient( + self.teacher_config.server_ip, self.teacher_config.server_port, n_server_workers=self.n_server_workers + ) + + self.params_dtype = PrecisionType.to_dtype("bfloat16") + + def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): + """ + Creates the train and validation dataloaders. + """ + # TODO: we have to make sure the batch size is divisible by the dp size + from verl.trainer.main_ppo import create_rl_sampler + + self.train_dataset, self.val_dataset = train_dataset, val_dataset + + if train_sampler is None: + train_sampler = create_rl_sampler(self.config.data, self.train_dataset) + if collate_fn is None: + from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + + collate_fn = default_collate_fn + + num_workers = self.config.data["dataloader_num_workers"] + + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), + num_workers=num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=train_sampler, + ) + + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + + if self.val_dataset: + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + + self.val_dataloader = StatefulDataLoader( + dataset=self.val_dataset, + batch_size=val_batch_size, + num_workers=num_workers, + shuffle=self.config.data.get("validation_shuffle", True), + drop_last=False, + collate_fn=collate_fn, + ) + + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" + + print( + f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " + f"{len(self.val_dataloader)}" + ) + else: + print(f"Size of train dataloader: {len(self.train_dataloader)}") + + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = min(self.config.trainer.total_training_steps, total_training_steps) + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self.resource_pool_manager.create_resource_pool() + + # Build Ray classes per pool + resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # Rollout group + rollout_pool = self.resource_pool_manager.get_resource_pool(Role.Rollout) + rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.Rollout], + config=self.config.actor_rollout_ref, + role="rollout", + ) + resource_pool_to_cls[rollout_pool]["rollout"] = rollout_cls + + # Actor group + actor_pool = self.resource_pool_manager.get_resource_pool(Role.Actor) + actor_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.Actor], + config=self.config.actor_rollout_ref, + role="actor", + ) + resource_pool_to_cls[actor_pool]["actor"] = actor_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.trainer, "profile_steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "profile_steps") + assert OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None, ( + "worker_nsight_options must be set when profile_steps is set" + ) + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.trainer, "worker_nsight_options") + ) + + for resource_pool, class_dict in resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + device_name=self.device_name, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + time.sleep(20) # avoid port conflict + + self.rollout_wg = all_wg["rollout"] + self.actor_wg = all_wg["actor"] + + # Initialize both groups + self.rollout_wg.init_model() + self.actor_wg.init_model() + self.actor_rollout_wg = self.actor_wg # to be compatible with the functions that not be modified + weights_info = self.actor_wg.get_actor_weights_info()[0] + self.rollout_wg.set_actor_weights_info(weights_info) + from ray.util.collective import collective + + actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers + collective.create_collective_group( + actor_rollout_workers, + len(actor_rollout_workers), + list(range(0, len(actor_rollout_workers))), + backend="nccl", + group_name="actor_rollout", + ) + + def sync_rollout_weights(self): + assert not self.hybrid_engine + self.actor_wg.sync_rollout_weights() + ray.get(self.rollout_wg.sync_rollout_weights()) + + def _create_continuous_iterator(self): + """ + Create a continuous data iterator across epoch + """ + for epoch in range(self.config.trainer.total_epochs): + iterator = iter(self.train_dataloader) + for batch_dict in iterator: + yield epoch, batch_dict + + def _async_gen_next_batch(self, epoch, batch_dict, sync_before_generation=True): + """ + Call parameter synchronization and asynchronous sequence generation. + """ + batch = DataProto.from_single_dict(batch_dict) + # pop those keys for generation + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_data" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("multi_modal_data") + if "raw_prompt" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + if "interaction_kwargs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("interaction_kwargs") + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + gen_batch.meta_info["global_steps"] = self.global_steps + # sync weights from actor to rollout + if sync_before_generation: + self.sync_rollout_weights() + # Call non-blocking rollout (worker method registered with blocking=False) + gen_batch_output = self.rollout_wg.async_generate_sequences(gen_batch) + return GenerationBatchFuture(epoch, batch, gen_batch_output) + + def _async_get_teacher_knowledge(self, future: GenerationBatchFuture): + """Asynchronously obtain teacher model knowledge for generated sequences. + + This method retrieves generated sequences from the future object, adds response length metadata, + and asynchronously queries the teacher model for knowledge distillation. The teacher model's output + is set in the future object for subsequent processing. + + Args: + future (GenerationBatchFuture): Future object containing generated sequences and metadata + + Returns: + GenerationBatchFuture: The same future object with teacher knowledge set + + Raises: + RuntimeError: If teacher client initialization fails or knowledge retrieval fails + """ + _, _, gen_batch_output = future.get() + gen_batch_output.meta_info["response_length"] = self.config.data.max_response_length + + future.set_teacher_batch_output( + get_teacher_knowledge(gen_batch_output, self.teacher_client, self.n_server_workers, is_async=True) + ) + return future + + def one_step_off_scheduler(self, continuous_iterator): + """One-step-off scheduler implementation (version 1) for GKD training with improved pipeline. + + This scheduler optimizes the training pipeline by: + 1. Overlapping rollout weight synchronization with teacher knowledge processing + 2. Maintaining consistent timing measurement across iterations + 3. Reducing idle time between generation and knowledge distillation phases + + The scheduler maintains the following timing metrics: + - sync_rollout_weights: Time taken to synchronize rollout weights + - wait_prev_gen: Time waiting for previous generation to complete + - wait_prev_teacher: Time waiting for teacher knowledge to be ready + + Args: + continuous_iterator: Iterator providing (epoch, batch_dict) tuples for training + + Yields: + tuple: Contains (batch, gen_batch_output, teacher_batch_output, timing_metrics) + - batch: Original input batch data + - gen_batch_output: Generated sequences from main model + - teacher_batch_output: Knowledge distillation from teacher model + - timing_metrics: Dictionary of timing measurements + """ + timing = {} + for i, (epoch, batch_dict) in enumerate(continuous_iterator): + if i == 0: + # sync weights and start first async rollout + with marked_timer("sync_rollout_weights", timing): + fut = self._async_gen_next_batch(epoch, batch_dict) + # wait for previous rollout finish and start async generate teacher knowledge + with marked_timer("wait_prev_gen", timing): + prev_fut = self._async_get_teacher_knowledge(fut) + # no yield here, so we will continue to the next loop and enter `else` block + if i == 1: + # we don't need to sync weights here because we have not trained the actor yet + # start second async rollout + fut = self._async_gen_next_batch(epoch, batch_dict, sync_before_generation=False) + # wait for generating teacher knowledge finish + # and get previous result including rollout and teacher knowledge + with marked_timer("wait_prev_teacher", timing): + prev_result = prev_fut.get() + yield *prev_result, timing + + # start next step from here + timing = {} + # wait for previous rollout finish and start async generate teacher knowledge + with marked_timer("wait_prev_gen", timing): + prev_fut = self._async_get_teacher_knowledge(fut) + else: + # sync weights and start next async rollout + with marked_timer("sync_rollout_weights", timing): + fut = self._async_gen_next_batch(epoch, batch_dict) + # wait for generating teacher knowledge finish + # and get previous result including rollout and teacher knowledge + with marked_timer("wait_prev_teacher", timing): + prev_result = prev_fut.get() + yield *prev_result, timing + + # start next step from here + timing = {} + # wait for previous rollout finish and start async generate teacher knowledge + with marked_timer("wait_prev_gen", timing): + prev_fut = self._async_get_teacher_knowledge(fut) + + # for last step + with marked_timer("wait_prev_teacher", timing): + prev_result = prev_fut.get() + yield *prev_result, timing + + def two_step_off_scheduler(self, continuous_iterator): + """Two-step-off scheduler implementation for GKD training with optimized pipeline. + + This scheduler implements a double-buffered pipeline that overlaps: + 1. Sequence generation with teacher knowledge distillation + 2. Weight synchronization with previous batch processing + + Key features: + - Maintains two parallel processing streams (current and previous batches) + - Overlaps computation and communication where possible + - Provides consistent timing metrics across iterations + + Pipeline stages: + 1. Initialization: Start first generation without teacher processing + 2. Steady state: Alternate between processing teacher knowledge and starting new generation + 3. Final state: Process last batch of teacher knowledge + + Timing metrics collected: + - sync_rollout_weights: Time for weight synchronization between actor and rollout workers + - wait_prev_prev_teacher: Time waiting for teacher knowledge from two batches ago + - wait_prev_gen: Time waiting for previous generation to complete + + Args: + continuous_iterator: Iterator providing (epoch, batch_dict) tuples for training + + Yields: + tuple: Contains (batch, gen_batch_output, teacher_batch_output, timing_metrics) + - batch: Original input batch data + - gen_batch_output: Generated sequences from main model + - teacher_batch_output: Knowledge distillation from teacher model + - timing_metrics: Dictionary of timing measurements + """ + timing = {} + for i, (epoch, batch_dict) in enumerate(continuous_iterator): + if i == 0: + with marked_timer("sync_rollout_weights", timing): + rollout_future = self._async_gen_next_batch(epoch, batch_dict) + continue + elif i == 1: + teacher_future = self._async_get_teacher_knowledge(rollout_future) + rollout_future = self._async_gen_next_batch(epoch, batch_dict, sync_before_generation=False) + continue + elif i == 2: + with marked_timer("wait_prev_prev_teacher", timing): + result = teacher_future.get() + with marked_timer("wait_prev_gen", timing): + teacher_future = self._async_get_teacher_knowledge(rollout_future) + rollout_future = self._async_gen_next_batch(epoch, batch_dict, sync_before_generation=False) + yield *result, timing + timing = {} + else: + with marked_timer("wait_prev_prev_teacher", timing): + result = teacher_future.get() + with marked_timer("wait_prev_gen", timing): + teacher_future = self._async_get_teacher_knowledge(rollout_future) + with marked_timer("sync_rollout_weights", timing): + rollout_future = self._async_gen_next_batch(epoch, batch_dict) + yield *result, timing + timing = {} + + # for second to last step + with marked_timer("wait_prev_prev_teacher", timing): + result = teacher_future.get() + with marked_timer("wait_prev_gen", timing): + teacher_future = self._async_get_teacher_knowledge(rollout_future) + yield *result, timing + + # for last step + with marked_timer("wait_prev_prev_teacher", timing): + result = teacher_future.get() + yield *result, timing + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + max_steps_duration = 0 + + # Pre-warm: submit the first rollout + continuous_iterator = self._create_continuous_iterator() + + scheduler_type = self.config.trainer.scheduler + + if scheduler_type == "one_step_off": + scheduler = self.one_step_off_scheduler(continuous_iterator) + elif scheduler_type == "two_step_off": + scheduler = self.two_step_off_scheduler(continuous_iterator) + else: + raise TypeError(f"unrecognized scheduler type: {scheduler_type}") + + # Main loop + while True: + do_profile = ( + self.global_steps in self.config.trainer.profile_steps + if self.config.trainer.profile_steps is not None + else False + ) + if do_profile: + self.rollout_wg.start_profile() + self.actor_wg.start_profile() + + metrics = {} + timing_raw = {} + is_last_step = self.global_steps >= self.total_training_steps + + with marked_timer("step", timing_raw): + _, batch, gen_batch_output, teacher_batch_output, schedule_timing = next(scheduler) + if teacher_batch_output is None: + # save model + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + self._save_checkpoint() + print("Error in getting teacher knowledge. Skip this batch.") + progress_bar.update(1) + self.global_steps += 1 + if is_last_step: + progress_bar.close() + return + continue + + timing_raw.update(schedule_timing) + + gen_timing = gen_batch_output.meta_info.pop("timing", {}) + for k, v in gen_timing.items(): + if isinstance(v, list): + array_v = np.array(v) + timing_raw[k + "_mean"] = array_v.mean().item() + timing_raw[k + "_min"] = array_v.min().item() + timing_raw[k + "_max"] = array_v.max().item() + timing_raw[k] = array_v.max().item() + else: + timing_raw[k] = v + + timing_raw.update(teacher_batch_output.meta_info.pop("timing")) + + # Compute statistics of generated response lengths distribution + response_lens = ( + (gen_batch_output.batch["responses"] != self.tokenizer.pad_token_id).sum(dim=-1).tolist() + ) + metrics.update( + { + "response_seq_len/average": sum(response_lens) / len(response_lens), + "response_seq_len/max": max(response_lens), + "response_seq_len/min": min(response_lens), + "response_seq_len/max_count": response_lens.count(max(response_lens)), + "response_seq_len/min_count": response_lens.count(min(response_lens)), + } + ) + + # Merge generated outputs back + batch = batch.union(gen_batch_output) + + # Debug print + one_attention_mask = batch.batch["attention_mask"][0].to(torch.bool) + one_sentence = batch.batch["input_ids"][0] + print("INFO:", "generate text done.") + print("DEBUG:", self.tokenizer.decode(one_sentence[one_attention_mask].tolist())) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + batch = batch.union(teacher_batch_output) + + # # update actor + # with marked_timer("send_teacher_knowledge", timing_raw, color="red"): + # self.actor_wg.send_teacher_knowledge(teacher_batch_output) + + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + actor_output = self.actor_wg.update_actor(batch) + + print("INFO:", "update actor done.") + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # save model + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + # Metrics and bookkeeping + steps_duration = timing_raw["step"] + max_steps_duration = max(max_steps_duration, steps_duration) + # training metrics + metrics["training/global_step"] = self.global_steps + # collect metrics + # metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if do_profile: + self.rollout_wg.stop_profile() + self.actor_wg.stop_profile() + + if is_last_step: + progress_bar.close() + return diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/run_moonlight_dsv3_training.sh b/ICL/DAPO/verl-recipe/gkd/megatron/run_moonlight_dsv3_training.sh new file mode 100644 index 0000000000000000000000000000000000000000..d45d08fea4b95ea0f52f54486986fb84b1c85b79 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/run_moonlight_dsv3_training.sh @@ -0,0 +1,123 @@ +#!/bin/bash +set -x + +# 0. download the config +# only need to download the `configuration_deepseek.py`, `config.json`, `tokenizer_config.json`, `tokenizer.json` and `generation_config.json` +# remove the `quantization_config` in the `config.json` +# set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported + +# huggingface-cli download deepseek-ai/DeepSeek-V3-0324 configuration_deepseek.py config.json + +# 1. download the dist_ckpt format model from https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main +# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path + +HF_MODEL_PATH=/path/to/Moonlight-16B-A3B-Instruct +DIST_CKPT_PATH=/path/to/Moonlight-16B-A3B-Instruct-MCORE + +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 +export NVTE_FLASH_ATTN=1 +export NVTE_DEBUG=1 +export NVTE_DEBUG_LEVEL=2 + +# 2. run the script +gsm8k_train_path=/path/to/train.parquet +train_files=$gsm8k_train_path + +# 512 H20(96GB) +NODES=1 +PP=3 +TP=1 +EP=2 +ETP=1 +INFER_TP=1 +SP=True +if [ $TP == 1 ]; then + SP=False +fi +# consider TP/ETP, and enable recompute if short of memory + +TEACHER_SERVER_HOST=127.0.0.1 +TEACHER_SERVER_PORT=15555 + +function check_server_ready() { + local server=$1 + local ip=$2 + local port=$3 + + echo "check $server server ready at $ip:$port..." + result=`echo -e "\n" | telnet $ip $port 2> /dev/null | grep Connected | wc -l` + if [ $result -ne 1 ]; then + echo "server $server is not ready at $ip:$port, exit..." + exit 1 + fi +} + +check_server_ready teacher $TEACHER_SERVER_HOST $TEACHER_SERVER_PORT + +function now() { + date '+%Y-%m-%d-%H-%M' +} + +# full recompute +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +WORKING_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/config/runtime_env.yaml"} + +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m main_gkd --config-name on_policy_distill_trainer \ + data.train_files="$train_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=2048 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.trust_remote_code=True \ + actor_rollout_ref.model.path=$HF_MODEL_PATH \ + actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.actor.megatron.sequence_parallel=$SP \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=False \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_expert_capacity_factor=1.2 \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.micro_batch_size=2 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.max_token_len=6000 \ + actor_rollout_ref.actor.use_torch_compile=True \ + actor_rollout_ref.actor.checkpoint.save_contents=['model'] \ + actor_rollout_ref.actor.checkpoint.load_contents=[] \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ + actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ + +actor_rollout_ref.actor.distill_loss.name=kl \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=sync \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.top_p=0.99 \ + actor_rollout_ref.rollout.top_k=-1 \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$INFER_TP \ + actor_rollout_ref.rollout.load_format='dummy_megatron' \ + actor_rollout_ref.rollout.agent.num_workers=1 \ + actor_rollout_ref.teacher.server_ip=$TEACHER_SERVER_HOST \ + actor_rollout_ref.teacher.server_port=$TEACHER_SERVER_PORT \ + trainer.logger=['console'] \ + trainer.project_name='on-policy-distill' \ + trainer.experiment_name="moonlight-dsv3-$(now)" \ + trainer.nnodes=$NODES \ + trainer.n_gpus_per_node=6 \ + rollout.nnodes=$NODES \ + rollout.n_gpus_per_node=2 \ + trainer.scheduler="two_step_off" \ + trainer.save_freq=100000 \ + trainer.test_freq=-1 \ + trainer.val_before_train=False \ + trainer.total_epochs=10 $@ diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/teacher/__init__.py b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d4221f6257f30a38e5ed1b7b67b74dca01965778 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 Individual Contributor: furunding +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .client import TeacherClient + +__all__ = ["TeacherClient"] diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/teacher/client.py b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/client.py new file mode 100644 index 0000000000000000000000000000000000000000..3e521384ea0092eca404d89d065d4943b490f072 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/client.py @@ -0,0 +1,208 @@ +# Copyright 2025 Individual Contributor: furunding +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import queue +import random +import threading +from concurrent.futures import Future +from contextlib import nullcontext +from datetime import datetime + +import torch +import zmq +from codetiming import Timer + +try: + from .utils import deserialize, serialize +except ImportError: + from utils import deserialize, serialize + +DEBUG = False + + +def check_if_invalid(topk_logps, inputs): + is_valid = True + reason = "" + for x in topk_logps: + if x.isnan().any(): + is_valid = False + reason = "nan" + break + elif x.isinf().any(): + is_valid = False + reason = "inf" + break + elif (x == 0).any(): + is_valid = False + reason = "zero" + break + if not is_valid: + if isinstance(inputs, torch.Tensor): + inputs = inputs.tolist() + with open("teacher_debug.log", "a") as f: + f.write("{}\n".format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))) + f.write(f"{reason}\n") + f.write(f"{str(inputs)}\n") + + +class TeacherClient: + def __init__( + self, + server_ip, + server_port, + num_microbatches=1, + max_tokens=1, + n_server_workers=1, + temperature=1, + only_response=False, + max_seq_len=None, + ) -> None: + self.server_ip = server_ip + self.server_port = server_port + self.num_microbatches = num_microbatches + self.n_server_workers = n_server_workers + self.max_tokens = max_tokens + self.task_queue = queue.Queue() + self.mutex = threading.Lock() if n_server_workers > 1 else nullcontext() + self.context = zmq.Context() + self.temperature = temperature + self.only_response = only_response + self.max_seq_len = max_seq_len + self._run() + + def bg_task(self): + socket = self.context.socket(zmq.REQ) + socket.connect(f"tcp://{self.server_ip}:{self.server_port}") + socket.setsockopt(zmq.LINGER, 0) + socket.setsockopt(zmq.RCVTIMEO, 600000) # 接收超时 30 分钟 + + while True: + futures = [] + inputs = [] + batch = [] + try: + with self.mutex: + for _ in range(self.num_microbatches): + future, data = self.task_queue.get() + if DEBUG: + inputs.append(data) + futures.append(future) + batch.extend(data.tolist() if isinstance(data, torch.Tensor) else data) + + if self.max_seq_len: + max_tokens = [min(self.max_tokens, self.max_seq_len - len(prompt)) for prompt in batch] + request = {"prompt_token_ids": batch, "max_tokens": max_tokens} + else: + request = {"prompt_token_ids": batch, "max_tokens": self.max_tokens} + if self.temperature: + request["temperature"] = self.temperature + if self.only_response: + request["only_response"] = True + + socket.send(serialize(request)) + raw = socket.recv() + response = deserialize(raw) + + if isinstance(response, dict) and response.get("status") == "error": + reason = response.get("reason", "unknown") + err = RuntimeError(f"Teacher error: {reason}") + for f in futures: + f.set_exception(err) + continue + + required = ("responses", "teacher_topk_logprobs", "teacher_topk_indices") + for k in required: + if k not in response: + raise RuntimeError(f"Invalid response: missing key '{k}'") + + total = len(response["teacher_topk_logprobs"]) + if self.num_microbatches <= 0 or total % self.num_microbatches != 0: + raise RuntimeError(f"Size mismatch: total={total}, num_microbatches={self.num_microbatches}") + + mbs = total // self.num_microbatches + for i, future in enumerate(futures): + s, e = i * mbs, (i + 1) * mbs + responses = response["responses"][s:e] + teacher_topk_logps = response["teacher_topk_logprobs"][s:e] + if DEBUG: + check_if_invalid(teacher_topk_logps, inputs[i]) + teacher_topk_indices = response["teacher_topk_indices"][s:e] + future.set_result((responses, teacher_topk_logps, teacher_topk_indices)) + + except zmq.Again: + err = TimeoutError(f"Timeout waiting for server {self.server_ip}:{self.server_port}") + for f in futures: + f.set_exception(err) + continue + except Exception as e: + for f in futures: + try: + f.set_exception(e) + except Exception: + pass + continue + + def _run(self): + for _ in range(self.n_server_workers): + threading.Thread(target=self.bg_task, daemon=True).start() + + def submit(self, data): + future = Future() + self.task_queue.put((future, data)) + return future + + def __del__(self): + self.context.destroy() + + +if __name__ == "__main__": + gbs = 128 + n_gps = 1 + mbs = 2 + seq_len = 4096 + + prompt_lens = (n_gps * gbs) * [seq_len] + + tc = TeacherClient( + server_ip="127.0.0.1", server_port=15555, num_microbatches=gbs // mbs, n_server_workers=1, only_response=False + ) + + prompt_token_ids = [] + + for pl in prompt_lens: + prompt_token_ids.append([random.randint(1, 99999) for j in range(pl)]) + + with Timer(name="get_topk_logprobs", initial_text=True): + futures = [] + for i in range(0, n_gps * gbs, mbs): + futures.append(tc.submit(prompt_token_ids[i : i + mbs])) + + for future in futures: + responses, teacher_topk_logprobs, teacher_topk_indices = future.result() + + print(len(teacher_topk_logprobs), len(teacher_topk_indices)) + + assert len(responses) == mbs + assert len(teacher_topk_logprobs) == mbs + assert len(teacher_topk_indices) == mbs + + assert all(x.shape == y.shape for x, y in zip(teacher_topk_logprobs, teacher_topk_indices, strict=False)) + out_lens = [x.shape[0] for x in teacher_topk_logprobs] + out_dims = [x.shape[1] for x in teacher_topk_logprobs] + assert all(out_len == seq_len for out_len in out_lens) + assert all(out_dim == 256 for out_dim in out_dims) + assert all(x.dtype == torch.float32 for x in teacher_topk_logprobs), [ + x.dtype for x in teacher_topk_logprobs + ] + assert all(x.dtype == torch.int32 for x in teacher_topk_indices) + assert all(x.dtype == torch.int32 for x in responses) diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/teacher/join_server.sh b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/join_server.sh new file mode 100644 index 0000000000000000000000000000000000000000..f163a2d0ddd0ad6d1be22e10997c745195598f8e --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/join_server.sh @@ -0,0 +1,33 @@ +export PROXY_FRONTEND_PORT=15555 +export PROXY_BACKEND_PORT=15556 + +PROXY_IP="127.0.0.1" +BACKEND=vllm +CKPT_PATH="/path/to/TEACHER_MODEL/" + +wait_server_ready() { + server=$1 + ip=$2 + port=$3 + while true; do + echo "wait $server server ready at $ip:$port..." + result=`echo -e "\n" | telnet $ip $port 2> /dev/null | grep Connected | wc -l` + if [ $result -eq 1 ]; then + break + else + sleep 1 + fi + done +} + +# pkill -f "python proxy.py" +# pkill -f "python worker.py" +ps -ef | grep "python worker.py" | grep -v grep | awk -F ' ' '{print $2}' | xargs -r kill -9 + +wait_server_ready proxy $PROXY_IP $PROXY_BACKEND_PORT + +echo "teacher proxy is ready" + +nohup python worker.py --backend $BACKEND --proxy-addr $PROXY_IP:$PROXY_BACKEND_PORT --tp-size 8 --n-logprobs 256 --ckpt-path $CKPT_PATH &> worker.log & + +echo "teacher server is ready" diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/teacher/proxy.py b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..bdfd65ad8aec1e03f8eb79cd8ad555042f2bdcf9 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/proxy.py @@ -0,0 +1,59 @@ +# Copyright 2025 Individual Contributor: furunding +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import zmq + +context = zmq.Context() + +frontend_listen_port = os.environ.get("PROXY_FRONTEND_PORT") +backend_listen_port = os.environ.get("PROXY_BACKEND_PORT") + +assert frontend_listen_port is not None, "PROXY_FRONTEND_PORT is not set" +assert backend_listen_port is not None, "PROXY_BACKEND_PORT is not set" + +# 创建前端 ROUTER 套接字并绑定到客户端连接地址 +frontend = context.socket(zmq.ROUTER) +frontend.bind(f"tcp://*:{frontend_listen_port}") + +# 创建后端 DEALER 套接字并绑定到服务端连接地址 +backend = context.socket(zmq.DEALER) +backend.bind(f"tcp://*:{backend_listen_port}") + +# 创建 poller 用于同时监听多个套接字 +poller = zmq.Poller() +poller.register(frontend, zmq.POLLIN) +poller.register(backend, zmq.POLLIN) + +print("proxy is running...") + +while True: + socks = dict(poller.poll()) + + if frontend in socks: + # 从 ROUTER 接收来自客户端的消息(multipart 消息) + parts = frontend.recv_multipart() + # print(f"收到客户端消息: {parts}") + + # 将完整的 multipart 消息转发给 DEALER + backend.send_multipart(parts) + + if backend in socks: + # 从 DEALER 接收来自服务端的回复 + reply_parts = backend.recv_multipart() + # print(f"收到服务端回复: {reply_parts}") + + # 将回复转发回原始客户端(假设第一个部分是客户端 ID) + frontend.send_multipart(reply_parts) diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/teacher/start_server.sh b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/start_server.sh new file mode 100644 index 0000000000000000000000000000000000000000..6baa29833a89a2594c48cb4cecc166840f139fe0 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/start_server.sh @@ -0,0 +1,34 @@ +export PROXY_FRONTEND_PORT=15555 +export PROXY_BACKEND_PORT=15556 + +BACKEND=vllm +CKPT_PATH="/path/to/TEACHER_MODEL/" + +wait_server_ready() { + server=$1 + ip=$2 + port=$3 + while true; do + echo "wait $server server ready at $ip:$port..." + result=`echo -e "\n" | telnet $ip $port 2> /dev/null | grep Connected | wc -l` + if [ $result -eq 1 ]; then + break + else + sleep 1 + fi + done +} + +ps -ef | grep "python proxy.py" | grep -v grep | awk -F ' ' '{print $2}' | xargs -r kill -9 +ps -ef | grep "python worker.py" | grep -v grep | awk -F ' ' '{print $2}' | xargs -r kill -9 + +nohup python proxy.py &> proxy.log & + +wait_server_ready proxy localhost $PROXY_BACKEND_PORT + +echo "teacher proxy is ready" + +nohup python worker.py --backend $BACKEND --tp-size 1 --n-logprobs 256 --ckpt-path $CKPT_PATH &> worker.log & +echo "start teacher worker" + +echo "teacher server is ready" \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/teacher/utils.py b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..126b44a445c0313ace6dcb23f5f3c6deb4ef48b3 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/utils.py @@ -0,0 +1,61 @@ +# Copyright 2025 Individual Contributor: furunding +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io + +import torch + + +def chunk_list(lst, n_chunks): + """Split a list into chunks of equal length""" + size = len(lst) // n_chunks + for i, start in enumerate(range(0, len(lst), size)): + if i == n_chunks - 1: + yield lst[start:] + return + else: + yield lst[start : start + size] + + +def serialize(data): + buffer = io.BytesIO() + torch.save(data, buffer) + return buffer.getbuffer() + + +def deserialize(message): + buffer = io.BytesIO(message) + return torch.load(buffer) + + +if __name__ == "__main__": + lst = list(range(12)) + sub_lsts = list(chunk_list(lst, 3)) + + assert len(sub_lsts) == 3 + assert sub_lsts[0] == [0, 1, 2, 3] + assert sub_lsts[1] == [4, 5, 6, 7] + assert sub_lsts[2] == [8, 9, 10, 11] + + lst = list(range(11)) + sub_lsts = list(chunk_list(lst, 3)) + assert len(sub_lsts) == 3 + assert sub_lsts[0] == [0, 1, 2] + assert sub_lsts[1] == [3, 4, 5] + assert sub_lsts[2] == [6, 7, 8, 9, 10] + + lst = list(range(11)) + sub_lsts = list(chunk_list(lst, 1)) + assert len(sub_lsts) == 1 + assert sub_lsts[0] == list(range(11)) diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/teacher/vllm_engine.py b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/vllm_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..a4189cb7388254730b7baeed578aea68023107a2 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/vllm_engine.py @@ -0,0 +1,244 @@ +# Copyright 2025 Individual Contributor: furunding +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import random +from typing import NamedTuple + +import torch +from codetiming import Timer +from transformers import AutoConfig +from vllm import LLM, SamplingParams + +# from vllm.v1.outputs import LogprobsTensors +from vllm.v1.engine.logprobs import LogprobsProcessor + + +def _update_prompt_logprobs( + self, + prompt_logprobs_tensors, +) -> None: + """Update with prompt logprobs from EngineCore. + + Args: + prompt_logprobs_tensors: tuple containing the prompt logprobs + tensors. + + """ + + # Prompt logprobs are enabled. + assert self.num_prompt_logprobs is not None + assert self.prompt_logprobs is not None + + self.prompt_logprobs.append(prompt_logprobs_tensors) + + +def _update_sample_logprobs(self, logprobs_lists) -> None: + """Update with sample logprobs from EngineCore. + + Outer lists are only of len > 1 if EngineCore made + >1 tokens in prior step (e.g. in spec decoding). + + Args: + logprobs_lists: the lists of logprob tokens, logprobs, and ranks. + + """ + + assert self.num_logprobs is not None + assert self.logprobs is not None + assert self.cumulative_logprob is not None + + # token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists + + # for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, + # token_ids_lst): + + # # Detokenize (non-incrementally). + # decoded_tokens = NONES if self.tokenizer is None else ( + # convert_ids_list_to_tokens(self.tokenizer, token_ids)) + + # # Sampler puts the sampled logprob in first. + # sampled_token_logprob = logprobs[0] + # self.cumulative_logprob += sampled_token_logprob + + # # Update with the Logprob dictionary for this pos. + # self.logprobs.append( + # self._make_logprob_dict( + # logprobs, + # token_ids, + # decoded_tokens, + # rank, + # self.num_logprobs, + # )) + self.logprobs.append(logprobs_lists) + + +LogprobsProcessor._update_prompt_logprobs = _update_prompt_logprobs +LogprobsProcessor._update_sample_logprobs = _update_sample_logprobs + + +class LogprobsTensors(NamedTuple): + # [num_reqs, max_num_logprobs + 1] + logprob_token_ids: torch.Tensor + # [num_reqs, max_num_logprobs + 1] + logprobs: torch.Tensor + # [num_reqs] + selected_token_ranks: torch.Tensor + + def tolists(self): + return LogprobsTensors( + logprob_token_ids=self.logprob_token_ids.cpu(), + logprobs=self.logprobs.cpu(), + selected_token_ranks=self.selected_token_ranks.cpu(), + ) + + @staticmethod + def empty_cpu(num_positions: int, num_tokens_per_position: int) -> "LogprobsTensors": + """Create empty LogprobsTensors on CPU.""" + + logprob_token_ids = torch.empty((num_positions, num_tokens_per_position), dtype=torch.int32, device="cpu") + logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32) + selected_token_ranks = torch.empty(num_positions, dtype=torch.int32, device="cpu") + return LogprobsTensors( + logprob_token_ids=logprob_token_ids, + logprobs=logprobs, + selected_token_ranks=selected_token_ranks, + ) + + def slice(self, start: int, end: int): + return LogprobsTensors( + self.logprob_token_ids[start:end], + self.logprobs[start:end], + self.selected_token_ranks[start:end], + ) + + +# outputs.LogprobsTensors = LogprobsTensors +# def tolists(self): +# return self + + +# LogprobsTensors.tolists = tolists +# setattr(LogprobsTensors, "slice", slice) + + +class VLLMEngine: + def __init__(self, ckpt_path, n_logprobs=0, tp_size=1): + self.n_logprobs = n_logprobs + # self.llm = LLM(ckpt_path, tensor_parallel_size=tp_size, trust_remote_code=True, + # enable_chunked_prefill=False, distributed_executor_backend="ray", + # max_logprobs=n_logprobs, gpu_memory_utilization=0.7) + self.llm = LLM( + ckpt_path, + tensor_parallel_size=tp_size, + trust_remote_code=True, + enable_chunked_prefill=False, + max_logprobs=n_logprobs, + gpu_memory_utilization=0.7, + ) + + def get_topk_logprobs(self, prompt_token_ids, temperature=0.8, max_new_tokens=1, only_response=False): + def make_sampling_params(i=None): + return SamplingParams( + temperature=temperature, + top_p=0.95, + detokenize=False, + logprobs=self.n_logprobs, + prompt_logprobs=None if only_response else self.n_logprobs, + max_tokens=max_new_tokens[i] if (i is not None) else max_new_tokens, + ) + + if isinstance(max_new_tokens, list): + assert len(prompt_token_ids) == len(max_new_tokens) + sampling_params = [make_sampling_params(i) for i in range(len(max_new_tokens))] + else: + sampling_params = make_sampling_params() + + outputs = self.llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params) + + responses, teacher_topk_logprobs, teacher_topk_indices = [], [], [] + for output in outputs: + responses.append(torch.tensor(output.outputs[0].token_ids, dtype=torch.int32)) + if self.n_logprobs > 0: + response_topk_logprobs = torch.tensor( + [x.logprobs[0] for x in output.outputs[0].logprobs], + dtype=torch.float32, + )[:, 1:] + response_topk_indices = torch.tensor( + [x.logprob_token_ids[0] for x in output.outputs[0].logprobs], + dtype=torch.int32, + )[:, 1:] + if only_response: + teacher_topk_logprobs.append(response_topk_logprobs) + teacher_topk_indices.append(response_topk_indices) + else: + prompt_topk_logprobs = output.prompt_logprobs[1].logprobs[:, 1:].to(torch.float32) + prompt_topk_indices = output.prompt_logprobs[1].logprob_token_ids[:, 1:].to(torch.int32) + teacher_topk_logprobs.append(torch.vstack([prompt_topk_logprobs, response_topk_logprobs])) + teacher_topk_indices.append(torch.vstack([prompt_topk_indices, response_topk_indices])) + + return responses, teacher_topk_logprobs, teacher_topk_indices + + # def get_response_and_topk_logprobs(self, prompt_token_ids, max_tokens=64): + # sampling_params = SamplingParams(temperature=0.8, top_p=0.95, detokenize=False, + # logprobs=self.n_logprobs, max_tokens=max_tokens) + + # outputs = self.llm.generate(prompt_token_ids=prompt_token_ids, + # sampling_params=sampling_params) + + # student_topk_logprobs, student_topk_indices = [], [] + # for output in outputs: + # student_topk_logprobs.append([]) + # student_topk_indices.append([]) + # for logprob_list in output.outputs[0].logprobs: + # student_topk_logprobs[-1].extend(logprob_list.logprobs) + # student_topk_indices[-1].extend(logprob_list.logprob_token_ids) + + # return student_topk_logprobs, student_topk_indices + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test vLLM logprob") + parser.add_argument("model_dir", help="Model directory") + parser.add_argument("--tp-size", type=int, default=1, help="TP size") + parser.add_argument("--batch-size", "-b", type=int, default=64, help="Test batch size") + parser.add_argument("--seq-len", "-s", type=int, default=3840, help="Test sequence length") + parser.add_argument("--token-file", "-t", type=str, help="Input token file") + args = parser.parse_args() + + config = AutoConfig.from_pretrained(args.model_dir) + print(f"Reading configs from {args.model_dir}: {config.vocab_size=}") + + prompt_token_ids = [] + if args.token_file: + # Init input with tokenid file + from get_batch import get_batch + + prompt_token_ids = get_batch() + else: + # Init input randomly + prompt_lens = args.batch_size * [args.seq_len] + for pl in prompt_lens: + prompt_token_ids.append([random.randint(1, config.vocab_size - 1000) for j in range(pl)]) + + engine = VLLMEngine(ckpt_path=args.model_dir, n_logprobs=256, tp_size=args.tp_size) + + with Timer(name="get_topk_logprobs", initial_text=True): + responses, teacher_topk_logprobs, teacher_topk_indices = engine.get_topk_logprobs( + prompt_token_ids, temperature=0.7, max_new_tokens=1, only_response=True + ) + # debug + import ipdb + + ipdb.set_trace() diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/teacher/worker.py b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..1b284e83ae960e9bdd3a01c557ec1289bdd8c8de --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/teacher/worker.py @@ -0,0 +1,95 @@ +# Copyright 2025 Individual Contributor: furunding +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import functools + +import torch +import zmq +from codetiming import Timer +from utils import deserialize, serialize + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--proxy-addr", type=str, default="localhost:15556") + parser.add_argument("--backend", type=str, default="vllm") + parser.add_argument("--seq-len", type=int, default=3840) + parser.add_argument("--n-logprobs", type=int, default=256) + parser.add_argument("--ckpt-path", type=str, required=True) + parser.add_argument("--tp-size", type=int, default=1) + parser.add_argument("--ep-size", type=int, default=1) + parser.add_argument("--dp-size", type=int, default=1) + args = parser.parse_args() + + if args.backend == "vllm": + from vllm_engine import VLLMEngine + + engine = VLLMEngine(args.ckpt_path, args.n_logprobs, args.tp_size) + else: + raise ValueError(f"Unknown backend: {args.backend}.") + + context = zmq.Context() + socket = context.socket(zmq.REP) + # socket.bind(f"tcp://*:{port}") + socket.connect(f"tcp://{args.proxy_addr}") + + print("worker started...", flush=True) + + # TODO: 新增prefix_cache_hit监控 + + while True: + message = socket.recv() + try: + with Timer(name="deserialize", initial_text=True, logger=functools.partial(print, flush=True)): + request = deserialize(message) + except Exception as e: + print("[Server Error] Deserialize failed:", str(e), flush=True) + socket.send(serialize({"status": "error", "reason": f"Deserialize failed: {e}"})) + continue + if isinstance(request, dict) and "prompt_token_ids" in request: + prompt_token_ids = request["prompt_token_ids"] + temperature = request.get("temperature", 0.8) + max_tokens = request.get("max_tokens", 1) + only_response = request.get("only_response", False) + if isinstance(prompt_token_ids, torch.Tensor): + prompt_token_ids = prompt_token_ids.tolist() + with Timer(name="get_prompt_topk_logprobs", initial_text=True, logger=functools.partial(print, flush=True)): + ### try and sendback error + try: + responses, logps, indices = engine.get_topk_logprobs( + prompt_token_ids, temperature, max_new_tokens=max_tokens, only_response=only_response + ) + except Exception as e: + print("[Server Error] Exception occurred during generation:", str(e)) + socket.send(serialize({"status": "error", "reason": f"Generate failed: {str(e)}"})) + continue + with Timer(name="serialize", initial_text=True, logger=functools.partial(print, flush=True)): + message = serialize( + { + "status": "ok", + "teacher_topk_logprobs": logps, + "teacher_topk_indices": indices, + "responses": responses, + } + ) + with Timer(name="send", initial_text=True, logger=functools.partial(print, flush=True)): + socket.send(message) + + else: + socket.send(serialize({"status": "error", "reason": "invalid request format."})) + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/teacher_utils.py b/ICL/DAPO/verl-recipe/gkd/megatron/teacher_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b2abb75dae9514d30dad7e719e902d6d984cea --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/teacher_utils.py @@ -0,0 +1,172 @@ +# Copyright 2025 Individual Contributor: furunding +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility functions for teacher model knowledge distillation. + +Functions: + get_teacher_knowledge: Retrieve teacher model's top-k predictions and log probabilities. +""" + +import time +from types import SimpleNamespace + +import torch + +from verl import DataProto + +teacher_topk_logps_padded, teacher_topk_indices_padded = None, None + + +def get_teacher_knowledge(batch: DataProto, teacher_client, n_server_workers=1, is_async=False): + """ + Retrieve teacher model's top-k predictions and log probabilities for knowledge distillation. + + Args: + batch (DataProto): Input batch containing input_ids and attention_mask + teacher_client: Client for communicating with teacher model + n_server_workers (int): Number of parallel workers for teacher model inference + is_async (bool): Whether to use asynchronous processing + + Returns: + If is_async=True: SimpleNamespace with get() method to process futures + If is_async=False: Processed DataProto containing teacher knowledge + + Raises: + RuntimeError: If teacher model request fails + """ + + input_ids = [] + attention_mask = batch.batch["attention_mask"].to(torch.bool) + # response_length = batch.meta_info["response_length"] + + for ids, mask in zip(batch.batch["input_ids"], attention_mask, strict=False): + input_ids.append(ids[mask].tolist()) + + all_teacher_topk_logps = [] + all_teacher_topk_indices = [] + + batch_size = len(input_ids) + assert batch_size % n_server_workers == 0 + micro_batch_size = batch_size // n_server_workers + futures = [] + tik1 = time.time() + tok1 = tik1 + + def cb(future): + nonlocal tok1 + tok1 = max(tok1, time.time()) + + for i in range(0, batch_size, micro_batch_size): + fut = teacher_client.submit(input_ids[i : i + micro_batch_size]) + fut.add_done_callback(cb) + futures.append(fut) + + def handle_futures(): + for future in futures: + try: + _, teacher_topk_logps, teacher_topk_indices = future.result() + except Exception as e: + raise RuntimeError(f"Teacher request failed: {e}") from e + + all_teacher_topk_logps.extend(teacher_topk_logps) + all_teacher_topk_indices.extend(teacher_topk_indices) + + tik2 = time.time() + # teacher_topk_logps = [x.to(params_dtype) for x in all_teacher_topk_logps] + # teacher_topk_indices = [x.to(params_dtype) for x in all_teacher_topk_indices] + teacher_topk_logps, teacher_topk_indices = all_teacher_topk_logps, all_teacher_topk_indices + + real_seq_lens = torch.tensor([x.size(0) for x in teacher_topk_logps], dtype=torch.int32) + + topk = teacher_topk_logps[0].size(-1) + + logp_dtype = teacher_topk_logps[0].dtype + idx_dtype = teacher_topk_indices[0].dtype + teacher_knowledge_shape = list(batch.batch["input_ids"].shape) + [topk] + + global teacher_topk_logps_padded, teacher_topk_indices_padded + if ( + teacher_topk_logps_padded is None + or teacher_topk_logps_padded.dtype != logp_dtype + or teacher_topk_logps_padded.shape != torch.Size(teacher_knowledge_shape) + ): + teacher_topk_logps_padded = torch.zeros(*teacher_knowledge_shape, dtype=logp_dtype) + else: + teacher_topk_logps_padded.zero_() + + if ( + teacher_topk_indices_padded is None + or teacher_topk_indices_padded.dtype != idx_dtype + or teacher_topk_indices_padded.shape != torch.Size(teacher_knowledge_shape) + ): + teacher_topk_indices_padded = torch.zeros(*teacher_knowledge_shape, dtype=idx_dtype) + else: + teacher_topk_indices_padded.zero_() + + batch_size = attention_mask.size(0) + for i in range(batch_size): + teacher_topk_logps_padded[i][attention_mask[i]] = teacher_topk_logps[i] + teacher_topk_indices_padded[i][attention_mask[i]] = teacher_topk_indices[i] + + output_batch = DataProto.from_single_dict( + data={"real_seq_lens": real_seq_lens}, + ) + + output_batch.non_tensor_batch.update( + { + "teacher_topk_logps": teacher_topk_logps_padded.numpy(), + "teacher_topk_indices": teacher_topk_indices_padded.numpy(), + } + ) + + tok2 = time.time() + + output_batch.meta_info["timing"] = {"get_teacher_knowledge": (tok1 - tik1) + (tok2 - tik2)} + + return output_batch + + if is_async: + return SimpleNamespace(get=handle_futures) + else: + return handle_futures() + + +if __name__ == "__main__": + batch = DataProto.load_from_disk("gen_batch_output") + from teacher import TeacherClient + + teacher_client = TeacherClient(server_ip="10.215.192.141", server_port=15555) + output_batch = get_teacher_knowledge(batch, 2, teacher_client) + output_batch_chunks = output_batch.chunk(2) + + for data in output_batch_chunks: + topk = data.meta_info["topk"] + seq_lens = data.batch["seq_lens"] + teacher_topk_logps = data.batch["teacher_topk_logps"].view(-1, topk) + teacher_topk_indices = data.batch["teacher_topk_indices"].view(-1, topk) + + attention_mask = data.batch["attention_mask"] + batch_size, sequence_length = attention_mask.size(0), attention_mask.size(1) + teacher_topk_logps_padded = torch.zeros(batch_size, sequence_length, topk, dtype=teacher_topk_logps.dtype) + teacher_topk_indices_padded = torch.zeros(batch_size, sequence_length, topk, dtype=teacher_topk_indices.dtype) + + teacher_topk_logps_padded[attention_mask] = teacher_topk_logps[: seq_lens.sum()] + teacher_topk_indices_padded[attention_mask] = teacher_topk_indices[: seq_lens.sum()] + + data.batch["teacher_topk_logps"] = teacher_topk_logps_padded + data.batch["teacher_topk_indices"] = teacher_topk_indices_padded + + assert (data.batch["teacher_topk_logps"] == data.batch["teacher_topk_logps_padded"]).all() + assert (data.batch["teacher_topk_indices"] == data.batch["teacher_topk_indices_padded"]).all() diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/test_qwen.sh b/ICL/DAPO/verl-recipe/gkd/megatron/test_qwen.sh new file mode 100644 index 0000000000000000000000000000000000000000..058fef97a690dc33bc322f9d08dc44bfa5e5010e --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/test_qwen.sh @@ -0,0 +1,88 @@ +set -x + +# 0. download the config +# only need to download the `configuration_deepseek.py`, `config.json`, `tokenizer_config.json`, `tokenizer.json` and `generation_config.json` +# remove the `quantization_config` in the `config.json` +# set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported + +# huggingface-cli download deepseek-ai/DeepSeek-V3-0324 configuration_deepseek.py config.json + +# 1. download the dist_ckpt format model from https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main +# change the HF_MODEL_PATH to your own path +HF_MODEL_PATH=/path/to/Qwen3-0.6B +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 +export NVTE_FLASH_ATTN=1 +export NVTE_DEBUG=1 +export NVTE_DEBUG_LEVEL=2 + +# 2. run the script +gsm8k_train_path=/path/to/train.parquet +gsm8k_test_path=/path/to/test.parquet +train_files=$gsm8k_train_path +test_files=$gsm8k_test_path + +# 512 H20(96GB) +NODES=1 +PP=1 +TP=1 +EP=1 +ETP=1 +INFER_TP=1 +# consider TP/ETP, and enable recompute if short of memory + +# full recompute +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +WORKING_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/config/runtime_env.yaml"} +# RAY_ADDRESS='auto' ray job submit --working-dir . -- +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m main_gkd --config-name on_policy_distill_trainer \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.train_batch_size=64 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.trust_remote_code=True \ + +teacher.server_ip=127.0.0.1 \ + +teacher.server_port=15555 \ + actor_rollout_ref.model.path=$HF_MODEL_PATH \ + actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.actor.megatron.sequence_parallel=False \ + +actor_rollout_ref.actor.megatron.override_transformer_config.sequence_parallel=False \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + +actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + +actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.top_p=0.99 \ + actor_rollout_ref.rollout.top_k=-1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$INFER_TP \ + actor_rollout_ref.rollout.load_format='auto' \ + +algorithm.use_kl_in_reward=False \ + trainer.logger=['console'] \ + trainer.project_name='verl_examples' \ + trainer.experiment_name='qwen-distill' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=$NODES \ + rollout.n_gpus_per_node=4 \ + rollout.nnodes=$NODES \ + trainer.save_freq=-1 \ + trainer.test_freq=25 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ + trainer.val_before_train=False \ + trainer.total_training_steps=10 \ + trainer.total_epochs=1 $@ + # +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=11 \ diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/test_qwen_sglang.sh b/ICL/DAPO/verl-recipe/gkd/megatron/test_qwen_sglang.sh new file mode 100644 index 0000000000000000000000000000000000000000..f55208d4b71933d3963e2e94dbab8a10593ae2a0 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/test_qwen_sglang.sh @@ -0,0 +1,88 @@ +set -x + +# 0. download the config +# only need to download the `configuration_deepseek.py`, `config.json`, `tokenizer_config.json`, `tokenizer.json` and `generation_config.json` +# remove the `quantization_config` in the `config.json` +# set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported + +# huggingface-cli download deepseek-ai/DeepSeek-V3-0324 configuration_deepseek.py config.json + +# 1. download the dist_ckpt format model from https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main +# change the HF_MODEL_PATH to your own path +HF_MODEL_PATH=/path/to/Qwen3-0.6B +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 +export NVTE_FLASH_ATTN=1 +export NVTE_DEBUG=1 +export NVTE_DEBUG_LEVEL=2 + +# 2. run the script +gsm8k_train_path=/path/to/train.parquet +gsm8k_test_path=/path/to/test.parquet +train_files=$gsm8k_train_path +test_files=$gsm8k_test_path + +# 512 H20(96GB) +NODES=1 +PP=1 +TP=1 +EP=1 +ETP=1 +INFER_TP=1 +# consider TP/ETP, and enable recompute if short of memory + +# full recompute +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ +# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ + +WORKING_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/config/runtime_env.yaml"} +# RAY_ADDRESS='auto' ray job submit --working-dir . -- +ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m main_gkd --config-name on_policy_distill_trainer \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.train_batch_size=64 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.trust_remote_code=True \ + +teacher.server_ip=127.0.0.1 \ + +teacher.server_port=15555 \ + actor_rollout_ref.model.path=$HF_MODEL_PATH \ + actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.actor.megatron.sequence_parallel=False \ + +actor_rollout_ref.actor.megatron.override_transformer_config.sequence_parallel=False \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + +actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + +actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + +actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.top_p=0.99 \ + actor_rollout_ref.rollout.top_k=-1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$INFER_TP \ + actor_rollout_ref.rollout.load_format='auto' \ + +algorithm.use_kl_in_reward=False \ + trainer.logger=['console'] \ + trainer.project_name='verl_examples' \ + trainer.experiment_name='qwen-distill' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=$NODES \ + rollout.n_gpus_per_node=4 \ + rollout.nnodes=$NODES \ + trainer.save_freq=-1 \ + trainer.test_freq=25 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ + trainer.val_before_train=False \ + trainer.total_training_steps=10 \ + trainer.total_epochs=1 $@ + # +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=11 \ diff --git a/ICL/DAPO/verl-recipe/gkd/megatron/test_teacher_server.py b/ICL/DAPO/verl-recipe/gkd/megatron/test_teacher_server.py new file mode 100644 index 0000000000000000000000000000000000000000..f53761e739254f6fb6bf441f99ed0674b76a8e7f --- /dev/null +++ b/ICL/DAPO/verl-recipe/gkd/megatron/test_teacher_server.py @@ -0,0 +1,36 @@ +# Copyright 2025 Individual Contributor: furunding +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import torch +from teacher import TeacherClient + + +def main(): + teacher_client = TeacherClient("127.0.0.1", 15555) + tokens = [[random.randint(1, 99999) for _ in range(100)] for _ in range(2)] + tokens[0][40] = 128858 + _, teacher_topk_logps, teacher_topk_indices = teacher_client.submit(tokens).result() + assert all(logps.shape == (100, 256) for logps in teacher_topk_logps) + assert all(logps.dtype == torch.float32 for logps in teacher_topk_logps) + assert all(indices.shape == (100, 256) for indices in teacher_topk_indices) + assert all(indices.dtype == torch.int32 for indices in teacher_topk_indices) + import ipdb + + ipdb.set_trace() + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/gvpo/README.md b/ICL/DAPO/verl-recipe/gvpo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f5170d06a49cf84fe6df64f61fde8933b0e5fc65 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gvpo/README.md @@ -0,0 +1,154 @@ +
+ +# GVPO: Group Variance Policy Optimization + +[![NeurIPS](https://img.shields.io/badge/NeurIPS-b693f9?style=for-the-badge&logo=neurips&logoColor=white)](https://neurips.cc/virtual/2025/poster/117119) +[![Arxiv](https://img.shields.io/badge/Arxiv-b31b1b?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/abs/2504.19599) +[![GitHub](https://img.shields.io/badge/Code-000000?style=for-the-badge&logo=github&logoColor=white)](https://github.com/jszkc/GVPO) +[![机器之心](https://img.shields.io/badge/机器之心-07C160?style=for-the-badge&logo=wechat&logoColor=white)](https://mp.weixin.qq.com/s/mhv0bo0PEB67jbUkZU3sXg) +[![知乎](https://img.shields.io/badge/知乎-0084FF?style=for-the-badge&logo=zhihu&logoColor=white)](https://zhuanlan.zhihu.com/p/1911487456173359632) + +
+ + +## ⭐️ Overview + +**GVPO (Group Variance Policy Optimization)** is a **reinforcement learning algorithm** designed for **post-training large language models (LLMs)**. It provides both a theoretically sound and practically useful advancement for optimizing policies. + +### 🚀 Key Highlights + +- **Unique Optimal Solution:** + GVPO guarantees convergence to a unique solution that maximizes the following objective: + + $$max_{\pi_{\theta}} \mathbb{E}_{x\sim\mathcal{D},y\sim\pi_\theta(y|x)}[R(x,y)]-\beta\mathbb{D}_{KL}[\pi_\theta(y|x)||\pi_{\theta^\prime}(y|x)]$$ + +- **No Importance Sampling:** + Improves stability by eliminating the need for importance weighting. + +- **Off-Policy Flexibility:** + Supports **diverse off-policy sampling distributions**, including **experience replay** and **human demonstrations**. + + +## 🧩 Getting Started + +### 1. Installation + +This project uses **[verl](https://verl.readthedocs.io/en/latest/start/install.html)** (v0.6.0). + +To install verl, please refer to the [official installation guide](https://verl.readthedocs.io/en/latest/start/install.html). + +--- + +### 2. Data Preparation + +We follow the GRPO training setup from the official verl [example script](https://github.com/volcengine/verl/blob/ddd86f527a4af75095e4677b02b5aa272913a088/examples/grpo_trainer/run_qwen2-7b_math.sh), which uses the following datasets: + +- `DigitalLearningGmbH/MATH-lighteval` +- `openai/gsm8k` + +To download and preprocess these datasets, run: + +```bash +python -m examples.data_preprocess.math_dataset.py +python -m examples.data_preprocess.gsm8k.py +``` + +--- + +### 3. Training + +Before launching training, ensure that the **model path** in the script is correctly set. + +To start GVPO training: + +```bash +bash recipe/gvpo/run_qwen2-7b_math_gvpo.sh +``` + + +## 📘 Documentation + +Below we summarize the main components and logic of this GVPO implementation. + + +### `gvpo_core_algos.py` + +This module defines the core **GVPO loss function**, formulated as a Mean Squared Error (MSE): + +$$\mathcal{L}_{\text{GVPO}}(\theta)=\frac{1}{2}\sum_{x, \{y_i\} } \sum_{i=1}^k [(R_\theta(x,y_i)-\overline{R_\theta(x,\{y_i\})})-(R(x,y_i)-\overline{R(x,\{y_i\})})]^2$$ + +This function aggregates statistics across GPUs to compute the group-mean log ratios. + +```python +def compute_policy_loss_gvpo(old_log_prob, log_prob, advantages, response_mask, beta, uid, device_mesh, n): + + rtheta = ((log_prob * response_mask).sum(dim=-1) - (old_log_prob * response_mask).sum(dim=-1)) * beta + r_minus_avg = (advantages * response_mask).sum(dim=-1) / response_mask.sum(dim=-1) + + process_group = device_mesh._flatten().get_group() + group_size = torch.distributed.get_world_size(group=process_group) + data = {"rtheta": rtheta.clone().detach(), "uid": uid} + data = allgather_dict_tensors(data,group_size,process_group) + + unique_uids = torch.unique(data['uid']) + means = {} + for u in unique_uids: + mask = (data['uid'] == u) + mean_val = data['rtheta'][mask].mean() + assert data['rtheta'][mask].shape[0] == n + means[u.item()] = mean_val + pg_loss = 0 + for i in range(len(rtheta)): + pg_loss += 0.5 * ((rtheta[i] - means[uid[i].item()]) - r_minus_avg[i])**2 + pg_loss = pg_loss / (n-1) + return pg_loss +``` + +--- + +### `gvpo_ray_trainer.py` + +We modify the `_balance_batch` method to ensure that **all responses from the same prompt group** are processed within the same training iteration. +This grouping is necessary for correctly computing group-level log-ratio averages. + +```python +def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): +``` +> **Note:** +> Since each iteration processes `world_size * ppo_micro_batch_size_per_gpu` responses, +> the number of rollouts per prompt (`rollout.n`) must evenly divide this total. + +--- + +### `run_qwen2-7b_math_gvpo.sh` + +This script adapts verl’s GRPO training configuration for GVPO. +The following parameters are particularly important: + +```bash +algorithm.adv_estimator=grpo +actor_rollout_ref.actor.use_kl_loss=False +actor_rollout_ref.actor.policy_loss.loss_mode="gvpo" +actor_rollout_ref.actor.gvpo_beta=0.1 +algorithm.use_kl_in_reward=False +algorithm.norm_adv_by_std_in_grpo=False +``` + +When converting a GRPO training script to GVPO, make sure to **update these parameters accordingly**. + +> **Note:** +> This training script is based on verl’s official examples. +> Some settings may differ from those described in the GVPO paper—please check carefully when reproducing results. + +## ✍️ Citation + +If you use GVPO in your research, please cite this paper once the corresponding paper is available: + +``` +@article{zhang2025gvpo, + title={GVPO: Group variance policy optimization for large language model post-training}, + author={Zhang, Kaichen and Hong, Yuzhong and Bao, Junwei and Jiang, Hongfei and Song, Yang and Hong, Dingqian and Xiong, Hui}, + journal={arXiv preprint arXiv:2504.19599}, + year={2025} +} +``` \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/gvpo/gvpo_actor_config.py b/ICL/DAPO/verl-recipe/gvpo/gvpo_actor_config.py new file mode 100644 index 0000000000000000000000000000000000000000..83564090518c0ec8338f31e44fc7e58702fb95cf --- /dev/null +++ b/ICL/DAPO/verl-recipe/gvpo/gvpo_actor_config.py @@ -0,0 +1,248 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional + +from omegaconf import MISSING + +from verl.base_config import BaseConfig +from verl.trainer.config import CheckpointConfig +from verl.utils.profiler.config import ProfilerConfig +from verl.workers.config.engine import FSDPEngineConfig, McoreEngineConfig +from verl.workers.config.model import HFModelConfig +from verl.workers.config.optimizer import OptimizerConfig + +__all__ = ["PolicyLossConfig", "ActorConfig", "FSDPActorConfig", "McoreActorConfig"] + + +@dataclass +class PolicyLossConfig(BaseConfig): + """Configuration for policy loss computation. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + loss_mode (str): Loss function mode. Options: 'vanilla', 'clip-cov', 'kl-cov', 'gpg'. + clip_cov_ratio (float): Ratio of tokens to be clipped for clip-cov loss. + clip_cov_lb (float): Lower bound for clip-cov loss. + clip_cov_ub (float): Upper bound for clip-cov loss. + kl_cov_ratio (float): Ratio of tokens to be applied KL penalty for kl-cov loss. + ppo_kl_coef (float): KL divergence penalty coefficient. + """ + + loss_mode: str = "vanilla" + clip_cov_ratio: float = 0.0002 + clip_cov_lb: float = 1.0 + clip_cov_ub: float = 5.0 + kl_cov_ratio: float = 0.0002 + ppo_kl_coef: float = 0.1 + + +@dataclass +class ActorConfig(BaseConfig): + """Configuration for actor model training. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + strategy (str): Training strategy. Must be specified. + ppo_mini_batch_size (int): Mini-batch size for PPO training. + ppo_micro_batch_size (Optional[int]): Micro-batch size for PPO training. + If None, uses ppo_micro_batch_size_per_gpu. + ppo_micro_batch_size_per_gpu (Optional[int]): Micro-batch size per GPU for PPO training. + use_dynamic_bsz (bool): Whether to use dynamic batch sizing. + ppo_max_token_len_per_gpu (int): Maximum token length per GPU for PPO training. + clip_ratio (float): PPO clipping ratio for policy loss. + clip_ratio_low (float): Lower bound for PPO clipping ratio. + clip_ratio_high (float): Upper bound for PPO clipping ratio. + policy_loss (PolicyLossConfig): Configuration for policy loss computation. + clip_ratio_c (float): Clipping ratio for critic loss. + loss_agg_mode (str): Loss aggregation mode. Options: 'token-mean', 'sample-mean'. + entropy_coeff (float): Entropy coefficient for regularization. + use_kl_loss (bool): Whether to use KL divergence loss. + use_torch_compile (bool): Whether to use torch.compile for optimization. + kl_loss_coef (float): KL divergence loss coefficient. + kl_loss_type (str): Type of KL loss to use. + ppo_epochs (int): Number of PPO epochs per training step. + shuffle (bool): Whether to shuffle data during training. + checkpoint (CheckpointConfig): Configuration for checkpointing. + optim (OptimizerConfig): Configuration for optimizer. + use_fused_kernels (bool): Whether to use custom fused kernels (e.g., FlashAttention, fused MLP). + """ + + _mutable_fields = BaseConfig._mutable_fields | { + "ppo_mini_batch_size", + "ppo_micro_batch_size", + "ppo_micro_batch_size_per_gpu", + "ppo_infer_micro_batch_size_per_gpu", + } + + strategy: str = MISSING + ppo_mini_batch_size: int = 256 + ppo_micro_batch_size: Optional[int] = None # deprecate + ppo_micro_batch_size_per_gpu: Optional[int] = None + ppo_infer_micro_batch_size_per_gpu: Optional[int] = None + use_dynamic_bsz: bool = False + ppo_max_token_len_per_gpu: int = 16384 + ppo_infer_max_token_len_per_gpu: int = 16384 + clip_ratio: float = 0.2 + clip_ratio_low: float = 0.2 + clip_ratio_high: float = 0.2 + freeze_vision_tower: bool = False + policy_loss: PolicyLossConfig = field(default_factory=PolicyLossConfig) + clip_ratio_c: float = 3.0 + loss_agg_mode: str = "token-mean" + entropy_coeff: float = 0 + use_kl_loss: bool = False + use_torch_compile: bool = True + kl_loss_coef: float = 0.001 + kl_loss_type: str = "low_var_kl" + ppo_epochs: int = 1 + shuffle: bool = False + checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig) + optim: OptimizerConfig = field(default_factory=OptimizerConfig) + use_fused_kernels: bool = False + profiler: ProfilerConfig = field(default_factory=ProfilerConfig) + engine: BaseConfig = field(default_factory=BaseConfig) + data_loader_seed = 1 + rollout_n: int = 1 # must be override by sampling config + model_config: HFModelConfig = field(default_factory=BaseConfig) + + gvpo_beta: float = 0.1 # GVPO specific beta + + def __post_init__(self): + """Validate actor configuration parameters.""" + assert self.strategy != MISSING + assert self.rollout_n != MISSING + if not self.use_dynamic_bsz: + if self.ppo_micro_batch_size is not None and self.ppo_micro_batch_size_per_gpu is not None: + raise ValueError( + "[actor] You have set both 'actor.ppo_micro_batch_size' AND 'actor.ppo_micro_batch_size_per_gpu'. " + "Please remove 'actor.ppo_micro_batch_size' because only '*_ppo_micro_batch_size_per_gpu' is " + "supported (the former is deprecated)." + ) + else: + assert not (self.ppo_micro_batch_size is None and self.ppo_micro_batch_size_per_gpu is None), ( + "[actor] Please set at least one of 'actor.ppo_micro_batch_size' or " + "'actor.ppo_micro_batch_size_per_gpu' if use_dynamic_bsz is not enabled." + ) + + valid_loss_agg_modes = [ + "token-mean", + "seq-mean-token-sum", + "seq-mean-token-mean", + "seq-mean-token-sum-norm", + ] + if self.loss_agg_mode not in valid_loss_agg_modes: + raise ValueError(f"Invalid loss_agg_mode: {self.loss_agg_mode}") + + def validate(self, n_gpus: int, train_batch_size: int, model_config: dict = None): + """Validate actor configuration with runtime parameters.""" + if not self.use_dynamic_bsz: + if train_batch_size < self.ppo_mini_batch_size: + raise ValueError( + f"train_batch_size ({train_batch_size}) must be >= " + f"actor.ppo_mini_batch_size ({self.ppo_mini_batch_size})" + ) + + sp_size = getattr(self, "ulysses_sequence_parallel_size", 1) + if self.ppo_micro_batch_size is not None: + if self.ppo_mini_batch_size % self.ppo_micro_batch_size != 0: + raise ValueError( + f"ppo_mini_batch_size ({self.ppo_mini_batch_size}) must be divisible by " + f"ppo_micro_batch_size ({self.ppo_micro_batch_size})" + ) + if self.ppo_micro_batch_size * sp_size < n_gpus: + raise ValueError( + f"ppo_micro_batch_size ({self.ppo_micro_batch_size}) * " + f"ulysses_sequence_parallel_size ({sp_size}) must be >= n_gpus ({n_gpus})" + ) + + @staticmethod + def _check_mutually_exclusive(mbs, mbs_per_gpu, name: str): + """Validate mutually exclusive micro batch size configuration options.""" + param = "ppo_micro_batch_size" + param_per_gpu = f"{param}_per_gpu" + + if mbs is None and mbs_per_gpu is None: + raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.") + + if mbs is not None and mbs_per_gpu is not None: + raise ValueError( + f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove " + f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)." + ) + + +@dataclass +class McoreActorConfig(ActorConfig): + """Configuration for Megatron actor models. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + strategy (str): Training strategy set to 'megatron' for Megatron parallelism. + data_loader_seed (Optional[int]): Seed for data loader. If None, uses global seed. + load_weight (bool): Whether to load model weights from checkpoint. + megatron (dict[str, Any]): Configuration for Megatron parallelism settings. + profile (dict[str, Any]): Configuration for profiling settings. + """ + + strategy: str = "megatron" + data_loader_seed: Optional[int] = None + load_weight: bool = True + megatron: McoreEngineConfig = field(default_factory=McoreEngineConfig) + profile: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class FSDPActorConfig(ActorConfig): + """Configuration for FSDP actor models. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + strategy (str): Training strategy set to 'fsdp' for Fully Sharded Data Parallel. + grad_clip (float): Gradient clipping threshold. + ulysses_sequence_parallel_size (int): Ulysses sequence parallel size for long sequences. + entropy_from_logits_with_chunking (bool): Whether to compute entropy from logits + with chunking for memory efficiency. + entropy_checkpointing (bool): Whether to use gradient checkpointing for entropy computation. + fsdp_config (dict[str, Any]): Configuration for FSDP settings. + use_remove_padding (bool): Whether to remove padding tokens in inputs during training + """ + + strategy: str = "fsdp" + grad_clip: float = 1.0 + ulysses_sequence_parallel_size: int = 1 + entropy_from_logits_with_chunking: bool = False + entropy_checkpointing: bool = False + fsdp_config: FSDPEngineConfig = field(default_factory=FSDPEngineConfig) + use_remove_padding: bool = False + profiler: ProfilerConfig = field(default_factory=ProfilerConfig) + + def __post_init__(self): + """Validate FSDP actor configuration parameters.""" + super().__post_init__() + + def validate(self, n_gpus: int, train_batch_size: int, model_config: dict = None): + """Validate FSDP actor configuration with runtime parameters.""" + super().validate(n_gpus, train_batch_size, model_config) + + if self.strategy in {"fsdp", "fsdp2"} and self.ulysses_sequence_parallel_size > 1: + if model_config and not model_config.get("use_remove_padding", False): + raise ValueError( + "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." + ) diff --git a/ICL/DAPO/verl-recipe/gvpo/gvpo_core_algos.py b/ICL/DAPO/verl-recipe/gvpo/gvpo_core_algos.py new file mode 100644 index 0000000000000000000000000000000000000000..5965661291881b6a495d6b282b705efba45fcf78 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gvpo/gvpo_core_algos.py @@ -0,0 +1,1556 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Core functions to implement PPO algorithms. +The function implemented in this file should be used by trainer with different distributed strategies to +implement PPO-like algorithms. +""" + +__all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"] + +from collections import defaultdict +from enum import Enum +from typing import Any, Callable, Optional + +import numpy as np +import torch +from omegaconf import DictConfig + +import verl.utils.torch_functional as verl_F +from verl.trainer.config import AlgoConfig +from verl.utils import as_torch_index, group_mean_std +from verl.utils.import_utils import deprecated +from verl.utils.torch_functional import allgather_dict_tensors +from verl.workers.config import ActorConfig + + +def compute_policy_loss_gvpo(old_log_prob, log_prob, advantages, response_mask, beta, uid, device_mesh, n): + rtheta = ((log_prob * response_mask).sum(dim=-1) - (old_log_prob * response_mask).sum(dim=-1)) * beta + r_minus_avg = (advantages * response_mask).sum(dim=-1) / response_mask.sum(dim=-1) + + process_group = device_mesh._flatten().get_group() + group_size = torch.distributed.get_world_size(group=process_group) + data = {"rtheta": rtheta.clone().detach(), "uid": uid} + data = allgather_dict_tensors(data, group_size, process_group) + + unique_uids = torch.unique(data["uid"]) + means = {} + for u in unique_uids: + mask = data["uid"] == u + mean_val = data["rtheta"][mask].mean() + assert data["rtheta"][mask].shape[0] == n + means[u.item()] = mean_val + pg_loss = 0 + for i in range(len(rtheta)): + pg_loss += 0.5 * ((rtheta[i] - means[uid[i].item()]) - r_minus_avg[i]) ** 2 + pg_loss = pg_loss / (n - 1) + return pg_loss + + +PolicyLossFn = Callable[ + [ + torch.Tensor, # old_log_prob + torch.Tensor, # log_prob + torch.Tensor, # advantages + torch.Tensor, # response_mask + str, # loss_agg_mode + Optional[DictConfig | AlgoConfig], # config + torch.Tensor | None, # rollout_log_probs + ], + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], +] + +POLICY_LOSS_REGISTRY: dict[str, PolicyLossFn] = {} + + +def register_policy_loss(name: str) -> Callable[[PolicyLossFn], PolicyLossFn]: + """Register a policy loss function with the given name. + + Args: + name (str): The name to register the policy loss function under. + + Returns: + function: Decorator function that registers the policy loss function. + """ + + def decorator(func: PolicyLossFn) -> PolicyLossFn: + POLICY_LOSS_REGISTRY[name] = func + return func + + return decorator + + +def get_policy_loss_fn(name): + """Get the policy loss with a given name. + + Args: + name: `(str)` + The name of the policy loss. + + Returns: + `(callable)`: The policy loss function. + """ + loss_name = name + if loss_name not in POLICY_LOSS_REGISTRY: + raise ValueError( + f"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}" + ) + return POLICY_LOSS_REGISTRY[loss_name] + + +class AdvantageEstimator(str, Enum): + """Using an enumeration class to avoid spelling errors in adv_estimator. + + Note(haibin.lin): this enum class is immutable after creation. Extending this + enum for new estimators may not be necessary since users can always just call + `verl.trainer.ppo.core_algos.register` with string name for a custom advantage + estimator instead. + """ + + GAE = "gae" + GRPO = "grpo" + REINFORCE_PLUS_PLUS = "reinforce_plus_plus" + REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" + REMAX = "remax" + RLOO = "rloo" + OPO = "opo" + GRPO_PASSK = "grpo_passk" + GPG = "gpg" + RLOO_VECTORIZED = "rloo_vectorized" + GRPO_VECTORIZED = "grpo_vectorized" + + +ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {} + + +def register_adv_est(name_or_enum: str | AdvantageEstimator) -> Any: + """Decorator to register a advantage estimator function with a given name. + + Args: + name_or_enum: `(str)` or `(AdvantageEstimator)` + The name or enum of the advantage estimator. + + """ + + def decorator(fn): + name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum + if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn: + raise ValueError( + f"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}" + ) + ADV_ESTIMATOR_REGISTRY[name] = fn + return fn + + return decorator + + +def get_adv_estimator_fn(name_or_enum): + """Get the advantage estimator function with a given name. + + Args: + name_or_enum: `(str)` or `(AdvantageEstimator)` + The name or enum of the advantage estimator. + + Returns: + `(callable)`: The advantage estimator function. + """ + name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum + if name not in ADV_ESTIMATOR_REGISTRY: + raise ValueError(f"Unknown advantage estimator simply: {name}") + return ADV_ESTIMATOR_REGISTRY[name] + + +class AdaptiveKLController: + """ + Adaptive KL controller described in the paper: + https://arxiv.org/pdf/1909.08593.pdf + """ + + def __init__(self, init_kl_coef, target_kl, horizon): + self.value = init_kl_coef + self.target = target_kl + self.horizon = horizon + + def update(self, current_kl, n_steps): + """Update the KL coefficient based on current KL divergence. + + Args: + current_kl (float): Current KL divergence value. + n_steps (int): Number of steps taken. + """ + target = self.target + proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.horizon + self.value *= mult + + +class FixedKLController: + """Fixed KL controller.""" + + def __init__(self, kl_coef): + self.value = kl_coef + + def update(self, current_kl, n_steps): + """Update method for fixed KL controller (no-op). + + Args: + current_kl (float): Current KL divergence value (unused). + n_steps (int): Number of steps taken (unused). + """ + pass + + +def get_kl_controller(kl_ctrl): + """Factory function to create appropriate KL controller based on configuration. + + Args: + kl_ctrl: Configuration object containing KL controller settings. + + Returns: + KL controller instance (FixedKLController or AdaptiveKLController). + + Raises: + NotImplementedError: If controller type is not supported. + AssertionError: If adaptive controller horizon is not positive. + """ + if kl_ctrl.type == "fixed": + return FixedKLController(kl_coef=kl_ctrl.kl_coef) + elif kl_ctrl.type == "adaptive": + assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}" + return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon) + else: + raise NotImplementedError + + +@register_adv_est(AdvantageEstimator.GAE) # or simply: @register_adv_est("gae") +def compute_gae_advantage_return( + token_level_rewards: torch.Tensor, + values: torch.Tensor, + response_mask: torch.Tensor, + gamma: torch.Tensor, + lam: torch.Tensor, +): + """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py + + Args: + token_level_rewards: `(torch.Tensor)` + shape is (bs, response_length) + values: `(torch.Tensor)` + shape is (bs, response_length) + response_mask: `(torch.Tensor)` + shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. + gamma is `(float)` + discounted factor used in RL + lam: `(float)` + lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + + """ + with torch.no_grad(): + nextvalues = 0 + lastgaelam = 0 + advantages_reversed = [] + gen_len = token_level_rewards.shape[-1] + + for t in reversed(range(gen_len)): + delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] + lastgaelam_ = delta + gamma * lam * lastgaelam + + # skip values and TD-error on observation tokens + nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues + lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam + + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + + returns = advantages + values + advantages = verl_F.masked_whiten(advantages, response_mask) + return advantages, returns + + +# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. +@register_adv_est(AdvantageEstimator.GRPO) # or simply: @register_adv_est("grpo") +def compute_grpo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for GRPO, operating only on Outcome reward + (with only one scalar reward for each response). + + Args: + token_level_rewards: `(torch.Tensor)` + shape is (bs, response_length) + response_mask: `(torch.Tensor)` + shape is (bs, response_length) + index: `(np.ndarray)` + index array for grouping + epsilon: `(float)` + small value to avoid division by zero + norm_adv_by_std_in_grpo: `(bool)` + whether to scale the GRPO advantage + config: `(Optional[AlgoConfig])` + algorithm configuration object + + Note: + If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO. + If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783). + + Returns: + advantages: `(torch.Tensor)` + shape is (bs, response_length) + Returns: `(torch.Tensor)` + shape is (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + scores_tensor = torch.stack(id2score[idx]) + id2mean[idx] = torch.mean(scores_tensor) + id2std[idx] = torch.std(scores_tensor) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + if norm_adv_by_std_in_grpo: + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) + else: + scores[i] = scores[i] - id2mean[index[i]] + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.GRPO_VECTORIZED) +def compute_grpo_vectorized_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Vectorized GRPO(outcome-only): + For each group g: + a_i = \\frac{r_i - \\mu_g}{\\sigma_g} (or without dividing by \\sigma_g), + then broadcast the scalar across the token dimension (multiplied by response_mask).。 + """ + with torch.no_grad(): + scores = token_level_rewards.sum(dim=-1) + g = as_torch_index(index, device=scores.device) + mean_g, std_g, _ = group_mean_std(scores, g, eps=epsilon) + if norm_adv_by_std_in_grpo: + scalars = (scores - mean_g[g]) / (std_g[g] + epsilon) + else: + scalars = scores - mean_g[g] + advantages = scalars.unsqueeze(-1) * response_mask + return advantages, advantages + + +@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk") +def compute_grpo_passk_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + norm_adv_by_std_in_grpo: bool = True, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for Pass@k using a GRPO-style outcome reward formulation. + Only the best response per group gets a non-zero advantage: r_max - r_second_max. + + Implemented as described in https://arxiv.org/abs/2503.19595. + + Args: + token_level_rewards: (bs, response_length) + response_mask: (bs, response_length) + index: (bs,) → group ID per sample + epsilon: float for numerical stability + config: (AlgoConfig) algorithm settings, which contains "norm_adv_by_std_in_grpo" + + Returns: + advantages: (bs, response_length) + returns: (bs, response_length) + """ + assert config is not None + # if True, normalize advantage by std within group + norm_adv_by_std_in_grpo = config.get("norm_adv_by_std_in_grpo", True) + scores = token_level_rewards.sum(dim=-1) # (bs,) + advantages = torch.zeros_like(scores) + + id2scores = defaultdict(list) + id2indices = defaultdict(list) + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + idx = index[i] + id2scores[idx].append(scores[i]) + id2indices[idx].append(i) + + for idx in id2scores: + rewards = torch.stack(id2scores[idx]) # (k,) + if rewards.numel() < 2: + raise ValueError( + f"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}." + ) + topk, topk_idx = torch.topk(rewards, 2) + r_max, r_second_max = topk[0], topk[1] + i_max = id2indices[idx][topk_idx[0].item()] + advantage = r_max - r_second_max + if norm_adv_by_std_in_grpo: + std = torch.std(rewards) + advantage = advantage / (std + epsilon) + advantages[i_max] = advantage + + advantages = advantages.unsqueeze(-1) * response_mask + return advantages, advantages + + +@register_adv_est( + AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE +) # or simply: @register_adv_est("reinforce_plus_plus_baseline") +def compute_reinforce_plus_plus_baseline_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: torch.Tensor, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward + (with only one scalar reward for each response). + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.stack(id2score[idx])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = scores[i] - id2mean[index[i]] + + scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask + scores = verl_F.masked_whiten(scores, response_mask) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.RLOO) # or simply: @register_adv_est("rloo") +def compute_rloo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.stack(id2score[idx])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + response_num = len(id2score[index[i]]) + if response_num > 1: + scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / ( + response_num - 1 + ) + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.OPO) # or simply: @register_adv_est("opo") +def compute_opo_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585 + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + response_length = response_mask.sum(dim=-1) + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2len = defaultdict(list) + id2bsl = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + id2len[index[i]].append(response_length[i]) + + for idx in id2score: + if len(id2score[idx]) == 1: + id2bsl[idx] = torch.tensor(0.0) + elif len(id2score[idx]) > 1: + score_tensor = torch.stack(id2score[idx]) + len_tensor = torch.stack(id2len[idx]) + id2bsl[idx] = (len_tensor * score_tensor).sum() / len_tensor.sum() + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = scores[i] - id2bsl[index[i]] + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS) # or simply: @register_adv_est("reinforce_plus_plus") +def compute_reinforce_plus_plus_outcome_advantage( + token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for REINFORCE++. + This implementation is based on the paper: https://arxiv.org/abs/2501.03262 + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + assert config is not None + gamma = config.gamma + with torch.no_grad(): + returns = torch.zeros_like(token_level_rewards) + running_return = 0 + + for t in reversed(range(token_level_rewards.shape[1])): + running_return = token_level_rewards[:, t] + gamma * running_return + returns[:, t] = running_return + # Reset after EOS + running_return = running_return * response_mask[:, t] + + advantages = verl_F.masked_whiten(returns, response_mask) + advantages = advantages * response_mask + + return advantages, returns + + +@register_adv_est(AdvantageEstimator.REMAX) # or simply: @register_adv_est("remax") +def compute_remax_outcome_advantage( + token_level_rewards: torch.Tensor, + reward_baselines: torch.Tensor, + response_mask: torch.Tensor, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for ReMax, operating only on Outcome reward + This implementation is based on the paper: https://arxiv.org/abs/2310.10505 + (with only one scalar reward for each response). + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + reward_baselines: `(torch.Tensor)` + shape: (bs,) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + + with torch.no_grad(): + returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + advantages = returns - reward_baselines.unsqueeze(-1) * response_mask + + return advantages, returns + + +@register_adv_est(AdvantageEstimator.GPG) # or simply: @register_adv_est("gpg") +def compute_gpg_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + f_norm: float = 1.0, + alpha: float = 1.0, + config=None, + **kwargs, +): + """ + Compute advantage for GPG, operating only on Outcome reward + (with only one scalar reward for each response). + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + index: `(np.ndarray)` + shape: (bs,) + epsilon: (float) + f_norm: (float) + alpha: (float) + config: (dict) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + + with torch.no_grad(): + bsz = scores.shape[0] + m = torch.count_nonzero(scores) + alpha = bsz / m.clamp(min=1) + + for i in range(bsz): + id2score[index[i]].append(scores[i]) + + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + scores_tensor = torch.stack(id2score[idx]) + id2mean[idx] = torch.mean(scores_tensor) + id2std[idx] = torch.std(scores_tensor) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm) + scores = scores.unsqueeze(-1) * response_mask + + return scores, scores + + +@register_adv_est(AdvantageEstimator.RLOO_VECTORIZED) # or simply: @register_adv_est("rloo_vectorized") +def compute_rloo_vectorized_outcome_advantage( + token_level_rewards: torch.Tensor, + response_mask: torch.Tensor, + index: np.ndarray, + epsilon: float = 1e-6, + config: Optional[AlgoConfig] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 + + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + config: (AlgoConfig) algorithm config + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + scores = token_level_rewards.sum(dim=-1) + + with torch.no_grad(): + inv = torch.from_numpy(np.unique(index, return_inverse=True)[1]).to(scores.device) + + c = torch.bincount(inv)[inv].to(scores.dtype) + adv = ((c * scores - torch.bincount(inv, weights=scores)[inv]) / (c - 1).clamp_min(1)) * (c > 1) + + adv = adv.unsqueeze(-1) * response_mask + + return adv, adv + + +def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): + """Compute token-level rewards with KL penalty. + + Args: + token_level_scores (torch.Tensor): Token-level reward scores. + old_log_prob (torch.Tensor): Log probabilities from current policy. + ref_log_prob (torch.Tensor): Log probabilities from reference policy. + kl_ratio (float): KL penalty coefficient. + + Returns: + torch.Tensor: Token-level rewards with KL penalty applied. + """ + kl = old_log_prob - ref_log_prob + return token_level_scores - kl * kl_ratio + + +def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str): + """ + Aggregate the loss matrix into a scalar. + + Args: + loss_mat: `(torch.Tensor)`: + shape: (bs, response_length) + loss_mask: `(torch.Tensor)`: + shape: (bs, response_length) + loss_agg_mode: (str) choices: + method to aggregate the loss matrix into a scalar. + Returns: + loss: `a scalar torch.Tensor` + aggregated loss + """ + if loss_agg_mode == "token-mean": + loss = verl_F.masked_mean(loss_mat, loss_mask) + elif loss_agg_mode == "seq-mean-token-sum": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum + loss = torch.mean(seq_losses) # seq-mean + elif loss_agg_mode == "seq-mean-token-mean": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean + loss = torch.mean(seq_losses) # seq-mean + elif loss_agg_mode == "seq-mean-token-sum-norm": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) + loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor + # (loss_mask.shape[-1]) should ideally be constant + # throughout training to well-replicate the DrGRPO paper. + # TODO: Perhaps add user-defined normalizer argument to + # agg_loss to ensure divisor stays constant throughout. + else: + raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") + + return loss + + +@deprecated("verl.trainer.ppo.core_algos.compute_policy_loss_vanilla") +def compute_policy_loss( + old_log_prob, + log_prob, + advantages, + response_mask, + cliprange=None, + cliprange_low=None, + cliprange_high=None, + clip_ratio_c=3.0, + loss_agg_mode: str = "token-mean", +): + """ + Compute the clipped policy objective and related metrics for PPO. + + Adapted from + https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + cliprange (float, optional): + Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. + Defaults to None (must be provided). + cliprange_low (float, optional): + Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. + cliprange_high (float, optional): + Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. + clip_ratio_c (float, optional): + Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. + Defaults to 3.0. + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + """ + assert clip_ratio_c > 1.0, ( + "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + + f" but get the value: {clip_ratio_c}." + ) + + negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + pg_losses1 = -advantages * ratio + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + pg_losses2 = -advantages * torch.clamp( + ratio, 1 - cliprange_low, 1 + cliprange_high + ) # - clip(ratio, 1-cliprange, 1+cliprange) * A + clip_pg_losses1 = torch.maximum( + pg_losses1, pg_losses2 + ) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) + pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) + + pg_losses3 = -advantages * clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + pg_clipfrac_lower = verl_F.masked_mean( + torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask + ) + + pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower + + +@register_policy_loss("vanilla") # type: ignore[arg-type] +def compute_policy_loss_vanilla( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | AlgoConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the clipped policy objective and related metrics for PPO. + + Adapted from + https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + config: `(verl.trainer.config.ActorConfig)`: + config for the actor. + rollout_log_probs: `(torch.Tensor)`: + log probabilities of actions under the rollout policy, shape (batch_size, response_length). + """ + + assert config is not None + assert not isinstance(config, AlgoConfig) + clip_ratio = config.clip_ratio # Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. + clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio + clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio + clip_ratio_c = config.get( # Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. + "clip_ratio_c", 3.0 + ) + + cliprange = clip_ratio + cliprange_low = clip_ratio_low + cliprange_high = clip_ratio_high + + assert clip_ratio_c > 1.0, ( + "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + + f" but get the value: {clip_ratio_c}." + ) + + negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + pg_losses1 = -advantages * ratio + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + pg_losses2 = -advantages * torch.clamp( + ratio, 1 - cliprange_low, 1 + cliprange_high + ) # - clip(ratio, 1-cliprange, 1+cliprange) * A + clip_pg_losses1 = torch.maximum( + pg_losses1, pg_losses2 + ) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) + pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) + + pg_losses3 = -advantages * clip_ratio_c + clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) + pg_clipfrac_lower = verl_F.masked_mean( + torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask + ) + + pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) + + # Apply rollout importance sampling weights if provided + if rollout_is_weights is not None: + pg_losses = pg_losses * rollout_is_weights + + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower + + +@register_policy_loss("gspo") +def compute_policy_loss_gspo( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "seq-mean-token-mean", + config: Optional[DictConfig | ActorConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the clipped policy objective and related metrics for GSPO. + + See https://arxiv.org/pdf/2507.18071 for more details. + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. For GSPO, it is recommended to use "seq-mean-token-mean". + """ + + assert config is not None + assert isinstance(config, ActorConfig) + clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else config.clip_ratio + clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else config.clip_ratio + + negative_approx_kl = log_prob - old_log_prob + + # compute sequence-level importance ratio: + # si(θ) = (π_θ(yi|x)/π_θold(yi|x))^(1/|yi|) = + # exp [(1/|y_i|) * Σ_t log(π_θ(y_i,t|x,y_i, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Adapted from + https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495 + Args: + log_prob: `(torch.Tensor)` + shape: (bs, response_length) + advantages: `(torch.Tensor)` + shape: (bs, response_length) + response_mask: `(torch.Tensor)` + shape: (bs, response_length) + return: + pg_loss: `a scalar torch.Tensor` + policy gradient loss computed via GPG + """ + pg_losses = -log_prob * advantages + + # Apply rollout importance sampling weights if provided + if rollout_is_weights is not None: + pg_losses = pg_losses * rollout_is_weights + + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0) + + +@register_policy_loss("clip_cov") +def compute_policy_loss_clip_cov( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | AlgoConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the clipped policy objective and related metrics for Clip-Cov. + + Adapted from + https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + cliprange (float, optional): + Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. + Defaults to None (must be provided). + cliprange_low (float, optional): + Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. + cliprange_high (float, optional): + Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + clip_cvo_ratio (float, optional): + Ratio for clipping the covariance. Defaults to 0.0002. + clip_cov_lb (float, optional): + Lower bound for clipping covariance. Defaults to 1.0. + clip_cov_ub (float, optional): + Upper bound for clipping covariance. Defaults to 5.0. + """ + assert config is not None + assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet" + assert config.policy_loss is not None + + clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002 + cliprange = config.clip_ratio + cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange + cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange + clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0 + clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0 + + assert clip_cov_ratio > 0, "clip_ratio should be larger than 0." + + negative_approx_kl = log_prob - old_log_prob + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + pg_losses1 = -advantages * ratio + + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + + corr = torch.ones_like(advantages) + pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) + clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0) + + cov_all = (advantages - verl_F.masked_mean(advantages, response_mask)) * ( + log_prob - verl_F.masked_mean(log_prob.detach(), response_mask) + ) + cov_all[response_mask == 0] = -torch.inf + cov_all[clip_by_origin] = -torch.inf + + clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1) + top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0) + top_k_idx = torch.nonzero(top_k_idx) + + if len(top_k_idx) > 0: + perm = torch.randperm(len(top_k_idx)) + top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]] + else: + top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long) + + corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0 + + pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask) + + pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr + + # Apply rollout importance sampling weights if provided + if rollout_is_weights is not None: + pg_losses = pg_losses * rollout_is_weights + + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0) + + +@register_policy_loss("kl_cov") +def compute_policy_loss_kl_cov( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | AlgoConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the clipped policy objective and related metrics for Clip-Cov. + + Adapted from + https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + kl_cov_ratio (float, optional): + Ratio for selecting the top-k covariance values. Defaults to 0.0002. + ppo_kl_coef (float, optional): + Coefficient for the KL penalty term in the loss. Defaults to 1. + """ + assert config is not None + assert not isinstance(config, AlgoConfig), "passing AlgoConfig not supported yet" + assert config.policy_loss is not None + + kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002 + ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0 + + assert kl_cov_ratio > 0, "kl_cov_ratio should be larger than 0." + + negative_approx_kl = log_prob - old_log_prob + abs_kl = negative_approx_kl.abs() + ratio = torch.exp(negative_approx_kl) + ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask) + pg_losses1 = -advantages * ratio + pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl + pg_losses = pg_losses1 + + all_valid = response_mask > 0 + all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0] + all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu() + all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu() + + k = min(kl_cov_ratio, len(all_valid_adv)) + + if k != 0: + cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean()) + k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio)) + large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices + + if len(large_cov_idxs) != 0: + large_cov_idxs = all_valid_idx[large_cov_idxs] + pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[ + large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1] + ] + + # Apply rollout importance sampling weights if provided + if rollout_is_weights is not None: + pg_losses = pg_losses * rollout_is_weights + + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0) + + +@register_policy_loss("geo_mean") +def compute_policy_loss_geo_mean( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | AlgoConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the clipped policy objective and related metrics for GMPO. + + Adapted from paper https://arxiv.org/abs/2507.20673 + https://github.com/callsys/GMPO/blob/main/train_zero_math_gmpo.py + + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + not used + """ + + assert config is not None + assert not isinstance(config, AlgoConfig) + clip_ratio = config.clip_ratio # Clipping parameter. See https://arxiv.org/abs/1707.06347. + clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio + clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio + + cliprange = clip_ratio + cliprange_low = clip_ratio_low + cliprange_high = clip_ratio_high + if cliprange_low is None: + cliprange_low = cliprange + if cliprange_high is None: + cliprange_high = cliprange + + negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability (uncomment it if you like) + # negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + # Clipping at token-level & Clipping wider + sgn_advantage = torch.sign(advantages) + negative_approx_kl_clamp = torch.clamp(negative_approx_kl, -cliprange_low, cliprange_high) + negative_approx_kl_min = torch.min(sgn_advantage * negative_approx_kl, sgn_advantage * negative_approx_kl_clamp) + negative_approx_kl_min = sgn_advantage * negative_approx_kl_min + + # Geometric-Mean Policy Optimization + response_mask_sum = response_mask.sum(dim=-1) + ratio = torch.exp((negative_approx_kl_min * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)) + # we only support sequence level advantage for now, + # otherwise, below would be not consistent with the paper + advantage = (advantages * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8) + pg_losses = -advantage * ratio + + # Apply rollout importance sampling weights if provided + # For geo_mean, IS weights are 2D (batch_size, seq_length) and need to be aggregated to sequence level + if rollout_is_weights is not None: + # Aggregate token-level weights to sequence level using geometric mean for consistency + # Note: rollout_is_weights is always 2D regardless of rollout_is_level + seq_is_weights = torch.exp( + (torch.log(rollout_is_weights + 1e-10) * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8) + ) + pg_losses = pg_losses * seq_is_weights + + pg_loss = torch.mean(pg_losses) + + # higher: ratio is too large that need clamp to clip_high (when adv > 0) + clipped = torch.ne(negative_approx_kl, negative_approx_kl_clamp) + pg_clipfrac = verl_F.masked_mean((clipped * (advantages > 0)).float(), response_mask) + pg_clipfrac_lower = verl_F.masked_mean((clipped * (advantages < 0)).float(), response_mask) + + return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower + + +def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"): + """Compute categorical entropy loss (For backward compatibility) + + Args: + logits (torch.Tensor): shape is (bs, response_length, vocab_size) + response_mask (torch.Tensor): shape is (bs, response_length) + + Returns: + entropy: a scalar torch.Tensor + + """ + # compute entropy + token_entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) + entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + return entropy_loss + + +def compute_value_loss( + vpreds: torch.Tensor, + returns: torch.Tensor, + values: torch.Tensor, + response_mask: torch.Tensor, + cliprange_value: float, + loss_agg_mode: str = "token-mean", +): + """ + Compute the clipped value-function loss for PPO. + + Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151 + + Args: + vpreds (torch.FloatTensor): + Predicted values from the value head, shape (batch_size, response_length). + values (torch.FloatTensor): + Old (baseline) values from the value head, shape (batch_size, response_length). + returns (torch.FloatTensor): + Ground-truth returns, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the value loss calculation. + cliprange_value (float): + Clip range for value prediction updates. + loss_agg_mode (str, optional): + Aggregation mode for `agg_loss`. Defaults to "token-mean". + + Returns: + vf_loss (torch.FloatTensor): + A scalar tensor containing the aggregated value-function loss. + vf_clipfrac (float): + Fraction of elements where the clipped loss was used. + """ + vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) + vf_losses1 = (vpreds - returns) ** 2 + vf_losses2 = (vpredclipped - returns) ** 2 + clipped_vf_losses = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) + return vf_loss, vf_clipfrac + + +def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: + """Compute KL divergence given logprob and ref_logprob. Optionally using straight through to bind k2 on other + kl penalty compute method for unbiased KL gradient estimation. + See more description in http://joschu.net/blog/kl-approx.html + + Args: + logprob: + ref_logprob: + + Returns: + kl_estimate + """ + forward_score = kl_penalty_forward(logprob, ref_logprob, kl_penalty) + if not kl_penalty.endswith("+") or kl_penalty in ("mse", "k2"): + return forward_score + + """ + The expectation of k1 and k3 estimator is the expectaed value of KL, but the expected gradient of k1 and k3 + estimator is not the expectaed gradient of KL. On the other hand k2 estimator gives right gradient estimator, + so we use a straight through trick here if the kl_penalty method ends with '+', .e.g., k3+. + """ + backward_score = 0.5 * (logprob - ref_logprob).square() + + return backward_score - backward_score.detach() + forward_score.detach() + + +def kl_penalty_forward(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: + """Compute KL divergence given logprob and ref_logprob. + Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 + See more description in http://joschu.net/blog/kl-approx.html + + Args: + logprob: + ref_logprob: + + Returns: + kl_estimate + """ + if kl_penalty in ("kl", "k1"): + return logprob - ref_logprob + + if kl_penalty == "abs": + return (logprob - ref_logprob).abs() + + if kl_penalty in ("mse", "k2"): + return 0.5 * (logprob - ref_logprob).square() + + # J. Schulman. Approximating kl divergence, 2020. + # # URL http://joschu.net/blog/kl-approx.html. + if kl_penalty in ("low_var_kl", "k3"): + kl = ref_logprob - logprob + # For numerical stability + kl = torch.clamp(kl, min=-20, max=20) + ratio = torch.exp(kl) + kld = (ratio - kl - 1).contiguous() + return torch.clamp(kld, min=-10, max=10) + + if kl_penalty == "full": + # so, here logprob and ref_logprob should contain the logits for every token in vocabulary + raise NotImplementedError + + raise NotImplementedError + + +def compute_pf_ppo_reweight_data( + data, + reweight_method: str = "pow", + weight_pow: float = 2.0, +): + """Reweight the data based on the token_level_scores. + + Args: + data: DataProto object, containing batch, non_tensor_batch and meta_info + reweight_method: str, choices: "pow", "max_min", "max_random" + weight_pow: float, the power of the weight + + Returns: + + """ + + @torch.no_grad() + def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor: + """Compute importance weights for resampling based on scores. + + Args: + scores (torch.Tensor): Tensor of scores to compute weights from. + reweight_method (str): Method for computing weights ('pow', 'max_min', 'max_random'). + weight_pow (float): Power exponent for 'pow' method. + + Returns: + torch.Tensor: Computed importance weights. + + Raises: + ValueError: If reweight_method is not supported. + """ + if reweight_method == "pow": + weights = torch.pow(torch.abs(scores), weight_pow) + elif reweight_method == "max_min": + max_score = torch.max(scores) + min_score = torch.min(scores) + weights = torch.where((scores == max_score) | (scores == min_score), 1.0, 0.0) + elif reweight_method == "max_random": + max_score = torch.max(scores) + weights = torch.where(scores == max_score, 0.4, 0.1) + else: + raise ValueError(f"Unsupported reweight_method: {reweight_method}") + return weights + + scores = data.batch["token_level_scores"].sum(dim=-1) + weights = compute_weights(scores, reweight_method, weight_pow) + weights = torch.clamp(weights + 1e-8, min=1e-8) + + batch_size = scores.shape[0] + sample_indices = torch.multinomial(weights, batch_size, replacement=True) + + resampled_batch = {key: tensor[sample_indices] for key, tensor in data.batch.items()} + + sample_indices_np = sample_indices.numpy() + resampled_non_tensor_batch = {} + for key, array in data.non_tensor_batch.items(): + if isinstance(array, np.ndarray): + resampled_non_tensor_batch[key] = array[sample_indices_np] + else: + resampled_non_tensor_batch[key] = [array[i] for i in sample_indices_np] + + resampled_meta_info = {} + for key, value in data.meta_info.items(): + if isinstance(value, list) and len(value) == batch_size: + resampled_meta_info[key] = [value[i] for i in sample_indices_np] + else: + resampled_meta_info[key] = value + + from copy import deepcopy + + resampled_data = deepcopy(data) + resampled_data.batch = type(data.batch)(resampled_batch) + resampled_data.batch.batch_size = data.batch.batch_size + resampled_data.non_tensor_batch = resampled_non_tensor_batch + resampled_data.meta_info = resampled_meta_info + + return resampled_data diff --git a/ICL/DAPO/verl-recipe/gvpo/gvpo_dp_actor.py b/ICL/DAPO/verl-recipe/gvpo/gvpo_dp_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..19275fddce3008aa56522b1978fc372889d77975 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gvpo/gvpo_dp_actor.py @@ -0,0 +1,504 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Single Process Actor +""" + +import logging +import os + +import torch +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.tensor import DTensor + +import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input +from verl.utils.device import get_device_id, get_device_name +from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ +from verl.utils.profiler import GPUMemoryLogger +from verl.utils.py_functional import append_to_dict +from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch +from verl.utils.torch_functional import logprobs_from_logits +from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs +from verl.workers.actor import BasePPOActor + +from .gvpo_actor_config import ActorConfig +from .gvpo_core_algos import agg_loss, compute_policy_loss_gvpo, kl_penalty + +__all__ = ["DataParallelPPOActor"] + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class DataParallelPPOActor(BasePPOActor): + """FSDP DataParallel PPO Actor or Ref worker + + Args: + config (ActorConfig): Actor config + actor_module (nn.Module): Actor or ref module + actor_optimizer (torch.optim.Optimizer, optional): Actor optimizer. Defaults to None. + """ + + def __init__( + self, + config: ActorConfig, + actor_module: nn.Module, + actor_optimizer: torch.optim.Optimizer = None, + device_mesh=None, + config_all=None, + ): + """When optimizer is None, it is Reference Policy""" + super().__init__(config) + self.device_mesh = device_mesh + self.config_all = config_all + + self.actor_module = actor_module + self.actor_optimizer = actor_optimizer + role = "Ref" if actor_optimizer is None else "Actor" + + self.use_remove_padding = self.config.get("use_remove_padding", False) + if torch.distributed.get_rank() == 0: + print(f"{role} use_remove_padding={self.use_remove_padding}") + self.use_fused_kernels = self.config.get("use_fused_kernels", False) + if torch.distributed.get_rank() == 0: + print(f"{role} use_fused_kernels={self.use_fused_kernels}") + + self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size + self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 + + if self.config.entropy_from_logits_with_chunking: + entropy_from_logits = verl_F.entropy_from_logits_with_chunking + else: + entropy_from_logits = verl_F.entropy_from_logits + + self.compute_entropy_from_logits = ( + torch.compile(entropy_from_logits, dynamic=True) + if self.config.get("use_torch_compile", True) # use torch compile by default + else entropy_from_logits + ) + self.device_name = get_device_name() + + def _forward_micro_batch( + self, micro_batch, temperature, calculate_entropy=False + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Returns: + entropy: # (bs, response_len) + log_probs: # (bs, response_len) + """ + response_length = micro_batch["responses"].size(-1) + multi_modal_inputs = {} + if "multi_modal_inputs" in micro_batch.keys(): + from verl.utils.model import extract_multi_modal_inputs + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) + + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + entropy = None + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen) + + if self.use_remove_padding: + input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + if position_ids.dim() == 3: + position_ids_rmpad = ( + index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + .transpose(0, 1) + .unsqueeze(1) + ) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + if "image_bound" in multi_modal_inputs: + from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo + + multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( + input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs + ) + + # for compute the log_prob + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + + # pad and slice the inputs if sp > 1 + if self.use_ulysses_sp: + is_vlm_model = hasattr( + getattr(self.actor_module, "module", self.actor_module).config, "vision_config" + ) + if is_vlm_model: + # vlm model's inputs will be sliced after embedding + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + else: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, + position_ids_rmpad=None, + sp_size=self.ulysses_sequence_parallel_size, + ) + + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + + # only pass input_ids and position_ids to enable flash_attn_varlen + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs.squeeze(0) # (total_nnz,) + entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) + + else: + logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) + logits_rmpad.div_(temperature) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + inplace_backward = True + if calculate_entropy: + inplace_backward = False + log_probs = logprobs_from_logits( + logits=logits_rmpad, + labels=input_ids_rmpad_rolled, + inplace_backward=inplace_backward, + ) + + # compute entropy + if calculate_entropy: + if not self.config.entropy_checkpointing: + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + else: + entropy_rmpad = torch.utils.checkpoint.checkpoint( + self.compute_entropy_from_logits, logits_rmpad + ) + + # gather log_prob if sp > 1 + if self.use_ulysses_sp: + # gather and unpad for the ulysses sp + log_probs = gather_outputs_and_unpad( + log_probs, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + if calculate_entropy: + entropy_rmpad = gather_outputs_and_unpad( + entropy_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + # pad back to (bsz, seqlen) + if calculate_entropy: + full_entropy = pad_input( + hidden_states=entropy_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + full_log_probs = pad_input( + hidden_states=log_probs.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + + # only return response part: + if calculate_entropy: + entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) + + else: # not using rmpad and no ulysses sp + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs[:, -response_length - 1 : -1] + entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) + + else: + logits = output.logits + + logits.div_(temperature) + logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size) + log_probs = logprobs_from_logits(logits, micro_batch["responses"]) + if calculate_entropy: + if not self.config.entropy_checkpointing: + entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + else: + entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits) + + return entropy, log_probs + + def _optimizer_step(self): + assert self.config.grad_clip is not None + + if isinstance(self.actor_module, FSDP): + grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip) + elif isinstance(self.actor_module, FSDPModule): + grad_norm = fsdp2_clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) + + if isinstance(grad_norm, DTensor): + grad_norm = grad_norm.full_tensor() + + # if grad_norm is not finite, skip the update + if not torch.isfinite(grad_norm): + print(f"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}") + self.actor_optimizer.zero_grad() + else: + self.actor_optimizer.step() + return grad_norm + + @GPUMemoryLogger(role="dp actor", logger=logger) + def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor: + """Compute the log probability of the responses given input_ids, attention_mask and position_ids + + Args: + data (DataProto): a DataProto containing keys + + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the + concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. + + Returns: + torch.Tensor: the log_prob tensor + """ + # set to eval + self.actor_module.eval() + + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + assert not use_dynamic_bsz, "Dynamic batch size is not supported in GVPO" + + if use_dynamic_bsz: + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) + else: + micro_batches = data.split(micro_batch_size) + + log_probs_lst = [] + entropy_lst = [] + for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} + with torch.no_grad(): + entropy, log_probs = self._forward_micro_batch( + model_inputs, temperature=temperature, calculate_entropy=calculate_entropy + ) + log_probs_lst.append(log_probs) + if calculate_entropy: + entropy_lst.append(entropy) + + log_probs = torch.concat(log_probs_lst, dim=0) + entropys = None + if calculate_entropy: + entropys = torch.concat(entropy_lst, dim=0) + + if use_dynamic_bsz: + log_probs = restore_dynamic_batch(log_probs, batch_idx_list) + if calculate_entropy: + entropys = restore_dynamic_batch(entropys, batch_idx_list) + + return log_probs, entropys + + @GPUMemoryLogger(role="dp actor", logger=logger) + def update_policy(self, data: DataProto): + # make sure we are in training mode + self.actor_module.train() + + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + + select_keys = [ + "responses", + "response_mask", + "input_ids", + "attention_mask", + "position_ids", + "old_log_probs", + "advantages", + "uid_tensor", + ] + if self.config.use_kl_loss: + select_keys.append("ref_log_prob") + # Include pre-computed IS weights if present in batch + # Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True + if "rollout_is_weights" in data.batch.keys(): + select_keys.append("rollout_is_weights") + + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] + + data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) + + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + mini_batches = data.split(self.config.ppo_mini_batch_size) + + on_policy = len(mini_batches) == 1 and self.config.ppo_epochs == 1 + + metrics = {} + for _ in range(self.config.ppo_epochs): + for batch_idx, mini_batch in enumerate(mini_batches): + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len) + else: + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) + micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + + self.actor_optimizer.zero_grad() + + for micro_batch in micro_batches: + micro_batch = micro_batch.to(get_device_id()) + micro_batch_metrics = {} + model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} + response_mask = model_inputs["response_mask"] + old_log_prob = model_inputs["old_log_probs"] + advantages = model_inputs["advantages"] + uid = model_inputs["uid_tensor"] + + entropy_coeff = self.config.entropy_coeff + loss_agg_mode = self.config.loss_agg_mode + + if self.config.use_dynamic_bsz: + loss_scale_factor = response_mask.shape[0] / self.config.ppo_mini_batch_size + else: + loss_scale_factor = 1 / self.gradient_accumulation + + # all return: (bsz, response_length) + calculate_entropy = False + if entropy_coeff != 0: + calculate_entropy = True + entropy, log_prob = self._forward_micro_batch( + model_inputs, temperature=temperature, calculate_entropy=calculate_entropy + ) + + if on_policy: + old_log_prob = log_prob.detach() + else: + old_log_prob = model_inputs["old_log_probs"] + + loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") + + if loss_mode == "gvpo": + pg_loss = compute_policy_loss_gvpo( + old_log_prob=old_log_prob, + log_prob=log_prob, + advantages=advantages, + response_mask=response_mask, + beta=self.config.gvpo_beta, + uid=uid, + device_mesh=self.device_mesh, + n=self.config_all.rollout.n, + ) + else: + raise NotImplementedError("loss_mode should be gvpo") + + if entropy_coeff != 0: + entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + # compute policy loss + policy_loss = pg_loss - entropy_loss * entropy_coeff + else: + policy_loss = pg_loss + + if self.config.use_kl_loss: + ref_log_prob = model_inputs["ref_log_prob"] + # compute kl loss + kld = kl_penalty( + logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type + ) + kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + + policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef + micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() * loss_scale_factor + micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef + + if self.config.use_dynamic_bsz: + # relative to the dynamic bsz + loss = policy_loss * loss_scale_factor + else: + loss = policy_loss * loss_scale_factor + loss.backward() + + micro_batch_metrics.update( + { + "actor/pg_loss": pg_loss.detach().item() * loss_scale_factor, + } + ) + append_to_dict(metrics, micro_batch_metrics) + + grad_norm = self._optimizer_step() + mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, mini_batch_metrics) + self.actor_optimizer.zero_grad() + return metrics diff --git a/ICL/DAPO/verl-recipe/gvpo/gvpo_fsdp_workers.py b/ICL/DAPO/verl-recipe/gvpo/gvpo_fsdp_workers.py new file mode 100644 index 0000000000000000000000000000000000000000..6a1f87410d50c2e450916ec04fece51aa47501a8 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gvpo/gvpo_fsdp_workers.py @@ -0,0 +1,1935 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The main entry point to run the PPO algorithm +""" + +import asyncio +import datetime +import json +import logging +import os +import warnings +from dataclasses import asdict +from typing import Any, Optional + +import numpy as np +import psutil +import torch +import torch.distributed +import torch.distributed as dist +from codetiming import Timer +from omegaconf import DictConfig, OmegaConf, open_dict +from peft import LoraConfig, TaskType, get_peft_model +from safetensors.torch import save_file +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType + +try: + # for torch 2.5+ + from torch.distributed.tensor import DTensor +except ImportError: + from torch.distributed._tensor import DTensor + +import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.models.transformers.monkey_patch import apply_monkey_patch +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register +from verl.utils import hf_processor, hf_tokenizer +from verl.utils.activation_offload import enable_activation_offloading +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.device import ( + get_device_id, + get_device_name, + get_nccl_backend, + get_torch_device, + set_expandable_segments, +) +from verl.utils.flops_counter import FlopsCounter +from verl.utils.fs import copy_to_local +from verl.utils.fsdp_utils import ( + CPUOffloadPolicy, + MixedPrecisionPolicy, + apply_fsdp2, + collect_lora_params, + fsdp2_load_full_state_dict, + fsdp_version, + get_fsdp_wrap_policy, + get_init_weight_context_manager, + get_shard_placement_fn, + init_fn, + layered_summon_lora_params, + load_fsdp_model_to_gpu, + load_fsdp_optimizer, + offload_fsdp_model_to_cpu, + offload_fsdp_optimizer, + replace_lora_wrapper, +) +from verl.utils.import_utils import import_external_libs +from verl.utils.memory_utils import aggressive_empty_cache +from verl.utils.model import compute_position_id_with_mask, convert_weight_keys +from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer +from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max +from verl.utils.py_functional import convert_to_regular_types +from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig +from verl.workers.rollout import get_rollout_class +from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +device_name = get_device_name() + + +def create_device_mesh(world_size, fsdp_size): + if fsdp_size < 0 or fsdp_size >= world_size: + device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) + else: + device_mesh = init_device_mesh( + device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] + ) + return device_mesh + + +def get_sharding_strategy(device_mesh): + from torch.distributed.fsdp import ShardingStrategy + + if device_mesh.ndim == 1: + sharding_strategy = ShardingStrategy.FULL_SHARD + elif device_mesh.ndim == 2: + sharding_strategy = ShardingStrategy.HYBRID_SHARD + else: + raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") + return sharding_strategy + + +def get_vl_model_vision_tower(vl_model_instance): + """ + Util to extract Vision Tower from a VL model instance + """ + if hasattr(vl_model_instance, "model") and hasattr(vl_model_instance.model, "visual"): + # transformers >= 4.52.0 + return vl_model_instance.model.visual + elif hasattr(vl_model_instance, "visual"): + # transformers < 4.52.0 + return vl_model_instance.visual + return None + + +class ActorRolloutRefWorker(Worker, DistProfilerExtension): + """ + This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy + or a hybrid engine based on the config.rollout + """ + + def __init__(self, config: DictConfig, role: str, **kwargs): + Worker.__init__(self) + + self.config = config + import torch.distributed + + if not torch.distributed.is_initialized(): + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + torch.distributed.init_process_group( + backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", + rank=rank, + world_size=world_size, + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + + # build device mesh for FSDP + world_size = torch.distributed.get_world_size() + # TODO(sgm): support FSDP hybrid shard for larger model + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size) + + # build device mesh for Ulysses Sequence Parallel + self.ulysses_device_mesh = None + self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1) + dp = world_size // self.ulysses_sequence_parallel_size + if self.ulysses_sequence_parallel_size > 1: + self.ulysses_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) + + # create training dispatch + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "actor", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + else: + self._register_dispatch_collect_info("actor", dp_rank=self.rank, is_collect=True) + + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self._lora_rank = self.config.model.get("lora_rank", 0) + self._is_lora = self._lora_rank > 0 + + self.role = role + assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] + + self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] + self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] + self._is_ref = self.role in ["ref", "actor_rollout_ref"] + self.use_orig_params = self.config.actor.fsdp_config.get("use_orig_params", False) + + # TODO(haibin.lin): + # As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig, + # it will actually convert the ProfilerConfig dataclass back to a DictConfig. + # We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py) + # as they provides DictConfig-like interface + # The benefit of creating the dataclass config is to perform validation during __post_init__ + if self._is_actor: + omega_profiler_config = config.actor.get("profiler", {}) + elif self._is_rollout: + # NOTE: In colocation mode, rollout config may not take effect (follow the actor config) + # This is for extendability in AsyncRL cases + omega_profiler_config = config.rollout.get("profiler", {}) + elif self._is_ref: + omega_profiler_config = config.ref.get("profiler", {}) + else: + raise ValueError( + f"Invalid role {self.role}, should be one of " + "['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']" + ) + # omega_profiler_config is DictConfig + # profiler_config is a ProfilerConfig dataclass + profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) + ) + + self._is_offload_param = False + self._is_offload_optimizer = False + if self._is_actor: + self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False) + self._is_offload_optimizer = self.config.actor.fsdp_config.get("optimizer_offload", False) + elif self._is_ref: + # TODO: it seems that manual offload is slowly than FSDP offload + self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False) + + # normalize config + if self._is_actor: + self.config.actor.ppo_mini_batch_size *= self.config.rollout.n + self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size + assert self.config.actor.ppo_mini_batch_size > 0, ( + f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after " + f"normalization" + ) + # micro bsz + if self.config.actor.ppo_micro_batch_size is not None: + self.config.actor.ppo_micro_batch_size //= ( + self.device_mesh.size() // self.ulysses_sequence_parallel_size + ) + self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size + + if self.config.actor.ppo_micro_batch_size_per_gpu is not None: + assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, ( + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by " + f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ) + assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, ( + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than " + f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + ) + + # normalize rollout config + if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: + self.config.rollout.log_prob_micro_batch_size //= ( + self.device_mesh.size() // self.ulysses_sequence_parallel_size + ) + self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size + # normalize ref config + if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: + self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size + self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size + + def _build_model_optimizer( + self, + model_path, + fsdp_config: FSDPEngineConfig, + optim_config, + override_model_config, + use_remove_padding=False, + use_fused_kernels=False, + enable_gradient_checkpointing=False, + trust_remote_code=False, + use_liger=False, + role="actor", + enable_activation_offload=False, + ): + from torch import optim + from torch.distributed.fsdp import CPUOffload, MixedPrecision + from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForImageTextToText, + AutoModelForVision2Seq, + ) + + from verl.utils.model import get_generation_config, print_model_size, update_model_config + from verl.utils.torch_dtypes import PrecisionType + + assert role in ["actor", "ref"] + + log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger) + local_path = model_path + + # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect + # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code) + + if self.config.model.get("custom_chat_template", None) is not None: + if self.processor is not None: + self.processor.chat_template = self.config.model.custom_chat_template + else: + self.tokenizer.chat_template = self.config.model.custom_chat_template + + torch_dtype = fsdp_config.get("model_dtype", None) + if torch_dtype is None: + torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 + else: + torch_dtype = PrecisionType.to_dtype(torch_dtype) + + # override model kwargs + actor_model_config = AutoConfig.from_pretrained( + local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2" + ) + # TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53 + # which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids + # Maybe support Ulysses in VisionAttention in the future and remove this patch + if self.ulysses_sequence_parallel_size > 1 and hasattr(actor_model_config, "vision_config"): + actor_model_config.vision_config._attn_implementation = "eager" + + # patch for kimi-vl + if getattr(actor_model_config, "model_type", None) == "kimi_vl": + actor_model_config.text_config.topk_method = "greedy" + + self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code) + + override_config_kwargs = { + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + } + override_config_kwargs.update(override_model_config) + update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs) + if self.rank == 0: + print(f"Model config after override: {actor_model_config}") + + # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang + init_context = get_init_weight_context_manager( + use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh + ) + + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + has_remote_code = hasattr(actor_model_config, "auto_map") and any( + actor_model_config.architectures[0] in val for val in actor_model_config.auto_map.values() + ) + if has_remote_code: + auto_class = next( + k for k, v in actor_model_config.auto_map.items() if actor_model_config.architectures[0] in v + ) + match auto_class: + case "AutoModelForVision2Seq": + actor_module_class = AutoModelForVision2Seq + case "AutoModelForCausalLM": + actor_module_class = AutoModelForCausalLM + case "AutoModelForImageTextToText": + actor_module_class = AutoModelForImageTextToText + case _: + actor_module_class = AutoModel + else: + if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys(): + actor_module_class = AutoModelForVision2Seq + elif type(actor_model_config) in AutoModelForCausalLM._model_mapping.keys(): + actor_module_class = AutoModelForCausalLM + elif type(actor_model_config) in AutoModelForImageTextToText._model_mapping.keys(): + actor_module_class = AutoModelForImageTextToText + else: + actor_module_class = AutoModel + + actor_module = actor_module_class.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=actor_model_config, + trust_remote_code=trust_remote_code, + ) + + # Apply Liger kernel to the model if use_liger is set to True + if use_liger: + from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance + + _apply_liger_kernel_to_instance(model=actor_module) + + fused_kernel_options = self.config.model.get("fused_kernel_options", None) + fused_kernels_backend = ( + fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None + ) + + apply_monkey_patch( + model=actor_module, + use_remove_padding=use_remove_padding, + ulysses_sp_size=self.ulysses_sequence_parallel_size, + use_fused_kernels=use_fused_kernels, + fused_kernels_backend=fused_kernels_backend, + ) + + # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 + actor_module.to(torch_dtype) + + if enable_gradient_checkpointing: + actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + if self._is_lora: + print("Applying LoRA to actor module") + actor_module.enable_input_require_grads() + # Convert config to regular Python types before creating PEFT model + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "exclude_modules": convert_to_regular_types(self.config.model.exclude_modules), + "bias": "none", + } + actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) + + self.use_orig_params = fsdp_config.get("use_orig_params", False) + if self.config.actor.get("freeze_vision_tower", False): + vision_tower = get_vl_model_vision_tower(actor_module) + if vision_tower is not None: + vision_tower.requires_grad_(False) + self.use_orig_params = True + if self.rank == 0: + print("[actor model] Vision tower is set to not trainable.") + else: + if self.rank == 0: + print("[actor model] No vision tower found.") + + torch.distributed.barrier() + + if self.rank == 0: + print_model_size(actor_module) + + log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger) + + # We wrap FSDP for rollout as well + mixed_precision_config = fsdp_config.get("mixed_precision", None) + if mixed_precision_config is not None: + param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) + reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) + buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) + else: + param_dtype = torch.bfloat16 + reduce_dtype = torch.float32 + buffer_dtype = torch.float32 + + mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) + + auto_wrap_policy = get_fsdp_wrap_policy( + module=actor_module, + config=fsdp_config.get("wrap_policy", None), + is_lora=self.config.model.get("lora_rank", 0) > 0, + ) + + if self._is_rollout and self.config.rollout.name == "hf": + # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma + auto_wrap_policy = None + + if self.rank == 0: + print(f"wrap_policy: {auto_wrap_policy}") + + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) + + # TODO: add transformer policy + # We force reference policy to use CPUOffload to save memory. + # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation + cpu_offload = None if role == "actor" else CPUOffload(offload_params=True) + fsdp_strategy = self.config.actor.strategy + if fsdp_strategy == "fsdp": + actor_module_fsdp = FSDP( + actor_module, + cpu_offload=cpu_offload, + param_init_fn=init_fn, + auto_wrap_policy=auto_wrap_policy, + device_id=get_device_id(), + sharding_strategy=sharding_strategy, # zero3 + mixed_precision=mixed_precision, + sync_module_states=True, + device_mesh=self.device_mesh, + use_orig_params=self.use_orig_params, + forward_prefetch=fsdp_config.get("forward_prefetch", False), + ) + elif fsdp_strategy == "fsdp2": + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True + ) + if role == "actor" and fsdp_config.offload_policy: + cpu_offload = CPUOffloadPolicy(pin_memory=True) + self._is_offload_param = False + self._is_offload_optimizer = False + else: + cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True) + + fsdp_kwargs = { + "mesh": fsdp_mesh, + "mp_policy": mp_policy, + "offload_policy": cpu_offload, + "reshard_after_forward": fsdp_config.reshard_after_forward, + "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), + } + full_state = actor_module.state_dict() + apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config) + fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload) + actor_module_fsdp = actor_module + else: + raise NotImplementedError(f"not implement {fsdp_strategy}") + + if enable_activation_offload: + enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing) + + log_gpu_memory_usage(f"After {role} FSDP init", logger=logger) + + # TODO: add more optimizer args into config + if role == "actor" and optim_config is not None: + from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup + + actor_optimizer = optim.AdamW( + actor_module_fsdp.parameters(), + lr=optim_config.lr, + betas=optim_config.get("betas", (0.9, 0.999)), + weight_decay=optim_config.get("weight_decay", 1e-2), + ) + + total_steps = optim_config.get("total_training_steps", 0) + num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1)) + lr_scheduler_type = optim_config.get("lr_scheduler_type", "constant") + min_lr_ratio = optim_config.get("min_lr_ratio", 0.0) + num_cycles = optim_config.get("num_cycles", 0.5) + if num_warmup_steps < 0: + num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0) + num_warmup_steps = int(num_warmup_steps_ratio * total_steps) + + if self.rank == 0: + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + + if lr_scheduler_type == "constant": + actor_lr_scheduler = get_constant_schedule_with_warmup( + optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps + ) + elif lr_scheduler_type == "cosine": + actor_lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=actor_optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, + min_lr_ratio=min_lr_ratio, + num_cycles=num_cycles, + ) + else: + raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") + + log_gpu_memory_usage(f"After {role} optimizer init", logger=logger) + else: + actor_optimizer = None + actor_lr_scheduler = None + + return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config + + def _build_rollout(self, trust_remote_code=False): + from torch.distributed.device_mesh import init_device_mesh + + # 1. parse rollout and huggingface model config + rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout) + model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig) + self.model_config = model_config + + # 2. build rollout device mesh + infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size + infer_pp = self.config.rollout.pipeline_model_parallel_size + infer_world_size = infer_tp * infer_pp + dp = self.world_size // infer_world_size + assert self.world_size % infer_world_size == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}" + ) + rollout_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] + ) + rollout_name = self.config.rollout.name + + if rollout_name == "hf": + self._register_dispatch_collect_info("rollout", dp_rank=self.rank, is_collect=True) + else: + is_collect = ( + rollout_device_mesh["infer_tp"].get_local_rank() == 0 + and rollout_device_mesh["infer_pp"].get_local_rank() == 0 + ) + self._register_dispatch_collect_info( + "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + + # 3. init trainer and rollout random states + self.torch_random_states = get_torch_device().get_rng_state() + gen_dp_rank = rollout_device_mesh["dp"].get_local_rank() + get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) + + # 4. build rollout model + log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=logger) + self.rollout = get_rollout_class(rollout_config.name, rollout_config.mode)( + config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh + ) + log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=logger) + + # Full params + if torch.distributed.get_world_size() == 1 and fsdp_version(self.actor_module_fsdp) == 1: + FSDP.set_state_dict_type( + self.actor_module_fsdp, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig(), + ) + elif fsdp_version(self.actor_module_fsdp) == 1: + FSDP.set_state_dict_type( + self.actor_module_fsdp, + state_dict_type=StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(), + ) + + # used for LoRA + self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format + self.layered_summon = self.config.rollout.get("layered_summon", False) + + # 5. switch to trainer mode + # NOTE: It's critical that hybrid engine in trainer mode initially to load checkpoint. + # For sync mode, we directly switch to trainer mode here. + # For async mode, we can't call run_until_complete here, so we will switch to trainer mode in AgentLoopManager. + if rollout_config.mode == "sync" and self._is_actor: + loop = asyncio.get_event_loop() + loop.run_until_complete(self.trainer_mode()) + + async def rollout_mode(self): + """Context switch hybridengine to rollout mode.""" + aggressive_empty_cache(force_sync=True) + + log_gpu_memory_usage("Before load_fsdp_model_to_gpu", logger=logger) + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + log_gpu_memory_usage("After load_fsdp_model_to_gpu", logger=logger) + + peft_config = None + peft_model = getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) + if hasattr(peft_model, "peft_config"): # LoRA + peft_config = peft_model.peft_config.get("default", None) + params = collect_lora_params( + module=self.actor_module_fsdp, + layered_summon=self.config.rollout.get("layered_summon", False), + base_sync_done=self.base_sync_done, + ) + if not self.base_sync_done: + params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()} + else: + params = self.actor_module_fsdp.state_dict() + + params = convert_weight_keys( + params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) + ) + + # Special handling for LoRA with sleep_level=2: + # When sleep_level=2, base model weights are destroyed during each sleep cycle. + # separately collect and update LoRA weights and base model weights through their respective interfaces. + # Here: params contains LoRA weights, base_model_params contains base model weights. + if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2: + base_model_params = collect_lora_params( + module=self.actor_module_fsdp, + layered_summon=self.layered_summon, + base_sync_done=False, + ) + base_model_params = {replace_lora_wrapper(k, peft_config): v for k, v in base_model_params.items()} + base_model_params = convert_weight_keys( + base_model_params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) + ) + + log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger) + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload_fsdp_model_to_cpu", logger=logger) + + set_expandable_segments(False) + + if peft_config is not None and self.base_sync_done: + per_tensor_param = params.items() if isinstance(params, dict) else params # Fixed: handle dict case + else: + device = get_device_id() # used when fsdp2 set cpu_offload_policy + per_tensor_param = ( + (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) + for name, param in params.items() + ) + + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["weights"]) + log_gpu_memory_usage("After resume weights", logger=logger) + + if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2: + per_tensor_base_params = ( + (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) + for name, param in base_model_params.items() + ) + await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False) + del base_model_params, per_tensor_base_params + + await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done) + log_gpu_memory_usage("After update_weights", logger=logger) + del params, per_tensor_param + aggressive_empty_cache(force_sync=True) + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["kv_cache"]) + log_gpu_memory_usage("After resume kv_cache", logger=logger) + + self.base_sync_done = True + # important: need to manually set the random states of each tp to be identical. + self.torch_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.gen_random_states) + + async def trainer_mode(self): + """Context switch hybridengine to trainer mode.""" + if self.config.rollout.free_cache_engine: + log_gpu_memory_usage("Before rollout offload", logger=logger) + await self.rollout.release() + log_gpu_memory_usage("After rollout offload", logger=logger) + + self.actor_module_fsdp.train() + + # add empty cache after each compute + aggressive_empty_cache(force_sync=True) + + set_expandable_segments(True) + + # restore random states + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + from .gvpo_dp_actor import DataParallelPPOActor + + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + + override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) + use_remove_padding = self.config.model.get("use_remove_padding", False) + use_shm = self.config.model.get("use_shm", False) + use_fused_kernels = self.config.model.get("use_fused_kernels", False) + + if self._is_actor or self._is_rollout: + # we need the model for actor and rollout + if self._is_actor: + optim_config = self.config.actor.optim + fsdp_config = omega_conf_to_dataclass(self.config.actor.fsdp_config) + else: + optim_config = None + fsdp_config = FSDPEngineConfig() + + local_path = copy_to_local(self.config.model.path, use_shm=use_shm) + ( + self.actor_module_fsdp, + self.actor_optimizer, + self.actor_lr_scheduler, + self.actor_model_config, + ) = self._build_model_optimizer( + model_path=local_path, + fsdp_config=fsdp_config, + optim_config=optim_config, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="actor", + enable_activation_offload=self.config.model.get("enable_activation_offload", False), + ) + + # get the original unwrapped module + if fsdp_version(self.actor_module_fsdp) == 1: + self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during init", logger=logger) + + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.actor_optimizer) + log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) + + if self._is_actor: + actor_cfg = omega_conf_to_dataclass(self.config.actor) + self.actor = DataParallelPPOActor( + config=actor_cfg, + actor_module=self.actor_module_fsdp, + actor_optimizer=self.actor_optimizer, + device_mesh=self.device_mesh, + config_all=self.config, + ) + + if self._is_rollout: + self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) + + if self._is_ref: + ref_model_path = self.config.model.path + ref_model = self.config.ref.get("model", None) + if ref_model is not None: + ref_model_path = ref_model.get("path", self.config.model.path) + + if self.rank == 0: + print("reference model:", ref_model_path) + local_path = copy_to_local(ref_model_path, use_shm=use_shm) + self.ref_module_fsdp = self._build_model_optimizer( + model_path=local_path, + fsdp_config=omega_conf_to_dataclass(self.config.ref.fsdp_config), + optim_config=None, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="ref", + )[0] + OmegaConf.set_struct(self.config.ref, True) + with open_dict(self.config.ref): + self.config.ref.use_remove_padding = use_remove_padding + self.config.ref.use_fused_kernels = use_fused_kernels + self.ref_policy = DataParallelPPOActor( + config=self.config.ref, + actor_module=self.ref_module_fsdp, + device_mesh=self.device_mesh, + config_all=self.config, + ) + + if self._is_actor: + self.flops_counter = FlopsCounter(self.actor_model_config) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.actor_module_fsdp, + optimizer=self.actor.actor_optimizer, + lr_scheduler=self.actor_lr_scheduler, + processing_class=self.processor if self.processor is not None else self.tokenizer, + checkpoint_config=self.config.actor.checkpoint, + ) + + if not self._is_actor and self._is_rollout: + # If ActorRolloutRefWorker is initialized as a standalone rollout, + # create a checkpoint manager for FSDP model to allow loading FSDP checkpoints for rollout. + + checkpoint_contents = OmegaConf.create({"load_contents": ["model"], "save_contents": []}) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.actor_module_fsdp, + optimizer=None, + lr_scheduler=None, + processing_class=self.processor if self.processor is not None else self.tokenizer, + checkpoint_config=checkpoint_contents, + ) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="red", role="actor_update") + def update_actor(self, data: DataProto): + assert self._is_actor + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + if self._is_offload_optimizer: + load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) + + with self.ulysses_sharding_manager: + data = data.to("cpu") # data will to device with each micro batch on actor.update_policy + + # perform training + with Timer(name="update_policy", logger=None) as timer: + metrics = self.actor.update_policy(data=data) + delta_time = timer.last + global_num_tokens = data.meta_info["global_token_num"] + estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) + metrics["perf/mfu/actor"] = ( + estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size + ) + metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) + metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) + metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) + + lr = self.actor_lr_scheduler.get_last_lr()[0] + metrics["actor/lr"] = lr + self.actor_lr_scheduler.step() + + # TODO: here, we should return all metrics + output = DataProto(meta_info={"metrics": metrics}) + + output = output.to("cpu") + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during update_actor", logger=logger) + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.actor_optimizer) + log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) + + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) + @DistProfiler.annotate(color="red", role="rollout_generate") + def generate_sequences(self, prompts: DataProto): + # Support all hardwares + assert self._is_rollout + prompts = prompts.to(get_device_id()) + + meta_info = { + "eos_token_id": self.generation_config.eos_token_id + if self.generation_config is not None + else self.tokenizer.eos_token_id, + "pad_token_id": self.generation_config.pad_token_id + if self.generation_config is not None + else self.tokenizer.pad_token_id, + } + prompts.meta_info.update(meta_info) + + timing_generate = {} + if self._is_actor: # For rollout only, we do not switch context. + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self.rollout_mode()) + log_gpu_memory_usage("After switch to rollout mode", logger=logger) + + with simple_timer("generate_sequences", timing_generate): + output = self.rollout.generate_sequences(prompts=prompts) + + if self._is_actor: + loop.run_until_complete(self.trainer_mode()) + log_gpu_memory_usage("After switch to trainer mode", logger=logger) + + # We calculate the average timing across all ranks + # to make sure meta_info["timing"] is the same + timing_generate_topk_ratio, timing_generate_min, timing_generate_max = topk_reduce_ratio_min_max( + timing_generate["generate_sequences"] + ) + timing_generate = reduce_timing(timing_generate) + timing_generate.update( + { + "generation_timing/max": timing_generate_max, + "generation_timing/min": timing_generate_min, + "generation_timing/topk_ratio": timing_generate_topk_ratio, + } + ) + output.meta_info["timing"] = timing_generate + output = output.to("cpu") + + # clear kv cache + get_torch_device().empty_cache() + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") + def compute_log_prob(self, data: DataProto): + # when is_lora is True, we use the actor without lora applied to calculate the log_prob + # which is mostly used for ref log_prob calculation + assert self._is_actor + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + + # Support all hardwares + from contextlib import nullcontext + + is_lora = data.meta_info.pop("is_lora", False) + adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext() + # we should always recompute old_log_probs when it is HybridEngine + data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu + data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz + data.meta_info["temperature"] = self.config.rollout.temperature + # perform recompute log_prob + with self.ulysses_sharding_manager: + with adapter_ctx: + output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) + output = DataProto.from_dict( + tensors={"old_log_probs": output, "entropys": entropys}, + meta_info={"temperature": self.config.rollout.temperature}, + ) + + output = output.to("cpu") + + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1: + self.actor.actor_module._handle.reshard(True) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger) + + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") + def compute_ref_log_prob(self, data: DataProto): + if self._is_lora: + # if _is_lora, actor without lora applied is the ref + data.meta_info["is_lora"] = True + data = self.compute_log_prob(data) + # this old_log_probs is in fact ref_log_prob + data = DataProto.from_dict(tensors={"ref_log_prob": data.batch["old_log_probs"]}) + return data + assert self._is_ref + # else: + # otherwise, the class have a standalone ref model + + micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["temperature"] = self.config.rollout.temperature + data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz + with self.ulysses_sharding_manager: + data = data.to("cpu") # data will to device with each micro batch on ref.compute_log_prob + output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) + output = DataProto.from_dict(tensors={"ref_log_prob": output}) + + output = output.to("cpu") + + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + if self.world_size > 1: + if fsdp_version(self.ref_policy.actor_module) == 1: + self.ref_policy.actor_module._handle.reshard(True) + elif fsdp_version(self.ref_policy.actor_module) == 2: + self.ref_policy.actor_module.reshard() + + return output + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + from verl.utils.logger import log_with_rank + + # only support save and load ckpt for actor + assert self._is_actor + + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + + self.checkpoint_manager.save_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) + dist.barrier() + + if self._is_lora and hasattr(getattr(self, "actor_module", self.actor_module_fsdp), "peft_config"): + lora_save_path = os.path.join(local_path, "lora_adapter") + peft_model = getattr(self, "actor_module", self.actor_module_fsdp) + peft_config = {} + if dist.get_rank() == 0: + os.makedirs(lora_save_path, exist_ok=True) + peft_config = asdict(peft_model.peft_config.get("default", {})) + peft_config["task_type"] = peft_config["task_type"].value + peft_config["peft_type"] = peft_config["peft_type"].value + peft_config["target_modules"] = list(peft_config["target_modules"]) + try: + if fsdp_version(self.actor_module_fsdp) > 0: + self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name()) + lora_params = layered_summon_lora_params(self.actor_module_fsdp) + if dist.get_rank() == 0: + save_file(lora_params, os.path.join(lora_save_path, "adapter_model.safetensors")) + with open(os.path.join(lora_save_path, "adapter_config.json"), "w", encoding="utf-8") as f: + json.dump(peft_config, f, ensure_ascii=False, indent=4) + except Exception as e: + log_with_rank( + f"Save LoRA Adapter Error ({e})", rank=dist.get_rank(), logger=logger, log_only_rank_0=True + ) + + dist.barrier() + log_with_rank( + f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}", + rank=dist.get_rank(), + logger=logger, + log_only_rank_0=True, + ) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): + assert self._is_actor or (not self._is_actor and self._is_rollout), ( + f"Checkpoint loading is only supported for Actor or standalone Rollout Workers, but got " + f"{self._is_actor} and {self._is_rollout}" + ) + + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + + self.checkpoint_manager.load_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + + if self._is_offload_optimizer: + offload_fsdp_optimizer(self.actor_optimizer) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def start_profile(self, **kwargs) -> None: + """Start profiling for the current rank in the current training step.""" + self.profiler.start(**kwargs) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def stop_profile(self) -> None: + """Stop profiling for the current rank in the current training step.""" + self.profiler.stop() + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def dump_memory_snapshot(self, tag: str = "manual", sub_dir: str = None) -> None: + """Manually trigger a CUDA memory snapshot dump on all ranks.""" + # Memory snapshot is now handled by the profiler system + # This method is kept for backward compatibility but delegates to profiler + if hasattr(self, "profiler") and hasattr(self.profiler, "_impl"): + try: + # Try to use the profiler's memory snapshot functionality + if hasattr(self.profiler._impl, "sampler"): + out_dir = OmegaConf.select(self.config, "actor.profiler.save_path") or "." + self.profiler._impl.sampler.dump_memory_snapshot(out_dir=out_dir, tag=tag, sub_dir=sub_dir) + except Exception: + # silently ignore if profiler doesn't support memory snapshots + pass + + +class CriticWorker(Worker, DistProfilerExtension): + def __init__(self, config: FSDPCriticConfig): + Worker.__init__(self) + omega_profiler_config = config.get("profiler", {}) + profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) + ) + import torch.distributed + + self.config = config + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + backend=get_nccl_backend(), + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + self.config: FSDPCriticConfig = config + + # build device mesh for Ulysses Sequence Parallel + world_size = torch.distributed.get_world_size() + from torch.distributed.device_mesh import init_device_mesh + + fsdp_size = self.config.model.fsdp_config.fsdp_size + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + + self.ulysses_device_mesh = None + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + dp = world_size // self.ulysses_sequence_parallel_size + if self.ulysses_sequence_parallel_size > 1: + self.ulysses_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) + + # create training dispatch + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "critic", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + else: + self._register_dispatch_collect_info("critic", dp_rank=self.rank, is_collect=True) + + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + + # set FSDP offload params + self._is_offload_param = self.config.model.fsdp_config.param_offload + self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload + + # normalize config + self.config.ppo_mini_batch_size *= self.config.rollout_n + self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + if self.config.ppo_micro_batch_size is not None: + self.config.ppo_micro_batch_size //= ( + torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + ) + self.config.forward_micro_batch_size //= ( + torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + ) + self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size + self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size + + if self.config.ppo_micro_batch_size_per_gpu is not None: + assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, ( + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by " + f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + ) + assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, ( + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than " + f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + ) + self._is_lora = self.config.model.get("lora_rank", 0) > 0 + self.use_orig_params = self.config.model.fsdp_config.get("use_orig_params", False) + + def _build_critic_model_optimizer(self, config): + # the following line is necessary + from torch import optim + from torch.distributed.fsdp import MixedPrecision + + from verl.utils.model import load_valuehead_model, print_model_size + from verl.utils.torch_dtypes import PrecisionType + + use_shm = config.model.get("use_shm", False) + local_path = copy_to_local(config.model.path, use_shm=use_shm) + # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info + # using random initialized model from any architecture. May not be the same as Actor. + + tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm) + self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) + self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) + + if self.config.model.get("custom_chat_template", None) is not None: + if self.processor is not None: + self.processor.chat_template = self.config.model.custom_chat_template + else: + self.tokenizer.chat_template = self.config.model.custom_chat_template + override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) + override_config_kwargs = { + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + } + override_config_kwargs.update(override_config) + if self.rank == 0: + print(f"Critic overriding config {override_config_kwargs}") + + torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") + torch_dtype = PrecisionType.to_dtype(torch_dtype) + + from transformers import AutoConfig + + critic_model_config = AutoConfig.from_pretrained( + local_path, + attn_implementation="flash_attention_2", + trust_remote_code=config.model.get("trust_remote_code", False), + ) + # TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53 + # which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids + # Maybe support Ulysses in VisionAttention in the future and remove this patch + if self.ulysses_sequence_parallel_size > 1 and hasattr(critic_model_config, "vision_config"): + critic_model_config.vision_config._attn_implementation = "eager" + + critic_model_config.num_labels = 1 + # patch for kimi-vl + if getattr(critic_model_config, "model_type", None) == "kimi_vl": + critic_model_config.text_config.topk_method = "greedy" + + init_context = get_init_weight_context_manager( + use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh + ) + + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + critic_model_config.classifier_dropout = 0.0 + critic_model_config.hidden_dropout = "0" + critic_model_config.summary_dropout_prob = 0.0 + + critic_module = load_valuehead_model( + local_path, + torch_dtype, + critic_model_config, + config.model.get("trust_remote_code", False), + ) + + use_remove_padding = config.model.get("use_remove_padding", False) + + apply_monkey_patch( + model=critic_module, + use_remove_padding=use_remove_padding, + ulysses_sp_size=self.ulysses_sequence_parallel_size, + ) + + # some parameters may not in torch_dtype + critic_module.to(torch_dtype) + + if config.model.get("enable_gradient_checkpointing", False): + critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + if self._is_lora: + print("Applying LoRA to critic module") + critic_module.enable_input_require_grads() + # Convert config to regular Python types before creating PEFT model + lora_config = { + "task_type": TaskType.CAUSAL_LM, + "r": self.config.model.lora_rank, + "lora_alpha": self.config.model.lora_alpha, + "target_modules": convert_to_regular_types(self.config.model.target_modules), + "bias": "none", + } + critic_module = get_peft_model(critic_module, LoraConfig(**lora_config)) + + if self.rank == 0: + print_model_size(critic_module) + + self.critic_model_config = critic_model_config + + fsdp_config = self.config.model.fsdp_config + mixed_precision_config = fsdp_config.get("mixed_precision", None) + if mixed_precision_config is not None: + param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) + reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) + buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) + else: + param_dtype = torch.bfloat16 + reduce_dtype = torch.float32 + buffer_dtype = torch.float32 + + mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) + + auto_wrap_policy = get_fsdp_wrap_policy( + module=critic_module, + config=self.config.model.fsdp_config.wrap_policy, + is_lora=self.config.model.get("lora_rank", 0) > 0, + ) + + log_gpu_memory_usage("Before critic FSDP", logger=None) + + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) + + self.use_orig_params = fsdp_config.get("use_orig_params", False) + if self.config.model.get("freeze_vision_tower", False): + vision_tower = get_vl_model_vision_tower(critic_module) + if vision_tower is not None: + vision_tower.requires_grad_(False) + self.use_orig_params = True + if self.rank == 0: + print("[critic model] Vision tower is set to not trainable.") + else: + if self.rank == 0: + print("[critic model] No vision tower found.") + + # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation + if config.strategy == "fsdp": + critic_module = FSDP( + critic_module, + param_init_fn=init_fn, + use_orig_params=self.use_orig_params, + auto_wrap_policy=auto_wrap_policy, + device_id=get_device_id(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=self.config.model.fsdp_config.forward_prefetch, + device_mesh=self.device_mesh, + cpu_offload=None, + ) + elif config.strategy == "fsdp2": + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True + ) + offload_policy = None + if fsdp_config.offload_policy: + self._is_offload_param = False + self._is_offload_optimizer = False + offload_policy = CPUOffloadPolicy(pin_memory=True) + + fsdp_kwargs = { + "mesh": fsdp_mesh, + "mp_policy": mp_policy, + "offload_policy": offload_policy, + "reshard_after_forward": fsdp_config.reshard_after_forward, + "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), + } + full_state = critic_module.state_dict() + apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config) + fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy) + else: + raise NotImplementedError(f"Unknown strategy {config.strategy}") + + if config.model.get("enable_activation_offload", False): + enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False) + enable_activation_offloading(critic_module, config.strategy, enable_gradient_checkpointing) + + log_gpu_memory_usage("After critic FSDP", logger=None) + + critic_optimizer = optim.AdamW( + critic_module.parameters(), + lr=config.optim.lr, + betas=config.optim.get("betas", (0.9, 0.999)), + weight_decay=config.optim.get("weight_decay", 1e-2), + ) + + total_steps = config.optim.get("total_training_steps", 0) + num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1)) + + lr_scheduler_type = config.optim.get("lr_scheduler_type", "constant") + if num_warmup_steps < 0: + num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0) + num_warmup_steps = int(num_warmup_steps_ratio * total_steps) + + if self.rank == 0: + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + + from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup + + if lr_scheduler_type == "constant": + critic_lr_scheduler = get_constant_schedule_with_warmup( + optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps + ) + elif lr_scheduler_type == "cosine": + min_lr_ratio = config.optim.get("min_lr_ratio", 0.0) + num_cycles = config.optim.get("num_cycles", 0.5) + critic_lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=critic_optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, + min_lr_ratio=min_lr_ratio, + num_cycles=num_cycles, + ) + else: + raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") + + return critic_module, critic_optimizer, critic_lr_scheduler + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + + from verl.workers.critic import DataParallelPPOCritic + + self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( + self.config + ) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + log_gpu_memory_usage("After offload critic model during init", logger=logger) + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.critic_optimizer) + log_gpu_memory_usage("After offload critic optimizer during init", logger=logger) + + self.critic = DataParallelPPOCritic( + config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer + ) + + self.flops_counter = FlopsCounter(self.critic_model_config) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.critic_module, + optimizer=self.critic_optimizer, + lr_scheduler=self.critic_lr_scheduler, + processing_class=self.processor if self.processor is not None else self.tokenizer, + checkpoint_config=self.config.checkpoint, + ) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) + @DistProfiler.annotate(color="cyan") + def compute_values(self, data: DataProto): + if self._is_offload_param: + load_fsdp_model_to_gpu(self.critic_module) + micro_batch_size = self.config.forward_micro_batch_size_per_gpu + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz + # perform forward computation + with self.ulysses_sharding_manager: + data = data.to("cpu") # data will to device with each micro batch on critic.compute_values + values = self.critic.compute_values(data=data) + output = DataProto.from_dict(tensors={"values": values}) + + output = output.to("cpu") + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) + @DistProfiler.annotate(color="pink") + def update_critic(self, data: DataProto): + if self._is_offload_param: + load_fsdp_model_to_gpu(self.critic_module) + if self._is_offload_optimizer: + load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id()) + + # perform forward computation + with self.ulysses_sharding_manager: + data = data.to("cpu") # data will to device with each micro batch on critic.update_critic + with Timer(name="update_critic", logger=None) as timer: + metrics = self.critic.update_critic(data=data) + delta_time = timer.last + + global_num_tokens = data.meta_info["global_token_num"] + estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) + metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + + lr = self.critic_lr_scheduler.get_last_lr()[0] + metrics["critic/lr"] = lr + self.critic_lr_scheduler.step() + + output = DataProto(batch=None, meta_info={"metrics": metrics}) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.critic_optimizer) + + output = output.to("cpu") + return output + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + import torch + + if self._is_offload_param: + load_fsdp_model_to_gpu(self.critic_module) + + self.checkpoint_manager.save_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) + + torch.distributed.barrier() + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): + import torch + + if self._is_offload_param: + load_fsdp_model_to_gpu(self.critic_module) + + self.checkpoint_manager.load_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load + ) + + torch.distributed.barrier() + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.critic_module) + + if self._is_offload_optimizer: + offload_fsdp_optimizer(self.critic_optimizer) + + +# TODO(sgm): we may need to extract it to dp_reward_model.py +class RewardModelWorker(Worker, DistProfilerExtension): + """ + Note that we only implement the reward model that is subclass of AutoModelForTokenClassification. + """ + + def __init__(self, config): + Worker.__init__(self) + + omega_profiler_config = config.get("profiler", {}) + profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) + if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: + tool_config = omega_conf_to_dataclass( + omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) + ) + else: + tool_config = None + DistProfilerExtension.__init__( + self, + DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config), + ) + + import torch.distributed + + self.config = config + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + backend=get_nccl_backend(), + timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + init_method=os.environ.get("DIST_INIT_METHOD", None), + ) + + # build device mesh for Ulysses Sequence Parallel + world_size = torch.distributed.get_world_size() + from torch.distributed.device_mesh import init_device_mesh + + fsdp_size = self.config.model.fsdp_config.fsdp_size + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + + self.ulysses_device_mesh = None + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + dp = world_size // self.ulysses_sequence_parallel_size + if self.ulysses_sequence_parallel_size > 1: + self.ulysses_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) + + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + + # create training dispatch + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "reward", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + else: + self._register_dispatch_collect_info("reward", dp_rank=self.rank, is_collect=True) + + self.use_remove_padding = self.config.model.get("use_remove_padding", False) + + # normalize config + if self.config.micro_batch_size is not None: + self.config.micro_batch_size //= torch.distributed.get_world_size() + self.config.micro_batch_size_per_gpu = self.config.micro_batch_size + + def _build_model(self, config): + # the following line is necessary + from torch.distributed.fsdp import CPUOffload + from transformers import AutoConfig, AutoModelForTokenClassification + + use_shm = config.model.get("use_shm", False) + # download the checkpoint from hdfs + local_path = copy_to_local(config.model.path, use_shm=use_shm) + + if self.config.model.input_tokenizer is None: + self._do_switch_chat_template = False + else: + self._do_switch_chat_template = True + input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer, use_shm=use_shm) + self.input_tokenizer = hf_tokenizer( + input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False) + ) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False)) + + trust_remote_code = config.model.get("trust_remote_code", False) + model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) + model_config.num_labels = 1 + + # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect + init_context = get_init_weight_context_manager( + use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh + ) + + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + model_config.classifier_dropout = 0.0 + reward_module = AutoModelForTokenClassification.from_pretrained( + pretrained_model_name_or_path=local_path, + config=model_config, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + + apply_monkey_patch( + model=reward_module, + use_remove_padding=config.model.get("use_remove_padding", False), + ulysses_sp_size=self.ulysses_sequence_parallel_size, + ) + + reward_module.to(torch.bfloat16) + + auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config) + + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) + + if config.strategy == "fsdp": + reward_module = FSDP( + reward_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=get_device_id(), + sharding_strategy=sharding_strategy, # zero3 + sync_module_states=True, + cpu_offload=CPUOffload(offload_params=True), + forward_prefetch=self.config.model.fsdp_config.forward_prefetch, + device_mesh=self.device_mesh, + ) + elif config.strategy == "fsdp2": + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + cpu_offload = CPUOffloadPolicy(pin_memory=True) + fsdp_kwargs = { + "mesh": fsdp_mesh, + "offload_policy": cpu_offload, + "reshard_after_forward": config.model.fsdp_config.reshard_after_forward, + "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), + } + full_state = reward_module.state_dict() + apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config) + fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload) + else: + raise NotImplementedError(f"Unknown strategy: {config.strategy}") + return reward_module + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + self.reward_module = self._build_model(config=self.config) + + def _forward_micro_batch(self, micro_batch): + from verl.utils.attention_utils import index_first_axis, pad_input, rearrange, unpad_input + from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs + + with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) + + if self.use_remove_padding: + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + if position_ids.dim() == 3: + position_ids_rmpad = ( + index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) + .transpose(0, 1) + .unsqueeze(1) + ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + # pad and slice the inputs if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + ) + + # only pass input_ids and position_ids to enable flash_attn_varlen + output = self.reward_module( + input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False + ) + reward_rmpad = output.logits + reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) + + # gather output if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + reward_rmpad = gather_outputs_and_unpad( + reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) + + # pad it back + rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) + else: + output = self.reward_module( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ) + rm_score = output.logits # (batch_size, seq_len, 1) + rm_score = rm_score.squeeze(-1) + + # extract the result of the last valid token + eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) + rm_score = rm_score[torch.arange(batch_size), eos_mask_idx] + return rm_score + + def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): + batch_size = data.batch.batch_size[0] + # expand as token_level_reward + attention_mask = data.batch["attention_mask"] + position_ids = data.batch["position_ids"] + response_length = data.batch["responses"].shape[-1] + if position_ids.dim() == 3: # qwen2vl mrope [bs, 3, seq_len] + position_ids = position_ids[:, 0, :] + eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) + token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) + token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores + + # select the response part + token_level_scores = token_level_scores[:, -response_length:] + + return token_level_scores + + def _switch_chat_template(self, data: DataProto): + src_max_length = data.batch["attention_mask"].shape[-1] + + src_tokenizer = self.input_tokenizer + target_tokenizer = self.tokenizer + + rm_input_ids = [] + rm_attention_mask = [] + + for i in range(data.batch.batch_size[0]): + if not isinstance(data.non_tensor_batch["raw_prompt"][i], list | np.ndarray): + raise TypeError( + f"raw_prompt must be a list or numpy array, got {type(data.non_tensor_batch['raw_prompt'][i])}" + ) + + # extract raw prompt + chat: list = list(data.non_tensor_batch["raw_prompt"][i]) + + # extract response + response_ids = data.batch["responses"][i] + response_length = response_ids.shape[-1] + valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + # decode + response = src_tokenizer.decode(valid_response_ids) + # remove bos and eos + response = response.replace(src_tokenizer.eos_token, "") + + chat.append({"role": "assistant", "content": response}) + + prompt_with_chat_template = target_tokenizer.apply_chat_template( + chat, add_generation_prompt=False, tokenize=False + ) + if self.rank == 0 and i == 0: + # for debugging purpose + print(f"Switch template. chat: {prompt_with_chat_template}") + + # the maximum length is actually determined by the reward model itself + max_length = self.config.get("max_length", src_max_length) + if max_length is None: + max_length = src_max_length + + model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) + input_ids, attention_mask = verl_F.postprocess_data( + input_ids=model_inputs["input_ids"], + attention_mask=model_inputs["attention_mask"], + max_length=max_length, + pad_token_id=target_tokenizer.pad_token_id, + left_pad=False, # right padding + truncation=self.config.get("truncation", "right"), + ) # truncate from the right + + rm_input_ids.append(input_ids) + rm_attention_mask.append(attention_mask) + + rm_input_ids = torch.cat(rm_input_ids, dim=0) + rm_attention_mask = torch.cat(rm_attention_mask, dim=0) + + rm_position_ids = compute_position_id_with_mask(rm_attention_mask) + + rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids} + + return DataProto.from_dict(rm_inputs) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="reward")) + @DistProfiler.annotate(color="brown") + def compute_rm_score(self, data: DataProto): + import itertools + + from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches + + # Support all hardwares + data = data.to(get_device_id()) + if self._do_switch_chat_template: + rm_data = self._switch_chat_template(data) + else: + rm_input_ids = data.batch["input_ids"] + rm_attention_mask = data.batch["attention_mask"] + rm_position_ids = data.batch["position_ids"] + rm_inputs = { + "input_ids": rm_input_ids, + "attention_mask": rm_attention_mask, + "position_ids": rm_position_ids, + } + rm_data = DataProto.from_dict(rm_inputs) + + # Support all hardwares + rm_data = rm_data.to(get_device_id()) + + # perform forward computation + with self.ulysses_sharding_manager: + use_dynamic_bsz = self.config.use_dynamic_bsz + if use_dynamic_bsz: + max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len) + else: + micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) + output = [] + for micro_batch in micro_batches: + rm_score = self._forward_micro_batch(micro_batch) + output.append(rm_score) + scores = torch.cat(output, dim=0) # (batch_size) + + if use_dynamic_bsz: + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + scores = scores[revert_indices] + + token_level_scores = self._expand_to_token_level(data, scores) + # Note that this is only the scores, may not be the final rewards used to train RL + output = DataProto.from_dict(tensors={"rm_scores": token_level_scores}) + + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + if self.world_size > 1 and fsdp_version(self.reward_module) == 1: + self.reward_module._handle.reshard(True) + + output = output.to("cpu") + return output + + +# ================================= Async related workers ================================= +class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + async def wake_up(self): + await self.rollout_mode() + return True + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + async def sleep(self): + await self.trainer_mode() + return True + + # ============================ vLLM related ============================ + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + def get_zeromq_address(self): + return self.rollout.get_zeromq_address() + + # ============================ SGLang related ============================ + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) + async def chat_completion(self, json_request): + ret = await self.rollout.chat_completion(json_request) + return ret + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) + async def generate( + self, + prompt_ids: list[int], + sampling_params: dict[str, Any], + request_id: str, + image_data: Optional[list[Any]] = None, + ) -> list[int]: + ret = await self.rollout.generate(prompt_ids, sampling_params, request_id, image_data=image_data) + return ret diff --git a/ICL/DAPO/verl-recipe/gvpo/gvpo_main_ppo.py b/ICL/DAPO/verl-recipe/gvpo/gvpo_main_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..21b28b53b0da93109f79f35d59f51eea200dbe5f --- /dev/null +++ b/ICL/DAPO/verl-recipe/gvpo/gvpo_main_ppo.py @@ -0,0 +1,413 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other mpain. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.experimental.dataset.sampler import AbstractSampler +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.ppo.reward import load_reward_manager +from verl.trainer.ppo.utils import need_critic, need_reference_policy +from verl.utils.config import validate_config +from verl.utils.device import is_cuda_available +from verl.utils.import_utils import load_extern_type + +from .gvpo_ray_trainer import RayGVPOTrainer + + +@hydra.main(config_path="config", config_name="gvpo_trainer", version_base=None) +def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config_dict: Hydra configuration dictionary containing training parameters. + """ + run_ppo(config) + + +# Define a function to run the PPO-like training process +def run_ppo(config) -> None: + """Initialize Ray cluster and run distributed PPO training process. + + Args: + config: Training configuration object containing all necessary parameters + for distributed PPO training including Ray initialization settings, + model paths, and training hyperparameters. + """ + # Check if Ray is not initialized + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + default_runtime_env = get_ppo_ray_runtime_env() + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if ( + is_cuda_available + and config.global_profiler.tool == "nsys" + and config.global_profiler.get("steps") is not None + and len(config.global_profiler.get("steps", [])) > 0 + ): + from verl.utils.import_utils import is_nvtx_available + + assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + """Ray remote class for executing distributed PPO training tasks. + + This class encapsulates the main training logic and runs as a Ray remote actor + to enable distributed execution across multiple nodes and GPUs. + + Attributes: + role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes + mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation + """ + + def __init__(self): + self.role_worker_mapping = {} + self.mapping = {} + + def add_actor_rollout_worker(self, config): + """Add actor rollout worker based on the actor strategy.""" + from verl.single_controller.ray import RayWorkerGroup + + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + from .gvpo_fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import Role + + self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls) + + return actor_rollout_cls, ray_worker_group_cls + + def add_critic_worker(self, config): + """Add critic worker to role mapping.""" + if config.critic.strategy in {"fsdp", "fsdp2"}: + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable"]: + from .gvpo_fsdp_workers import CriticWorker + elif use_legacy_worker_impl == "disable": + from verl.workers.roles import CriticWorker + + print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + elif config.critic.strategy == "megatron": + from verl.workers.megatron_workers import CriticWorker + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import Role + + self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker) + + def init_resource_pool_mgr(self, config): + """Initialize resource pool manager.""" + from verl.trainer.ppo.ray_trainer import Role + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + # TODO Here you can use the new registration method to support dynamic registration of roles + if config.reward_model.enable_resource_pool: + if config.reward_model.n_gpus_per_node <= 0: + raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0") + if config.reward_model.nnodes <= 0: + raise ValueError("config.reward_model.nnodes must be greater than 0") + + reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes + resource_pool_spec["reward_pool"] = reward_pool + + self.mapping[Role.ActorRollout] = global_pool_id + self.mapping[Role.Critic] = global_pool_id + from verl.trainer.ppo.ray_trainer import ResourcePoolManager + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping) + return resource_pool_manager + + def add_reward_model_worker(self, config): + """Add reward model worker if enabled.""" + from verl.trainer.ppo.ray_trainer import Role + + if config.reward_model.enable: + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable"]: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from .gvpo_fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + elif use_legacy_worker_impl == "disable": + from verl.workers.roles import RewardModelWorker + + print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + if config.reward_model.enable_resource_pool: + self.mapping[Role.RewardModel] = "reward_pool" + else: + self.mapping[Role.RewardModel] = "global_pool" + + def add_ref_policy_worker(self, config, ref_policy_cls): + """Add reference policy worker if KL loss or KL reward is used.""" + from verl.trainer.ppo.ray_trainer import Role + + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls) + self.mapping[Role.RefPolicy] = "global_pool" + + def run(self, config): + """Execute the main PPO training workflow. + + This method sets up the distributed training environment, initializes + workers, datasets, and reward functions, then starts the training process. + + Args: + config: Training configuration object containing all parameters needed + for setting up and running the PPO training process. + """ + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) + self.add_critic_worker(config) + + # We should adopt a multi-source reward function here: + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # finally, we combine all the rewards together + # The reward type depends on the tag of the data + self.add_reward_model_worker(config) + + # Add a reference policy worker if KL loss or KL reward is used. + self.add_ref_policy_worker(config, actor_rollout_cls) + + # validate config + validate_config( + config=config, + use_reference_policy=need_reference_policy(self.role_worker_mapping), + use_critic=need_critic(config), + ) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + # Load the reward manager for training and validation. + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + val_reward_fn = load_reward_manager( + config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + ) + + resource_pool_manager = self.init_resource_pool_mgr(config) + + from verl.utils.dataset.rl_dataset import collate_fn + + # Create training and validation datasets. + train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True) + val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False) + train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the PPO trainer. + trainer = RayGVPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=self.role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + # Initialize the workers of the trainer. + trainer.init_workers() + + # Start the training process. + trainer.fit() + + +def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True): + """Create a dataset. + + Arguments: + data_paths: List of paths to data files. + data_config: The data config. + tokenizer (Tokenizer): The tokenizer. + processor (Processor): The processor. + + Returns: + dataset (Dataset): The dataset. + """ + from torch.utils.data import Dataset + + from verl.utils.dataset.rl_dataset import RLHFDataset + + # Check if a custom dataset class is specified in the data configuration + # and if the path to the custom class is provided + if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: + # Dynamically load the custom dataset class + dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + # Verify that the custom dataset class inherits from torch.utils.data.Dataset + if not issubclass(dataset_cls, Dataset): + raise TypeError( + f"The custom dataset class '{data_config.custom_cls.name}' from " + f"'{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset" + ) + elif "datagen" in data_config and data_config.datagen.get("path", None) is not None and is_train: + # If a data generation strategy is specified, use the DynamicGenDataset class + from verl.utils.dataset.dynamicgen_dataset import DynamicGenDataset + + dataset_cls = DynamicGenDataset + print("Using DynamicGenDataset for data generation.") + else: + # Use the default RLHFDataset class if no custom class is specified + dataset_cls = RLHFDataset + print(f"Using dataset class: {dataset_cls.__name__}") + + # Instantiate the dataset using the determined dataset class + dataset = dataset_cls( + data_files=data_paths, + tokenizer=tokenizer, + processor=processor, + config=data_config, + ) + + return dataset + + +def create_rl_sampler(data_config, dataset): + """Create a sampler for the dataset. + + Arguments: + data_config: The data config. + dataset (Dataset): The dataset. + + Returns: + sampler (Sampler): The sampler. + """ + import torch + from torch.utils.data import RandomSampler, SequentialSampler + + if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: + curriculum_class = load_extern_type( + data_config.sampler.class_path, + data_config.sampler.class_name, + ) + sampler = curriculum_class( + data_source=dataset, + data_config=data_config, + ) + assert isinstance(sampler, AbstractSampler) + assert data_config.get("dataloader_num_workers", 8) == 0, ( + "If using curriculum, num_workers must be 0 to prevent data caching. " + "If the dataloader caches data before the batch is done the " + "curriculum sampler won't have the opportunity to reorder it. " + ) + + # Use a sampler to facilitate checkpoint resumption. + # If shuffling is enabled in the data configuration, create a random sampler. + elif data_config.shuffle: + train_dataloader_generator = torch.Generator() + train_dataloader_generator.manual_seed(data_config.get("seed", 1)) + sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) + else: + # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. + sampler = SequentialSampler(data_source=dataset) + + return sampler + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/gvpo/gvpo_ray_trainer.py b/ICL/DAPO/verl-recipe/gvpo/gvpo_ray_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3cd139a49bc461d4055a6702b0db9fec0f6b03f9 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gvpo/gvpo_ray_trainer.py @@ -0,0 +1,61 @@ +import torch + +from verl import DataProto +from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance + + +class RayGVPOTrainer(RayPPOTrainer): + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): + """Reorder the data on single controller such that each dp rank gets similar total tokens""" + attention_mask = batch.batch["attention_mask"] + batch_size = attention_mask.shape[0] + global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + world_size = self.actor_rollout_wg.world_size + n = self.config.actor_rollout_ref.rollout.n + bs_per_gpu = self.config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu + assert world_size * bs_per_gpu % n == 0, ( + f"GVPO requires: world_size {world_size} * bs_per_gpu {bs_per_gpu} should be divisible by n {n}" + ) + k_partitions = batch_size // (world_size * bs_per_gpu) + phase1_seqlen_lst = [sum(global_seqlen_lst[i : i + n]) for i in range(0, len(global_seqlen_lst), n)] + if batch_size % (world_size * bs_per_gpu) == 0: + phase1_partition_lst = get_seqlen_balanced_partitions( + phase1_seqlen_lst, k_partitions=k_partitions, equal_size=True + ) + else: + if k_partitions > 0: + phase1_partition_lst = get_seqlen_balanced_partitions( + phase1_seqlen_lst[: k_partitions * world_size * bs_per_gpu / n], + k_partitions=k_partitions, + equal_size=True, + ) + else: + phase1_partition_lst = [] + phase1_partition_lst.append( + list(range(k_partitions * world_size * bs_per_gpu // n + 1, len(phase1_seqlen_lst))) + ) + + global_idx = [-1] * batch_size + for k in range(len(phase1_partition_lst)): + partition = phase1_partition_lst[k] + phase2_seqlen_lst = [global_seqlen_lst[i * n + j] for i in partition for j in range(n)] + inx = [i * n + j for i in partition for j in range(n)] + phase2_partition_lst = get_seqlen_balanced_partitions( + phase2_seqlen_lst, k_partitions=world_size, equal_size=True + ) + for i in range(len(phase2_partition_lst)): + for j in range(len(phase2_partition_lst[i])): + global_idx[i * (batch_size // world_size) + k * bs_per_gpu + j] = inx[phase2_partition_lst[i][j]] + + global_partition_lst = [ + global_idx[i * (batch_size // world_size) : (i + 1) * (batch_size // world_size)] for i in range(world_size) + ] + global_idx = torch.tensor(global_idx) + + batch.union(DataProto.from_single_dict({"uid_tensor": torch.Tensor([i // n for i in range(batch_size)])})) + batch.reorder(global_idx) + global_balance_stats = log_seqlen_unbalance( + seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix + ) + metrics.update(global_balance_stats) diff --git a/ICL/DAPO/verl-recipe/gvpo/run_qwen2-7b_math_gvpo.sh b/ICL/DAPO/verl-recipe/gvpo/run_qwen2-7b_math_gvpo.sh new file mode 100644 index 0000000000000000000000000000000000000000..b93bd77e6fda038c59c291b46df5714e95800b88 --- /dev/null +++ b/ICL/DAPO/verl-recipe/gvpo/run_qwen2-7b_math_gvpo.sh @@ -0,0 +1,52 @@ +set -x + + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +python3 -m recipe.gvpo.gvpo_main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-Math-7B \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.policy_loss.loss_mode="gvpo" \ + actor_rollout_ref.actor.gvpo_beta=0.1 \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + algorithm.norm_adv_by_std_in_grpo=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console"]' \ + trainer.project_name='verl_gvpo_example_gsm8k_math' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.test_freq=20 \ + trainer.total_epochs=15 $@ diff --git a/ICL/DAPO/verl-recipe/infigui-g1/reward_fn.py b/ICL/DAPO/verl-recipe/infigui-g1/reward_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..9d5db0eef48b780f6982202394ee2c4af3e9838d --- /dev/null +++ b/ICL/DAPO/verl-recipe/infigui-g1/reward_fn.py @@ -0,0 +1,388 @@ +# Copyright 2025 Individual Contributor: InfiX.ai +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import math +import re +from itertools import combinations + +FMT_RATIO = 1.0 +ACC_RATIO = 1.0 + + +# ============================================================================ +# Utility Functions +# ============================================================================ + + +def extract_think_format(predict_str: str) -> None | dict[str, str]: + """ + Check if the predicted string meets format requirements and extract thinking and answer parts. + + Args: + predict_str: The predicted string + + Returns: + If format requirements are met, returns a dictionary containing thinking and answer parts; + otherwise returns None + """ + if not predict_str or not isinstance(predict_str, str): + return None + + # Check if is at the beginning + if not predict_str.startswith(""): + return None + + # Check if there is ... format + pattern = r"(.*?)" + think_match = re.search(pattern, predict_str, re.DOTALL) + if not think_match: + return None + + if predict_str.count("") != 1 or predict_str.count("") != 1: + return None + + # Extract thinking content + think_content = think_match.group(1).strip() + if not think_content: + return None + + # Get content after + think_end_pos = predict_str.find("") + len("") + post_think_content = predict_str[think_end_pos:].strip() + + # Check if there is non-empty content after + if not post_think_content: + return None + + return {"think": think_content, "answer": post_think_content} + + +def extract_and_parse_json(input_string, wrapper): + """ + Try to extract and parse JSON from a string. + + Args: + input_string: The input string + wrapper: JSON wrapper symbols, can be '{}' or '[]' + + Returns: + Parsed JSON object, returns None if parsing fails + """ + if len(wrapper) != 2: + raise ValueError("Wrapper must be exactly two characters long") + + start_char, end_char = wrapper + start_index = input_string.find(start_char) + + if start_index == -1: + return None + + # Find the matching end character by balancing brackets/braces + balance = 1 + end_index = -1 + for i in range(start_index + 1, len(input_string)): + if input_string[i] == start_char: + balance += 1 + elif input_string[i] == end_char: + balance -= 1 + + if balance == 0: + end_index = i + break + + if end_index == -1: + return None + + json_string = input_string[start_index : end_index + 1] + + try: + return json.loads(json_string) + except json.JSONDecodeError: + return None + + +# ============================================================================ +# AER Reward Functions +# ============================================================================ + + +def _extract_verifiable_answer(answer): + """ + Extract and verify the format of the point list from the answer string. + + A valid format is a JSON list of dictionaries, where each dictionary + has a "point_2d" key with a list of two numbers as the value. + + Args: + answer: The answer string to extract points from + + Returns: + List of valid points or None if format is invalid + """ + points = extract_and_parse_json(answer, "[]") + if points is None or not isinstance(points, list): + return None + + # Verify each point in the list + for point in points: + if isinstance(point, dict) and "point_2d" in point: + point_2d = point["point_2d"] + if isinstance(point_2d, list) and len(point_2d) == 2: + continue + + # If any point is malformed, the whole answer is invalid + return None + + return points + + +def _format_reward(answer): + """ + Calculate the format reward for 'point' type data. + + This function is now primarily used as a check to see if the format is valid. + + Args: + answer: The answer string to validate + + Returns: + Tuple of (reward, is_collinear) where reward is 1.0 for valid format, 0.0 otherwise + """ + points = _extract_verifiable_answer(answer) + if points is None: + return 0.0, 0 + + points_2d = [item["point_2d"] for item in points] + if _check_collinear(points_2d): + return 0.0, 1 + + return 1.0, 0 + + +def _check_collinear(points_2d): + """ + Check if 3 or more points in the list are collinear on any straight line. + + This uses the cross-product method to avoid division and handle all line types. + + Args: + points_2d: A list of [x, y] coordinates + + Returns: + True if 3 or more points are collinear, False otherwise + """ + if len(points_2d) < 3: + return False + + # Iterate through all unique combinations of 3 points + for p1, p2, p3 in combinations(points_2d, 3): + x1, y1 = p1 + x2, y2 = p2 + x3, y3 = p3 + + # Check for collinearity using the cross-product method. + # If (y2 - y1) * (x3 - x1) == (y3 - y1) * (x2 - x1), the points are collinear. + # This is equivalent to checking if the area of the triangle formed by the points is 0. + if math.isclose((y2 - y1) * (x3 - x1), (y3 - y1) * (x2 - x1)): + return True + + return False + + +def _accuracy_reward(answer, ground_truth): + """ + Calculate the accuracy reward based on the symmetric zero-centered formula. + + The reward is in the range [-1, 1]. + + Args: + answer: The answer string containing predicted points + ground_truth: Ground truth bounding box dictionary + + Returns: + Tuple containing: + - accuracy (float): The calculated reward + - extracted_answer (str): The JSON string of the predicted points + - num_pred (int): The number of predicted points + - first_correct (int): 1 if the first predicted point is correct, 0 otherwise + """ + pred_points = _extract_verifiable_answer(answer) + + # If no valid points are extracted, this is considered a format error, return -1 reward + if pred_points is None: + return -1.0, "", 0, 0 + + num_pred = len(pred_points) + extracted_answer = json.dumps(pred_points) + + if num_pred == 0: + return -1.0, extracted_answer, 0, 0 + + # Find the rank 'k' of the first correct point + first_correct_rank = -1 + for i, item in enumerate(pred_points): + point_2d = item["point_2d"] + if ( + ground_truth["x1"] <= point_2d[0] <= ground_truth["x2"] + and ground_truth["y1"] <= point_2d[1] <= ground_truth["y2"] + ): + first_correct_rank = i + 1 # 1-based index + break + + # Calculate reward based on the zero-centered symmetric formula + accuracy = 0.0 + if first_correct_rank != -1: + # Case a: Correct point found (Positive reward space) + k = first_correct_rank + accuracy = 1.0 / math.sqrt(num_pred * k) + else: + # Case b: No correct point found (Negative reward space) + accuracy = -1.0 / num_pred + + first_correct = 1 if first_correct_rank == 1 else 0 + + return accuracy, extracted_answer, num_pred, first_correct + + +def calculate_point_reward(solution_str, ground_truth, extra_info=None, fmt_ratio=1.0, acc_ratio=1.0, **kwargs): + """ + Calculate the final reward for 'point' type data. + + Implements the full logic including format checks, collinearity checks, + and the zero-centered symmetric reward calculation. + + Args: + solution_str: The solution string from the model + ground_truth: Ground truth data + extra_info: Extra information dictionary + fmt_ratio: Format reward ratio + acc_ratio: Accuracy reward ratio + **kwargs: Additional keyword arguments + + Returns: + Dictionary containing detailed reward information + """ + if extra_info.get("no_think", False): + answer = solution_str + else: + solution_dict = extract_think_format(solution_str) + # If the overall 'think'/'answer' format is wrong, return score of -1 + if solution_dict is None: + return { + "score": -1.0, + "format": 0.0, + "accuracy": -1.0, + "pred": "", + "num_pred": 0, + "has_correct": 0, + "first_correct": 0, + "only_correct": 0, + "is_collinear": 0, + } + + answer = solution_dict["answer"] + + # Reuse _format_reward to check the format of the 'answer' part + # If it's invalid, return score of -1 + format_reward, is_collinear = _format_reward(answer) + if format_reward == 0.0: + return { + "score": -1.0, + "format": 0.0, + "accuracy": -1.0, + "pred": "", + "num_pred": 0, + "has_correct": 0, + "first_correct": 0, + "only_correct": 0, + "is_collinear": is_collinear, + } + + # If format is OK, calculate the accuracy reward + accuracy_reward, extracted_answer, num_pred, first_correct = _accuracy_reward(answer, ground_truth) + + return { + "score": fmt_ratio * format_reward + acc_ratio * accuracy_reward, + "format": format_reward, + "accuracy": accuracy_reward, + "pred": extracted_answer, + "num_pred": num_pred, + "has_correct": 1 if accuracy_reward > 0 else 0, + "first_correct": first_correct, + "only_correct": 1 if num_pred == 1 and accuracy_reward > 0 else 0, + "is_collinear": 0, + } + + +# ============================================================================ +# AER Reward Handler Registry +# ============================================================================ + +# Dictionary to map data_source to the respective reward calculation function +AER_REWARD_HANDLERS = { + "point": calculate_point_reward, +} + + +def aer_gui_reward_function(data_source, solution_str, ground_truth, extra_info=None, **kwargs): + """ + Main reward function dispatcher for the Adaptive Exploration Reward (AER) system. + + Delegates reward calculation to specific functions based on the data_source using a dictionary lookup. + + Args: + data_source: The source or type of the data (e.g., "point", "bbox") + solution_str: The solution string generated by the model + ground_truth: The ground truth data + extra_info: Any extra information passed along (optional) + **kwargs: Additional keyword arguments that might be passed from the PPO trainer config + + Returns: + Dictionary containing detailed reward information with keys: + - score: The final calculated reward score + - format: Format validation score + - accuracy: Accuracy score + - pred: Extracted prediction string + - num_pred: Number of predictions + - has_correct: Whether any correct prediction exists + - first_correct: Whether first prediction is correct + - only_correct: Whether only one correct prediction exists + - is_collinear: Whether points are collinear (for point type) + """ + handler = AER_REWARD_HANDLERS.get(data_source, None) + + if handler: + try: + return handler( + solution_str, ground_truth, extra_info=extra_info, fmt_ratio=FMT_RATIO, acc_ratio=ACC_RATIO, **kwargs + ) + except Exception as e: + logging.exception( + f"Error executing reward handler for data_source '{data_source}': {e}", + ) + return { + "score": -1.0, + "format": 0.0, + "accuracy": -1.0, + "pred": "", + "num_pred": 0, + "has_correct": 0, + "first_correct": 0, + "only_correct": 0, + "is_collinear": 0, + } # Return a default penalty score on error + else: + raise ValueError(f"Unknown data_source: '{data_source}'. No specific reward handler defined.") diff --git a/ICL/DAPO/verl-recipe/infigui-g1/run_3b.sh b/ICL/DAPO/verl-recipe/infigui-g1/run_3b.sh new file mode 100644 index 0000000000000000000000000000000000000000..811af25c7232aa89c2162e90aa1bfa2c94cad293 --- /dev/null +++ b/ICL/DAPO/verl-recipe/infigui-g1/run_3b.sh @@ -0,0 +1,55 @@ +#!/bin/bash +set -x +ulimit -n 65535 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=rloo \ + data.train_files=./data/omniact_grounding_filtered/omniact_filtered_train.parquet \ + data.val_files=./data/omniact_grounding_filtered/omniact_filtered_val.parquet \ + data.train_batch_size=128 \ + data.max_prompt_length=7168 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=False \ + data.truncation='error' \ + data.image_key=images \ + custom_reward_function.path=./recipe/infigui-g1/reward_fn.py \ + custom_reward_function.name=aer_gui_reward_function \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ + actor_rollout_ref.model.enable_activation_offload=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=False \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=0 \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.clip_ratio_high=0.4 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.max_num_batched_tokens=8192 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.logger=['console','wandb'] \ + trainer.project_name='infigui-g1' \ + trainer.experiment_name='3b' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=16 \ + trainer.test_freq=16 \ + trainer.total_epochs=6 diff --git a/ICL/DAPO/verl-recipe/infigui-g1/run_7b.sh b/ICL/DAPO/verl-recipe/infigui-g1/run_7b.sh new file mode 100644 index 0000000000000000000000000000000000000000..480d7bb90db67f1cb6ac7e2e49559ecf1992c58e --- /dev/null +++ b/ICL/DAPO/verl-recipe/infigui-g1/run_7b.sh @@ -0,0 +1,55 @@ +#!/bin/bash +set -x +ulimit -n 65535 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=rloo \ + data.train_files=./data/omniact_grounding_filtered/omniact_filtered_train.parquet \ + data.val_files=./data/omniact_grounding_filtered/omniact_filtered_val.parquet \ + data.train_batch_size=128 \ + data.max_prompt_length=7168 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=False \ + data.truncation='error' \ + data.image_key=images \ + custom_reward_function.path=./recipe/infigui-g1/reward_fn.py \ + custom_reward_function.name=aer_gui_reward_function \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ + actor_rollout_ref.model.enable_activation_offload=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=False \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=0 \ + actor_rollout_ref.actor.ppo_mini_batch_size=128 \ + actor_rollout_ref.actor.clip_ratio_high=0.4 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.max_num_batched_tokens=8192 \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=False \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.logger=['console','wandb'] \ + trainer.project_name='infigui-g1' \ + trainer.experiment_name='7b' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=16 \ + trainer.test_freq=16 \ + trainer.total_epochs=6 diff --git a/ICL/DAPO/verl-recipe/open_math_reasoning/README.md b/ICL/DAPO/verl-recipe/open_math_reasoning/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0f96c71fbf010ce06b776db87e107e26b92eb859 --- /dev/null +++ b/ICL/DAPO/verl-recipe/open_math_reasoning/README.md @@ -0,0 +1,68 @@ +# Open math reasoning +## Introduction +In this recipe, we perform SFT on the [open math reasoning](https://huggingface.co/datasets/nvidia/OpenMathReasoning) dataset using the new SFT trainer with backend agostic model engine. Note that our goal is not to replicate the [AIMO-2 Winning Solution](https://arxiv.org/abs/2504.16891) work, but to demonstrate a SFT demo from end to end. + +Note that you may need to modify the path as needed in the following scripts. +## Dataset Preprocessing +### Download Dataset +```bash +hf download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* --local-dir /path/to/dataset/nvidia/OpenMathReasoning +hf download math-ai/aime24 --repo-type dataset --local-dir /path/to/dataset/math-ai/aime24 +hf download math-ai/aime25 --repo-type dataset --local-dir /path/to/dataset/math-ai/aime25 +``` + +### Preprocess the dataset +```bash +python3 recipe/open_math_reasoning/prepare_nvidia-OpenMathReasoning_sft.py --local_dataset_path /path/to/nvidia/OpenMathReasoning --local_save_dir /path/to/open_math_reasoning +``` + +### Prepare the eval dataset +```bash +python3 recipe/open_math_reasoning/prepare_eval_dataset.py --local_dataset_path /path/to/dataset --local_save_dir /path/to/eval_dataset +``` + +## Train the model using SFT +```bash +export CKPT_HOME=/path/to/ckpt +export MODEL_ID=Qwen/Qwen3-8B-Base +export TRAIN_FILES=/path/to/open_math_reasoning/cot_dataset.parquet +``` + +### FSDP backend +```bash +export BACKEND=fsdp2 +bash recipe/open_math_reasoning/run_sft_qwen3_8b.sh +``` + +### Megatron backend +```bash +export BACKEND=megatron +bash recipe/open_math_reasoning/run_sft_qwen3_8b.sh +``` + +## Eval the model +### Merge checkpoint into huggingface format +FSDP backend +```bash +python -m verl.model_merger merge --backend fsdp --local_dir /path/to/ckpt/global_step_19751 --target_dir /path/to/ckpt/global_step_19751/huggingface +``` +Megatron backend +```bash +python -m verl.model_merger merge --backend megatron --local_dir /path/to/ckpt/global_step_19751 --target_dir /path/to/ckpt/global_step_19751/huggingface --use_cpu_initialization +``` + +### Generate the responses +```bash +export MODEL_PATH=/path/to/ckpt/global_step_19751/huggingface +bash recipe/open_math_reasoning/run_generation.sh +``` + +### Evaluate the responses +```bash +bash recipe/open_math_reasoning/run_eval.sh +``` + +You should see the results like: +```python +{'test_score/aime24': 0.584375, 'test_score/aime25': 0.43333333333333335} +``` diff --git a/ICL/DAPO/verl-recipe/open_math_reasoning/compute_score.py b/ICL/DAPO/verl-recipe/open_math_reasoning/compute_score.py new file mode 100644 index 0000000000000000000000000000000000000000..e1907fba44b054532647b39cffc66913ef311bae --- /dev/null +++ b/ICL/DAPO/verl-recipe/open_math_reasoning/compute_score.py @@ -0,0 +1,22 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def compute_score_data_source(data_source, response, ground_truth): + from verl.utils.reward_score.math_reward import compute_score + + if data_source in ["aime24", "aime25"]: + return compute_score(response, ground_truth) + else: + raise ValueError(f"Unknown data source: {data_source}") diff --git a/ICL/DAPO/verl-recipe/open_math_reasoning/prepare_eval_dataset.py b/ICL/DAPO/verl-recipe/open_math_reasoning/prepare_eval_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..14adf2f7d91b82056e241c7c1c0cb9684ab7c012 --- /dev/null +++ b/ICL/DAPO/verl-recipe/open_math_reasoning/prepare_eval_dataset.py @@ -0,0 +1,96 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# prepare eval dataset including AIME'24, AIME'25 + +# hf download math-ai/aime24 --repo-type dataset --local-dir /opt/tiger/datasets/math-ai/aime24 +# hf download math-ai/aime25 --repo-type dataset --local-dir /opt/tiger/datasets/math-ai/aime25 + +import os + +import datasets + +from verl.utils.reward_score.math_reward import remove_boxed + +instruction_following = "Please reason step by step, and put your final answer within \\boxed{}." + + +def make_map_fn(data_source): + def process_fn(example, idx): + question_raw = example.pop("problem") + + question = question_raw + " " + instruction_following + + if "solution" not in example: + example["solution"] = example["answer"] + + answer_raw = example.pop("solution") + + example.clear() + + try: + solution = remove_boxed(answer_raw) + except Exception: + solution = answer_raw + + data = { + "data_source": data_source, + "prompt": [ + { + "role": "user", + "content": question, + } + ], + "ability": "math", + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": { + "index": idx, + "answer": answer_raw, + "question": question_raw, + }, + } + return data + + return process_fn + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--local_dataset_path", default=None, help="The local path to the raw dataset, if it exists.") + parser.add_argument( + "--local_save_dir", default="~/data/math-ai", help="The save directory for the preprocessed dataset." + ) + + args = parser.parse_args() + + if args.local_dataset_path is not None: + aime24_dataset_path = os.path.join(args.local_dataset_path, "math-ai/aime24") + aime25_dataset_path = os.path.join(args.local_dataset_path, "math-ai/aime25") + else: + aime24_dataset_path = "math-ai/aime24" + aime25_dataset_path = "math-ai/aime25" + + aime24_dataset = datasets.load_dataset(aime24_dataset_path, split="test") + aime25_dataset = datasets.load_dataset(aime25_dataset_path, split="test") + + aime24_dataset = aime24_dataset.map(function=make_map_fn("aime24"), with_indices=True) + aime25_dataset = aime25_dataset.map(function=make_map_fn("aime25"), with_indices=True) + + local_save_dir = os.path.expanduser(args.local_save_dir) + os.makedirs(local_save_dir, exist_ok=True) + + aime24_dataset.to_parquet(os.path.join(local_save_dir, "aime24_test.parquet")) + aime25_dataset.to_parquet(os.path.join(local_save_dir, "aime25_test.parquet")) diff --git a/ICL/DAPO/verl-recipe/open_math_reasoning/prepare_nvidia-OpenMathReasoning_sft.py b/ICL/DAPO/verl-recipe/open_math_reasoning/prepare_nvidia-OpenMathReasoning_sft.py new file mode 100644 index 0000000000000000000000000000000000000000..b799be52c7e6bc571706c08d6dd13a3e4e6254bf --- /dev/null +++ b/ICL/DAPO/verl-recipe/open_math_reasoning/prepare_nvidia-OpenMathReasoning_sft.py @@ -0,0 +1,72 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +huggingface-cli download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* \ + --local-dir /path/to/nvidia/OpenMathReasoning +huggingface-cli download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* \ + --local-dir /opt/tiger/nvidia/OpenMathReasoning +""" + +import argparse +import os + +import datasets + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dataset_path", default=None, help="The local path to the raw dataset, if it exists.") + parser.add_argument( + "--local_save_dir", + default="~/data/open_math_reasoning", + help="The save directory for the preprocessed dataset.", + ) + + args = parser.parse_args() + local_dataset_path = args.local_dataset_path + + data_source = "nvidia/OpenMathReasoning" + + if local_dataset_path is not None: + dataset = datasets.load_dataset(local_dataset_path, split="cot") + else: + dataset = datasets.load_dataset(data_source, split="cot") + + def make_map_fn(split): + def process_fn(example, idx): + question = example.pop("problem") + solution = example.pop("generated_solution") + + extra_info = {} + for key, value in example.items(): + extra_info[key] = value + example.clear() + + data = { + "messages": [ + {"role": "user", "content": question, "loss_mask": 0}, + {"role": "assistant", "content": solution, "loss_mask": 1}, + ], + "extra_info": extra_info, + } + return data + + return process_fn + + # filter out data where the problem_type is not has_answer_extracted + dataset = dataset.filter(lambda example: example["problem_type"] == "has_answer_extracted") + dataset = dataset.map(function=make_map_fn("cot"), with_indices=True) + local_save_dir = os.path.expanduser(args.local_save_dir) + os.makedirs(local_save_dir, exist_ok=True) + dataset.to_parquet(os.path.join(local_save_dir, "cot_dataset.parquet")) diff --git a/ICL/DAPO/verl-recipe/open_math_reasoning/run_generation.sh b/ICL/DAPO/verl-recipe/open_math_reasoning/run_generation.sh new file mode 100644 index 0000000000000000000000000000000000000000..40187c75264a884dd5945d413bbb700db45a3628 --- /dev/null +++ b/ICL/DAPO/verl-recipe/open_math_reasoning/run_generation.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +MODEL_PATH=${MODEL_PATH:-/path/to/ckpt/global_step_19751/huggingface} + +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} +NNODES=${NNODES:-1} +OUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_8b_gen_test.parquet} +GEN_TP=${GEN_TP:-1} # Default tensor parallel size to 2 + +aime24_test_path=${HOME}/data/math-ai/aime24_test.parquet +aime25_test_path=${HOME}/data/math-ai/aime25_test.parquet +train_files="['$aime24_test_path', '$aime25_test_path']" + +python3 -m verl.trainer.main_generation_server \ + trainer.nnodes="${NNODES}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.rollout.temperature=1.0 \ + actor_rollout_ref.rollout.top_p=0.7 \ + actor_rollout_ref.rollout.prompt_length=2048 \ + actor_rollout_ref.rollout.response_length=20480 \ + actor_rollout_ref.rollout.tensor_model_parallel_size="${GEN_TP}" \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=32 \ + data.train_files="$train_files" \ + data.prompt_key=prompt \ + +data.output_path="${OUTPUT_PATH}" \ + + + diff --git a/ICL/DAPO/verl-recipe/open_math_reasoning/run_sft_qwen3_8b.sh b/ICL/DAPO/verl-recipe/open_math_reasoning/run_sft_qwen3_8b.sh new file mode 100644 index 0000000000000000000000000000000000000000..ec564a1d602d8c265b276cffae4588667905a676 --- /dev/null +++ b/ICL/DAPO/verl-recipe/open_math_reasoning/run_sft_qwen3_8b.sh @@ -0,0 +1,95 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"} + +TRAIN_FILES=${TRAIN_FILES:-/path/to/cot_dataset.parquet} + +backend=${BACKEND:-fsdp} + +project_name=verl_sft_test + +RESUME_MODE=auto +MODEL_ID=${MODEL_ID:-Qwen/Qwen3-8B-Base} + +SP_SIZE=${SP_SIZE:-8} +FSDP_SIZE=${FSDP_SIZE:-16} +FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp2"} + +TP_SIZE=${TP_SIZE:-8} +PP_SIZE=${PP_SIZE:-2} +VPP_SIZE=${VPP_SIZE:-null} +CP_SIZE=${CP_SIZE:-1} + +PAD_MODE=${PAD_MODE:-no_padding} + +USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True} + +FSDP_ENGINE_CONFIG="\ + engine=${backend} \ + optim=${backend} \ + optim.lr=2e-5 \ + optim.lr_warmup_steps_ratio=0.01 \ + optim.weight_decay=0.1 \ + optim.betas="[0.9,0.95]" \ + optim.clip_grad=1.0 \ + optim.min_lr_ratio=0.1 \ + optim.warmup_style=cosine \ + engine.ulysses_sequence_parallel_size=${SP_SIZE} \ + engine.strategy=${FSDP_STRATEGY} \ + engine.fsdp_size=${FSDP_SIZE}" + + +MEGATRON_ENGINE_CONFIG="\ + engine=${backend} \ + optim=${backend} \ + optim.lr=2e-5 \ + optim.lr_warmup_steps_ratio=0.01 \ + optim.weight_decay=0.1 \ + optim.betas="[0.9,0.95]" \ + optim.clip_grad=1.0 \ + optim.lr_warmup_init=0 \ + optim.lr_decay_style=cosine \ + optim.min_lr=2e-6 \ + engine.tensor_model_parallel_size=${TP_SIZE} \ + engine.pipeline_model_parallel_size=${PP_SIZE} \ + engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \ + engine.context_parallel_size=${CP_SIZE} \ + engine.use_mbridge=True" + +if [ "$backend" = "fsdp" ]; then + ENGINE_CONFIG="$FSDP_ENGINE_CONFIG" + echo "Using fsdp engine" + exp_name=nvidia-openmathreasoning-qwen3-8b-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp-1008a1 +else + ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG" + echo "Using megatron engine" + exp_name=nvidia-openmathreasoning-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-megatron-1018a1 +fi + +CKPT_HOME=${CKPT_HOME:-$HOME/open_verl/sft/${project_name}/${exp_name}} +mkdir -p "${CKPT_HOME}" + +torchrun --standalone --nnodes=1 --nproc-per-node=${NUM_TRAINERS:-8} \ + ${ENTRYPOINT} \ + data.train_files="${TRAIN_FILES}" \ + data.train_batch_size=96 \ + data.max_length=32768 \ + data.pad_mode=${PAD_MODE} \ + data.truncation=error \ + data.use_dynamic_bsz=True \ + data.max_token_len_per_gpu=65536 \ + data.messages_key=messages \ + model.path=$MODEL_ID \ + model.use_remove_padding=${USE_REMOVE_PADDING} \ + ${ENGINE_CONFIG} \ + trainer.test_freq=-1 \ + trainer.save_freq=4000 \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPT_HOME}" \ + trainer.resume_mode=${RESUME_MODE} \ + trainer.max_ckpt_to_keep=5 \ + checkpoint.save_contents=[model,optimizer,extra] \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/prime/main_prime.py b/ICL/DAPO/verl-recipe/prime/main_prime.py new file mode 100644 index 0000000000000000000000000000000000000000..39d20de4326bbea9957326405903abd3c5c02ecb --- /dev/null +++ b/ICL/DAPO/verl-recipe/prime/main_prime.py @@ -0,0 +1,163 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.trainer.ppo.utils import need_reference_policy +from verl.utils.config import validate_config + +from .prime_ray_trainer import RayPRIMETrainer + + +@hydra.main(config_path="config", config_name="prime_trainer", version_base=None) +def main(config): + run_prime(config) + + +def run_prime(config, compute_score=None): + if not ray.is_initialized(): + default_runtime_env = {"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}} + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + # this is for local ray cluster + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + ray.get(main_task.remote(config, compute_score)) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +def main_task(config, compute_score=None): + # print initial config + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_local_path_from_hdfs + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # define worker classes + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.fsdp_workers import ActorRolloutRefWorker + + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker + + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + } + + # use reference model + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + if config.reward_model.enable: + from .prime_fsdp_workers import PRIMERewardModelWorker + + role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # validate config + # TODO: Additional config checks can be added with proper function under prime recipe + validate_config( + config=config, + use_reference_policy=need_reference_policy(role_worker_mapping), + use_critic=False, + ) + + # download the checkpoint from hdfs + local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + from verl.utils import hf_tokenizer + + tokenizer = hf_tokenizer(local_path) + reward_manager_name = config.reward_model.get("reward_manager", "naive") + if reward_manager_name == "naive": + from verl.workers.reward_manager import NaiveRewardManager + + reward_manager_cls = NaiveRewardManager + elif reward_manager_name == "prime": + from verl.workers.reward_manager import PrimeRewardManager + + reward_manager_cls = PrimeRewardManager + else: + raise NotImplementedError + reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score) + + # Note that we always use function-based RM for validation + val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score) + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RayPRIMETrainer( + config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + ) + trainer.init_workers() + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/prime/prime_dp_rm.py b/ICL/DAPO/verl-recipe/prime/prime_dp_rm.py new file mode 100644 index 0000000000000000000000000000000000000000..67cd82e66f5d73645c9f0bd1a42799261916c6bb --- /dev/null +++ b/ICL/DAPO/verl-recipe/prime/prime_dp_rm.py @@ -0,0 +1,402 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implement a multiprocess PPOCritic +""" + +import itertools + +import torch +import torch.distributed +from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input +from torch import nn, optim +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.utils.device import get_device_name +from verl.utils.py_functional import append_to_dict +from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs + +from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm + +__all__ = ["DataParallelPRIMERewardModel"] + + +class DataParallelPRIMERewardModel: + def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer): + self.config = config + self.reward_module = reward_module + self.ref_module = ref_module + self.reward_optimizer = reward_optimizer + self.use_remove_padding = self.config.model.get("use_remove_padding", False) + print(f"Reward model use_remove_padding={self.use_remove_padding}") + self.use_fused_kernels = self.config.model.get("use_fused_kernels", False) + print(f"Reward model use_fused_kernels={self.use_fused_kernels}") + + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + + def _forward_micro_batch(self, micro_batch, prompt_length): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + + num_actions = micro_batch["input_ids"].shape[-1] - prompt_length + max_positions = micro_batch["attention_mask"][:, prompt_length:].sum(-1) + + if self.use_remove_padding: + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + # for compute the log_prob + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + + # pad and slice the inputs if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + ) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size + ) + + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) + output = self.reward_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + use_cache=False, + return_dict=True, + ) + + if self.use_fused_kernels: + rm_log_labels = output.log_probs.squeeze(0) # (total_nnz,) + rm_log_labels = rm_log_labels.to(torch.float32) + + else: + rm_output_logits = output.logits.squeeze(0) + rm_log_labels = verl_F.logprobs_from_logits( + logits=rm_output_logits, + labels=input_ids_rmpad_rolled, + ) + + if self.ulysses_sequence_parallel_size > 1: + rm_log_labels = gather_outputs_and_unpad( + rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) + rm_log_labels = pad_input( + hidden_states=rm_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ).squeeze(-1)[:, -num_actions - 1 : -1] + + else: + output = self.reward_module( + input_ids=micro_batch["input_ids"], + attention_mask=micro_batch["attention_mask"], + position_ids=micro_batch["position_ids"], + use_cache=False, + return_dict=True, + ) + + if self.use_fused_kernels: + rm_log_labels = output.log_probs[:, :-1] # (bsz, seq_length) + rm_log_labels = rm_log_labels.to(torch.float32) + + else: + rm_output_logits = output.logits + rm_log_prob = torch.nn.functional.log_softmax( + rm_output_logits[:, :-1, :], dim=-1 + ) # (batch_size, seq_length, vocab_size) + rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1)).squeeze( + -1 + ) # (batch, seq_length) + + if self.ref_module is not None: + # do not have to pad again + with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16): + if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding: + ref_output = self.ref_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + use_cache=False, + return_dict=True, + ) + + if self.use_fused_kernels: + ref_log_labels = ref_output.log_probs.squeeze(0) # (total_nnz,) + ref_log_labels = ref_log_labels.to(torch.float32) + + else: + ref_output_logits = ref_output.logits.squeeze(0) + ref_log_labels = verl_F.logprobs_from_logits( + logits=ref_output_logits, labels=input_ids_rmpad_rolled + ) + + ref_log_labels = gather_outputs_and_unpad( + ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) + ref_log_labels = pad_input( + hidden_states=ref_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen + ).squeeze(-1)[:, -num_actions - 1 : -1] + else: + ref_output = self.ref_module( + input_ids=micro_batch["input_ids"], + attention_mask=micro_batch["attention_mask"], + position_ids=micro_batch["position_ids"], + use_cache=False, + return_dict=True, + ) + + if self.use_fused_kernels: + ref_log_labels = ref_output.log_probs[:, :-1] # (batch_size, seq_length) + ref_log_labels = ref_log_labels.to(torch.float32) + + else: + ref_output_logits = ref_output.logits + ref_log_prob = torch.nn.functional.log_softmax( + ref_output_logits[:, :-1, :], dim=-1 + ) # (batch_size, seq_length, vocab_size) + ref_log_labels = ref_log_prob.gather( + dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1) + ).squeeze(-1) # (batch, seq_length) + + else: + ref_log_labels = micro_batch["old_log_probs"] + + ref_log_labels.to(rm_log_labels.dtype) + q = rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:] # this is actually diff of q + + # trim unnecessary logprobs here + for i in range(micro_batch["input_ids"].shape[0]): + q[i, max_positions[i] :] = 0 + + # reward computation does not need gradient. only q needs + with torch.no_grad(): + # generalized estimation of r should go before the reward filling. r means process reward for policy + # model, or the advantage of reward model. + lam = self.config.get("lambda", 0.0) + beta = self.config.model.get("beta_train", 0.05) + if lam == 0.0: + r = q * beta + else: + # reward coefficient takes no effect here + acc = micro_batch["acc"] + q_ = q * beta + r = torch.zeros_like(q) + lastgaelam = 0 + # change the last token and mask out all paddings to make this process easier if we rely on + # outcome reward to calculate V + for i in range(q.shape[0]): + if self.config.prime_use_gt: + q_[i, max_positions[i] - 1] = acc[i] - q_[i, : max_positions[i] - 1].sum() + q_[i, max_positions[i] :] = 0 + + for t in reversed(range(num_actions)): + delta = q_[:, t] + lastgaelam = delta + lam * lastgaelam + r[:, t] = lastgaelam + + token_level_score = torch.zeros_like(q) + + if self.config.prime_granularity == "token": + for i in range(micro_batch["input_ids"].shape[0]): + token_level_score[i, : max_positions[i] - 1] = r[i, : max_positions[i] - 1] + elif self.config.prime_granularity == "whole": + for i in range(micro_batch["input_ids"].shape[0]): + token_level_score[i, max_positions[i] - 1] = r[i, : max_positions[i]] + else: + raise NotImplementedError + + return token_level_score, q + + def _optimizer_step(self): + assert self.config.model.optim.grad_clip is not None + + if isinstance(self.reward_module, FSDP): + grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.reward_module.parameters(), max_norm=self.config.model.optim.grad_clip + ) + self.reward_optimizer.step() + return grad_norm + + def prime_norm(self, token_level_scores): + if self.config.prime_norm == "batch_norm": + reverse_cumsum = torch.cumsum(token_level_scores.flip(dims=[1]), dim=-1).flip(dims=[1]) + token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6) + return token_level_scores + + def compute_rm_score(self, data: DataProto): + self.reward_module.eval() + self.ref_module.eval() + micro_batch_size = data.meta_info["micro_batch_size"] + select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "acc"] + batch = data.select(batch_keys=select_keys).batch + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + prompt_length = data.batch["input_ids"].shape[-1] - data.batch["responses"].shape[-1] + + if use_dynamic_bsz: + # split using dynamic bsz + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + else: + micro_batches = batch.split(micro_batch_size) + + rm_scores_lst = [] + q_lst = [] + for micro_batch in micro_batches: + with torch.no_grad(): + rm_score, q = self._forward_micro_batch(micro_batch, prompt_length) + rm_scores_lst.append(rm_score) + q_lst.append(q) + rm_scores = torch.concat(rm_scores_lst, dim=0) + q = torch.concat(q_lst, dim=0) + + rm_scores = self.prime_norm(rm_scores) + + if use_dynamic_bsz: + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == rm_scores.size(0), f"{len(indices)} vs. {rm_scores.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + rm_scores = rm_scores[revert_indices] + + return ( + rm_scores, + q.detach(), + { + "reward_model/reward": rm_scores.sum(dim=-1).mean().item(), + "reward_model/raw_reward": q.sum(dim=-1).mean().item(), + }, + ) + + def update_rm(self, data: DataProto): + # make sure we are in training mode + self.reward_module.train() + metrics = {} + + beta = self.config.model.get("beta_train", 0.05) + + select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "acc", "prompts"] + + for key in ["Q_bc", "acc_bc"]: + if key in data.batch.keys(): + select_keys.append(key) + + batch = data.select(batch_keys=select_keys).batch + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + dataloader = batch.split(self.config.mini_batch_size) + + rm_scores_lst = [] + q_lst = [] + + for batch_idx, data in enumerate(dataloader): + # split batch into micro_batches + mini_batch = data + if self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + else: + micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu) + self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu + + self.reward_optimizer.zero_grad() + + for data in micro_batches: + data = data.to(get_device_name()) + attention_mask = data["attention_mask"] + acc = data["acc"] + + prompt_ids = data["prompts"] + prompt_length = prompt_ids.shape[-1] + + response_mask = attention_mask[:, prompt_length:] + + rm_score, q = self._forward_micro_batch(data, prompt_length) + + rm_scores_lst.append(rm_score) + q_lst.append(q.detach()) + + if self.config.model.loss_type == "ce": + dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta) + elif self.config.model.loss_type == "dpo": + # the implementation of dpo is actually detached, which means we have to know the average + # value of w/l reward before the update. + dpo_loss = compute_detach_dpo_loss_rm( + q, acc, Q_bc=data["Q_bc"], acc_bc=data["acc_bc"], response_mask=response_mask, beta=beta + ) + elif self.config.model.loss_type == "bon_acc": + # change the original distribution of each sample to BoN distribution, then update reward model + dpo_loss = compute_detach_dpo_loss_rm( + q, + acc, + Q_bc=data["Q_bc"], + acc_bc=data["acc_bc"], + response_mask=response_mask, + beta=beta, + bon_mode="bon_acc", + ) + elif self.config.model.loss_type == "bon_rm": + dpo_loss = compute_detach_dpo_loss_rm( + q, + acc, + Q_bc=data["Q_bc"], + acc_bc=data["acc_bc"], + response_mask=response_mask, + beta=beta, + bon_mode="bon_rm", + ) + else: + raise NotImplementedError + + data = {"reward_model/dpo_loss": dpo_loss.detach().item()} + + if self.config.use_dynamic_bsz: + # relative to the dynamic bsz + loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size) + else: + loss = dpo_loss / self.gradient_accumulation + + loss.backward() + + append_to_dict(metrics, data) + + grad_norm = self._optimizer_step() + data = {"reward_model/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, data) + self.reward_optimizer.zero_grad() + + rm_scores = torch.cat(rm_scores_lst, dim=0) + q = torch.concat(q_lst, dim=0) + + rm_scores = self.prime_norm(rm_scores) + + metrics.update( + { + "reward_model/reward": rm_scores.sum(dim=-1).mean().item(), + "reward_model/raw_reward": q.sum(dim=-1).mean().item(), + } + ) + + return rm_scores, metrics diff --git a/ICL/DAPO/verl-recipe/prime/prime_fsdp_workers.py b/ICL/DAPO/verl-recipe/prime/prime_fsdp_workers.py new file mode 100644 index 0000000000000000000000000000000000000000..410d97e7494b947bd6a955c196f86d0968702059 --- /dev/null +++ b/ICL/DAPO/verl-recipe/prime/prime_fsdp_workers.py @@ -0,0 +1,381 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import warnings + +import torch +import torch.distributed +from omegaconf import OmegaConf +from torch.distributed.device_mesh import init_device_mesh + +from verl import DataProto +from verl.models.transformers.monkey_patch import apply_monkey_patch +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, register +from verl.utils import hf_tokenizer +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.utils.device import get_device_id, get_device_name, get_nccl_backend +from verl.utils.flops_counter import FlopsCounter +from verl.utils.fs import copy_local_path_from_hdfs +from verl.utils.fsdp_utils import ( + get_fsdp_wrap_policy, + get_init_weight_context_manager, + init_fn, + load_fsdp_model_to_gpu, + load_fsdp_optimizer, + offload_fsdp_model_to_cpu, + offload_fsdp_optimizer, +) +from verl.utils.import_utils import import_external_libs +from verl.utils.profiler import log_gpu_memory_usage +from verl.workers.config.optimizer import build_optimizer +from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy +from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager + +from .prime_core_algos import compute_dpo_abs_accuracy, compute_dpo_accuracy + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class PRIMERewardModelWorker(Worker): + def __init__(self, config): + super().__init__() + import torch.distributed + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend=get_nccl_backend()) + self.config = config + + # build device mesh for Ulysses Sequence Parallel + world_size = torch.distributed.get_world_size() + + fsdp_size = self.config.model.fsdp_config.fsdp_size + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + + self.ulysses_device_mesh = None + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + dp = world_size // self.ulysses_sequence_parallel_size + if self.ulysses_sequence_parallel_size > 1: + self.ulysses_device_mesh = init_device_mesh( + get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) + + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + + # set FSDP offload params + self._is_offload_param = self.config.model.fsdp_config.param_offload + self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload + + # normalize config + self.config.mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + if self.config.micro_batch_size is not None: + self.config.micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size + self.config.micro_batch_size_per_gpu = self.config.micro_batch_size + assert self.config.mini_batch_size % self.config.micro_batch_size_per_gpu == 0 + + def _build_reward_ref_model_optimizer(self, config): + # the following line is necessary + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import MixedPrecision + + from verl.utils.model import print_model_size + from verl.utils.torch_dtypes import PrecisionType + + local_path = copy_local_path_from_hdfs(config.model.path) + + tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path) + self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) + + override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) + override_config_kwargs = { + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + } + override_config_kwargs.update(override_config) + if self.rank == 0: + print(f"Reward model overriding config {override_config_kwargs}") + + torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") + torch_dtype = PrecisionType.to_dtype(torch_dtype) + + from transformers import AutoConfig, AutoModelForCausalLM + + trust_remote_code = False + reward_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) + reward_model_config.num_labels = 1 + + init_context = get_init_weight_context_manager(use_meta_tensor=not reward_model_config.tie_word_embeddings) + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + reward_model_config.classifier_dropout = 0.0 + reward_model_config.hidden_dropout = "0" + reward_module = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=local_path, + torch_dtype=torch_dtype, + config=reward_model_config, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + + fused_kernel_options = config.model.get("fused_kernel_options", None) + fused_kernels_backend = ( + fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None + ) + + apply_monkey_patch( + model=reward_module, + ulysses_sp_size=self.ulysses_sequence_parallel_size, + use_remove_padding=config.model.get("use_remove_padding", False), + use_fused_kernels=config.model.get("use_fused_kernels", False), + fused_kernels_backend=fused_kernels_backend, + ) + + # some parameters may not in torch_dtype + reward_module.to(torch_dtype) + + if config.model.get("enable_gradient_checkpointing", False): + reward_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + if self.rank == 0: + print_model_size(reward_module) + + self.reward_model_config = reward_model_config + + fsdp_config = self.config.model.fsdp_config + mixed_precision_config = fsdp_config.get("mixed_precision", None) + if mixed_precision_config is not None: + param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) + reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) + buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) + else: + param_dtype = torch.bfloat16 + reduce_dtype = torch.float32 + buffer_dtype = torch.float32 + + mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) + + auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config.wrap_policy) + + log_gpu_memory_usage("Before reward model FSDP", logger=None) + + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) + + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + reward_model_config.classifier_dropout = 0.0 + reward_model_config.hidden_dropout = "0" + ref_module = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=copy_local_path_from_hdfs(config.model.ref_path), + torch_dtype=torch_dtype, + config=reward_model_config, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + + # some parameters may not in torch_dtype + ref_module.to(torch_dtype) + + reward_module = FSDP( + reward_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=get_device_id(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=False, + device_mesh=self.device_mesh, + cpu_offload=None, + ) + + log_gpu_memory_usage("After reward FSDP", logger=None) + + ref_module = FSDP( + ref_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=get_device_id(), + sharding_strategy=sharding_strategy, + mixed_precision=mixed_precision, + sync_module_states=True, + forward_prefetch=False, + device_mesh=self.device_mesh, + cpu_offload=None, + ) + + reward_optimizer = build_optimizer(reward_module.parameters(), config.model.optim) + + total_steps = config.model.optim.get("total_training_steps", 0) + num_warmup_steps = int(config.model.optim.get("lr_warmup_steps", -1)) + if num_warmup_steps < 0: + num_warmup_steps_ratio = config.model.optim.get("lr_warmup_steps_ratio", 0.0) + num_warmup_steps = int(num_warmup_steps_ratio * total_steps) + + print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + + from verl.utils.torch_functional import get_constant_schedule_with_warmup + + reward_lr_scheduler = get_constant_schedule_with_warmup( + optimizer=reward_optimizer, num_warmup_steps=num_warmup_steps + ) + + return reward_module, ref_module, reward_optimizer, reward_lr_scheduler + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + + from .prime_dp_rm import DataParallelPRIMERewardModel + + self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = ( + self._build_reward_ref_model_optimizer(config=self.config) + ) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.reward_module) + offload_fsdp_model_to_cpu(self.ref_module) + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.reward_optimizer) + + self.rm = DataParallelPRIMERewardModel( + config=self.config, + reward_module=self.reward_module, + ref_module=self.ref_module, + reward_optimizer=self.reward_optimizer, + ) + + self.flops_counter = FlopsCounter(self.reward_model_config) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.reward_module, + optimizer=self.reward_optimizer, + lr_scheduler=self.reward_lr_scheduler, + tokenizer=self.tokenizer, + ) + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + def compute_rm_score(self, data: DataProto): + data = data.to(get_device_name()) + + if self._is_offload_param: + load_fsdp_model_to_gpu(self.reward_module) + load_fsdp_model_to_gpu(self.ref_module) + torch.distributed.barrier() + micro_batch_size = self.config.micro_batch_size_per_gpu + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz + # perform forward computation + with self.ulysses_sharding_manager: + data = self.ulysses_sharding_manager.preprocess_data(data=data) + rm_scores, q, metrics = self.rm.compute_rm_score(data=data) + + prompt_length = data.batch["prompts"].shape[-1] + response_mask = data.batch["attention_mask"][:, prompt_length:] + acc = data.batch["acc"] + + dpo_acc = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"]) + dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info["n"]) + + metrics["reward_model/dpo_acc"] = dpo_acc.detach().item() + metrics["reward_model/dpo_acc_abs"] = dpo_acc_abs.detach().item() + + output = DataProto.from_dict(tensors={"rm_scores": rm_scores, "q": q}, meta_info={"metrics": metrics}) + output = self.ulysses_sharding_manager.postprocess_data(data=output) + + output = output.to("cpu") + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.reward_module) + offload_fsdp_model_to_cpu(self.ref_module) + torch.distributed.barrier() + torch.cuda.empty_cache() + return output + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + def update_rm(self, data: DataProto): + data = data.to(get_device_name()) + if self._is_offload_param: + load_fsdp_model_to_gpu(self.ref_module) + load_fsdp_model_to_gpu(self.reward_module) + torch.distributed.barrier() + if self._is_offload_optimizer: + load_fsdp_optimizer(optimizer=self.reward_optimizer, device_id=get_device_id()) + + # perform forward computation + with self.ulysses_sharding_manager: + data = self.ulysses_sharding_manager.preprocess_data(data=data) + + rm_scores, metrics = self.rm.update_rm(data=data) + + self.reward_lr_scheduler.step() + lr = self.reward_lr_scheduler.get_last_lr()[0] + metrics["rm/lr"] = lr + + prompt_length = data.batch["prompts"].shape[-1] + response_mask = data.batch["attention_mask"][:, prompt_length:] + acc = data.batch["acc"] + + dpo_acc_before = compute_dpo_accuracy( + rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"] + ) + dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info["n"]) + + metrics["reward_model/dpo_acc_before"] = dpo_acc_before.detach().item() + metrics["reward_model/dpo_acc_abs_before"] = dpo_acc_abs.detach().item() + + output = DataProto.from_dict(tensors={"rm_scores": rm_scores}, meta_info={"metrics": metrics}) + output = self.ulysses_sharding_manager.postprocess_data(data=output) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.reward_module) + offload_fsdp_model_to_cpu(self.ref_module) + torch.distributed.barrier() + torch.cuda.empty_cache() + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.reward_optimizer) + output = output.to("cpu") + return output + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + import torch + + if self._is_offload_param: + load_fsdp_model_to_gpu(self.reward_module) + + self.checkpoint_manager.save_checkpoint( + local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep + ) + + torch.distributed.barrier() + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.reward_module) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def load_checkpoint(self, local_path, del_local_after_load=True): + import torch + + if self._is_offload_param: + load_fsdp_model_to_gpu(self.reward_module) + + self.checkpoint_manager.load_checkpoint(local_path=local_path, del_local_after_load=del_local_after_load) + + torch.distributed.barrier() + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.reward_module) diff --git a/ICL/DAPO/verl-recipe/prime/prime_ray_trainer.py b/ICL/DAPO/verl-recipe/prime/prime_ray_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..df53b0ec296621295722265f35a77af4c87ba4ae --- /dev/null +++ b/ICL/DAPO/verl-recipe/prime/prime_ray_trainer.py @@ -0,0 +1,597 @@ +# Copyright 2024 PRIME team and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +FSDP PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import os +import statistics +import uuid +from copy import deepcopy +from pprint import pprint + +import numpy as np +import torch +from omegaconf import OmegaConf, open_dict + +from verl import DataProto +from verl.single_controller.ray import RayWorkerGroup +from verl.trainer.ppo.core_algos import agg_loss +from verl.trainer.ppo.metric_utils import _compute_response_info +from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager +from verl.trainer.ppo.utils import Role, WorkerType +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path +from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn +from verl.utils.metric import reduce_metrics +from verl.utils.profiler.performance import simple_timer + +from . import prime_core_algos + + +def compute_advantage(data: DataProto, adv_estimator, config): + if adv_estimator == "rloo": + responses = data.batch["responses"] + response_length = responses.size(-1) + attention_mask = data.batch["attention_mask"] + response_mask = attention_mask[:, -response_length:] + advantages, returns = prime_core_algos.compute_rloo_advantage_return( + data, response_mask, config.actor_rollout_ref.rollout.n, config + ) + data.batch["advantages"] = advantages + data.batch["returns"] = returns + data.batch["response_mask"] = response_mask + else: + raise NotImplementedError + return data + + +def compute_data_metrics(batch, use_critic=True): + advantages = batch.batch["advantages"] + returns = batch.batch["returns"] + + max_response_length = batch.batch["responses"].shape[-1] + + prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() + response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() + + max_prompt_length = prompt_mask.size(-1) + + response_info = _compute_response_info(batch) + prompt_length = response_info["prompt_length"] + response_length = response_info["response_length"] + + valid_adv = torch.masked_select(advantages, response_mask) + valid_returns = torch.masked_select(returns, response_mask) + + if use_critic: + values = batch.batch["values"] + valid_values = torch.masked_select(values, response_mask) + return_diff_var = torch.var(valid_returns - valid_values) + return_var = torch.var(valid_returns) + + metrics = { + # adv + "critic/advantages/mean": torch.mean(valid_adv).detach().item(), + "critic/advantages/max": torch.max(valid_adv).detach().item(), + "critic/advantages/min": torch.min(valid_adv).detach().item(), + # returns + "critic/returns/mean": torch.mean(valid_returns).detach().item(), + "critic/returns/max": torch.max(valid_returns).detach().item(), + "critic/returns/min": torch.min(valid_returns).detach().item(), + **( + { + # values + "critic/values/mean": torch.mean(valid_values).detach().item(), + "critic/values/max": torch.max(valid_values).detach().item(), + "critic/values/min": torch.min(valid_values).detach().item(), + # vf explained var + "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } + if use_critic + else {} + ), + # response length + "response_length/mean": torch.mean(response_length).detach().item(), + "response_length/max": torch.max(response_length).detach().item(), + "response_length/min": torch.min(response_length).detach().item(), + "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) + .detach() + .item(), + # prompt length + "prompt_length/mean": torch.mean(prompt_length).detach().item(), + "prompt_length/max": torch.max(prompt_length).detach().item(), + "prompt_length/min": torch.min(prompt_length).detach().item(), + "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + } + return metrics + + +def compute_response_mask(data: DataProto): + responses = data.batch["responses"] + response_length = responses.size(1) + attention_mask = data.batch["attention_mask"] + return attention_mask[:, -response_length:] + + +def compute_timing_metrics(batch, timing_raw): + response_info = _compute_response_info(batch) + num_prompt_tokens = torch.sum(response_info["prompt_length"]).item() + num_response_tokens = torch.sum(response_info["response_length"]).item() + num_overall_tokens = num_prompt_tokens + num_response_tokens + + num_tokens_of_section = { + "gen": num_response_tokens, + **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]}, + } + + return { + **{f"timing_s/{name}": value for name, value in timing_raw.items()}, + **{ + f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] + for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) + }, + } + + +class RayPRIMETrainer(RayPPOTrainer): + """ + Note that this trainer runs on the driver process on a single CPU/GPU node. + """ + + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + reward_fn=None, + val_reward_fn=None, + device_name="cuda", + ): + # assert get_torch_device().is_available(), 'cuda must be available on driver' + + super().__init__( + config, + tokenizer, + role_worker_mapping, + resource_pool_manager, + ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + device_name=device_name, + ) + + self.use_critic = False + + def _create_dataloader(self, *args, **kwargs): + from torch.utils.data import DataLoader, RandomSampler, SequentialSampler + + # TODO: we have to make sure the batch size is divisible by the dp size + self.train_dataset = RLHFDataset( + data_files=self.config.data.train_files, tokenizer=self.tokenizer, config=self.config.data + ) + # use sampler for better ckpt resume + if self.config.data.shuffle: + train_dataloader_generator = torch.Generator() + seed = self.config.data.get("seed") + if seed is not None: + train_dataloader_generator.manual_seed(seed) + sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator) + else: + sampler = SequentialSampler(data_source=self.train_dataset) + + self.train_dataloader = DataLoader( + dataset=self.train_dataset, + batch_size=int(self.config.data.train_batch_size * self.config.data.oversample_factor), + drop_last=True, + collate_fn=collate_fn, + sampler=sampler, + ) + + self.val_dataset = RLHFDataset( + data_files=self.config.data.val_files, tokenizer=self.tokenizer, config=self.config.data + ) + self.val_dataloader = DataLoader( + dataset=self.val_dataset, + batch_size=len(self.val_dataset), + shuffle=True, + drop_last=True, + collate_fn=collate_fn, + ) + + assert len(self.train_dataloader) >= 1 + assert len(self.val_dataloader) >= 1 + + print(f"Size of train dataloader: {len(self.train_dataloader)}") + print(f"Size of val dataloader: {len(self.val_dataloader)}") + + # inject total_training_steps to actor/critic optim_config. This is hacky. + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f"Total training steps: {self.total_training_steps}") + + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + self.config.critic.optim.total_training_steps = total_training_steps + + def _save_checkpoint(self): + # path: given_path + `/global_step_{global_steps}` + `/actor` + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) + print(f"local_global_step_folder: {local_global_step_folder}") + actor_local_path = os.path.join(local_global_step_folder, "actor") + + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") + ) + self.actor_rollout_wg.save_checkpoint( + actor_local_path, + actor_remote_path, + self.global_steps, + ) + + if self.use_rm: + reward_local_path = os.path.join(local_global_step_folder, "reward") + reward_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "reward") + ) + self.rm_wg.save_checkpoint( + reward_local_path, + reward_remote_path, + self.global_steps, + ) + + # save dataloader + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") + import dill + + torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill) + + # latest checkpointed iteration tracker (for atomic usage) + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) + temp_tracker = local_latest_checkpointed_iteration + ".tmp" + with open(temp_tracker, "w") as f: + f.write(str(self.global_steps)) + os.replace(temp_tracker, local_latest_checkpointed_iteration) + + def _load_checkpoint(self): + if self.config.trainer.resume_mode == "disable": + return 0 + + # load from hdfs + if self.config.trainer.default_hdfs_dir is not None: + NotImplementedError("load from hdfs is not implemented yet") + else: + checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + + # find global_step_folder + if self.config.trainer.resume_mode == "auto": + if global_step_folder is None: + print("Training from scratch") + return 0 + else: + if self.config.trainer.resume_mode == "resume_path": + assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) + global_step_folder = self.config.trainer.resume_from_path + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + print(f"Load from checkpoint folder: {global_step_folder}") + # set global step + self.global_steps = int(global_step_folder.split("global_step_")[-1]) + + print(f"Setting global step to {self.global_steps}") + print(f"Resuming from {global_step_folder}") + + actor_path = os.path.join(global_step_folder, "actor") + reward_path = os.path.join(global_step_folder, "reward") + # load actor + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + # load rm + if self.use_rm: + self.rm_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) + + # load dataloader, + # TODO: from remote not implemented yet + dataloader_local_path = os.path.join(global_step_folder, "data.pt") + self.train_dataloader = torch.load(dataloader_local_path) + if isinstance(self.train_dataloader.dataset, RLHFDataset): + self.train_dataloader.dataset.resume_dataset_state() + + def compute_reward(self, batch: DataProto, n_samples: int): + update_style = self.config.reward_model.model.get("update", "none") + reward_output_metrics = {} + if update_style == "none": # only run forward + reward_output = self.rm_wg.compute_rm_score(batch) + elif update_style == "after": # update and directly return the reward + reward_output = self.rm_wg.update_rm(batch) + elif update_style == "before": # update reward model, and then run forward + reward_output = self.rm_wg.update_rm(batch) + if "metrics" in reward_output.meta_info.keys(): + reward_output_metrics = reduce_metrics(reward_output.meta_info["metrics"]) + + reward_output = self.rm_wg.compute_rm_score(batch) + elif update_style == "reverse": # run forward to calculate statistics, then update reward model + reward_output = self.rm_wg.compute_rm_score(batch) + + # broadcast q and acc tensor to each result + bc_td = DataProto.from_dict( + tensors={ + "Q_bc": reward_output.batch["q"] + .sum(dim=-1) + .view(-1, n_samples) + .unsqueeze(1) + .expand(-1, n_samples, -1) + .reshape(-1, n_samples), + "acc_bc": batch.batch["acc"] + .view(-1, n_samples) + .unsqueeze(1) + .expand(-1, n_samples, -1) + .reshape(-1, n_samples), + } + ) + batch = batch.union(bc_td) + reward_output = self.rm_wg.update_rm(batch) + else: + raise NotImplementedError + + return reward_output, reward_output_metrics + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC to + construct the PPO dataflow. The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + # we start from step 1 + self.global_steps += 1 + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # pop those keys for generation + gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"]) + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + with simple_timer("step", timing_raw): + # generate a batch + with simple_timer("gen", timing_raw): + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == "remax": + with simple_timer("gen_max", timing_raw): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + + batch = batch.union(gen_baseline_output) + rm_scores, _ = self.compute_reward(batch, 1) + reward_baseline_tensor = rm_scores.batch.get( + "rm_scores", rm_scores.batch.get("acc_bc", None) + ) + if reward_baseline_tensor is None: + raise ValueError( + "Neither 'rm_scores' nor 'acc_bc' found in reward model output for baseline." + ) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + keys_to_pop = set(gen_baseline_output.batch.keys()) + keys_to_pop.update(rm_scores.batch.keys()) + batch.pop(batch_keys=list(keys_to_pop)) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + # verify + with simple_timer("verify", timing_raw): + scores = self.reward_fn.verify(batch) + metrics["acc"] = statistics.mean(scores) + + # filter the batch. 1/oversample_factor samples will be kept. + # If there is a filter, prompts passing it will be prioritized. + + batch = self.filter_and_downsample(scores, batch) + batch.meta_info["n"] = self.config.actor_rollout_ref.rollout.n + n_samples = self.config.actor_rollout_ref.rollout.n + + # recompute old_log_probs + with simple_timer("old_log_prob", timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = compute_response_mask(batch) + actor_config = self.config.actor_rollout_ref.actor + entropy_agg = agg_loss( + loss_mat=entropys, + loss_mask=response_masks, + loss_agg_mode=actor_config.loss_agg_mode, + loss_scale_factor=actor_config.loss_scale_factor, + ) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with simple_timer("ref", timing_raw): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + with simple_timer("adv", timing_raw): + if self.use_rm: + reward_output, reward_output_metrics = self.compute_reward(batch, n_samples) + batch = batch.union(reward_output) + if "metrics" in reward_output.meta_info.keys(): + reward_output_metrics.update(reduce_metrics(reward_output.meta_info["metrics"])) + metrics.update(reward_output_metrics) + + # compute advantages, executed on the driver process + batch = compute_advantage( + batch, adv_estimator=self.config.algorithm.adv_estimator, config=self.config + ) + + # update actor + with simple_timer("update_actor", timing_raw): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and self.global_steps % self.config.trainer.test_freq == 0 + ): + with simple_timer("testing", timing_raw): + val_metrics: dict = self._validate() + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0: + with simple_timer("save_checkpoint", timing_raw): + self._save_checkpoint() + + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + self.global_steps += 1 + + if self.global_steps >= self.total_training_steps: + # perform validation after training + if self.val_reward_fn is not None: + val_metrics = self._validate() + pprint(f"Final validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if ( + self.config.trainer.save_freq > 0 + and (self.global_steps - 1) % self.config.trainer.save_freq != 0 + ): + with simple_timer("save_checkpoint", timing_raw): + self._save_checkpoint() + return + + def filter_and_downsample(self, scores, batch: DataProto): + """ + downsample the batch according to oversample_factor + samples passing the filters will be prioritized + """ + n_samples = int(self.config.actor_rollout_ref.rollout.n) + reward_matrix = torch.tensor(scores).reshape(-1, n_samples) + + filter_mask = torch.ones((reward_matrix.shape[0]), dtype=torch.bool) + + if self.config.data.filter_accuracy: + acc_tensor = torch.mean(reward_matrix, dim=-1) + filter_mask[ + (acc_tensor > self.config.data.accuracy_upper_bound) + | (acc_tensor < self.config.data.accuracy_lower_bound) + ] = False + + if self.config.data.filter_truncate: + length_matrix = ( + batch.batch["attention_mask"][:, -batch.batch["responses"].shape[-1] :] + .sum(dim=-1) + .reshape(-1, n_samples) + ) + length_tensor = torch.max(length_matrix, dim=-1)[0] + filter_mask[length_tensor >= self.config.data.max_response_length - 1] = False + + reorder_index = torch.argsort(filter_mask, descending=True) + reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples).unsqueeze(0)).view(-1) + batch.reorder( + reorder_index[: int(len(batch) // self.config.data.oversample_factor)] + ) # this operation is inplace + + return batch diff --git a/ICL/DAPO/verl-recipe/prime/run_prime_qwen.sh b/ICL/DAPO/verl-recipe/prime/run_prime_qwen.sh new file mode 100644 index 0000000000000000000000000000000000000000..145f31b7bada41456f2b5b069016a51eeab82602 --- /dev/null +++ b/ICL/DAPO/verl-recipe/prime/run_prime_qwen.sh @@ -0,0 +1,64 @@ +set -x + + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet + +# download from https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +model_path=PRIME-RL/Eurus-2-7B-SFT +# model_path=Qwen/Qwen2.5-0.5B-Instruct + +python3 -m recipe.prime.main_prime \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=64 \ + data.val_batch_size=6312 \ + data.max_prompt_length=1024 \ + data.max_response_length=3072 \ + data.filter_overlong_prompts=True \ + data.filter_accuracy=True \ + data.accuracy_lower_bound=0.2 \ + data.accuracy_upper_bound=0.8 \ + data.oversample_factor=4 \ + actor_rollout_ref.model.path=$model_path \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + algorithm.adv_estimator=rloo \ + algorithm.use_kl_in_reward=True \ + algorithm.kl_penalty=kl \ + algorithm.kl_ctrl.kl_coef=0.001 \ + reward_model.model.path=$model_path \ + reward_model.micro_batch_size_per_gpu=1 \ + reward_model.model.update=before \ + reward_model.model.beta_train=0.05 \ + reward_model.model.optim.lr=1e-6 \ + reward_model.model.optim.grad_clip=10.0 \ + reward_model.model.input_tokenizer=null \ + reward_model.mini_batch_size=64 \ + trainer.val_before_train=False \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='prime_example' \ + trainer.experiment_name='Eurus-2-7B-SFT-gsm8k' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=64 \ + trainer.test_freq=64 \ + trainer.total_epochs=15 $@ diff --git a/ICL/DAPO/verl-recipe/pyproject.toml b/ICL/DAPO/verl-recipe/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..1426b507b0e06d6025f004248848cef730b53325 --- /dev/null +++ b/ICL/DAPO/verl-recipe/pyproject.toml @@ -0,0 +1,113 @@ +# ------------------------------- +# build-system +# ------------------------------- +[build-system] +requires = [ + "setuptools>=61.0", + "wheel" +] +build-backend = "setuptools.build_meta" + +# ------------------------------- +# project (PEP 621 metadata) +# ------------------------------- +[project] +name = "verl" +# We'll mark the version as "dynamic" because it's read from the file "verl/version/version" +# (PEP 621 calls this "dynamic version"). +# The actual version is specified in the [tool.setuptools.dynamic] section below. +dynamic = ["version", "dependencies", "optional-dependencies", "authors", "urls"] + +description = "verl: Volcano Engine Reinforcement Learning for LLM" +license = {text = "Apache-2.0"} # Changed from file to text format +readme = {file = "README.md", content-type = "text/markdown"} +requires-python = ">=3.10" + +# ------------------------------- +# tool.ruff - Linting configuration +# ------------------------------- +[tool.ruff] +# Note: While the formatter will attempt to format lines such that they remain within the line-length, +# it isn't a hard upper bound, and formatted lines may exceed the line-length. +line-length = 120 +exclude = ["tests/workers/rollout/test_sglang_async_rollout_sf_tools.py", "scripts/legacy_model_merger.py"] + +[tool.ruff.lint] +isort = {known-first-party = ["verl"]} +# c.f. https://github.com/vllm-project/vllm/blob/ce8d6b75fc0586045df75ee1568a5b5f9957251b/pyproject.toml +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # isort + "I", + "G", +] +ignore = [ + # star imports + "F405", "F403", + # lambda expression assignment + "E731", + # Loop control variable not used within loop body + "B007", + # f-string format + "UP032", + # `.log()` statement uses f-string + "G004", + # X | None for type annotations + "UP045", + # deprecated import + "UP035", +] + +# ------------------------------- +# tool.mypy - typechecking config +# ------------------------------- +[tool.mypy] +pretty = true +ignore_missing_imports = true +explicit_package_bases = true +follow_imports = "skip" + +# Blanket silence +ignore_errors = true + +[[tool.mypy.overrides]] +module = [ +"verl.trainer.config.algorithm", +"verl.trainer.ppo.core_algos", +"verl.trainer.ppo.reward", +"verl.workers.reward_manager", +"verl.workers.reward_manager.*", +] +ignore_errors = false + +# ------------------------------- +# tool.setuptools - Additional config +# ------------------------------- +[tool.setuptools] +# True means `setuptools` will attempt to include all relevant files in package_data automatically. +# This corresponds to `include_package_data=True` in setup.py. +include-package-data = true + +# We read the version from a file in 'verl/version/version' +[tool.setuptools.dynamic] +version = {file = "verl/version/version"} + +# If you need to mimic `package_dir={'': '.'}`: +[tool.setuptools.package-dir] +"" = "." + +# If you need to include specific non-Python data (like YAML files or version file): +# This is the rough equivalent of package_data={'': ['version/*'], 'verl': ['trainer/config/*.yaml']} +[tool.setuptools.package-data] +verl = [ + "version/*", + "trainer/config/*.yaml", + "trainer/config/*/*.yaml", +] diff --git a/ICL/DAPO/verl-recipe/qat/README.md b/ICL/DAPO/verl-recipe/qat/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2fe7d6cfb182aac6e3f499048636d55de4d38190 --- /dev/null +++ b/ICL/DAPO/verl-recipe/qat/README.md @@ -0,0 +1,213 @@ +# NVFP4 QAT (Quantization-Aware Training) + +This module provides **NVFP4 W4A16 Quantization-Aware Training (QAT)**, enabling seamless integration between FSDP distributed training and vLLM inference engine. This allows **NVFP4 quantized** inference during training without causing KL divergence explosion. + +**Dependency**: Requires `vllm==0.15.0`. + +--- + +## Data Preparation + +QAT Recipe reuses the DAPO dataset. Download the data before running: + +```bash +bash recipe/dapo/prepare_dapo_data.sh +``` + +--- + +## Quick Start + +### Qwen3-30B-A3B-Base W4A16 Full Quantization + +```bash +bash recipe/qat/run_qwen3_30b_w4a16.sh +``` + +### Qwen3-30B-A3B-Base W4A16 FFN-only Quantization + +```bash +bash recipe/qat/run_qwen3_30b_w4a16_FFN_only.sh +``` + +--- + +## Key Parameters + +### QAT Configuration + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `qat.enable` | Enable QAT | `False` | +| `qat.mode` | Quantization mode | `"w4a16"` | +| `qat.group_size` | Quantization group size | `16` | +| `qat.ignore_patterns` | Layer name patterns to ignore (supports regex) | `["lm_head", "embed_tokens", "re:.*mlp.gate$"]` | +| `qat.quantization_config_path` | vLLM quantization config JSON path | `recipe/qat/config/nvfp4_w4a16.json` | + +### YAML Configuration Examples + +**Full Quantization**: + +```yaml +actor_rollout_ref: + actor: + qat: + enable: true + mode: "w4a16" + group_size: 16 + ignore_patterns: + - "lm_head" + - "embed_tokens" + - "re:.*mlp.gate$" + quantization_config_path: "recipe/qat/config/nvfp4_w4a16.json" +``` + +**FFN-only Quantization** (exclude Attention Linear layers): + +```yaml +actor_rollout_ref: + actor: + qat: + enable: true + mode: "w4a16" + group_size: 16 + ignore_patterns: + - "lm_head" + - "embed_tokens" + - "re:.*mlp.gate$" + - "re:.*self_attn.*" + quantization_config_path: "recipe/qat/config/nvfp4_w4a16.json" +``` + +--- + +## Implementation Overview + +### Module Structure + +``` +verl/utils/qat/ +├── __init__.py # Module entry point +├── core.py # Replaces nn.Linear → QATLinear, sets up scale fusion +├── linear.py # Defines QATLinear (fake quantization layer with Triton FP4 kernels) +├── quantizer.py # Packs weights to NVFP4 format for vLLM rollout +└── vllm_patch.py # vLLM dynamic weight loading patches +``` + +### QATLinear + +`QATLinear` is the core fake-quantized linear layer: + +- Inherits from `nn.Linear`, fully compatible with FSDP +- Uses high-performance Triton kernels for FP4 E2M1 quantization +- Supports STE (Straight-Through Estimator) for backpropagation + +``` +Forward: + 1. weight_fq = fake_quantize(weight) # Triton FP4 quantize → dequantize + 2. output = F.linear(x, weight_fq, bias) + +Backward: + 1. grad passes directly to original weight (STE) +``` + +### Scale Fusion + +To match vLLM's inference optimization, scales of related layers need to be fused: + +- **QKV Fusion**: `q_proj`, `k_proj`, `v_proj` share the same `global_scale` +- **Gate/Up Fusion**: `gate_proj`, `up_proj` share the same `global_scale` + +Fusion is implemented via `_fusion_siblings_ref` weak references, automatically handled in `_fake_quantize_weight()`. + +### vLLM Dynamic Weight Loading Patches + +PPO training requires multiple weight updates to vLLM. Native vLLM does not support repeated loading of quantized weights. + +### Data Flow + +``` +┌─────────────────────────┐ ┌─────────────────────────┐ +│ Training Phase │ │ Rollout Phase │ +│ (FSDP + QATLinear) │ │ (vLLM + NVFP4) │ +├─────────────────────────┤ ├─────────────────────────┤ +│ • forward: fake_quant │ │ 1. Get FSDP full params │ +│ • backward: STE │ │ 2. QATQuantizer quant │ +│ • optimizer.step() │ ──► │ 3. load_weights() │ +│ │ │ 4. vLLM Marlin inference│ +└─────────────────────────┘ └─────────────────────────┘ +``` + +--- + +## Experimental Results + +All experiments were conducted on B300. + +### Experiment 1: Qwen3-8B-Base QAT Comparison + +Comparing W4A16 quantized training with and without QAT: + +| Config | Description | Color | +|--------|-------------|-------| +| BF16 | Baseline, full precision training | Brown | +| W4A16 (no QAT) | Directly quantize BF16 weights and send to vLLM | Purple | +| W4A16 + QAT | Use Fake Quantization during training | Orange | + +**Conclusions**: +- Without QAT, `rollout_corr/kl` is two orders of magnitude higher than BF16, grows rapidly during training, and eventually crashes +- With QAT, KL divergence remains consistent with BF16 + + + +### Experiment 2: Qwen3-8B-Base Quantization Strategy Comparison + +Comparing different quantization strategies: + +| Config | Description | Color | +|--------|-------------|-------| +| BF16 | Baseline | Brown | +| W4A16 + QAT (Full) | Full quantization | Orange | +| W4A16 + QAT (FFN-only) | FFN-only quantization | Red | + +**Conclusions**: +- Online evaluation (first figure): During training, each configuration uses its own rollout precision (BF16 rollout for BF16, W4A16 rollout for W4A16), so the online metrics are not directly comparable across precisions. Under this setting, BF16 > FFN-only(W4A16) > Full(W4A16). +- Offline evaluation (second figure, AIME24/25): When all trained checkpoints are evaluated uniformly using BF16 precision, all three configurations achieve similar accuracy, indicating that QAT training does not degrade the model's inherent capability. + + + + +### Experiment 3: Qwen3-30B-A3B-Base QAT Validation + +Validating QAT effectiveness on larger models. Results are consistent with the 8B experiments. + +| Config | Color | +|--------|-------| +| BF16 | Brown | +| W4A16 + QAT (Full) | Orange | +| W4A16 + QAT (FFN-only) | Red | + + + +**Memory Analysis** + +Analyzed memory impact of NVFP4 during Rollout phase for Qwen3-30B-A3B-Base. + +Config: vLLM rollout settings with `gpu_memory_utilization=0.90`, `max_num_batched_tokens=32768`, `max_num_seqs=256`, `TP=1`. + +| Metric | BF16 | W4A16 + QAT (Full) | Change | +|--------|------|---------------------|--------| +| Weight | 56.88 GiB | 16.89 GiB | -39.99 GiB (↓70.3%) | +| KV Cache | 181.26 GiB | 221.26 GiB | +40.00 GiB (↑22.1%) | +| Peak Activation | 2.64 GiB | 2.64 GiB | - | +| Non-torch Memory | 0.14 GiB | 0.14 GiB | - | +| CUDAGraph Memory | -1.34 GiB | -1.16 GiB | +0.18 GiB | + +**Conclusion**: NVFP4 W4A16 reduces weight memory by **70.3%** (from 56.88 GiB to 16.89 GiB), freeing up ~40 GiB for additional KV Cache capacity. + +--- + +## Future Improvements + +- **W4A4 Mode**: W4A4 logic is included in the code, but currently has KL divergence issues and is not usable +- **Large-scale Models**: Due to VeRL FSDP design limitations, QAT experiments on Qwen-235B and other very large models are not yet supported diff --git a/ICL/DAPO/verl-recipe/qat/config/dapo_qat_trainer.yaml b/ICL/DAPO/verl-recipe/qat/config/dapo_qat_trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b10cf431138396ddc4d64b1a6542aa924172dc0a --- /dev/null +++ b/ICL/DAPO/verl-recipe/qat/config/dapo_qat_trainer.yaml @@ -0,0 +1,66 @@ +# DAPO trainer config with QAT (Quantization-Aware Training) support +# This config extends dapo_trainer.yaml with QAT-specific settings +# +# QAT Modes: +# - w4a16: Weight-only 4-bit quantization (NVFP4) +# - w4a4: Weight + Activation 4-bit quantization (NVFP4) +# +# Usage: +# python -m verl.trainer.main_ppo \ +# --config-path recipe/qat/config \ +# --config-name dapo_qat_trainer \ +# actor_rollout_ref.actor.qat.mode=w4a16 \ +# actor_rollout_ref.actor.qat.quantization_config_path=recipe/qat/config/nvfp4_w4a16.json + +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + gen_batch_size: ${data.train_batch_size} + +reward_model: + reward_manager: dapo + overlong_buffer: + enable: False + len: 0 + penalty_factor: 0.0 + log: False + +algorithm: + filter_groups: + _target_: verl.trainer.config.FilterGroupsConfig + enable: False + metric: null + max_num_gen_batches: 0 + +trainer: + project_name: verl-dapo-qat + +# ============================================================ +# QAT Configuration (overrides defaults in dp_actor.yaml) +# ============================================================ +actor_rollout_ref: + actor: + qat: + # Enable QAT + enable: true + # Quantization mode: "w4a16" (weight-only) or "w4a4" (weight + activation) + mode: "w4a16" + # Quantization group size (NVFP4 requires 16) + group_size: 16 + # Layers to skip (not quantized) + ignore_patterns: + - "lm_head" + - "embed_tokens" + - "re:.*mlp.gate$" + # Activation observer for W4A4: "static_minmax", "memoryless_minmax", "minmax" + activation_observer: "static_minmax" + # vLLM quantization config JSON path (relative to project root) + # W4A16: recipe/qat/config/nvfp4_w4a16.json + # W4A4: recipe/qat/config/nvfp4_w4a4.json + quantization_config_path: null # Specify in run script diff --git a/ICL/DAPO/verl-recipe/qat/run_qwen3_30b_w4a16.sh b/ICL/DAPO/verl-recipe/qat/run_qwen3_30b_w4a16.sh new file mode 100644 index 0000000000000000000000000000000000000000..fb6a51ffc98452702ccf3e0166f7dbb8f88cb0d8 --- /dev/null +++ b/ICL/DAPO/verl-recipe/qat/run_qwen3_30b_w4a16.sh @@ -0,0 +1,169 @@ +#!/usr/bin/env bash +# NVFP4 QAT W4A16 (Weight-only FP4) for Qwen3-30B-A3B-Base +set -euxo pipefail +current_dir="$(dirname "$(readlink -f "$0")")" + +project_name='DAPO-NVFP4-QAT' +exp_name=${exp_name:-'DAPO-Qwen3-30B-A3B-W4A16'} + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +# Rollout Correction parameters (for quantized rollout) +rollout_is=token # token-level importance sampling +rollout_is_threshold=2.0 +rollout_rs=null # response-level sampling (null = disabled) +rollout_rs_threshold=null +# rollout_rs_threshold_lower and rollout_token_veto_threshold removed in new verl version + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=512 +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=32 +gen_prompt_bsz=$((train_prompt_bsz * 2)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-1} + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen3-30B-A3B-Base"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# QAT Configuration +qat_enable=True +qat_mode=w4a16 # w4a16 for weight-only FP4 +qat_config_path="${qat_config_path:-"${WORKING_DIR}/recipe/qat/config/nvfp4_w4a16.json"}" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=1.0 + +# Performance +sp_size=1 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=True +gen_tp=1 + +export VERL_LOGGING_LEVEL=DEBUG +export VERL_PPO_LOGGING_LEVEL=DEBUG +export VLLM_LOGGING_LEVEL=DEBUG +export VLLM_CONFIGURE_LOGGING=1 +export VLLM_USE_V1=1 +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export TORCH_DIST_TIMEOUT=4000 + +RAY_ADDRESS='http://127.0.0.1:8265' ray job submit --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m verl.trainer.main_ppo \ + --config-path "${WORKING_DIR}/recipe/qat/config" \ + --config-name dapo_qat_trainer \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.rollout_correction.rollout_is=${rollout_is} \ + algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \ + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.90 \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.qat.enable=${qat_enable} \ + actor_rollout_ref.actor.qat.mode=${qat_mode} \ + actor_rollout_ref.actor.qat.quantization_config_path=${qat_config_path} \ + 'actor_rollout_ref.actor.qat.ignore_patterns=["lm_head", "embed_tokens", "re:.*mlp.gate$"]' \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$(( 1024 * 32 )) \ + actor_rollout_ref.rollout.max_num_seqs=256 \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=5 \ + trainer.save_freq=5 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 diff --git a/ICL/DAPO/verl-recipe/qat/run_qwen3_30b_w4a16_FFN_only.sh b/ICL/DAPO/verl-recipe/qat/run_qwen3_30b_w4a16_FFN_only.sh new file mode 100644 index 0000000000000000000000000000000000000000..76da5b43ad085c1cf2bcff028ce93269acaad822 --- /dev/null +++ b/ICL/DAPO/verl-recipe/qat/run_qwen3_30b_w4a16_FFN_only.sh @@ -0,0 +1,169 @@ +#!/usr/bin/env bash +# NVFP4 QAT W4A16 (Weight-only FP4) for Qwen3-30B-A3B-Base +set -euxo pipefail +current_dir="$(dirname "$(readlink -f "$0")")" + +project_name='DAPO-NVFP4-QAT' +exp_name=${exp_name:-'DAPO-Qwen3-30B-A3B-W4A16-FFN-ONLY'} + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +# Rollout Correction parameters (for quantized rollout) +rollout_is=token # token-level importance sampling +rollout_is_threshold=2.0 +rollout_rs=null # response-level sampling (null = disabled) +rollout_rs_threshold=null +# rollout_rs_threshold_lower and rollout_token_veto_threshold removed in new verl version + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024)) +max_response_length=$((1024 * 20)) +enable_overlong_buffer=True +overlong_buffer_len=512 +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=True +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=32 +gen_prompt_bsz=$((train_prompt_bsz * 2)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +WORKING_DIR=${WORKING_DIR:-"${PWD}"} +RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-1} + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen3-30B-A3B-Base"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# QAT Configuration +qat_enable=True +qat_mode=w4a16 # w4a16 for weight-only FP4 +qat_config_path="${qat_config_path:-"${WORKING_DIR}/recipe/qat/config/nvfp4_w4a16.json"}" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=1.0 + +# Performance +sp_size=1 +use_dynamic_bsz=True +actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) +infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) +offload=True +gen_tp=1 + +export VERL_LOGGING_LEVEL=DEBUG +export VERL_PPO_LOGGING_LEVEL=DEBUG +export VLLM_LOGGING_LEVEL=DEBUG +export VLLM_CONFIGURE_LOGGING=1 +export VLLM_USE_V1=1 +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export TORCH_DIST_TIMEOUT=4000 + +RAY_ADDRESS='http://127.0.0.1:8265' ray job submit --runtime-env="${RUNTIME_ENV}" \ + --working-dir "${WORKING_DIR}" \ + -- python3 -m verl.trainer.main_ppo \ + --config-path "${WORKING_DIR}/recipe/qat/config" \ + --config-name dapo_qat_trainer \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.rollout_correction.rollout_is=${rollout_is} \ + algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \ + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.90 \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.qat.enable=${qat_enable} \ + actor_rollout_ref.actor.qat.mode=${qat_mode} \ + actor_rollout_ref.actor.qat.quantization_config_path=${qat_config_path} \ + 'actor_rollout_ref.actor.qat.ignore_patterns=["lm_head", "embed_tokens", "re:.*mlp.gate$", "re:.*self_attn.*"]' \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$(( 1024 * 32 )) \ + actor_rollout_ref.rollout.max_num_seqs=256 \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k="${top_k}" \ + actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + reward_model.reward_manager=dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger='["console","wandb"]' \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=5 \ + trainer.save_freq=5 \ + trainer.total_epochs=1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 diff --git a/ICL/DAPO/verl-recipe/r1/README.md b/ICL/DAPO/verl-recipe/r1/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ddd23bcc3abe7560c50af6082a2a2bdb6601fe39 --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1/README.md @@ -0,0 +1,26 @@ +# DeepSeek R1 Reproduction + +This recipe is under development, if you are interested, checkout the TODO list and join this project! https://github.com/volcengine/verl/issues/708 + +## Reproducing Evaluation + +Eval Results of DS-R1-Distill-Qwen2.5-1.5B (k=8) + +Dataset | Test Results | Reported +-- | -- | -- +GPQA Diamond | 35.3 | 33.8 +LiveCodeBench | 16.9 | 16.9 +AIME 2024 | 30.4 | 28.9 +CNMO 2024 (en) | 45.1 | - +CNMO 2024 (zh) | 41.0 | - + +--- + +Eval Results (DS-R1) + +Dataset | Test Results (k=1) | Test Results (k=4) | Reported +-- | -- | -- | -- +GPQA Diamond | 67.7 | 69.6 | 71.5 +LiveCodeBench | 64.7 | 63.1 | 65.9 +AIME 2024 | 86.7 | 79.2 | 79.8 +CNMO 2024 | 75.0 | 78.5 | 78.8 diff --git a/ICL/DAPO/verl-recipe/r1/__init__.py b/ICL/DAPO/verl-recipe/r1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ICL/DAPO/verl-recipe/r1/data_process.py b/ICL/DAPO/verl-recipe/r1/data_process.py new file mode 100644 index 0000000000000000000000000000000000000000..fb41c814371aa21e4f08af449b43c0a4e5753634 --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1/data_process.py @@ -0,0 +1,203 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the dataset to parquet format +""" + +import argparse +import os +from functools import partial + +from datasets import concatenate_datasets, load_dataset + +from verl.utils.hdfs_io import copy, makedirs + + +def example_map_fn(example, idx, process_fn, data_source, ability, split): + question, solution = process_fn(example) + data = { + "data_source": data_source, + "prompt": [{"role": "user", "content": question}], + "ability": ability, + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": {"split": split, "index": idx}, + } + return data + + +def build_aime2024_dataset(): + def process_aime2024(example): + return example["Problem"], str(example["Answer"]) + + data_source = "Maxwell-Jia/AIME_2024" + print(f"Loading the {data_source} dataset from huggingface...", flush=True) + dataset = load_dataset(data_source, split="train") + map_fn = partial( + example_map_fn, process_fn=process_aime2024, data_source=data_source, ability="English", split="test" + ) + dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) + return dataset + + +def build_gpqa_dimond_dataset(): + import random + + GPQA_QUERY_TEMPLATE = ( + "Answer the following multiple choice question. The last line of your response should be of the following " + "format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before " + "answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}" + ) + + def process_gpqa_diamond(example): + choices = [example["Incorrect Answer 1"], example["Incorrect Answer 2"], example["Incorrect Answer 3"]] + random.shuffle(choices) + gold_index = random.randint(0, 3) + choices.insert(gold_index, example["Correct Answer"]) + query_prompt = GPQA_QUERY_TEMPLATE.format( + A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=example["Question"] + ) + gold_choice = "ABCD"[gold_index] + return query_prompt, gold_choice + + data_source = "Idavidrein/gpqa" + print(f"Loading the {data_source} dataset from huggingface...", flush=True) + + dataset = load_dataset(data_source, "gpqa_diamond", split="train") + map_fn = partial( + example_map_fn, process_fn=process_gpqa_diamond, data_source=data_source, ability="Math", split="test" + ) + dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) + return dataset + + +def build_cnmo2024_dataset(): + def process_cnmo2024(example): + return example["question"], example["answer"] + + data_source = "opencompass/LiveMathBench" + print(f"Loading the {data_source} dataset from huggingface...", flush=True) + + dataset_en = load_dataset(data_source, "v202412_CNMO_en", split="test") + map_fn_en = partial( + example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_en", ability="Math", split="test" + ) + dataset_en = dataset_en.map(map_fn_en, with_indices=True, remove_columns=dataset_en.column_names) + + dataset_zh = load_dataset(data_source, "v202412_CNMO_cn", split="test") + map_fn_zh = partial( + example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_zh", ability="Math", split="test" + ) + dataset_zh = dataset_zh.map(map_fn_zh, with_indices=True, remove_columns=dataset_zh.column_names) + + dataset = concatenate_datasets([dataset_en, dataset_zh]) + return dataset + + +def build_livecodebench_dataset(): + import base64 + import json + import pickle + import zlib + + def process_livecodebench(example): + # Construct Query Prompt + # From https://github.com/LiveCodeBench/LiveCodeBench/blob/998c52d394b836f15fff3b9a29866191108ff81b/lcb_runner/prompts/code_generation.py#L140 + query_prompt = ( + f"You will be given a question (problem specification) and will generate a correct Python program " + f"that matches the specification and passes all tests.\n\nQuestion: {example['question_content']}\n\n" + ) + if example["starter_code"]: + query_prompt += ( + f"You will use the following starter code to write the solution to the problem and enclose your " + f"code within delimiters.\n```python\n{example['starter_code']}\n```" + ) + else: + query_prompt += ( + "Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test " + "on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python " + "program runs, it reads the inputs, runs the algorithm and writes output to STDOUT." + "```python\n# YOUR CODE HERE\n```" + ) + + # Construct test cases + public_test_cases = json.loads(example["public_test_cases"]) + try: + private_test_cases = json.loads(example["private_test_cases"]) + except Exception as e: + print(f"Error loading private test cases: {e}") + private_test_cases = json.loads( + pickle.loads(zlib.decompress(base64.b64decode(example["private_test_cases"].encode("utf-8")))) + ) + full_test_cases = public_test_cases + private_test_cases + + metadata = json.loads(example["metadata"]) + test_cases = { + "inputs": [t["input"] for t in full_test_cases], + "outputs": [t["output"] for t in full_test_cases], + "fn_name": metadata.get("func_name", None), + } + text_cases_compressed = base64.b64encode(zlib.compress(pickle.dumps(json.dumps(test_cases)))).decode("utf-8") + return query_prompt, text_cases_compressed + + data_source = "livecodebench/code_generation_lite" + print(f"Loading the {data_source} dataset from huggingface...", flush=True) + dataset = load_dataset(data_source, split="test") + # R1 Evaluation use LiveCodeBench 24.08-25.01 + dataset = dataset.filter(lambda line: "2024-08-00T00:00:00" <= line["contest_date"] < "2025-01-00T00:00:00") + map_fn = partial( + example_map_fn, process_fn=process_livecodebench, data_source=data_source, ability="Code", split="test" + ) + + dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names, num_proc=8) + return dataset + + +TASK2DATA = { + "aime2024": build_aime2024_dataset, + "gpqa_diamond": build_gpqa_dimond_dataset, + "cnmo2024": build_cnmo2024_dataset, + "livecodebench": build_livecodebench_dataset, +} +SUPPORTED_TASKS = TASK2DATA.keys() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default="~/data/r1") + parser.add_argument("--hdfs_dir", default=None) + parser.add_argument("--tasks", default="all") + + args = parser.parse_args() + + if args.tasks.lower() == "all": + args.tasks = SUPPORTED_TASKS + else: + args.tasks = [task.strip() for task in args.tasks.split(",") if task.strip()] + for task in args.tasks: + if task not in SUPPORTED_TASKS: + raise NotImplementedError(f"{task} has not been supported.") + + datasets = [] + for task in args.tasks: + datasets.append(TASK2DATA[task]()) + test_dataset = concatenate_datasets(datasets) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + + copy(src=local_dir, dst=hdfs_dir) diff --git a/ICL/DAPO/verl-recipe/r1/main_eval.py b/ICL/DAPO/verl-recipe/r1/main_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..5c0e735a1a582071bde0a9eaf6681085c8b4272c --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1/main_eval.py @@ -0,0 +1,81 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Offline evaluate the performance of a generated file using reward model and ground truth verifier. +The input is a parquet file that contains N generated sequences and (optional) the ground truth. + +""" + +from collections import defaultdict + +import hydra +import numpy as np +import pandas as pd +import ray +from omegaconf import OmegaConf +from tqdm import tqdm + +from verl.trainer.ppo.reward import get_custom_reward_fn +from verl.utils.fs import copy_to_local + + +@ray.remote +def process_item(config, data_source, response_lst, reward_data): + reward_fn = get_custom_reward_fn(config) + ground_truth = reward_data["ground_truth"] + score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst] + return data_source, np.mean(score_lst) + + +@hydra.main(config_path="config", config_name="evaluation", version_base=None) +def main(config): + local_path = copy_to_local(config.data.path) + dataset = pd.read_parquet(local_path) + responses = dataset[config.data.response_key] + data_sources = dataset[config.data.data_source_key] + reward_model_data = dataset[config.data.reward_model_key] + + total = len(dataset) + + # Initialize Ray + if not ray.is_initialized(): + ray.init(**OmegaConf.to_container(config.ray_kwargs.get("ray_init", {}))) + + # evaluate test_score based on data source + data_source_reward = defaultdict(list) + + # Create remote tasks + remote_tasks = [ + process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) for i in range(total) + ] + + # Process results as they come in + with tqdm(total=total) as pbar: + while len(remote_tasks) > 0: + # Use ray.wait to get completed tasks + done_ids, remote_tasks = ray.wait(remote_tasks) + for result_id in done_ids: + data_source, score = ray.get(result_id) + data_source_reward[data_source].append(score) + pbar.update(1) + + metric_dict = {} + for data_source, rewards in data_source_reward.items(): + metric_dict[f"test_score/{data_source}"] = np.mean(rewards) + + print(metric_dict) + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/r1/reward_score.py b/ICL/DAPO/verl-recipe/r1/reward_score.py new file mode 100644 index 0000000000000000000000000000000000000000..9aeced911412327bb36dc65b159e0db5222a59b2 --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1/reward_score.py @@ -0,0 +1,30 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def reward_func(data_source, solution_str, ground_truth, extra_info=None): + if data_source in ["Maxwell-Jia/AIME_2024", "opencompass/cnmo2024_en", "opencompass/cnmo2024_zh"]: + from recipe.r1.tasks import math_reward + + return math_reward.compute_score(solution_str, ground_truth) + elif data_source == "Idavidrein/gpqa": + from recipe.r1.tasks import gpqa + + return gpqa.compute_score(solution_str, ground_truth) + elif data_source in ["livecodebench/code_generation_lite", "livecodebench/code_generation"]: + from recipe.r1.tasks import livecodebench + + return livecodebench.compute_score(solution_str, ground_truth) + else: + raise NotImplementedError diff --git a/ICL/DAPO/verl-recipe/r1/tasks/__init__.py b/ICL/DAPO/verl-recipe/r1/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce90c5eb352d85c59105c0dc85b5f1dd576f095 --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1/tasks/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ICL/DAPO/verl-recipe/r1/tasks/gpqa.py b/ICL/DAPO/verl-recipe/r1/tasks/gpqa.py new file mode 100644 index 0000000000000000000000000000000000000000..65b37e91662923f2e1acef29297213addfcc50f3 --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1/tasks/gpqa.py @@ -0,0 +1,25 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +# Extraction Template from https://github.com/openai/simple-evals/blob/90e3e821cabba2aeb6be651dcb662b253df04225/common.py#L25 +ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?" + + +def compute_score(solution_str, ground_truth) -> float: + match = re.search(ANSWER_PATTERN_MULTICHOICE, solution_str) + extracted_answer = match.group(1) if match else None + score = 1.0 if extracted_answer == ground_truth else 0.0 + return score diff --git a/ICL/DAPO/verl-recipe/r1/tasks/livecodebench.py b/ICL/DAPO/verl-recipe/r1/tasks/livecodebench.py new file mode 100644 index 0000000000000000000000000000000000000000..f0cbab681d7ee1b2a830b879b528a4170dc1faf7 --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1/tasks/livecodebench.py @@ -0,0 +1,72 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import multiprocessing +import pickle +import zlib + +# Reuse `run_test` for convenience +from verl.utils.reward_score.prime_code.testing_util import run_test + + +def _temp_run(in_outs, generation, debug, result, metadata_list, timeout): + res, metadata = run_test(in_outs, test=generation, debug=debug, timeout=timeout) + result.append(res) + metadata_list.append(metadata) + + +def check_correctness(in_outs, generation, timeout, debug=True): + """Check correctness of code generation with a global timeout. + The global timeout is to catch some extreme/rare cases not handled by the timeouts + inside `run_test`""" + + manager = multiprocessing.Manager() + result = manager.list() + metadata_list = manager.list() + p = multiprocessing.Process( + target=_temp_run, + args=(in_outs, generation, debug, result, metadata_list, timeout), + ) + p.start() + p.join(timeout=(timeout + 1) * len(in_outs["inputs"]) + 5) + if p.is_alive(): + p.kill() + if not result: + # consider that all tests failed + result = [[-1 for i in range(len(in_outs["inputs"]))]] + if debug: + print("global timeout") + return result[0], metadata_list[0] + + +def compute_score(completion, test_cases): + solution = completion.split("```python")[-1].split("```")[0] + + # extract test cases + try: + in_outs = json.loads(test_cases) + except Exception as e: + print(f"Error loading test cases: {e}") + in_outs = json.loads(pickle.loads(zlib.decompress(base64.b64decode(test_cases.encode("utf-8"))))) + + success = False + try: + res, metadata = check_correctness(in_outs=in_outs, generation=solution, timeout=6, debug=False) + success = all(map(lambda x: x is True, res)) + except Exception: + pass + + return success diff --git a/ICL/DAPO/verl-recipe/r1_ascend/deepscaler.py b/ICL/DAPO/verl-recipe/r1_ascend/deepscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..07d43346092b1951c355b99c3e2abc94e92fd4ae --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1_ascend/deepscaler.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from mathruler.grader import extract_boxed_content, grade_answer + + +def compute_score(data_source, solution_str, ground_truth, extra_info=None): + solution_str = solution_str.strip() + pattern = re.compile(r".*.*.*\\boxed\{.*\}.*", re.DOTALL) + format_match = re.fullmatch(pattern, solution_str) + score = 0.0 + if format_match: + score += 0.33 + + extract_output = extract_boxed_content(solution_str) + if grade_answer(extract_output, ground_truth): + score += 0.67 + + return score diff --git a/ICL/DAPO/verl-recipe/r1_ascend/engine_core.py b/ICL/DAPO/verl-recipe/r1_ascend/engine_core.py new file mode 100644 index 0000000000000000000000000000000000000000..7936390eeb2a94382397696c6adcdf10cc8c36bc --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1_ascend/engine_core.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/v1/engine/core.py + +import logging +import os +import time + +from vllm.config import VllmConfig +from vllm.v1.core.kv_cache_utils import get_kv_cache_config, unify_kv_cache_configs +from vllm.v1.engine.core import EngineCore +from vllm.v1.kv_cache_interface import KVCacheConfig + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def _initialize_kv_caches(self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: + start = time.time() + + # Get all kv cache needed by the model + kv_cache_specs = self.model_executor.get_kv_cache_specs() + + # Profiles the peak memory usage of the model to determine how much + # memory can be allocated for kv cache. + available_gpu_memory = self.model_executor.determine_available_memory() + + assert len(kv_cache_specs) == len(available_gpu_memory) + # Get the kv cache tensor size + self.kv_cache_configs = [ + get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, available_gpu_memory_one_worker) + for kv_cache_spec_one_worker, available_gpu_memory_one_worker in zip( + kv_cache_specs, available_gpu_memory, strict=False + ) + ] + + # Since we use a shared centralized controller, we need the + # `kv_cache_config` to be consistent across all workers to make sure + # all the memory operators can be applied to all workers. + unify_kv_cache_configs(self.kv_cache_configs) + + # All workers have the same kv_cache_config except layer names, so use + # an arbitrary one to initialize the scheduler. + assert all([cfg.num_blocks == self.kv_cache_configs[0].num_blocks for cfg in self.kv_cache_configs]) + num_gpu_blocks = self.kv_cache_configs[0].num_blocks + num_cpu_blocks = 0 + scheduler_kv_cache_config = self.kv_cache_configs[0] + + # Initialize kv cache and warmup the execution + self.model_executor.initialize_from_config(self.kv_cache_configs) + + elapsed = time.time() - start + logger.info(("init engine (profile, create kv cache, warmup model) took %.2f seconds"), elapsed) + return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config + + +EngineCore._initialize_kv_caches = _initialize_kv_caches diff --git a/ICL/DAPO/verl-recipe/r1_ascend/json_to_parquet.py b/ICL/DAPO/verl-recipe/r1_ascend/json_to_parquet.py new file mode 100644 index 0000000000000000000000000000000000000000..84a54a6af0d0da4aab81ffbb9d8a8c98a985306c --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1_ascend/json_to_parquet.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import os +import random + +import pandas as pd + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="deepscaler.json to parquet file", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--output_dir", type=str, required=True, help="output dir for train/test parquet file") + parser.add_argument("--json_path", type=str, default="./deepscaler.json", help="path of deepscaler.json") + parser.add_argument("--train_data_ratio", type=float, default=0.9, help="ratio of train data") + parser.add_argument("--seed", type=int, default=42, help="random seed") + return parser.parse_args() + + +def validate_arguments(args): + if not os.path.exists(args.json_path): + raise FileNotFoundError(f"File not found: {args.json_path}") + if not 0 < args.train_data_ratio < 1: + raise ValueError("Train data ratio should be between 0 and 1") + os.makedirs(args.output_dir, exist_ok=True) + + +def convert_json_to_parquet(json_path, train_data_ratio, output_dir, seed): + random.seed(seed) + + with open(json_path, encoding="utf-8") as f: + original_data = json.load(f) + + converted_data = [] + for item in original_data: + r1_template = ( + "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. " + "The assistant first thinks about the reasoning process in the mind " + "and then provides the user with the answer. " + "The reasoning process and answer are enclosed within " + "and tags, respectively, i.e., " + " reasoning process here answer here . " + "Put your final answer within \\\\boxed{}. " + ) + converted_item = { + "data_source": "deepscaler", + "prompt": [{"content": r1_template, "role": "system"}, {"content": item["problem"], "role": "user"}], + "ability": "math", + "reward_model": {"ground_truth": item["answer"], "style": "rule"}, + "extra_info": {"answer": item["solution"]}, + } + converted_data.append(converted_item) + + split_index = int(len(converted_data) * train_data_ratio) + train_data = converted_data[:split_index] + test_data = converted_data[split_index:] + + for item in train_data: + item["split"] = "train" + for item in test_data: + item["split"] = "test" + all_data = train_data + test_data + df = pd.DataFrame(all_data) + train_df = df[df["split"] == "train"] + test_df = df[df["split"] == "test"] + del train_df["split"] + del test_df["split"] + + train_df.to_parquet(os.path.join(output_dir, "train.parquet"), engine="pyarrow", index=False) + test_df.to_parquet(os.path.join(output_dir, "test.parquet"), engine="pyarrow", index=False) + logger.info( + f"Json to parquet success! Total num {len(all_data)}, train num {len(train_data)}, test num {len(test_data)}", + flush=True, + ) + + +def main(): + try: + args = parse_arguments() + validate_arguments(args) + convert_json_to_parquet(args.json_path, args.train_data_ratio, args.output_dir, args.seed) + except Exception as e: + logger.error(f"[ERROR]: {e}") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/r1_ascend/main_ppo.py b/ICL/DAPO/verl-recipe/r1_ascend/main_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..12ac47d233ee226b983b49b1c188c464a5211b93 --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1_ascend/main_ppo.py @@ -0,0 +1,151 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted from https://github.com/volcengine/verl/blob/main/verl/trainer/main_ppo.py +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import logging +import os + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.main_ppo import TaskRunner as TaskRunnerBase +from verl.utils.device import auto_set_device, is_cuda_available + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) +def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config_dict: Hydra configuration dictionary containing training parameters. + """ + # Automatically set `config.trainer.device = npu` when running on Ascend NPU. + auto_set_device(config) + + run_ppo(config) + + +# Define a function to run the PPO-like training process +def run_ppo(config) -> None: + """Initialize Ray cluster and run distributed PPO training process. + + Args: + config: Training configuration object containing all necessary parameters + for distributed PPO training including Ray initialization settings, + model paths, and training hyperparameters. + """ + # Check if Ray is not initialized + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + default_runtime_env = get_ppo_ray_runtime_env() + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + logger.info(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if ( + is_cuda_available + and config.global_profiler.tool == "nsys" + and config.global_profiler.get("steps") is not None + and len(config.global_profiler.get("steps", [])) > 0 + ): + from verl.utils.import_utils import is_nvtx_available + + assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = ray.remote(num_cpus=1)(TaskRunner).remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +# please make sure main_task is not scheduled on head +class TaskRunner(TaskRunnerBase): + """Ray remote class for executing distributed PPO training tasks. + + This class encapsulates the main training logic and runs as a Ray remote actor + to enable distributed execution across multiple nodes and GPUs. + + Attributes: + role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes + mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation + """ + + def __init__(self): + self.role_worker_mapping = {} + self.mapping = {} + + def add_actor_rollout_worker(self, config): + """Add actor rollout worker based on the actor strategy.""" + from verl.single_controller.ray import RayWorkerGroup + + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + elif config.actor_rollout_ref.actor.strategy == "megatron": + from verl.workers.megatron_workers import AsyncActorRolloutRefWorker + + # NPU-ADAPTATION: Modify the Megatron worker entry point and rewrite some functions. + from .megatron_workers import ActorRolloutRefWorker + # NPU-ADAPTATION END + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import Role + + self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls) + + return actor_rollout_cls, ray_worker_group_cls + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/r1_ascend/megatron_workers.py b/ICL/DAPO/verl-recipe/r1_ascend/megatron_workers.py new file mode 100644 index 0000000000000000000000000000000000000000..4c1ca2901e16f3f57aabd99a7e7103bb59fe6cbd --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1_ascend/megatron_workers.py @@ -0,0 +1,55 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted from https://github.com/volcengine/verl/blob/main/verl/workers/megatron_workers.py +""" +The main entry point to run the PPO algorithm +""" + +import torch +from mindspeed.core.megatron_basic.requirements_basic import dummy_compile + +# NPU-ADAPTATION: Save the original and dummy copies of `torch.compile`. +from mindspeed.patch_utils import MindSpeedPatchesManager +from omegaconf import DictConfig + +from verl.workers.megatron_workers import ActorRolloutRefWorker as ARRWorker +from verl.workers.rollout import base + +MindSpeedPatchesManager.patches_info["torch.compile"].remove_patch() +TRUE_COMPILE = torch.compile +DUMMY_COMPILE = dummy_compile +# NPU-ADAPTATION END + + +base._ROLLOUT_REGISTRY[("vllm", "sync")] = "recipe.r1_ascend.vllm_rollout_spmd.vLLMRollout" + + +class ActorRolloutRefWorker(ARRWorker): + def __init__(self, config: DictConfig, role: str, **kwargs): + super().__init__(config, role) + + def _build_rollout(self, *args, **kwargs): + """ + Build the rollout with temporary reversion to true torch.compile. + """ + # Temporarily restore true torch.compile for the rollout build + torch.compile = TRUE_COMPILE + + # Call parent method with original torch.compile + super()._build_rollout(*args, **kwargs) + + # Revert to dummy_compile after rollout is built + torch.compile = DUMMY_COMPILE diff --git a/ICL/DAPO/verl-recipe/r1_ascend/run_deepseekv3_671b_grpo_megatron_npu.sh b/ICL/DAPO/verl-recipe/r1_ascend/run_deepseekv3_671b_grpo_megatron_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..44ac1e2e57a0af9a57ae6ee1518e4094f5044995 --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1_ascend/run_deepseekv3_671b_grpo_megatron_npu.sh @@ -0,0 +1,122 @@ +set -x + +project_name='GRPO' +exp_name='DeepSeekV3-671B-GRPO-Megatron-256rank-gbs512' + +NNODES=16 +NPUS_PER_NODE=16 + +adv_estimator=grpo + +kl_coef=0.0 +use_kl_loss=True +kl_loss_coef=0.001 + +max_prompt_length=$((1024 * 1)) +max_response_length=$((1024 * 2)) +max_num_batched_tokens=1024 +ppo_mini_batch_size=512 + +train_prompt_bsz=512 +n_resp_per_prompt=16 + +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +CONFIG_PATH=${CONFIG_PATH:-"${RAY_DATA_HOME}/verl/trainer/config"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/DeepSeek-V3-hf"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/DeepseekV3-dist-ckpts"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/deepscaler/train.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/deepscaler/test.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Performance Related Parameter +offload=True +max_num_seqs=64 +gen_tp=2 + +# Currently, it is necessary to enable `enable_chunked_prefill` in the script. +# However, in vLLM ascend, this configuration is off by default and does not take effect. +python3 -m recipe.r1_ascend.main_ppo \ + --config-path="${CONFIG_PATH}" \ + --config-name='ppo_megatron_trainer.yaml' \ + custom_reward_function.path=recipe/r1_ascend/deepscaler.py \ + custom_reward_function.name=compute_score \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.truncation='error' \ + data.filter_overlong_prompts=True \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.max_num_seqs=${max_num_seqs} \ + actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.megatron.expert_model_parallel_size=32 \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=1 \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=8 \ + actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ + actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=block \ + actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=8 \ + actor_rollout_ref.actor.load_weight=True \ + actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ + actor_rollout_ref.actor.megatron.dist_checkpointing_path="${CKPTS_DIR}" \ + actor_rollout_ref.actor.megatron.param_offload=${offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.load_weight=True \ + actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ + actor_rollout_ref.ref.megatron.dist_checkpointing_path="${CKPTS_DIR}" \ + actor_rollout_ref.ref.megatron.param_offload=${offload} \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.free_cache_engine=True \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${NPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=5 \ + trainer.save_freq=-1 \ + trainer.total_epochs=1 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.multi_head_latent_attention=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.use_flash_attn=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.pipeline_num_transformer_layers=[[6],[8],[8],[8],[8],[8],[8],[7]] \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type='alltoall' \ + +actor_rollout_ref.actor.megatron.override_transformer_config.moe_grouped_gemm=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.use_fused_rotary_pos_emb=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.rope_scaling_type='yarn' \ + +actor_rollout_ref.actor.megatron.override_transformer_config.yarn_scaling_factor=40 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.rope_scaling_mscale=1.0 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.rope_scaling_mscale_all_dim=1.0 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.use_fused_swiglu=True \ + +actor_rollout_ref.actor.megatron.override_transformer_config.seq_length=2048 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=6 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=7 \ + +actor_rollout_ref.actor.megatron.override_transformer_config.swap_optimizer=True $@ \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/r1_ascend/vllm_parallel_state.py b/ICL/DAPO/verl-recipe/r1_ascend/vllm_parallel_state.py new file mode 100644 index 0000000000000000000000000000000000000000..82eb41757c5a905313671abf4e1aad9e7f7e614e --- /dev/null +++ b/ICL/DAPO/verl-recipe/r1_ascend/vllm_parallel_state.py @@ -0,0 +1,109 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Adapted from https://gitcode.com/Ascend/MindSpeed-RL/blob/2.1.0/mindspeed_rl/utils/utils.py + +import logging +import os +import re +import socket +import subprocess + +import torch +import vllm.envs as envs +from vllm.distributed import parallel_state as vllm_ps +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def _get_ip_by_ifname(): + """ + Get IPv4 address by interface name (e.g. eth0, en0) + returns IP string on success, and None on failure + """ + try: + # Execute `ifconfig` and capture its output + ifname = os.environ.get("HCCL_SOCKET_IFNAME", 0) + if ifname: + output = subprocess.check_output(["ifconfig", ifname], stderr=subprocess.STDOUT).decode() + # Match IPv4 addresses using regex, and exclude 127.0.0.1 + matches = re.findall(r"inet (?:addr:)?((?:\d{1,3}\.){3}\d{1,3})", output) + for ip in matches: + if ip != "127.0.0.1": + return ip + return None + except subprocess.CalledProcessError: + return None + + +def _get_current_node_ip() -> str: + try: + # Create UDP socket (Only used to get info of interface). + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + # Connect to external address, without actual communication. + s.connect(("8.8.8.8", 80)) # Google DNS Server. + local_ip = s.getsockname()[0] + except Exception: + local_ip = _get_ip_by_ifname() + if not local_ip: + # Fallback to iterative search on failure. + local_ip = "127.0.0.1" + hostname = socket.gethostname() + for addr in socket.getaddrinfo(hostname, None): + ip = addr[4][0] + if not ip.startswith("::"): + local_ip = ip + break + return local_ip + + +def get_cluster_info(): + # Ensure initialization of distributed env. + if not torch.distributed.is_initialized(): + raise RuntimeError("Distributed environment not initialized") + + world_size = torch.distributed.get_world_size() + + # Get IP address of current node. + ip_address = _get_current_node_ip() + + # Collect IP addresses of all ranks. + ip_list = [None] * world_size + torch.distributed.all_gather_object(ip_list, ip_address) + + return ip_list + + +### init DP group ranks for vLLM ascend +def init_parallel_state(tensor_parallel_size): + rank = int(os.getenv("RANK", "-1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + world_size: int = torch.distributed.get_world_size() + distributed_init_method = "env://" + backend = "hccl" + init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) + + initialize_model_parallel(tensor_parallel_size) + logger.info( + f"[DEBUG]: RANK[{rank}]: TP group: {vllm_ps._TP.ranks}\n" + f"[DEBUG]: RANK[{rank}]: PP group: {vllm_ps._PP.ranks}\n" + f"[DEBUG]: RANK[{rank}]: DP group: {vllm_ps._DP.ranks}\n" + f"[DEBUG]: RANK[{rank}]: EP group: {vllm_ps._EP.ranks}\n" + ) + + os.environ["VLLM_DP_RANK"] = str(vllm_ps._DP.rank_in_group) + envs.VLLM_DP_RANK = int(os.environ["VLLM_DP_RANK"]) + + ip_list = get_cluster_info() + + rank_0 = vllm_ps._DP.ranks[0] + index = rank_0 + os.environ["VLLM_DP_MASTER_PORT"] = str(int(os.environ.get("MASTER_PORT")) + 1 + index) + os.environ["VLLM_DP_MASTER_IP"] = ip_list[rank_0] + envs.VLLM_DP_MASTER_PORT = int(os.environ["VLLM_DP_MASTER_PORT"]) + envs.VLLM_DP_MASTER_IP = os.environ["VLLM_DP_MASTER_IP"] diff --git a/ICL/DAPO/verl-recipe/rep_exp/reward_manager/elliptical_reward_manager.py b/ICL/DAPO/verl-recipe/rep_exp/reward_manager/elliptical_reward_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..83040ac1c02dcb39c9db312cc02a4ed14df13bee --- /dev/null +++ b/ICL/DAPO/verl-recipe/rep_exp/reward_manager/elliptical_reward_manager.py @@ -0,0 +1,138 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +import torch + +from verl import DataProto +from verl.workers.reward_manager import NaiveRewardManager, register + +from ..reward_score import default_compute_score + + +@register("elliptical") +class EllipticalRewardManager(NaiveRewardManager): + """The reward manager.""" + + def __init__( + self, + tokenizer, + num_examine, + compute_score=None, + reward_fn_key="data_source", + beta: int = 1.0, + turn_off_elliptical_if_none_correct: bool = False, + turn_off_elliptical_if_some_correct: bool = False, + turn_off_elliptical_if_all_correct: bool = False, + turn_off_elliptical_if_rollout_incorrect: bool = False, + alpha: float = 1.0, + ) -> None: + """ + Initialize the NaiveRewardManager instance. + + Args: + tokenizer: The tokenizer used to decode token IDs into text. + num_examine: The number of batches of decoded responses to print to the console for debugging purpose. + compute_score: A function to compute the reward score. If None, `default_compute_score` will be used. + reward_fn_key: The key used to access the data source in the non-tensor batch data. Defaults to + "data_source". + """ + super().__init__(tokenizer, num_examine, default_compute_score, reward_fn_key) + self.beta = beta + self.turn_off_elliptical_if_none_correct = turn_off_elliptical_if_none_correct + self.turn_off_elliptical_if_some_correct = turn_off_elliptical_if_some_correct + self.turn_off_elliptical_if_all_correct = turn_off_elliptical_if_all_correct + self.turn_off_elliptical_if_rollout_incorrect = turn_off_elliptical_if_rollout_incorrect + self.alpha = alpha + + def __call__(self, data: DataProto, return_dict=False): + if "rm_scores" not in data.batch: + # this means we're doing validation, so we don't need to compute the elliptical reward + return super().__call__(data, return_dict=return_dict) + + reward_extra_info = defaultdict(list) + + intrinsic_reward_tensor = data.batch["rm_scores"] + data.pop(batch_keys=["rm_scores"]) + + extrinsic_reward_result = super().__call__(data, return_dict=True) + extrinsic_reward_tensor = extrinsic_reward_result["reward_tensor"] + extrinsic_reward_extra_info = extrinsic_reward_result["reward_extra_info"] + + self._maybe_turn_off_elliptical(data, extrinsic_reward_tensor, intrinsic_reward_tensor) + + reward_tensor = self.alpha * extrinsic_reward_tensor + self.beta * intrinsic_reward_tensor + + # Intrinsic reward extra info + reward_extra_info["intrinsic_reward"] = intrinsic_reward_tensor.numpy() + reward_extra_info["beta_scaled_intrinsic_reward"] = self.beta * intrinsic_reward_tensor.numpy() + reward_extra_info["extrinsic_reward"] = extrinsic_reward_tensor.numpy() + reward_extra_info["alpha_scaled_extrinsic_reward"] = self.alpha * extrinsic_reward_tensor.numpy() + reward_extra_info["total_reward"] = reward_tensor.numpy() + + # Update with extrinsic reward extra info + reward_extra_info.update(extrinsic_reward_extra_info) + + if return_dict: + return { + "reward_tensor": reward_tensor, + "reward_extra_info": reward_extra_info, + } + else: + return reward_tensor + + def _maybe_turn_off_elliptical( + self, data: DataProto, extrinsic_reward_tensor: torch.Tensor, intrinsic_reward_tensor: torch.Tensor + ) -> None: + """ + Potentially turn off the elliptical reward for samples that have one of the following properties: + (1) any of the rollouts have the correct answer + (2) all of the rollouts have the correct answer + + Args: + data (DataProto): The data proto containing the batch data. + extrinsic_reward_tensor (torch.Tensor): The extrinsic reward tensor. + intrinsic_reward_tensor (torch.Tensor): The intrinsic reward tensor. + + Returns: + None + """ + if self.turn_off_elliptical_if_rollout_incorrect: + mask = extrinsic_reward_tensor.sum(dim=-1) == 0 + intrinsic_reward_tensor[mask] = 0.0 + + visited_uids = set() + for uid in data.non_tensor_batch["uid"]: + if uid in visited_uids: + continue + + visited_uids.add(uid) + mask = torch.from_numpy(data.non_tensor_batch["uid"] == uid) + + # Potentially turn off elliptical if **no** rollout has the correct answer + if self.turn_off_elliptical_if_none_correct and extrinsic_reward_tensor[mask].sum() == 0: + intrinsic_reward_tensor[mask] = 0.0 + + # Potentially turn off elliptical if **some** rollouts have the correct answer + if ( + self.turn_off_elliptical_if_some_correct + and extrinsic_reward_tensor[mask].sum() > 0 + and extrinsic_reward_tensor[mask].sum() < mask.sum() + ): + intrinsic_reward_tensor[mask] = 0.0 + + # Potentially turn off elliptical if **all** rollouts have the correct answer + if self.turn_off_elliptical_if_all_correct and extrinsic_reward_tensor[mask].sum() == mask.sum(): + intrinsic_reward_tensor[mask] = 0.0 diff --git a/ICL/DAPO/verl-recipe/rep_exp/train_elliptical.sh b/ICL/DAPO/verl-recipe/rep_exp/train_elliptical.sh new file mode 100644 index 0000000000000000000000000000000000000000..bf1a78f587e417b6700ea88d8de047fe962e7f1e --- /dev/null +++ b/ICL/DAPO/verl-recipe/rep_exp/train_elliptical.sh @@ -0,0 +1,104 @@ +TASK=${1} # math, gsm8k, dapo-with-aime24 +SPARSE_DIM=${2} # the original paper used 32 for math/gsm8k, 128 for dapo-with-aime24 +BETA=${3} # 0.01 +SEED=${4} + +train_path=$HOME/data/${TASK}/train.parquet +dev_path=$HOME/data/${TASK}/dev.parquet + +train_files="['$train_path']" +dev_files="['$dev_path']" + +# Adjust things a bit for dapo-aime training since it has longer generations +# and hence is slower and consumes more memory +if [ ${TASK} == "dapo-with-aime24" ]; then + TEST_FREQ=10 + SAVE_FREQ=10 + TRAIN_BATCH_SIZE=512 + PPO_MINI_BATCH_SIZE=128 + + MAX_PROMPT_LENGTH=$((1024 * 2)) + MAX_RESPONSE_LENGTH=$((1024 * 8)) + MAX_NUM_BATCHED_TOKENS=$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH)) + GPU_MEMORY_UTILIZATION=0.5 + PPO_MICRO_BATCH_SIZE_PER_GPU=8 + REWARD_MODEL_MICRO_BATCH_SIZE_PER_GPU=16 +else + TEST_FREQ=20 + SAVE_FREQ=20 + TRAIN_BATCH_SIZE=1024 + PPO_MINI_BATCH_SIZE=256 + + MAX_PROMPT_LENGTH=1024 + MAX_RESPONSE_LENGTH=1024 + MAX_NUM_BATCHED_TOKENS=8192 + GPU_MEMORY_UTILIZATION=0.6 + PPO_MICRO_BATCH_SIZE_PER_GPU=16 + REWARD_MODEL_MICRO_BATCH_SIZE_PER_GPU=32 +fi + +OFFLINE=True + +PYTHONUNBUFFERED=1 TRANSFORMERS_OFFLINE=${OFFLINE} python3 -u -m rep_exp.main_rep_exp \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_files" \ + data.val_files="$dev_files" \ + data.train_batch_size=$TRAIN_BATCH_SIZE \ + data.max_prompt_length=$MAX_PROMPT_LENGTH \ + data.max_response_length=$MAX_RESPONSE_LENGTH \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=$PPO_MINI_BATCH_SIZE \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$PPO_MICRO_BATCH_SIZE_PER_GPU \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.ppo_epochs=1 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.mode=sync \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.max_num_batched_tokens=$MAX_NUM_BATCHED_TOKENS \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=$GPU_MEMORY_UTILIZATION \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + reward_model.enable=True \ + reward_model.model.path=Qwen/Qwen2.5-7B-Instruct \ + reward_model.model.use_remove_padding=False \ + reward_model.model.fsdp_config.param_offload=True \ + reward_model.micro_batch_size_per_gpu=$REWARD_MODEL_MICRO_BATCH_SIZE_PER_GPU \ + reward_model.model.input_tokenizer=null \ + reward_model.elliptical.enable=True \ + reward_model.elliptical.sparse_dim=$SPARSE_DIM \ + reward_model.elliptical.reward_type=leverage \ + reward_model.elliptical.randomize_sparse_matrix=True \ + reward_model.elliptical.normalization=none \ + reward_model.elliptical.persist_covariance=False \ + reward_model.reward_manager=elliptical \ + reward_model.reward_kwargs.elliptical.beta=$BETA \ + reward_model.reward_kwargs.elliptical.turn_off_elliptical_if_none_correct=True \ + reward_model.reward_kwargs.elliptical.turn_off_elliptical_if_some_correct=False \ + reward_model.reward_kwargs.elliptical.turn_off_elliptical_if_all_correct=False \ + reward_model.reward_kwargs.elliptical.turn_off_elliptical_if_rollout_incorrect=False \ + actor_rollout_ref.actor.loss_agg_mode=token-mean \ + actor_rollout_ref.actor.use_kl_loss=True \ + algorithm.norm_adv_by_std_in_grpo=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='rep-exp' \ + trainer.experiment_name="${TASK}_elliptical_seed_${SEED}_beta_${BETA}_sparse_dim_${SPARSE_DIM}" \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=$SAVE_FREQ \ + trainer.test_freq=$TEST_FREQ \ + trainer.total_epochs=1000 \ + trainer.resume_mode=disable \ + trainer.resume_from_path='' \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/retool/README.md b/ICL/DAPO/verl-recipe/retool/README.md new file mode 100644 index 0000000000000000000000000000000000000000..dab7c67becccd408b1de84720709ae6452b4557d --- /dev/null +++ b/ICL/DAPO/verl-recipe/retool/README.md @@ -0,0 +1,62 @@ +# Retool +[ReTool: Reinforcement Learning for Strategic Tool Use in LLMs](https://arxiv.org/abs/2504.11536) + +## Overview +- Base model: [Qwen/Qwen2.5-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-32B-Instruct) +- SFT dataset: [JoeYing/ReTool-SFT](https://huggingface.co/datasets/JoeYing/ReTool-SFT) +- RL dataset: [BytedTsinghua-SIA/DAPO-Math-17k](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k) +- Val dataset: [yentinglin/aime_2025](https://huggingface.co/datasets/yentinglin/aime_2025) + +## How it works + + Retool's workflow is divided into two key phases: + +1. Cold Start and Supervised Fine Tuning (SFT) + + The data generation pipeline builds a high-quality dataset containing code-enhanced inference trajectories, and supervised fine-tuning enables the model to master basic Tool call (e.g., code execution) and analysis of the execution results. + +2. Dynamic Interaction and Policy Optimization (RL). + + With the verl Reinforcement Learning framework, the model dynamically inserts code blocks during inference and interacts with the sandbox environment in real-time, generating a hybrid trajectory of natural language thinking and code snippets, sending the code to the sandbox for asynchronous execution when code termination markers are detected, and the execution results (success outputs/errors) are fed back to the model for guiding the subsequent inference. This "think-execute-feedback" cycle, together with the design of rewards based on the accuracy of the final answer, enables the model to independently optimize the Tool call strategy, and improves the reasoning efficiency and computational accuracy. + +## Installation + +```bash +pip install verl==0.6.1 +``` + +## SFT +1. Data preparation +```bash +python3 recipe/retool/retool_sft_preprocess.py +``` + +2. Training +```bash +bash recipe/retool/run_qwen2-32b_sft.sh +``` + +After 6 epoches, validation metrics: +- val-core/aime_2025/acc/mean@30: 0.24 +- val-aux/num_turns/mean: 7.2 + +## RL + +### GRPO +```bash +bash recipe/retool/run_qwen2-32b_dapo.sh +``` + +After 150 steps, validation metrics: +- val-core/aime_2025/acc/mean@30: 0.6 +- val-aux/num_turns/mean: 10 + +### PPO + +```bash +bash recipe/retool/run_qwen2-32b_ppo.sh +``` + +After 250 steps, validation metrics: +- val-core/aime_2025/acc/mean@30: 0.55 +- val-aux/num_turns/mean: 8.3 diff --git a/ICL/DAPO/verl-recipe/retool/retool.py b/ICL/DAPO/verl-recipe/retool/retool.py new file mode 100644 index 0000000000000000000000000000000000000000..223b36399d0f3ab52a6914f5725b716fbd5c7a91 --- /dev/null +++ b/ICL/DAPO/verl-recipe/retool/retool.py @@ -0,0 +1,120 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re +from typing import Any + +import datasets + +from verl.tools.base_tool import OpenAIFunctionToolSchema +from verl.tools.sandbox_fusion_tools import SandboxFusionTool +from verl.utils.dataset import RLHFDataset +from verl.utils.reward_score import math_dapo +from verl.utils.rollout_trace import rollout_trace_op + +logger = logging.getLogger(__name__) + + +class CustomSandboxFusionTool(SandboxFusionTool): + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + super().__init__(config, tool_schema) + self.code_pattern = re.compile(r"```python(.*?)```", re.DOTALL) + + @rollout_trace_op + async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: + code = parameters["code"] + matches = self.code_pattern.findall(code) + if matches: + code = matches[0].strip() + + # NOTE: some script may not explicitly print result, we need to add a print statement to the end of the script + lines = code.split("\n") + for i, line in reversed(list(enumerate(lines))): + if line == "": + continue + if not lines[i].startswith("print"): + lines[i] = f"print({line})" + break + code = "\n".join(lines) + + timeout = parameters.get("timeout", self.default_timeout) + language = parameters.get("language", self.default_language) + if not isinstance(code, str): + code = str(code) + + result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language) + # sandbox has no score or metrics, use Nones + return result, None, None + + +answer_format = """\nThe answer format must be: \\boxed{'The final answer goes here.'}""" + + +class CustomRLHFDataset(RLHFDataset): + """Custom dataset class to process Maxwell-Jia/AIME_2024, yentinglin/aime_2025 datasets.""" + + def _read_files_and_tokenize(self): + dataframes = [] + for parquet_file in self.data_files: + # read parquet files and cache + dataframe = datasets.load_dataset(parquet_file)["train"] + data_source = "/".join(parquet_file.split("/")[-2:]) + if data_source in ["Maxwell-Jia/AIME_2024", "yentinglin/aime_2025"]: + dataframe = dataframe.map( + self.map_fn, fn_kwargs={"data_source": data_source}, remove_columns=dataframe.column_names + ) + else: + dataframe = dataframe.map(self.map_fn2, num_proc=16) + dataframes.append(dataframe) + self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) + + print(f"dataset len: {len(self.dataframe)}") + + def map_fn(self, row: dict, *, data_source: str = None): + if data_source == "Maxwell-Jia/AIME_2024": + problem, answer = row["Problem"], row["Answer"] + elif data_source == "yentinglin/aime_2025": + problem, answer = row["problem"], row["answer"] + + prompt = problem + answer_format + data = { + "data_source": data_source.split("/")[1].lower(), # aime_2024, aime_2025 + "prompt": [{"role": "user", "content": prompt}], + "ability": "MATH", + "reward_model": {"ground_truth": str(answer)}, + "agent_name": "tool_agent", + } + return data + + def map_fn2(self, row: dict): + content = row["prompt"][0]["content"] + row["prompt"][0]["content"] = content + answer_format + row["agent_name"] = "tool_agent" + return row + + +def compute_score(data_source, solution_str, ground_truth, extra_info, **kwargs): + # use \\boxed{...} answer + result = math_dapo.compute_score(solution_str, ground_truth, strict_box_verify=True) + + # encourage model to call tools + num_turns = extra_info["num_turns"] + if result["score"] < 0: + tool_call_reward = (num_turns - 2) / 2 * 0.1 + result["score"] = min(-0.6, result["score"] + tool_call_reward) + + if result["pred"] is None: + result["pred"] = "" + + return result diff --git a/ICL/DAPO/verl-recipe/retool/retool_sft_preprocess.py b/ICL/DAPO/verl-recipe/retool/retool_sft_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..57d53c91c6290675771ccd4932ffab142fb4335f --- /dev/null +++ b/ICL/DAPO/verl-recipe/retool/retool_sft_preprocess.py @@ -0,0 +1,136 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert JoeYing/ReTool-SFT to standard multi-turn tool calling messages. +""" + +import json +import os +import re +from typing import Any + +import datasets +from omegaconf import OmegaConf + +code_pattern = re.compile(r"```python(.*?)```", re.DOTALL) + + +def extract_code_message(content: str) -> tuple[dict[str, Any], str]: + start, stop = "", "" + i = content.find(start) + if i == -1: + return None, content + j = content.find(stop) + assert j > i + + code = content[i + len(start) : j] + matches = code_pattern.findall(code) + if matches: + code = matches[0].strip() + + message = { + "role": "assistant", + "content": content[:i].strip(), + "tool_calls": [ + { + "type": "function", + "function": { + "name": "code_interpreter", + "arguments": {"code": code}, + }, + }, + ], + } + return message, content[j + len(stop) :] + + +def extract_answer_message(content: str) -> tuple[dict[str, Any], str]: + start, stop = "", "" + i = content.find(start) + if i == -1: + return None, content + j = content.find(stop) + assert j > i + + answer = content[:i] + content[i + len(start) : j] + message = { + "role": "assistant", + "content": answer.strip(), + } + return message, content[j + len(stop) :] + + +def extract_interpreter_message(content: str) -> tuple[dict[str, Any], str]: + start, stop = "", "" + i = content.find(start) + if i == -1: + return None, content + j = content.find(stop) + assert j > i + + interpreter = content[i + len(start) : j] + message = { + "role": "tool", + "content": interpreter.strip(), + } + return message, content[j + len(stop) :] + + +def process(row: dict, *, tools: str): + messages = [] + + # extract problem + content = row["messages"][0]["content"] + start = "*user question:*" + i = content.find(start) + assert i != -1 + prompt = content[i + len(start) :].replace("", "").replace("", "").strip() + messages.append( + { + "role": "user", + "content": prompt, + } + ) + + # extract multi turns + content = row["messages"][1]["content"] + role = "assistant" + while len(content) > 0: + if role == "assistant": + message, content = extract_code_message(content) + if message is None: + message, content = extract_answer_message(content) + assert message is not None + messages.append(message) + role = "tool" + else: + message, content = extract_interpreter_message(content) + assert message is not None + messages.append(message) + role = "assistant" + + tools = json.loads(tools) + return {"messages": messages, "tools": tools} + + +if __name__ == "__main__": + tools_config_file = "recipe/retool/sandbox_fusion_tool_config.yaml" + tools_config = OmegaConf.load(tools_config_file) + tool_schema = OmegaConf.to_container(tools_config["tools"][0]["tool_schema"]) + tools = json.dumps([tool_schema]) + + data = datasets.load_dataset("JoeYing/ReTool-SFT")["train"] + data = data.map(process, fn_kwargs={"tools": tools}) + save_path = os.path.expanduser("~/ReTool-SFT/data/train-00000-of-00001.parquet") + data.to_parquet(save_path) diff --git a/ICL/DAPO/verl-recipe/retool/run_gpt_oss_ppo.sh b/ICL/DAPO/verl-recipe/retool/run_gpt_oss_ppo.sh new file mode 100644 index 0000000000000000000000000000000000000000..6519a45ca73d20d903e09a35a79cf37481faa518 --- /dev/null +++ b/ICL/DAPO/verl-recipe/retool/run_gpt_oss_ppo.sh @@ -0,0 +1,125 @@ +set -x + +# ================= data/model/tool ================= +HDFS_ROOT=${HDFS_ROOT:-$PWD} +DATA_ROOT=${DATA_ROOT:-$PWD} + +dapo_math_17k=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k +aime_2024=$DATA_ROOT/dataset/Maxwell-Jia/AIME_2024 +aime_2025=$DATA_ROOT/dataset/yentinglin/aime_2025 +actor_model_path=lmsys/gpt-oss-20b-bf16 +critic_model_path=$actor_model_path + +train_files="['$dapo_math_17k']" +test_files="['$aime_2025']" + +# tool +tool_config_path=recipe/retool/sandbox_fusion_tool_config.yaml + +# wandb +project_name=wuxibin_retool +experiment_name=gpt-oss-20b-bf16_ppo +default_local_dir=$DATA_ROOT/checkpoint/$experiment_name + +# ================= algorithm ================= +adv_estimator=gae + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_turns=8 +max_prompt_length=2048 +max_response_length=16384 +actor_lr=1e-6 +critic_lr=2e-6 +gae_gamma=1.0 +gae_lam=1.0 + +critic_warmup=20 + +train_batch_size=512 +ppo_mini_batch_size=512 +n_resp_per_prompt_val=30 + +# ================= perfomance ================= +infer_tp=4 # vllm +train_sp=4 # train + +offload=True + +actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 2 )) +critic_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 )) + + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=$adv_estimator \ + algorithm.use_kl_in_reward=$use_kl_in_reward \ + algorithm.kl_ctrl.kl_coef=$kl_coef \ + algorithm.gamma=$gae_gamma \ + algorithm.lam=$gae_lam \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.return_raw_chat=True \ + data.train_batch_size=$train_batch_size \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=True \ + +data.apply_chat_template_kwargs.reasoning_effort=medium \ + data.truncation='error' \ + data.custom_cls.path=recipe/retool/retool.py \ + data.custom_cls.name=CustomRLHFDataset \ + custom_reward_function.path=recipe/retool/retool.py \ + custom_reward_function.name=compute_score \ + actor_rollout_ref.model.path=$actor_model_path \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \ + actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \ + actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \ + actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.optim.lr=$actor_lr \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \ + actor_rollout_ref.actor.fsdp_config.param_offload=$offload \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \ + actor_rollout_ref.rollout.multi_turn.enable=True \ + actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \ + actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \ + actor_rollout_ref.rollout.multi_turn.tool_config_path=$tool_config_path \ + actor_rollout_ref.rollout.multi_turn.format=gpt-oss \ + +actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend=triton \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.val_kwargs.top_p=1.0 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \ + critic.optim.lr=$critic_lr \ + critic.model.use_remove_padding=True \ + critic.model.path=$critic_model_path \ + critic.model.enable_gradient_checkpointing=True \ + critic.ppo_max_token_len_per_gpu=$critic_max_token_len_per_gpu \ + critic.ulysses_sequence_parallel_size=$train_sp \ + critic.model.fsdp_config.param_offload=$offload \ + critic.model.fsdp_config.optimizer_offload=$offload \ + trainer.critic_warmup=$critic_warmup \ + trainer.logger=['console','wandb'] \ + trainer.project_name=$project_name \ + trainer.experiment_name=$experiment_name \ + trainer.n_gpus_per_node=8 \ + trainer.val_before_train=True \ + trainer.log_val_generations=100 \ + trainer.nnodes=2 \ + trainer.save_freq=30 \ + trainer.default_local_dir=$default_local_dir \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ diff --git a/ICL/DAPO/verl-recipe/retool/run_qwen2-32b_dapo.sh b/ICL/DAPO/verl-recipe/retool/run_qwen2-32b_dapo.sh new file mode 100644 index 0000000000000000000000000000000000000000..2df380da24cfe3872665e407489d4450aeff20c9 --- /dev/null +++ b/ICL/DAPO/verl-recipe/retool/run_qwen2-32b_dapo.sh @@ -0,0 +1,107 @@ +set -x + +# ================= data/model/tool ================= +HDFS_ROOT=${HDFS_ROOT:-$PWD} +DATA_ROOT=${DATA_ROOT:-$PWD} + +dapo_math_17k=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k +aime_2024=$DATA_ROOT/dataset/Maxwell-Jia/AIME_2024 +aime_2025=$DATA_ROOT/dataset/yentinglin/aime_2025 +model_path=$HDFS_ROOT/checkpoint/multiturn-sft-qwen-2.5-32b-instruct/global_step_372 + +train_files="['$dapo_math_17k']" +test_files="['$aime_2025']" + +# tool +tool_config_path=recipe/retool/sandbox_fusion_tool_config.yaml + +# wandb +project_name=wuxibin_retool +experiment_name=qwen2.5-32b_dapo +default_local_dir=$DATA_ROOT/checkpoint/$experiment_name + +# ================= algorithm ================= +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_turns=8 +max_prompt_length=2048 +max_response_length=16384 +actor_lr=1e-6 + +train_batch_size=512 +ppo_mini_batch_size=64 +n_resp_per_prompt=16 +n_resp_per_prompt_val=30 + +# ================= perfomance ================= +infer_tp=4 # vllm +train_sp=8 # train +offload=True + +actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 1 )) +log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 4 )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=$adv_estimator \ + algorithm.use_kl_in_reward=$use_kl_in_reward \ + algorithm.kl_ctrl.kl_coef=$kl_coef \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.return_raw_chat=True \ + data.train_batch_size=$train_batch_size \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.custom_cls.path=recipe/retool/retool.py \ + data.custom_cls.name=CustomRLHFDataset \ + custom_reward_function.path=recipe/retool/retool.py \ + custom_reward_function.name=compute_score \ + actor_rollout_ref.model.path=$model_path \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \ + actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \ + actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \ + actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.optim.lr=$actor_lr \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \ + actor_rollout_ref.actor.fsdp_config.param_offload=$offload \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \ + actor_rollout_ref.rollout.multi_turn.enable=True \ + actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \ + actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \ + actor_rollout_ref.rollout.multi_turn.tool_config_path=$tool_config_path \ + actor_rollout_ref.rollout.multi_turn.format=hermes \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ + actor_rollout_ref.rollout.n=$n_resp_per_prompt \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \ + trainer.logger=['console','wandb'] \ + trainer.project_name=$project_name \ + trainer.experiment_name=$experiment_name \ + trainer.n_gpus_per_node=8 \ + trainer.val_before_train=True \ + trainer.log_val_generations=100 \ + trainer.nnodes=2 \ + trainer.save_freq=30 \ + trainer.default_local_dir=$default_local_dir \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ diff --git a/ICL/DAPO/verl-recipe/retool/run_qwen2-32b_ppo.sh b/ICL/DAPO/verl-recipe/retool/run_qwen2-32b_ppo.sh new file mode 100644 index 0000000000000000000000000000000000000000..1e3ef2cd7fbb0bafe12bcdb95ce9696060cc8e2b --- /dev/null +++ b/ICL/DAPO/verl-recipe/retool/run_qwen2-32b_ppo.sh @@ -0,0 +1,123 @@ +set -x + +# ================= data/model/tool ================= +HDFS_ROOT=${HDFS_ROOT:-$PWD} +DATA_ROOT=${DATA_ROOT:-$PWD} + +dapo_math_17k=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k +aime_2024=$DATA_ROOT/dataset/Maxwell-Jia/AIME_2024 +aime_2025=$DATA_ROOT/dataset/yentinglin/aime_2025 +actor_model_path=$HDFS_ROOT/checkpoint/multiturn-sft-qwen-2.5-32b-instruct/global_step_372 +critic_model_path=$actor_model_path + +train_files="['$dapo_math_17k']" +test_files="['$aime_2025']" + +# tool +tool_config_path=recipe/retool/sandbox_fusion_tool_config.yaml + +# wandb +project_name=wuxibin_retool +experiment_name=qwen2.5-32b_ppo +default_local_dir=$DATA_ROOT/checkpoint/$experiment_name + +# ================= algorithm ================= +adv_estimator=gae + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_turns=8 +max_prompt_length=2048 +max_response_length=16384 +actor_lr=1e-6 +critic_lr=2e-6 +gae_gamma=1.0 +gae_lam=1.0 + +critic_warmup=20 + +train_batch_size=1024 +ppo_mini_batch_size=256 +n_resp_per_prompt_val=30 + +# ================= perfomance ================= +infer_tp=4 # vllm +train_sp=4 # train + +offload=True + +actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 2 )) +critic_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 )) + + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=$adv_estimator \ + algorithm.use_kl_in_reward=$use_kl_in_reward \ + algorithm.kl_ctrl.kl_coef=$kl_coef \ + algorithm.gamma=$gae_gamma \ + algorithm.lam=$gae_lam \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.return_raw_chat=True \ + data.train_batch_size=$train_batch_size \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.custom_cls.path=recipe/retool/retool.py \ + data.custom_cls.name=CustomRLHFDataset \ + custom_reward_function.path=recipe/retool/retool.py \ + custom_reward_function.name=compute_score \ + actor_rollout_ref.model.path=$actor_model_path \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \ + actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \ + actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \ + actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.optim.lr=$actor_lr \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \ + actor_rollout_ref.actor.fsdp_config.param_offload=$offload \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \ + actor_rollout_ref.rollout.multi_turn.enable=True \ + actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \ + actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \ + actor_rollout_ref.rollout.multi_turn.tool_config_path=$tool_config_path \ + actor_rollout_ref.rollout.multi_turn.format=hermes \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \ + critic.optim.lr=$critic_lr \ + critic.model.use_remove_padding=True \ + critic.model.path=$critic_model_path \ + critic.model.enable_gradient_checkpointing=True \ + critic.ppo_max_token_len_per_gpu=$critic_max_token_len_per_gpu \ + critic.ulysses_sequence_parallel_size=$train_sp \ + critic.model.fsdp_config.param_offload=$offload \ + critic.model.fsdp_config.optimizer_offload=$offload \ + trainer.critic_warmup=$critic_warmup \ + trainer.logger=['console','wandb'] \ + trainer.project_name=$project_name \ + trainer.experiment_name=$experiment_name \ + trainer.n_gpus_per_node=8 \ + trainer.val_before_train=True \ + trainer.log_val_generations=100 \ + trainer.nnodes=2 \ + trainer.save_freq=30 \ + trainer.default_local_dir=$default_local_dir \ + trainer.test_freq=5 \ + trainer.total_epochs=1 $@ diff --git a/ICL/DAPO/verl-recipe/retool/run_qwen2-32b_sft.sh b/ICL/DAPO/verl-recipe/retool/run_qwen2-32b_sft.sh new file mode 100644 index 0000000000000000000000000000000000000000..d218b0e7eb87bdd5236fbb5eed61e8f2ccd306f8 --- /dev/null +++ b/ICL/DAPO/verl-recipe/retool/run_qwen2-32b_sft.sh @@ -0,0 +1,40 @@ +#!/bin/bash +set -x + +nnodes=2 +nproc_per_node=8 +master_addr= +master_port= + +experiment_name=multiturn-sft-qwen-2.5-32b-instruct +HDFS_ROOT=${HDFS_ROOT:-$PWD} +DATA_ROOT=${DATA_ROOT:-$PWD} + +TRAIN_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet +EVAL_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet +MODEL_PATH=$HDFS_ROOT/model/Qwen2.5-32B-Instruct +SAVE_PATH=$DATA_ROOT/checkpoint/$experiment_name + +torchrun --nnodes=$nnodes \ + --nproc_per_node=$nproc_per_node \ + --master-addr=$master_addr \ + --master-port=$master_port \ + --node-rank=$node_rank \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$TRAIN_DATA \ + data.val_files=$EVAL_DATA \ + data.max_length=16384 \ + data.train_batch_size=32 \ + data.multiturn.enable=true \ + data.multiturn.messages_key=messages \ + data.multiturn.tools_key=tools \ + data.micro_batch_size_per_gpu=4 \ + model.partial_pretrain=$MODEL_PATH \ + model.strategy=fsdp \ + trainer.default_local_dir=$SAVE_PATH \ + trainer.project_name=wuxibin-multiturn-sft \ + trainer.experiment_name=$experiment_name \ + trainer.logger='["console","wandb"]' \ + trainer.total_epochs=6 \ + ulysses_sequence_parallel_size=4 \ + use_remove_padding=true \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/retool/run_qwen2_7b_dapo.sh b/ICL/DAPO/verl-recipe/retool/run_qwen2_7b_dapo.sh new file mode 100644 index 0000000000000000000000000000000000000000..f1187a3d12d3b308e41c27e801319dc839691f8c --- /dev/null +++ b/ICL/DAPO/verl-recipe/retool/run_qwen2_7b_dapo.sh @@ -0,0 +1,109 @@ +set -x + +export VLLM_USE_V1=1 + +# ================= data/model/tool ================= +HDFS_ROOT=${HDFS_ROOT:-$PWD} +DATA_ROOT=${DATA_ROOT:-$PWD} + +dapo_math_17k=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k +aime_2024=$DATA_ROOT/dataset/Maxwell-Jia/AIME_2024 +aime_2025=$DATA_ROOT/dataset/yentinglin/aime_2025 +model_path=$HDFS_ROOT/checkpoint/multiturn-sft-qwen-2.5-7b-instruct/global_step_372 + +train_files="['$dapo_math_17k']" +test_files="['$aime_2025', '$aime_2024']" + +# tool +tool_config_path=recipe/retool/sandbox_fusion_tool_config.yaml + +# wandb +project_name=retool +experiment_name=qwen2.5-7b_dapo +default_local_dir=$DATA_ROOT/checkpoint/$experiment_name + +# ================= algorithm ================= +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_turns=16 +max_prompt_length=2048 +max_response_length=16384 +actor_lr=1e-6 + +train_batch_size=64 +ppo_mini_batch_size=16 +n_resp_per_prompt=16 +n_resp_per_prompt_val=30 + +# ================= perfomance ================= +infer_tp=4 # vllm +train_sp=4 # train +offload=True + +actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 1 )) +log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 4 )) + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=$adv_estimator \ + algorithm.use_kl_in_reward=$use_kl_in_reward \ + algorithm.kl_ctrl.kl_coef=$kl_coef \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.return_raw_chat=True \ + data.train_batch_size=$train_batch_size \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.custom_cls.path=recipe/retool/retool.py \ + data.custom_cls.name=CustomRLHFDataset \ + custom_reward_function.path=recipe/retool/retool.py \ + custom_reward_function.name=compute_score \ + actor_rollout_ref.model.path=$model_path \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \ + actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \ + actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \ + actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.optim.lr=$actor_lr \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \ + actor_rollout_ref.actor.fsdp_config.param_offload=$offload \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \ + actor_rollout_ref.rollout.multi_turn.enable=True \ + actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \ + actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \ + actor_rollout_ref.rollout.multi_turn.tool_config_path=$tool_config_path \ + actor_rollout_ref.rollout.multi_turn.format=hermes \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ + actor_rollout_ref.rollout.n=$n_resp_per_prompt \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \ + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ + actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \ + trainer.logger=['console','wandb'] \ + trainer.project_name=$project_name \ + trainer.experiment_name=$experiment_name \ + trainer.n_gpus_per_node=8 \ + trainer.val_before_train=True \ + trainer.log_val_generations=20 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.default_local_dir=$default_local_dir \ + trainer.test_freq=10 \ + trainer.total_epochs=1 $@ diff --git a/ICL/DAPO/verl-recipe/retool/run_qwen2_7b_sft_npu.sh b/ICL/DAPO/verl-recipe/retool/run_qwen2_7b_sft_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..ba203e6bee001198f1bd2b20b415397980162aaf --- /dev/null +++ b/ICL/DAPO/verl-recipe/retool/run_qwen2_7b_sft_npu.sh @@ -0,0 +1,36 @@ +#!/bin/bash +set -x + +nnodes=1 +nproc_per_node=8 + +project_name=retool_sft +experiment_name=multiturn-sft-qwen-2.5-7b-instruct + +TRAIN_DATA=PATH/TO/ReTool-SFT/data/train-00000-of-00001.parquet +EVAL_DATA=PATH/TO/ReTool-SFT/data/train-00000-of-00001.parquet +MODEL_PATH=PATH/TO/Qwen2.5-7B-Instruct +SAVE_PATH=PATH/TO/checkpoint/$experiment_name + +torchrun --nnodes=$nnodes \ + --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$TRAIN_DATA \ + data.val_files=$EVAL_DATA \ + data.max_length=16384 \ + data.train_batch_size=64 \ + data.multiturn.enable=true \ + data.multiturn.messages_key=messages \ + data.multiturn.tools_key=tools \ + data.micro_batch_size_per_gpu=8 \ + model.partial_pretrain=$MODEL_PATH \ + model.strategy=fsdp \ + trainer.default_local_dir=$SAVE_PATH \ + trainer.project_name=$project_name \ + trainer.experiment_name=$experiment_name \ + trainer.logger='["console"]' \ + trainer.total_epochs=6 \ + trainer.save_freq=10 \ + trainer.device=npu \ + ulysses_sequence_parallel_size=4 \ + use_remove_padding=true \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/retool/sandbox_fusion_tool_config.yaml b/ICL/DAPO/verl-recipe/retool/sandbox_fusion_tool_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..71b10e50ec95bbb42cca13ffd218a44b8d759ef0 --- /dev/null +++ b/ICL/DAPO/verl-recipe/retool/sandbox_fusion_tool_config.yaml @@ -0,0 +1,24 @@ +tools: + - class_name: "recipe.retool.retool.CustomSandboxFusionTool" + config: + sandbox_fusion_url: "http://localhost:8080/run_code" + num_workers: 128 + enable_global_rate_limit: true + rate_limit: 128 + default_timeout: 30 + default_language: "python" + memory_limit_mb: 1024 + type: native + + tool_schema: + type: "function" + function: + name: "code_interpreter" + description: "A tool for executing code." + parameters: + type: "object" + properties: + code: + type: "string" + description: "The code to execute." + required: ["code"] diff --git a/ICL/DAPO/verl-recipe/spin/README.md b/ICL/DAPO/verl-recipe/spin/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fa7d3ab7eaade84f4faeace2d706b129b4417f23 --- /dev/null +++ b/ICL/DAPO/verl-recipe/spin/README.md @@ -0,0 +1,179 @@ +# SPIN: Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models + +This repository hosts a `verl` recipe inspired by the paper **"Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models"** (SPIN). SPIN is a language model finetuning algorithm that enables iterative self-improvement through a self-play mechanism inspired by game theory. + +**Core Idea:** Models learn by playing against themselves, reducing reliance on external preference datasets or stronger teacher models: + +1. **Synthetic Data Generation:** The current model generates responses, creating its own training data from previous iterations. +2. **Two-Player Game Setup:** A game involving two players acted by a single LLM. +3. **Iterative Training:** The model progressively improves by refining its policy, with each iteration's model becoming the opponent for the next iteration. + +Paper Authors: [Zixiang Chen](https://github.com/uclaml/SPIN)\*, [Yihe Deng](https://github.com/uclaml/SPIN)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) + +[[Webpage](https://uclaml.github.io/SPIN/)] [[Huggingface](https://huggingface.co/papers/2401.01335)] [[Paper](https://arxiv.org/abs/2401.01335)] [[Original Implementation](https://github.com/uclaml/SPIN)] + +verl Implementation Authors: [Chendong Wang](https://cdwang96.github.io/), [Chenyang Zhao](https://github.com/zhaochenyang20) + +--- + +## Key Function (compute_online_dpo_loss) and Related works +SPIN (Chen et al., 2024) proposes an iterative self-play mechanism to fine-tune language models. In each iteration, SPIN's training objective, when using a logistic loss function, is equivalent to Direct Preference Optimization (DPO) loss (Rafailov et al., 2023). + +This `verl` recipe realizes SPIN's core concept by using DPO loss iteratively (Xu et al., 2023; Xiong et al., 2023; Snorkel AI, 2024). This means that in each iteration, we fine-tune the LLM using DPO loss for preference optimization. Notably, Xu et al. (2023) explored iterative preference optimization with pairwise cringe loss, while Xiong et al. (2023) discussed how to bridge theory and practice for RLHF under KL constraints using iterative training. The concept of iterative preference learning was also explored in online DPO (Guo et al., 2024), which focuses on direct alignment from online AI feedback. In online DPO, preference data is dynamically updated during training, allowing the model to learn from its own generated data. + +Specifically, we developed the **`compute_online_dpo_loss`** function and built this SPIN recipe on top of it. By incorporating online preference generation, this approach enables continuously refining language models without relying on fixed external preference datasets. + +**Reference Papers:** +* [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/abs/2401.01335) (Chen et al., 2024) +* [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) (Rafailov et al., 2023) +* [Somethings are more cringe than others: Preference optimization with the pairwise cringe loss](https://arxiv.org/abs/2312.16682) (Xu et al., 2023) +* [Iterative preference learning from human feedback: Bridging theory and practice for rlhf under kl-constraint](https://arxiv.org/abs/2312.11456) (Xiong et al., 2023) +* [Snorkel-Mistral-PairRM-DPO](https://huggingface.co/snorkelai/Snorkel-Mistral-PairRM-DPO) (Snorkel AI, 2024) +* [Direct language model alignment from online ai feedback](https://arxiv.org/abs/2402.04792) (Guo et al., 2024) + + +## Our Online DPO Implementation + +Our `compute_online_dpo_loss` function adapts `verl`'s existing PPO infrastructure (based on `verl` v0.3.0.post1) for this iterative online DPO. Key aspects of our implementation include: + +* **No Critic:** Unlike PPO, we omit the value function critic. +* **Dynamic Reference Model:** An explicit reference policy (`ref_policy_wg`) is used for DPO loss. This reference model's weights can be periodically updated from the actor (`ref_update_freq`), providing a dynamic baseline. +* **Online Preference Generation:** The `compute_onlineDPO_pref` function (in `core_algos.py`) dynamically creates chosen/rejected pairs based on a reward source (e.g., rule-based ranking for math problems). +* **DPO Loss Integration:** We replace PPO's policy loss with our `compute_online_dpo_loss` (in `core_algos.py`) within the actor update (`dp_actor.py`), directly optimizing the policy using the generated preferences. +* **Iterative Training Orchestration:** The `SpinTrainer` (in `spin_trainer.py`) manages the entire self-play loop: generation, preference labeling, optional reference model updates, and policy updates, enabling continuous self-improvement aligned with SPIN's principles. + +--- +## Algorithm + +This recipe implements an Online algorithm adapted to the `verl` Reinforcement Learning framework, which provides an alternative to PPO for fine-tuning language models. + +**Online Loop:** Instead of maximizing a scalar reward signal in PPO, this approach directly optimizes the policy model to align with preference data generated *online* during training: + +1. **Generation:** The current model generates multiple responses for each prompt in a batch. +2. **Preference Labeling:** A function evaluates these generated responses to determine which one is preferred (chosen) and which is dispreferred (rejected). This can be done using a reward function or implicit ranking based on specific rules. (In this recipe, we use rule-based ranking on the math problem). +3. **Update:** This preference tuple (`prompt`, `chosen_response`, `rejected_response`) is used to update the actor model using `compute_online_dpo_loss`, comparing against a reference model. + +**Connection with SPIN:** +Instead of only using a fixed target data distribution, the online generation loop in step 2 will dynamically change the target data distribution by using a certain Preference Labeling method (rule-based ranking on the math problem by selecting the better one in this recipe). This explores the direction mentioned in SPIN's paper Section 7 about "dynamically changing target data distribution" to potentially elevate LLM performance beyond the fixed human-annotated data ceiling. + +--- + +## Reproduce the Experiment (Example Setup) + +The following steps outline how to set up the environment and run the SPIN recipe, based on the provided test log using GSM8K and Qwen2.5-3B-Instruct. + +1. **Setup Environment (Example using Docker):** + ```bash + # Start a container with GPU access and shared memory + docker run -it --name spin_test --gpus all \ + --shm-size=32g \ + --ipc=host \ + -v /path/to/host/.cache:/root/.cache \ + -e HF_TOKEN= \ + lmsysorg/sglang:latest \ + /bin/bash + + # Inside the container or on your host machine: + # Ensure /tmp is writable + mkdir -p /tmp + chmod 1777 /tmp + + # Install Python 3.10 (if not present) and venv + sudo apt update + sudo apt install -y python3.10 python3.10-venv tmux + python3 -m ensurepip --upgrade + + # Create and activate a virtual environment + python3 -m venv ~/.python/spin_env + source ~/.python/spin_env/bin/activate + + # Install uv (fast package installer) + python3 -m pip install uv + ``` + +2. **Install verl and Dependencies:** + ```bash + # Clone the verl repository and checkout the spin branch + cd ~ + git clone git@github.com:volcengine/verl.git && cd verl + + # Install flash-attn (handle potential build issues) + python3 -m uv pip install wheel packaging + python3 -m uv pip install flash-attn --no-build-isolation --no-deps + + # Install verl with sglang extras + python3 -m uv pip install -e ".[sglang]" + ``` + *Note: If `flash-attn` installation fails, try the manual steps again or consult its documentation.* + +3. **Login & Download Data/Model:** + ```bash + # Login to Weights & Biases (optional, for logging) + export WANDB_API_KEY= + # wandb login + + # Download the GSM8K dataset + python3 examples/data_preprocess/gsm8k.py --local_save_dir ~/data/gsm8k # Adjusted path + + # Download the base model (Example: Qwen2.5-3B-Instruct) + huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen2.5-3B-Instruct + ``` + +4. **Configure:** + * Modify the configuration file (e.g., `config/spin_trainer.yaml` or the one specified in the run script) with correct paths to your downloaded model, data, desired hyperparameters (`dpo_beta`, learning rate, etc.), and distributed training settings (nodes, GPUs per node). + * Pay attention to `actor_rollout_ref.model`, `data` paths, `reward_model` config (if using one), and `trainer.ref_update_freq`. + +5. **Run Training:** + ```bash + # Set CUDA visible devices (adjust based on your hardware and config) + export CUDA_VISIBLE_DEVICES=0,1,2,3 + + # Launch the training script (e.g., test.sh or a custom script) + # Ensure test.sh points to the correct config and main script + bash recipe/spin/run_spin.sh + ``` + +--- + +## Configuration + +* The primary configuration is typically managed through a YAML file specified in the launch script (e.g., `config/spin_trainer.yaml`). +* Key configuration sections: + * `data`: Paths to training/validation prompt files, batch sizes, sequence lengths. + * `actor_rollout_ref`: Paths to the base model (used for actor and initial reference), FSDP settings, optimization parameters (learning rate, scheduler). + * `reward_model`: Configuration for the reward model used for online preference labeling (path, batch size, etc.). Can be omitted if using a simpler reward function. + * `algorithm`: DPO-specific hyperparameters like `dpo_beta`, `dpo_loss_type`. + * `trainer`: Distributed training settings (nodes, GPUs per node), logging (WandB), checkpointing frequency, and `ref_update_freq` (set > 0 to enable periodic reference model updates from the actor). + +--- + +## Key Files + +* `main_spin.py`: Main entry point using Hydra to load the config and launch the `SpinTrainer`. +* `spin_trainer.py`: Defines the `SpinTrainer` class, orchestrating the Online DPO training loop. +* `fsdp_workers.py`: Implements Ray workers (Actor, Reference) potentially using FSDP. +* `dp_actor.py`: Contains the actor class, including the DPO policy update logic. +* `core_algos.py`: Includes helper functions for `compute_online_dpo_loss` and `compute_onlineDPO_pref`. +* `config/spin_trainer.yaml` (or similar): Main Hydra configuration file for the recipe. +* `run_spin.sh` (or similar): Example bash script for launching a training run. +* `README.md`: This file. + +--- + +## Acknowledgement + +We sincerely thank the contribution and guidance from the `verl` community and advisors, including (adapted from SPPO): + +* [Zixiang Chen](https://sites.google.com/view/zxchen) +* [Yuhao Yang](https://github.com/yhyang201) +* [Yifan Zhang](https://github.com/yifanzhang-pro) +* [Yongan Xiang](https://github.com/BearBiscuit05) +* [Junrong Lin](https://github.com/ocss884) +* [Yuxuan Tong](https://github.com/tongyx361) +* [Guangming Shen](https://github.com/PeterSH6) +* [Biao He](https://www.linkedin.com/in/biao-he/) +* [Qingquan Song](https://qingquansong.github.io/) +* [Chenyang Zhao](https://zhaochenyang20.github.io/Chayenne/) +* [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) + +--- diff --git a/ICL/DAPO/verl-recipe/spin/dp_actor.py b/ICL/DAPO/verl-recipe/spin/dp_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..e95e6f9721f10ae56601daf46441edbaf1ed6af6 --- /dev/null +++ b/ICL/DAPO/verl-recipe/spin/dp_actor.py @@ -0,0 +1,288 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import math +from collections import defaultdict + +import numpy as np +import torch +from recipe.spin.core_algos import compute_online_dpo_loss, get_batch_logps + +from verl import DataProto +from verl.utils.device import get_device_name +from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches +from verl.workers.actor import DataParallelPPOActor + +__all__ = ["DataParallelPPOActor"] + + +class SPINDataParallelPPOActor(DataParallelPPOActor): + def compute_log_prob(self, data: DataProto) -> torch.Tensor: + """Compute the log probability of the responses given input_ids, attention_mask and position_ids + + Args: + data (DataProto): a DataProto containing keys + + ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the + concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. + + ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. + + ``responses``: tensor of shape [batch_size, response_length]. torch.int64. + + Returns: + torch.Tensor: the log_prob tensor + """ + # set to eval + self.actor_module.eval() + + micro_batch_size = data.meta_info["micro_batch_size"] + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error + use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] + + select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] + batch = data.select(batch_keys=select_keys).batch + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + + if has_multi_modal_inputs: + num_micro_batches = data.batch.batch_size[0] // micro_batch_size + non_tensor_select_keys = ["multi_modal_inputs"] + micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + elif use_dynamic_bsz: + # split using dynamic bsz + max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size + micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) + else: + micro_batches = batch.split(micro_batch_size) + + log_probs_lst = [] + for micro_batch in micro_batches: + if isinstance(micro_batch, DataProto): + micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} + + with torch.no_grad(): + _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature) + log_probs_lst.append(log_probs) + log_probs = torch.concat(log_probs_lst, dim=0) + + if use_dynamic_bsz: + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + log_probs = log_probs[revert_indices] + + return log_probs + + def update_policy_dpo_with_ref(self, data: DataProto): + """ + Performs the DPO update step using pre-calculated reference log probs + from an external, periodically updated reference model. + """ + self.actor_module.train() # Ensure training mode + + # --- Retrieve necessary data --- + try: + # Expects batch prepared by fit_dpo loop, including reference log probs + batch_td = data.batch + chosen_labels = batch_td["chosen_labels"] + rejected_labels = batch_td["rejected_labels"] + # ... other needed tensors like chosen/rejected input_ids, attention_mask, position_ids ... + + # === Get PRE-CALCULATED reference log probs from input data === + reference_chosen_logps = batch_td["reference_chosen_logps"] # Should be sequence-level logps + reference_rejected_logps = batch_td["reference_rejected_logps"] # Should be sequence-level logps + # ============================================================ + + # Get DPO params from meta_info + # beta = data.meta_info.get('dpo_beta', 0.1) # Default beta + beta = self.config.get("dpo_beta", 0.1) # Default beta + loss_type = data.meta_info.get("dpo_loss_type", "sigmoid") + label_smoothing = data.meta_info.get("dpo_label_smoothing", 0.0) + # reference_free should now be False as we provide ref logps + reference_free = data.meta_info.get("reference_free", False) # Default False + + except KeyError as e: + print(f"ERROR: Missing required key for DPO update (in update_policy_dpo): {e}") + print(f"Available keys in data.batch: {list(batch_td.keys())}") # Debug print + return {} # Return empty metrics on error + except Exception as e_data: + print(f"ERROR accessing data for DPO update (in update_policy_dpo): {e_data}") + return {} + + # --- Micro-batching Setup --- + micro_batch_size = self.config.get("ppo_micro_batch_size_per_gpu") + if micro_batch_size is None: + # Fallback or default if not set, or raise error + micro_batch_size = 1 # Example fallback, adjust as needed + print(f"Warning: 'ppo_micro_batch_size_per_gpu' not set, defaulting to {micro_batch_size}") + # raise ValueError("Config 'ppo_micro_batch_size_per_gpu' must be set.") + + # Ensure chosen_input_ids exists before getting shape + if "chosen_input_ids" not in batch_td: + print("ERROR: 'chosen_input_ids' not found in batch_td for DPO update.") + return {} + bsz = batch_td["chosen_input_ids"].shape[0] + + if bsz == 0: + print("Warning: DPO batch size is 0 in update_policy_dpo. Skipping update.") + return {"actor/dpo_loss": 0.0, "actor/grad_norm": 0.0} # Return zero metrics if batch is empty + + num_micro_batches = math.ceil(bsz / micro_batch_size) + gradient_accumulation_steps = num_micro_batches + + # --- Metrics Accumulation --- + total_loss = 0.0 + accumulated_metrics = defaultdict(list) + metrics = {} # Final metrics dict + + # --- Zero Gradients --- + self.actor_optimizer.zero_grad(set_to_none=True) + + # --- Micro-batch Loop --- + for i in range(num_micro_batches): + start_idx = i * micro_batch_size + end_idx = min(start_idx + micro_batch_size, bsz) + if start_idx >= end_idx: + continue + + # Slice the full DPO batch into micro-batches + # Important: Slice ALL required tensors, including labels and inputs + micro_batch_chosen_labels = chosen_labels[start_idx:end_idx] + micro_batch_rejected_labels = rejected_labels[start_idx:end_idx] + micro_batch_chosen_inputs = { + "input_ids": batch_td["chosen_input_ids"][start_idx:end_idx], + "attention_mask": batch_td["chosen_attention_mask"][start_idx:end_idx], + } + if "chosen_position_ids" in batch_td: + micro_batch_chosen_inputs["position_ids"] = batch_td["chosen_position_ids"][start_idx:end_idx] + + micro_batch_rejected_inputs = { + "input_ids": batch_td["rejected_input_ids"][start_idx:end_idx], + "attention_mask": batch_td["rejected_attention_mask"][start_idx:end_idx], + } + if "rejected_position_ids" in batch_td: + micro_batch_rejected_inputs["position_ids"] = batch_td["rejected_position_ids"][start_idx:end_idx] + + # Determine autocast dtype + autocast_dtype = torch.bfloat16 # Or get dynamically from config/FSDP settings + # --- Autocast Forward Pass --- + with torch.autocast(device_type=get_device_name(), dtype=autocast_dtype): + # --- Step 1: Forward pass for CURRENT policy log probs (with grad) --- + policy_chosen_outputs = self.actor_module(**micro_batch_chosen_inputs, use_cache=False) + policy_rejected_outputs = self.actor_module(**micro_batch_rejected_inputs, use_cache=False) + + # --- Step 2: Calculate CURRENT policy log probs using get_batch_logps --- + policy_chosen_logps = get_batch_logps( + policy_chosen_outputs.logits, micro_batch_chosen_labels, average_log_prob=False + ) + policy_rejected_logps = get_batch_logps( + policy_rejected_outputs.logits, micro_batch_rejected_labels, average_log_prob=False + ) + + # --- Step 3: Retrieve PRE-CALCULATED reference log probs (NO grad needed) --- + # Slice the full batch reference logps for the current micro-batch + micro_ref_chosen_logps = reference_chosen_logps[start_idx:end_idx] + micro_ref_rejected_logps = reference_rejected_logps[start_idx:end_idx] + # --- The ActorAsRef calculation block is REMOVED --- + + # --- Step 4: Calculate DPO Logits and Loss --- + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = micro_ref_chosen_logps - micro_ref_rejected_logps # Uses pre-calculated values + logits = pi_logratios - ref_logratios # DPO logits + + loss = compute_online_dpo_loss( + policy_chosen_logps=policy_chosen_logps, # Has grad + policy_rejected_logps=policy_rejected_logps, # Has grad + reference_chosen_logps=micro_ref_chosen_logps, # No grad (from input) + reference_rejected_logps=micro_ref_rejected_logps, # No grad (from input) + beta=beta, + label_smoothing=label_smoothing, + loss_type=loss_type, + reference_free=reference_free, # Should be False now + ) + + # --- Scale loss for gradient accumulation --- + scaled_loss = loss / gradient_accumulation_steps + + # --- Accumulate Metrics --- + total_loss += loss.item() # Unscaled loss + accumulated_metrics["actor/dpo_loss_batch"].append(loss.item()) + accumulated_metrics["actor/dpo_logits_batch"].append(logits.mean().item()) + # Accumulate policy and reference log probs/ratios if needed for debugging + accumulated_metrics["actor/policy_chosen_logps_batch"].append(policy_chosen_logps.mean().item()) + accumulated_metrics["actor/policy_rejected_logps_batch"].append(policy_rejected_logps.mean().item()) + accumulated_metrics["actor/reference_chosen_logps_batch"].append(micro_ref_chosen_logps.mean().item()) + accumulated_metrics["actor/reference_rejected_logps_batch"].append( + micro_ref_rejected_logps.mean().item() + ) + + # --- Backward Pass (outside autocast) --- + # Check if loss requires grad before backward + if scaled_loss.requires_grad: + scaled_loss.backward() + else: + print(f"Warning: Scaled loss at micro-batch {i} does not require grad. Skipping backward.") + + # --- End Micro-batch Loop --- + + # --- Optimizer Step (after accumulating gradients for all micro-batches) --- + grad_norm = self._optimizer_step() + + # --- Populate Final Metrics --- + if num_micro_batches > 0 and bsz > 0: # Check if any processing happened + metrics["actor/dpo_loss"] = total_loss / num_micro_batches + metrics["actor/grad_norm"] = ( + grad_norm.item() if torch.is_tensor(grad_norm) and torch.isfinite(grad_norm) else float("inf") + ) + # Average other accumulated metrics + for key, val_list in accumulated_metrics.items(): + if val_list: + metrics[key.replace("_batch", "")] = np.mean(val_list) + + # Calculate accuracy / rewards / margins based on averaged logprobs if desired + if ( + "actor/policy_chosen_logps" in metrics + and "actor/policy_rejected_logps" in metrics + and "actor/reference_chosen_logps" in metrics + and "actor/reference_rejected_logps" in metrics + ): + policy_ratio_mean = metrics["actor/policy_chosen_logps"] - metrics["actor/policy_rejected_logps"] + ref_ratio_mean = metrics["actor/reference_chosen_logps"] - metrics["actor/reference_rejected_logps"] + logits_mean = policy_ratio_mean - ref_ratio_mean + metrics["actor/rewards_chosen"] = beta * ( + metrics["actor/policy_chosen_logps"] - metrics["actor/reference_chosen_logps"] + ) + metrics["actor/rewards_rejected"] = beta * ( + metrics["actor/policy_rejected_logps"] - metrics["actor/reference_rejected_logps"] + ) + metrics["actor/rewards_accuracies"] = float(logits_mean > 0) # Mean accuracy proxy + metrics["actor/rewards_margins"] = metrics["actor/rewards_chosen"] - metrics["actor/rewards_rejected"] + + else: # Handle case where no micro-batches were run (e.g., bsz=0) + metrics["actor/dpo_loss"] = 0.0 + metrics["actor/grad_norm"] = 0.0 + # Initialize other metrics to 0 or NaN as appropriate + for key in accumulated_metrics.keys(): + metrics[key.replace("_batch", "")] = 0.0 + metrics["actor/rewards_chosen"] = 0.0 + metrics["actor/rewards_rejected"] = 0.0 + metrics["actor/rewards_accuracies"] = 0.0 + metrics["actor/rewards_margins"] = 0.0 + + return metrics # Return aggregated metrics diff --git a/ICL/DAPO/verl-recipe/spin/fsdp_workers.py b/ICL/DAPO/verl-recipe/spin/fsdp_workers.py new file mode 100644 index 0000000000000000000000000000000000000000..ccee0cb76f7c1022d6f695131778a30e3d33880e --- /dev/null +++ b/ICL/DAPO/verl-recipe/spin/fsdp_workers.py @@ -0,0 +1,598 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os +import warnings + +import numpy as np +import psutil +import torch +import torch.distributed +from codetiming import Timer +from omegaconf import OmegaConf, open_dict +from torch.distributed.device_mesh import init_device_mesh + +import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register +from verl.utils import hf_tokenizer +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device +from verl.utils.flops_counter import FlopsCounter +from verl.utils.fs import copy_to_local +from verl.utils.fsdp_utils import ( + get_fsdp_wrap_policy, + get_init_weight_context_manager, + init_fn, + load_fsdp_model_to_gpu, + load_fsdp_optimizer, + offload_fsdp_model_to_cpu, + offload_fsdp_optimizer, +) +from verl.utils.import_utils import import_external_libs +from verl.utils.model import compute_position_id_with_mask +from verl.utils.profiler import log_gpu_memory_usage +from verl.workers.fsdp_workers import ActorRolloutRefWorker +from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) + + +def create_device_mesh(world_size, fsdp_size): + if fsdp_size < 0 or fsdp_size >= world_size: + device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) + else: + device_mesh = init_device_mesh( + get_device_name(), mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] + ) + return device_mesh + + +def get_sharding_strategy(device_mesh): + from torch.distributed.fsdp import ShardingStrategy + + if device_mesh.ndim == 1: + sharding_strategy = ShardingStrategy.FULL_SHARD + elif device_mesh.ndim == 2: + sharding_strategy = ShardingStrategy.HYBRID_SHARD + else: + raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") + return sharding_strategy + + +class SPINRolloutRefWorker(ActorRolloutRefWorker): + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + from recipe.spin.dp_actor import SPINDataParallelPPOActor as DataParallelPPOActor + + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + + override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) + use_remove_padding = self.config.model.get("use_remove_padding", False) + use_fused_kernels = self.config.model.get("use_fused_kernels", False) + + if self._is_actor or self._is_rollout or self._is_ref: + # we need the model for actor and rollout + if self._is_actor or self._is_ref: + optim_config = self.config.actor.optim + fsdp_config = self.config.actor.fsdp_config + else: + optim_config = None + fsdp_config = OmegaConf.create() + self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = ( + self._build_model_optimizer( + model_path=self.config.model.path, + fsdp_config=fsdp_config, + optim_config=optim_config, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="actor", + ) + ) + + # get the original unwrapped module + self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module + + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.actor_optimizer) + log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) + # load from checkpoint + if self._is_actor or self._is_ref: + OmegaConf.set_struct(self.config.actor, True) + with open_dict(self.config.actor): + self.config.actor.use_remove_padding = use_remove_padding + self.config.actor.use_fused_kernels = use_fused_kernels + self.actor = DataParallelPPOActor( + config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer + ) + + if self._is_rollout: + self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) + + if self._is_ref: + self.ref_module_fsdp = self._build_model_optimizer( + model_path=self.config.model.path, + fsdp_config=self.config.ref.fsdp_config, + optim_config=None, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="ref", + )[0] + OmegaConf.set_struct(self.config.ref, True) + with open_dict(self.config.ref): + self.config.ref.use_remove_padding = use_remove_padding + self.config.ref.use_fused_kernels = use_fused_kernels + self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.actor_module_fsdp, + optimizer=self.actor.actor_optimizer, + lr_scheduler=self.actor_lr_scheduler, + processing_class=self.processor if self.processor is not None else self.tokenizer, + checkpoint_config=self.config.actor.checkpoint, + ) + + if self._is_actor: + self.flops_counter = FlopsCounter(self.actor_model_config) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.actor_module_fsdp, + optimizer=self.actor.actor_optimizer, + lr_scheduler=self.actor_lr_scheduler, + processing_class=self.processor if self.processor is not None else self.tokenizer, + checkpoint_config=self.config.actor.checkpoint, + ) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + def compute_ref_log_prob(self, data: DataProto): + assert self._is_ref + + # Support all hardwares + data = data.to(get_device_id()) + + micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu + data.meta_info["micro_batch_size"] = micro_batch_size + data.meta_info["temperature"] = self.config.rollout.temperature + data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz + with self.ulysses_sharding_manager: + output = self.ref_policy.compute_log_prob(data=data) + output = DataProto.from_dict(tensors={"ref_log_prob": output}) + + output = output.to("cpu") + + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + if self.world_size > 1: + self.ref_policy.actor_module._handle.reshard(True) + + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + def compute_log_prob(self, data: DataProto): + assert self._is_actor + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + + # Support all hardwares + data = data.to(get_device_id()) + # we should always recompute old_log_probs when it is HybridEngine + data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu + data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu + data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz + data.meta_info["temperature"] = self.config.rollout.temperature + # perform recompute log_prob + with self.ulysses_sharding_manager: + output = self.actor.compute_log_prob(data=data) + output = DataProto.from_dict( + tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature} + ) + + output = output.to("cpu") + + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + if self.world_size > 1: + self.actor.actor_module._handle.reshard(True) + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + + log_gpu_memory_usage("After compute_log_prob", logger=logger) + return output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) + def update_actor_dpo(self, data: DataProto): + """ + Wrapper for actor update step. Handles FSDP state management. + Calls self.actor.update_policy which now contains DPO logic based + on pre-calculated log probabilities. + """ + # Support all hardwares + data = data.to(get_device_id()) + + assert self._is_actor # Make sure this worker has the actor role + if self.actor is None: + raise RuntimeError("Actor instance (self.actor) not initialized in worker.") + + # --- FSDP State Management --- + if self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + if self._is_offload_optimizer: + load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) + + log_gpu_memory_usage("Before update policy (DPO via PPO path)", logger=logger) + + # --- Ulysses Sharding (if used) --- + with self.ulysses_sharding_manager: + # --- Call the core update method (now containing DPO logic) --- + with Timer(name="update_policy_dpo_via_ppo", logger=None) as timer: # Use a distinct timer name + # Calls the modified update_policy method + metrics = self.actor.update_policy_dpo_with_ref(data=data) # <-- THIS CALLS THE MODIFIED FUNCTION + delta_time = timer.last + + # --- Add Performance Metrics --- + # MFU calculation might be less accurate/meaningful here for DPO + metrics["perf/approx_tokens_processed"] = torch.sum( + data.batch.get("attention_mask", torch.tensor(0)) + ).item() # Approx tokens + metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) + metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) + metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) + global_num_tokens = data.meta_info["global_token_num"] + estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) + metrics["perf/mfu/actor"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size + + # --- LR Scheduler Step --- + lr = self.actor_lr_scheduler.get_last_lr()[0] + metrics["actor/lr"] = lr + self.actor_lr_scheduler.step() + + log_gpu_memory_usage("After update policy (DPO via PPO path)", logger=logger) + + # --- Prepare Output --- + output = DataProto(meta_info={"metrics": metrics}) + output = output.to("cpu") + + # --- FSDP State Management (Offload) --- + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.actor_optimizer) + + return output + + +# TODO(sgm): we may need to extract it to dp_reward_model.py +class RewardModelWorker(Worker): + """ + Note that we only implement the reward model that is subclass of AutoModelForTokenClassification. + """ + + def __init__(self, config): + super().__init__() + import torch.distributed + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend=get_nccl_backend()) + self.config = config + + # build device mesh for Ulysses Sequence Parallel + world_size = torch.distributed.get_world_size() + from torch.distributed.device_mesh import init_device_mesh + + fsdp_size = self.config.model.fsdp_config.fsdp_size + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + + self.ulysses_device_mesh = None + self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + dp = world_size // self.ulysses_sequence_parallel_size + if self.ulysses_sequence_parallel_size > 1: + self.ulysses_device_mesh = init_device_mesh( + get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + ) + + if self.ulysses_device_mesh is not None: + is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 + self._register_dispatch_collect_info( + "reward", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect + ) + else: + self._register_dispatch_collect_info("reward", dp_rank=self.rank, is_collect=True) + + self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + + self.use_remove_padding = self.config.model.get("use_remove_padding", False) + + # normalize config + if self.config.micro_batch_size is not None: + self.config.micro_batch_size //= torch.distributed.get_world_size() + self.config.micro_batch_size_per_gpu = self.config.micro_batch_size + + def _build_model(self, config): + # the following line is necessary + from torch.distributed.fsdp import CPUOffload + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from transformers import AutoConfig, AutoModelForTokenClassification + + # download the checkpoint from hdfs + local_path = copy_to_local(config.model.path) + + if self.config.model.input_tokenizer is None: + self._do_switch_chat_template = False + else: + self._do_switch_chat_template = True + input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer) + self.input_tokenizer = hf_tokenizer( + input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False) + ) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False)) + + trust_remote_code = config.model.get("trust_remote_code", False) + model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) + model_config.num_labels = 1 + + # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect + init_context = get_init_weight_context_manager( + use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh + ) + + with init_context(), warnings.catch_warnings(): + warnings.simplefilter("ignore") + model_config.classifier_dropout = 0.0 + reward_module = AutoModelForTokenClassification.from_pretrained( + pretrained_model_name_or_path=local_path, + config=model_config, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + trust_remote_code=trust_remote_code, + ) + + if config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1: + from verl.models.transformers.monkey_patch import apply_monkey_patch + + apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size) + + reward_module.to(torch.bfloat16) + + auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config) + + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) + + reward_module = FSDP( + reward_module, + param_init_fn=init_fn, + use_orig_params=False, + auto_wrap_policy=auto_wrap_policy, + device_id=get_device_id(), + sharding_strategy=sharding_strategy, # zero3 + sync_module_states=True, + cpu_offload=CPUOffload(offload_params=True), + forward_prefetch=False, + device_mesh=self.device_mesh, + ) + + return reward_module + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + self.reward_module = self._build_model(config=self.config) + + def _forward_micro_batch(self, micro_batch): + from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input + + from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs + + with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + + if self.use_remove_padding: + input_ids_rmpad, indices, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # unpad the position_ids to align the rotary + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + # pad and slice the inputs if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size + ) + + # only pass input_ids and position_ids to enable flash_attn_varlen + output = self.reward_module( + input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False + ) # prevent model thinks we are generating + reward_rmpad = output.logits + reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) + + # gather output if sp > 1 + if self.ulysses_sequence_parallel_size > 1: + reward_rmpad = gather_outputs_and_unpad( + reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) + + # pad it back + rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) + else: + output = self.reward_module( + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False + ) + rm_score = output.logits # (batch_size, seq_len, 1) + rm_score = rm_score.squeeze(-1) + + # extract the result of the last valid token + eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) + rm_score = rm_score[torch.arange(batch_size), eos_mask_idx] + return rm_score + + def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): + batch_size = data.batch.batch_size[0] + # expand as token_level_reward + attention_mask = data.batch["attention_mask"] + position_ids = data.batch["position_ids"] + response_length = data.batch["responses"].shape[-1] + eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) + token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) + token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores + + # select the response part + token_level_scores = token_level_scores[:, -response_length:] + + return token_level_scores + + def _switch_chat_template(self, data: DataProto): + src_max_length = data.batch["attention_mask"].shape[-1] + + src_tokenizer = self.input_tokenizer + target_tokenizer = self.tokenizer + + rm_input_ids = [] + rm_attention_mask = [] + + for i in range(data.batch.batch_size[0]): + if not isinstance(data.non_tensor_batch["raw_prompt"][i], list | np.ndarray): + raise TypeError( + f"raw_prompt must be a list or numpy array, got {type(data.non_tensor_batch['raw_prompt'][i])}" + ) + + # extract raw prompt + chat: list = list(data.non_tensor_batch["raw_prompt"][i]) + + # extract response + response_ids = data.batch["responses"][i] + response_length = response_ids.shape[-1] + valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + # decode + response = src_tokenizer.decode(valid_response_ids) + # remove bos and eos + response = response.replace(src_tokenizer.eos_token, "") + + chat.append({"role": "assistant", "content": response}) + + prompt_with_chat_template = target_tokenizer.apply_chat_template( + chat, add_generation_prompt=False, tokenize=False + ) + if self.rank == 0 and i == 0: + # for debugging purpose + print(f"Switch template. chat: {prompt_with_chat_template}") + + # the maximum length is actually determined by the reward model itself + max_length = self.config.get("max_length", src_max_length) + if max_length is None: + max_length = src_max_length + + model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) + input_ids, attention_mask = verl_F.postprocess_data( + input_ids=model_inputs["input_ids"], + attention_mask=model_inputs["attention_mask"], + max_length=max_length, + pad_token_id=target_tokenizer.pad_token_id, + left_pad=False, # right padding + truncation=self.config.get("truncation", "right"), + ) # truncate from the right + + rm_input_ids.append(input_ids) + rm_attention_mask.append(attention_mask) + + rm_input_ids = torch.cat(rm_input_ids, dim=0) + rm_attention_mask = torch.cat(rm_attention_mask, dim=0) + + rm_position_ids = compute_position_id_with_mask(rm_attention_mask) + + rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids} + + return DataProto.from_dict(rm_inputs) + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="reward")) + def compute_rm_score(self, data: DataProto): + import itertools + + from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches + + # Support all hardwares + data = data.to(get_device_id()) + if self._do_switch_chat_template: + rm_data = self._switch_chat_template(data) + else: + rm_input_ids = data.batch["input_ids"] + rm_attention_mask = data.batch["attention_mask"] + rm_position_ids = data.batch["position_ids"] + rm_inputs = { + "input_ids": rm_input_ids, + "attention_mask": rm_attention_mask, + "position_ids": rm_position_ids, + } + rm_data = DataProto.from_dict(rm_inputs) + + # Support all hardwares + rm_data.batch = rm_data.batch.to(get_device_id()) + + # perform forward computation + with self.ulysses_sharding_manager: + rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data) + data = self.ulysses_sharding_manager.preprocess_data(data=data) + + use_dynamic_bsz = self.config.use_dynamic_bsz + if use_dynamic_bsz: + max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len) + else: + micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) + output = [] + for micro_batch in micro_batches: + rm_score = self._forward_micro_batch(micro_batch) + output.append(rm_score) + scores = torch.cat(output, dim=0) # (batch_size) + + if use_dynamic_bsz: + indices = list(itertools.chain.from_iterable(indices)) + assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}" + revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) + scores = scores[revert_indices] + + token_level_scores = self._expand_to_token_level(data, scores) + # Note that this is only the scores, may not be the final rewards used to train RL + output = DataProto.from_dict(tensors={"rm_scores": token_level_scores}) + output = self.ulysses_sharding_manager.postprocess_data(data=output) + + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + self.reward_module._handle.reshard(True) + + output = output.to("cpu") + return output diff --git a/ICL/DAPO/verl-recipe/spin/run_spin.sh b/ICL/DAPO/verl-recipe/spin/run_spin.sh new file mode 100644 index 0000000000000000000000000000000000000000..798dedabed0fae0c601899d83bd38f5adde909ea --- /dev/null +++ b/ICL/DAPO/verl-recipe/spin/run_spin.sh @@ -0,0 +1,29 @@ +set -e +set -x +VISIBLE_DEVICES="4,5,6,7" +export HYDRA_FULL_ERROR=1 + +CUDA_VISIBLE_DEVICES=${VISIBLE_DEVICES} python3 -m recipe.spin.main_spin \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size=8 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=64 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=64 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.logger=console \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=1 \ + +trainer.log_freq=1 \ + trainer.ref_update_freq=1 \ + trainer.total_epochs=1000 2>&1 | tee verl_demo.log \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/spin/utils.py b/ICL/DAPO/verl-recipe/spin/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e3855f645410d3cf7a25a2b1147158faa139f6b9 --- /dev/null +++ b/ICL/DAPO/verl-recipe/spin/utils.py @@ -0,0 +1,164 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from omegaconf import DictConfig + + +def validate_config( + config: DictConfig, + use_reference_policy: bool, + use_critic: bool, +) -> None: + """ + Validate an OmegaConf DictConfig + + Args: + config: The OmegaConf DictConfig to validate. + use_reference_policy (bool): is ref policy needed + use_critic (bool): is critic needed + """ + # number of GPUs total + n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes + + # 1. Check total batch size for data correctness + real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n + assert real_train_batch_size % n_gpus == 0, ( + f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})." + ) + + # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" + # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". + def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): + settings = { + "actor_rollout_ref.actor": "micro_batch_size", + "critic": "micro_batch_size", + "reward_model": "micro_batch_size", + "actor_rollout_ref.ref": "log_prob_micro_batch_size", + "actor_rollout_ref.rollout": "log_prob_micro_batch_size", + } + + if name in settings: + param = settings[name] + param_per_gpu = f"{param}_per_gpu" + + if mbs is None and mbs_per_gpu is None: + raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.") + + if mbs is not None and mbs_per_gpu is not None: + raise ValueError( + f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. " + f"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported " + f"(the former is deprecated)." + ) + + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.actor.ppo_micro_batch_size, + config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, + "actor_rollout_ref.actor", + ) + + if use_reference_policy: + # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.ref.log_prob_micro_batch_size, + config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.ref", + ) + + # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive( + config.actor_rollout_ref.rollout.log_prob_micro_batch_size, + config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.rollout", + ) + + if use_critic and not config.critic.use_dynamic_bsz: + # Check for critic micro-batch size conflicts + check_mutually_exclusive( + config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic" + ) + + # Check for reward model micro-batch size conflicts + if ( + config.reward_model.enable + and not config.reward_model.use_dynamic_bsz + and not config.reward_model.use_reward_loop + ): + check_mutually_exclusive( + config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model" + ) + + # Actor + # check if train_batch_size is larger than ppo_mini_batch_size + # if NOT dynamic_bsz, we must ensure: + # ppo_mini_batch_size is divisible by ppo_micro_batch_size + # ppo_micro_batch_size * sequence_parallel_size >= n_gpus + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size + sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) + if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: + assert ( + config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size + == 0 + ) + assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus + + assert config.actor_rollout_ref.actor.loss_agg_mode in [ + "token-mean", + "seq-mean-token-sum", + "seq-mean-token-mean", + ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}" + + if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: + print("NOTICE: You have both enabled in-reward kl and kl loss.") + + # critic + if use_critic and not config.critic.use_dynamic_bsz: + assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size + sp_size = config.critic.get("ulysses_sequence_parallel_size", 1) + if config.critic.ppo_micro_batch_size is not None: + assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 + assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus + + # Check if use_remove_padding is enabled when using sequence parallelism for fsdp + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + if ( + config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 + or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1 + ): + assert config.actor_rollout_ref.model.use_remove_padding, ( + "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." + ) + + if use_critic and config.critic.strategy in {"fsdp", "fsdp2"}: + if config.critic.get("ulysses_sequence_parallel_size", 1) > 1: + assert config.critic.model.use_remove_padding, ( + "When using sequence parallelism for critic, you must enable `use_remove_padding`." + ) + + if config.data.get("val_batch_size", None) is not None: + print( + "WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines " + "as a whole batch, which will schedule the memory themselves." + ) + + # check eval config + if config.actor_rollout_ref.rollout.val_kwargs.do_sample: + assert config.actor_rollout_ref.rollout.temperature > 0, ( + "validation gen temperature should be greater than 0 when enabling do_sample" + ) + + print("[validate_config] All configuration checks passed successfully!") diff --git a/ICL/DAPO/verl-recipe/spo/README.md b/ICL/DAPO/verl-recipe/spo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bef9edd52a08fee0338325d734fdf5f3a5eb7a95 --- /dev/null +++ b/ICL/DAPO/verl-recipe/spo/README.md @@ -0,0 +1,353 @@ +# Single-stream Policy Optimization (SPO) + +[![arXiv](https://img.shields.io/badge/arXiv-2509.13232-b31b1b.svg)](https://arxiv.org/abs/2509.13232) +[![Python](https://img.shields.io/badge/python-3.12-blue.svg)](https://www.python.org/downloads/) +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) + +## Citation + +``` +@article{xu2025single, + title={Single-stream policy optimization}, + author={Xu, Zhongwen and Ding, Zihan}, + journal={arXiv preprint arXiv:2509.13232}, + year={2025} +} +``` + +## Installation + +### Prerequisites + +- Python 3.12 +- CUDA 12.8 compatible GPU +- Conda or Mamba package manager + +### Setup Instructions + +1. **Clone the VERL repository at the specific commit:** + +```bash +git clone https://github.com/volcengine/verl.git +cd verl +git checkout d7944c01e63e9eb639c8357648b7958550591158 +``` + +2. **Create and activate a new conda environment:** + +```bash +conda create -n spo python=3.12 -y +conda activate spo +``` + +3. **Install dependencies:** + +```bash +# Install vLLM with CUDA 12.8 support +pip install vllm==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu128 + +# Install Flash Attention +pip install --no-cache-dir --no-build-isolation flash_attn==2.7.4.post1 + +# Install verl +pip install -e . +``` + +### Environment Reference + +For a complete list of dependencies and package versions, see [`environment.yml`](./environment.yml). This file contains the full conda environment export and can be used as a reference for troubleshooting dependency issues. + +### Sandbox Runtime + +For instructions on setting up and serving the Sandbox runtime environment, see the [verl-reTool recipe documentation](https://www.notion.so/verl-reTool-recipe-2398b5b7feba80a58156fa936f9f8de6). + +## Offline Value Estimation + +Offline value estimation is a crucial preprocessing step in SPO that estimates the quality of responses in your training dataset using a pretrained model. This process helps initialize the value function for more efficient policy optimization. + +### Step 1: Preprocess Training Data + +First, split your training dataset into manageable subsets using the preprocessing script: + +```bash +python recipe/spo/estimate_offline_values/split_dapo_into_subsets.py \ + --dataset open-r1/DAPO-Math-17k-Processed \ + --output_dir DAPO-Math-17k-Processed_Splits \ + --num_subsets 5 +``` + +**Parameters:** +- `--dataset`: HuggingFace dataset identifier or local path (default: `open-r1/DAPO-Math-17k-Processed`) +- `--output_dir`: Directory where subset parquet files will be saved (required) +- `--num_subsets`: Number of subsets to split the dataset into (default: 5) + +This script will generate multiple subset `.parquet` files under the specified `output_dir`. For example: +- `DAPO-Math-17k-Processed_Splits/subset_0.parquet` +- `DAPO-Math-17k-Processed_Splits/subset_1.parquet` +- `DAPO-Math-17k-Processed_Splits/subset_2.parquet` +- `DAPO-Math-17k-Processed_Splits/subset_3.parquet` +- `DAPO-Math-17k-Processed_Splits/subset_4.parquet` + +### Step 2: Generate Offline Value Estimates + +Run the evaluation script to generate offline value estimates using a pretrained model. You'll need to process each subset individually: + +```bash +OUTPUT_DIR=spo_verl_pr \ +DATA_FILE=DAPO-Math-17k-Processed_Splits/subset_0.parquet \ +MODEL_PATH=Qwen/Qwen3-8B \ +EXP_NAME=offline_value_estimation_subset_0 \ +sh recipe/spo/estimate_offline_values/eval.sh +``` + +**Parameters:** +- `OUTPUT_DIR`: Directory where results will be saved +- `DATA_FILE`: Path to the subset parquet file to process +- `MODEL_PATH`: HuggingFace model identifier or local path to the pretrained model +- `EXP_NAME`: Experiment name for tracking and organizing results + +**Batch Processing:** + +To process all subsets, you can loop through them: + +```bash +for i in {0..N}; do + OUTPUT_DIR=spo_verl_pr \ + DATA_FILE=DAPO-Math-17k-Processed_Splits/subset_${i}.parquet \ + MODEL_PATH=Qwen/Qwen3-8B \ + EXP_NAME=offline_value_estimation_subset_${i} \ + sh recipe/spo/estimate_offline_values/eval.sh +done +``` + +Replace `N` with the actual number of subsets generated in Step 1. + +**Output Directory Structure** + +All subset outputs are saved in the `trainer.validation_data_dir` directory. The directory structure will look like this: + +``` +offline_value_estimation/ +├── offline_value_estimation_subset_0 +│ └── validation_data +│ └── 0.jsonl +├── offline_value_estimation_subset_1 +│ └── validation_data +│ └── 0.jsonl +├── offline_value_estimation_subset_2 +│ └── validation_data +│ └── 0.jsonl +├── offline_value_estimation_subset_3 +│ └── validation_data +│ └── 0.jsonl +└── offline_value_estimation_subset_4 + └── validation_data + └── 0.jsonl +``` + +Each subset directory contains: +- A `validation_data` subdirectory with the estimated values stored in JSONL format +- The `0.jsonl` file contains the offline value estimates for each response in the corresponding subset + +### Step 3: Merge Offline Value Estimates + +After generating offline value estimates for all subsets, merge them into a single file for downstream training: + +```bash +python recipe/spo/estimate_offline_values/merge_offline_values.py \ + --input_dir offline_value_estimation \ + --output_file offline_values.json +``` + +**Parameters:** +- `--input_dir`: Directory containing all subset outputs (the `trainer.validation_data_dir` from Step 2) +- `--output_file`: Path where the merged offline values JSON file will be saved +- `--pattern`: (Optional) Custom glob pattern to match subset result files (default: `offline_value_estimation_subset_*/validation_data/0.jsonl`) +- `--max_scores_per_prompt`: (Optional) Maximum number of scores to keep per prompt. If a prompt has more scores, they will be randomly subsampled (default: 8) + +**Output Format:** + +The merged file contains a dictionary mapping prompts to their corresponding offline value scores: + +```json +{ + "prompt_1": [score_1, score_2, ...], + "prompt_2": [score_1, score_2, ...], + ... +} +``` + +**Example:** + +```bash +# Merge all subsets from the default output directory +python recipe/spo/estimate_offline_values/merge_offline_values.py \ + --input_dir spo_verl_pr/offline_value_estimation \ + --output_file DAPO-Math-17k-Processed_Splits/offline_values.json + +# With custom pattern +python recipe/spo/estimate_offline_values/merge_offline_values.py \ + --input_dir /path/to/validation_data_dir \ + --output_file /path/to/output/offline_values.json \ + --pattern "custom_subset_*/validation_data/0.jsonl" + +# With custom max scores per prompt +python recipe/spo/estimate_offline_values/merge_offline_values.py \ + --input_dir spo_verl_pr/offline_value_estimation \ + --output_file DAPO-Math-17k-Processed_Splits/offline_values.json \ + --max_scores_per_prompt 16 +``` + +The script will: +- Automatically discover all subset result files matching the pattern +- Use concurrent processing to efficiently load data from multiple files +- Merge scores by prompt/question +- Display statistics about the merged data +- Save the final merged results to the specified output file + +## Training + +SPO provides two training methods: **GRPO** (Group Relative Policy Optimization) and **SPO** (Single-stream Policy Optimization). Both methods use the same training script but with different configurations. + +### Prerequisites + +Before training, ensure you have: +1. Preprocessed training data split into subsets (from [Step 1](#step-1-preprocess-training-data)) +2. For SPO method: Merged offline value estimates (from [Step 3](#step-3-merge-offline-value-estimates)) + +### Training with GRPO + +GRPO is a group-based policy optimization method that generates multiple responses per prompt during training. This is the simpler baseline method that doesn't require offline value estimation. + +```bash +OUTPUT_DIR=spo_verl_pr \ +TRAIN_DATA_DIR=DAPO-Math-17k-Processed_Splits \ +MODEL_PATH=Qwen/Qwen3-8B \ +EXP_NAME=grpo_training \ +METHOD=GRPO \ +sh recipe/spo/train.sh +``` + +**GRPO Configuration:** +- Generates **8 responses** per prompt during training +- Training batch size: 96 +- PPO mini-batch size: 12 +- Generation batch size: 96 (matches training batch size) + +### Training with SPO + +SPO is the single-stream policy optimization method that uses offline value estimates for more efficient training. This method generates only one response per prompt and uses Thompson Sampling to select prompts based on their offline value estimates. + +```bash +OUTPUT_DIR=spo_verl_pr \ +TRAIN_DATA_DIR=DAPO-Math-17k-Processed_Splits \ +MODEL_PATH=Qwen/Qwen3-8B \ +EXP_NAME=spo_training \ +METHOD=SPO \ +OFFLINE_VALUES=DAPO-Math-17k-Processed_Splits/offline_values.json \ +sh recipe/spo/train.sh +``` + +**SPO Configuration:** +- Generates **1 response** per prompt during training +- Training batch size: 768 (8x larger than GRPO) +- PPO mini-batch size: 96 (8x larger than GRPO) +- Generation batch size: 14,000 (for efficient batched generation) +- Requires offline values JSON file from preprocessing + +### Training Parameters + +All parameters are configured via environment variables: + +**Required Parameters:** +- `OUTPUT_DIR`: Directory where results, checkpoints, and logs will be saved +- `TRAIN_DATA_DIR`: Directory containing training data subset parquet files (subset_0.parquet through subset_4.parquet) +- `MODEL_PATH`: HuggingFace model identifier or local path to the pretrained model +- `EXP_NAME`: Experiment name for tracking and organizing results +- `METHOD`: Training method, either `GRPO` or `SPO` + +**SPO-Specific Parameters:** +- `OFFLINE_VALUES`: Path to the merged offline values JSON file (required when METHOD=SPO) + +**Optional Parameters:** +- `RESPONSE_LENGTH`: Maximum response length in tokens (default: 8192) +- `N_TRAIN`: Number of responses per prompt for training with GRPO (default: 8, overridden to 1 for SPO) +- `N_VAL`: Number of responses per prompt for validation (default: 16) +- `DEBUG`: Enable debug mode with smaller batch sizes (default: False) +- `VAL_BEFORE_TRAIN`: Run validation before starting training (default: False) + +### Output Directory Structure + +Training outputs are organized in the following structure: + +``` +/ +└── spo/ + └── / + ├── checkpoints/ + │ ├── epoch_0/ + │ ├── epoch_20/ + │ ├── epoch_40/ + │ └── ... + ├── validation_data/ + │ ├── 0.jsonl + │ ├── 10.jsonl + │ ├── 20.jsonl + │ └── ... + └── tensorboard/ + └── events.out.tfevents.* +``` + +**Directory Contents:** +- `checkpoints/`: Model checkpoints saved every 20 epochs +- `validation_data/`: Validation results in JSONL format, saved every 10 epochs +- `tensorboard/`: TensorBoard logs for monitoring training progress + +### Monitoring Training + +View training progress in real-time using TensorBoard: + +```bash +tensorboard --logdir /spo//tensorboard +``` + +Key metrics to monitor: +- `reward/mean`: Average reward across training samples +- `actor/loss`: Actor model loss +- `actor/lr`: Learning rate +- `validation/accuracy`: Validation accuracy on AIME 2024 and 2025 datasets + +### Example: Complete SPO Training Pipeline + +Here's a complete example combining all preprocessing and training steps: + +```bash +# Step 1: Split dataset into subsets +python recipe/spo/estimate_offline_values/split_dapo_into_subsets.py \ + --dataset open-r1/DAPO-Math-17k-Processed \ + --output_dir DAPO-Math-17k-Processed_Splits \ + --num_subsets 5 + +# Step 2: Generate offline value estimates for each subset +for i in {0..4}; do + OUTPUT_DIR=spo_verl_pr \ + DATA_FILE=DAPO-Math-17k-Processed_Splits/subset_${i}.parquet \ + MODEL_PATH=Qwen/Qwen3-8B \ + EXP_NAME=offline_value_estimation_subset_${i} \ + sh recipe/spo/estimate_offline_values/eval.sh +done + +# Step 3: Merge offline value estimates +python recipe/spo/estimate_offline_values/merge_offline_values.py \ + --input_dir spo_verl_pr/offline_value_estimation \ + --output_file DAPO-Math-17k-Processed_Splits/offline_values.json + +# Step 4: Train with SPO +OUTPUT_DIR=spo_verl_pr \ +TRAIN_DATA_DIR=DAPO-Math-17k-Processed_Splits \ +MODEL_PATH=Qwen/Qwen3-8B \ +EXP_NAME=spo_training \ +METHOD=SPO \ +OFFLINE_VALUES=DAPO-Math-17k-Processed_Splits/offline_values.json \ +sh recipe/spo/train.sh +``` \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/spo/environment.yml b/ICL/DAPO/verl-recipe/spo/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..90f13cd30dfa4ce849712421bdfcf6e513358175 --- /dev/null +++ b/ICL/DAPO/verl-recipe/spo/environment.yml @@ -0,0 +1,232 @@ +name: verl +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - bzip2=1.0.8=h5eee18b_6 + - ca-certificates=2025.9.9=h06a4308_0 + - expat=2.7.1=h6a678d5_0 + - ld_impl_linux-64=2.44=h153f514_2 + - libffi=3.4.4=h6a678d5_1 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - libuuid=1.41.5=h5eee18b_0 + - libxcb=1.17.0=h9b100fa_0 + - libzlib=1.3.1=hb25bd0a_0 + - ncurses=6.5=h7934f7d_0 + - openssl=3.0.18=hd6dcaed_0 + - pip=25.2=pyhc872135_1 + - pthread-stubs=0.3=h0ce48e5_1 + - python=3.12.0=h996f2a0_0 + - readline=8.3=hc2a1206_0 + - sqlite=3.50.2=hb25bd0a_1 + - tk=8.6.15=h54e0aa7_0 + - wheel=0.45.1=py312h06a4308_0 + - xorg-libx11=1.8.12=h9b100fa_1 + - xorg-libxau=1.0.12=h9b100fa_0 + - xorg-libxdmcp=1.1.5=h9b100fa_0 + - xorg-xorgproto=2024.1=h5eee18b_1 + - xz=5.6.4=h5eee18b_1 + - zlib=1.3.1=hb25bd0a_0 + - pip: + - absl-py==2.3.1 + - accelerate==1.11.0 + - aiohappyeyeballs==2.6.1 + - aiohttp==3.13.2 + - aiohttp-cors==0.8.1 + - aiosignal==1.4.0 + - annotated-doc==0.0.3 + - annotated-types==0.7.0 + - antlr4-python3-runtime==4.9.3 + - anyio==4.11.0 + - astor==0.8.1 + - attrs==25.4.0 + - blake3==1.0.8 + - cachetools==6.2.1 + - cbor2==5.7.1 + - certifi==2025.10.5 + - cffi==2.0.0 + - charset-normalizer==3.4.4 + - click==8.2.1 + - cloudpickle==3.1.1 + - codetiming==1.4.0 + - colorful==0.5.7 + - compressed-tensors==0.11.0 + - cupy-cuda12x==13.6.0 + - datasets==4.3.0 + - depyf==0.19.0 + - dill==0.4.0 + - diskcache==5.6.3 + - distlib==0.4.0 + - distro==1.9.0 + - dnspython==2.8.0 + - einops==0.8.1 + - email-validator==2.3.0 + - fastapi==0.120.1 + - fastapi-cli==0.0.14 + - fastapi-cloud-cli==0.3.1 + - fastrlock==0.8.3 + - filelock==3.20.0 + - flash-attn==2.7.4.post1 + - frozendict==2.4.6 + - frozenlist==1.8.0 + - fsspec==2025.9.0 + - gguf==0.17.1 + - gitdb==4.0.12 + - gitpython==3.1.45 + - google-api-core==2.28.1 + - google-auth==2.42.0 + - googleapis-common-protos==1.71.0 + - grpcio==1.76.0 + - h11==0.16.0 + - hf-xet==1.2.0 + - httpcore==1.0.9 + - httptools==0.7.1 + - httpx==0.28.1 + - huggingface-hub==0.36.0 + - hydra-core==1.3.2 + - idna==3.11 + - importlib-metadata==8.7.0 + - interegular==0.3.3 + - jinja2==3.1.6 + - jiter==0.11.1 + - jsonschema==4.25.1 + - jsonschema-specifications==2025.9.1 + - lark==1.2.2 + - llguidance==0.7.30 + - llvmlite==0.44.0 + - lm-format-enforcer==0.11.3 + - markdown==3.9 + - markdown-it-py==4.0.0 + - markupsafe==3.0.3 + - mdurl==0.1.2 + - mistral-common==1.8.5 + - mpmath==1.3.0 + - msgpack==1.1.2 + - msgspec==0.19.0 + - multidict==6.7.0 + - multiprocess==0.70.16 + - networkx==3.5 + - ninja==1.13.0 + - numba==0.61.2 + - numpy==1.26.4 + - nvidia-cublas-cu12==12.8.4.1 + - nvidia-cuda-cupti-cu12==12.8.90 + - nvidia-cuda-nvrtc-cu12==12.8.93 + - nvidia-cuda-runtime-cu12==12.8.90 + - nvidia-cudnn-cu12==9.10.2.21 + - nvidia-cufft-cu12==11.3.3.83 + - nvidia-cufile-cu12==1.13.1.3 + - nvidia-curand-cu12==10.3.9.90 + - nvidia-cusolver-cu12==11.7.3.90 + - nvidia-cusparse-cu12==12.5.8.93 + - nvidia-cusparselt-cu12==0.7.1 + - nvidia-nccl-cu12==2.27.3 + - nvidia-nvjitlink-cu12==12.8.93 + - nvidia-nvtx-cu12==12.8.90 + - omegaconf==2.3.0 + - openai==2.6.1 + - openai-harmony==0.0.4 + - opencensus==0.11.4 + - opencensus-context==0.1.3 + - opencv-python-headless==4.12.0.88 + - opentelemetry-api==1.38.0 + - opentelemetry-exporter-prometheus==0.59b0 + - opentelemetry-proto==1.38.0 + - opentelemetry-sdk==1.38.0 + - opentelemetry-semantic-conventions==0.59b0 + - orjson==3.11.4 + - outlines-core==0.2.11 + - packaging==25.0 + - pandas==2.3.3 + - partial-json-parser==0.2.1.1.post6 + - peft==0.17.1 + - pillow==12.0.0 + - platformdirs==4.5.0 + - prometheus-client==0.23.1 + - prometheus-fastapi-instrumentator==7.1.0 + - propcache==0.4.1 + - proto-plus==1.26.1 + - protobuf==6.33.0 + - psutil==7.1.2 + - py-cpuinfo==9.0.0 + - py-spy==0.4.1 + - pyarrow==22.0.0 + - pyasn1==0.6.1 + - pyasn1-modules==0.4.2 + - pybase64==1.4.2 + - pybind11==3.0.1 + - pycountry==24.6.1 + - pycparser==2.23 + - pydantic==2.12.3 + - pydantic-core==2.41.4 + - pydantic-extra-types==2.10.6 + - pygments==2.19.2 + - pylatexenc==2.10 + - python-dateutil==2.9.0.post0 + - python-dotenv==1.2.1 + - python-json-logger==4.0.0 + - python-multipart==0.0.20 + - pytz==2025.2 + - pyvers==0.1.0 + - pyyaml==6.0.3 + - pyzmq==27.1.0 + - ray==2.51.0 + - referencing==0.37.0 + - regex==2025.10.23 + - requests==2.32.5 + - rich==14.2.0 + - rich-toolkit==0.15.1 + - rignore==0.7.1 + - rpds-py==0.28.0 + - rsa==4.9.1 + - safetensors==0.6.2 + - scipy==1.16.3 + - sentencepiece==0.2.1 + - sentry-sdk==2.42.1 + - setproctitle==1.3.7 + - setuptools==79.0.1 + - shellingham==1.5.4 + - six==1.17.0 + - smart-open==7.4.1 + - smmap==5.0.2 + - sniffio==1.3.1 + - soundfile==0.13.1 + - soxr==1.0.0 + - starlette==0.49.1 + - sympy==1.14.0 + - tensorboard==2.20.0 + - tensorboard-data-server==0.7.2 + - tensordict==0.10.0 + - tiktoken==0.12.0 + - tokenizers==0.22.1 + - torch==2.8.0+cu128 + - torchaudio==2.8.0+cu128 + - torchdata==0.11.0 + - torchvision==0.23.0+cu128 + - tqdm==4.67.1 + - transformers==4.57.1 + - triton==3.4.0 + - typer==0.20.0 + - typing-extensions==4.15.0 + - typing-inspection==0.4.2 + - tzdata==2025.2 + - urllib3==2.5.0 + - uvicorn==0.38.0 + - uvloop==0.22.1 + - verl==0.7.0.dev0 + - virtualenv==20.35.4 + - vllm==0.11.0 + - wandb==0.22.3 + - watchfiles==1.1.1 + - websockets==15.0.1 + - werkzeug==3.1.3 + - wrapt==2.0.0 + - xformers==0.0.32.post1 + - xgrammar==0.1.25 + - xxhash==3.6.0 + - yarl==1.22.0 + - zipp==3.23.0 +prefix: /opt/conda/envs/spo diff --git a/ICL/DAPO/verl-recipe/spo/spo_main_ppo.py b/ICL/DAPO/verl-recipe/spo/spo_main_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..5d75de8e3b41e5e4e92a5d8bd7146bd9232545f4 --- /dev/null +++ b/ICL/DAPO/verl-recipe/spo/spo_main_ppo.py @@ -0,0 +1,368 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Modifications Copyright 2025 SPO authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +SPO main PPO training entry point. +This module extends the base PPO trainer with SPO-specific logic. +""" + +import os +import socket + +import hydra +import ray +from omegaconf import OmegaConf +from recipe.spo.spo_ray_trainer import RayPPOTrainer + +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from verl.trainer.main_ppo import TaskRunner as BaseTaskRunner +from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler +from verl.trainer.ppo.reward import load_reward_manager +from verl.trainer.ppo.utils import need_critic, need_reference_policy +from verl.utils.config import validate_config +from verl.utils.device import is_cuda_available + + +@hydra.main(config_path="config", config_name="spo_trainer", version_base=None) +def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config_dict: Hydra configuration dictionary containing training parameters. + """ + run_ppo(config) + + +# Define a function to run the PPO-like training process +def run_ppo(config, task_runner_class=None) -> None: + """Initialize Ray cluster and run distributed PPO training process. + + Args: + config: Training configuration object containing all necessary parameters + for distributed PPO training including Ray initialization settings, + model paths, and training hyperparameters. + task_runner_class: For recipe to change TaskRunner. + """ + # Check if Ray is not initialized + if not ray.is_initialized(): + # Initialize Ray with a local cluster configuration + # Set environment variables in the runtime environment to control tokenizer parallelism, + # NCCL debug level, VLLM logging level, and allow runtime LoRA updating + # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration + default_runtime_env = get_ppo_ray_runtime_env() + # SPO-specific debug logic: Enable Ray debug mode for legacy debugging + if config.trainer.debug: + default_runtime_env["env_vars"]["RAY_DEBUG"] = "legacy" + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + + if config.transfer_queue.enable: + # Add runtime environment variables for transfer queue + runtime_env_vars = runtime_env_kwargs.get("env_vars", {}) + runtime_env_vars["TRANSFER_QUEUE_ENABLE"] = "1" + runtime_env_kwargs["env_vars"] = runtime_env_vars + + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + if task_runner_class is None: + task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head + + # Create a remote instance of the TaskRunner class, and + # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete + if ( + is_cuda_available + and config.global_profiler.tool == "nsys" + and config.global_profiler.get("steps") is not None + and len(config.global_profiler.get("steps", [])) > 0 + ): + from verl.utils.import_utils import is_nvtx_available + + assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote() + else: + runner = task_runner_class.remote() + ray.get(runner.run.remote(config)) + + # [Optional] get the path of the timeline trace file from the configuration, default to None + # This file is used for performance analysis + timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) + if timeline_json_file: + ray.timeline(filename=timeline_json_file) + + +class TaskRunner(BaseTaskRunner): + """SPO-specific TaskRunner extending base implementation. + + Inherits most functionality from base TaskRunner and only overrides methods + that need to import from recipe.spo.spo_ray_trainer instead of verl.trainer.ppo.ray_trainer. + """ + + def add_actor_rollout_worker(self, config): + """Override: Add actor rollout worker with SPO Role import. + + SPO modification: Imports Role from recipe.spo.spo_ray_trainer + instead of verl.trainer.ppo.ray_trainer. + """ + from verl.single_controller.ray import RayWorkerGroup + + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import ( + ActorRolloutRefWorker, + AsyncActorRolloutRefWorker, + ) + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + from verl.workers.megatron_workers import ( + ActorRolloutRefWorker, + AsyncActorRolloutRefWorker, + ) + + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + from recipe.spo.spo_ray_trainer import Role + + self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls) + + return actor_rollout_cls, ray_worker_group_cls + + def add_critic_worker(self, config): + """Override: Add critic worker with SPO Role import. + + SPO modification: Imports Role from recipe.spo.spo_ray_trainer + instead of verl.trainer.ppo.ray_trainer. + """ + if config.critic.strategy in {"fsdp", "fsdp2"}: + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable"]: + from verl.workers.fsdp_workers import CriticWorker + elif use_legacy_worker_impl == "disable": + from verl.workers.roles import CriticWorker + + print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + elif config.critic.strategy == "megatron": + from verl.workers.megatron_workers import CriticWorker + + else: + raise NotImplementedError + + from recipe.spo.spo_ray_trainer import Role + + self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker) + + def init_resource_pool_mgr(self, config): + """Override: Initialize resource pool manager with SPO imports. + + SPO modification: Imports Role and ResourcePoolManager from + recipe.spo.spo_ray_trainer instead of verl.trainer.ppo.ray_trainer. + """ + from recipe.spo.spo_ray_trainer import Role + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + # TODO Here you can use the new registration method to support dynamic registration of roles + if config.reward_model.enable_resource_pool: + if config.reward_model.n_gpus_per_node <= 0: + raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0") + if config.reward_model.nnodes <= 0: + raise ValueError("config.reward_model.nnodes must be greater than 0") + + reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes + resource_pool_spec["reward_pool"] = reward_pool + + self.mapping[Role.ActorRollout] = global_pool_id + self.mapping[Role.Critic] = global_pool_id + from recipe.spo.spo_ray_trainer import ResourcePoolManager + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping) + return resource_pool_manager + + def add_reward_model_worker(self, config): + """Override: Add reward model worker with SPO Role import. + + SPO modification: Imports Role from recipe.spo.spo_ray_trainer + instead of verl.trainer.ppo.ray_trainer. + """ + from recipe.spo.spo_ray_trainer import Role + + if config.reward_model.enable: + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable"]: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + elif use_legacy_worker_impl == "disable": + from verl.workers.roles import RewardModelWorker + + print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + if config.reward_model.enable_resource_pool: + self.mapping[Role.RewardModel] = "reward_pool" + else: + self.mapping[Role.RewardModel] = "global_pool" + + def add_ref_policy_worker(self, config, ref_policy_cls): + """Override: Add reference policy worker with SPO Role import. + + SPO modification: Imports Role from recipe.spo.spo_ray_trainer + instead of verl.trainer.ppo.ray_trainer. + """ + from recipe.spo.spo_ray_trainer import Role + + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls) + self.mapping[Role.RefPolicy] = "global_pool" + + def run(self, config): + """Execute the main PPO training workflow. + + This method sets up the distributed training environment, initializes + workers, datasets, and reward functions, then starts the training process. + + Args: + config: Training configuration object containing all parameters needed + for setting up and running the PPO training process. + """ + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) + self.add_critic_worker(config) + + # We should adopt a multi-source reward function here: + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # finally, we combine all the rewards together + # The reward type depends on the tag of the data + self.add_reward_model_worker(config) + + # Add a reference policy worker if KL loss or KL reward is used. + self.add_ref_policy_worker(config, actor_rollout_cls) + + # validate config + validate_config( + config=config, + use_reference_policy=need_reference_policy(self.role_worker_mapping), + use_critic=need_critic(config), + ) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + # Load the reward manager for training and validation. + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + val_reward_fn = load_reward_manager( + config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + ) + + resource_pool_manager = self.init_resource_pool_mgr(config) + + from verl.utils.dataset.rl_dataset import collate_fn + + # Create training and validation datasets. + train_dataset = create_rl_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + is_train=True, + max_samples=config.data.get("train_max_samples", -1), + ) + val_dataset = create_rl_dataset( + config.data.val_files, + config.data, + tokenizer, + processor, + is_train=False, + max_samples=config.data.get("val_max_samples", -1), + ) + train_sampler = create_rl_sampler(config.data, train_dataset) + + # Initialize the PPO trainer. + trainer = RayPPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=self.role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + train_dataset=train_dataset, + val_dataset=val_dataset, + collate_fn=collate_fn, + train_sampler=train_sampler, + ) + # Initialize the workers of the trainer. + trainer.init_workers() + + # Start the training process. + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/spo/spo_ray_trainer.py b/ICL/DAPO/verl-recipe/spo/spo_ray_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..be8e362df467a11420863bc2bb34a8fd9a1dc7ff --- /dev/null +++ b/ICL/DAPO/verl-recipe/spo/spo_ray_trainer.py @@ -0,0 +1,782 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# Modifications Copyright 2025 SPO authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +SPO Trainer extending PPO Trainer with Self-Play Optimization. +This trainer inherits from the base PPO trainer and adds SPO-specific logic. +""" + +import json +import os +import random +import uuid +from collections import defaultdict +from copy import deepcopy +from pprint import pprint + +import numpy as np +import ray +import torch +from omegaconf import OmegaConf +from tqdm import tqdm + +from verl import DataProto +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.single_controller.ray import RayClassWithInitArgs +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, +) +from verl.trainer.ppo.ray_trainer import RayPPOTrainer as BaseRayPPOTrainer +from verl.trainer.ppo.ray_trainer import ( + ResourcePoolManager, + apply_kl_penalty, + compute_advantage, + compute_response_mask, +) +from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.trainer.ppo.utils import Role +from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.debug import marked_timer +from verl.utils.metric import reduce_metrics +from verl.utils.rollout_skip import RolloutSkip +from verl.utils.torch_functional import masked_mean + +# Re-export for backward compatibility +__all__ = [ + "RayPPOTrainer", + "ResourcePoolManager", + "Role", + "apply_kl_penalty", + "compute_advantage", + "compute_response_mask", +] + + +class RayPPOTrainer(BaseRayPPOTrainer): + """SPO-specific PPO trainer that extends the base trainer with Self-Play Optimization logic. + + This trainer inherits most functionality from the base RayPPOTrainer and adds: + - SPO-specific weighted sampling based on Thompson sampling + - SPO advantage calculation using Bayesian framework + - Alpha/beta updates with KL-based rho smoothing + """ + + def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path): + """Dump rollout/validation samples as JSONL.""" + reward_extra_infos_dict.pop("acc", None) + os.makedirs(dump_path, exist_ok=True) + filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") + + n = len(inputs) + base_data = { + "input": inputs, + "output": outputs, + "gts": gts, + "score": scores, + "step": [self.global_steps] * n, + } + + for k, v in reward_extra_infos_dict.items(): + if len(v) == n: + base_data[k] = v + + lines = [] + for i in range(n): + entry = {k: v[i] for k, v in base_data.items()} + lines.append(json.dumps(entry, ensure_ascii=False)) + + with open(filename, "w") as f: + f.write("\n".join(lines) + "\n") + + print(f"Dumped generations to {filename}") + + def _get_gen_batch(self, batch: DataProto) -> DataProto: + """Override: Get generation batch with SPO-specific keys. + + SPO modification: Includes "raw_prompt" in reward_model_keys. + """ + reward_model_keys = ( + set({"data_source", "reward_model", "extra_info", "uid", "raw_prompt"}) & batch.non_tensor_batch.keys() + ) + + # pop those keys for generation + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_model_keys + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop), + ) + + # For agent loop, we need reward model keys to compute score. + if self.async_rollout_mode: + gen_batch.non_tensor_batch.update(batch.non_tensor_batch) + + return gen_batch + + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + if self.hybrid_engine: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.ActorRollout], + config=self.config.actor_rollout_ref, + role=str(Role.ActorRollout), + ) + self.resource_pool_to_cls[resource_pool][str(Role.ActorRollout)] = actor_rollout_cls + else: + raise NotImplementedError + + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cfg = omega_conf_to_dataclass(self.config.critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) + self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls + + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role=str(Role.RefPolicy), + ) + self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls + + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls + + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.global_profiler, "steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") + # Only require nsight worker options when tool is nsys + if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": + assert ( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + ) + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + + if self.use_critic: + self.critic_wg = all_wg[str(Role.Critic)] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + self.ref_policy_wg = all_wg[str(Role.RefPolicy)] + self.ref_policy_wg.init_model() + + self.rm_wg = None + # initalization of rm_wg will be deprecated in the future + if self.use_rm: + self.rm_wg = all_wg[str(Role.RewardModel)] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = all_wg[str(Role.ActorRollout)] + self.actor_rollout_wg.init_model() + + # create async rollout manager and request scheduler + self.async_rollout_mode = False + if self.config.actor_rollout_ref.rollout.mode == "async": + from recipe.spo.agent_loop import SPOAgentLoopManager + + self.async_rollout_mode = True + self.async_rollout_manager = SPOAgentLoopManager( + config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg + ) + + def _get_spo_rho( + self, + prompt2protodata: dict[str, DataProto], + prompt2log_probs: dict[str, torch.Tensor], + prompt2D: dict[str, torch.Tensor], + micro_prompts: list[str], + spo_log_prob_batch_backup: DataProto, + ) -> torch.Tensor: + """Calculate rho for alpha and beta updating.""" + rho_metrics = {} + + if self.config.trainer.spo.rho.type == "constant": + # Repeat a constant to len(micro_prompts) as a torch.Tensor + rho = torch.full((len(micro_prompts),), self.config.trainer.spo.rho.value, dtype=torch.float32) + D = torch.full((len(micro_prompts),), 0.0, dtype=torch.float32) + rho_metrics["spo/rho"] = rho.mean().item() + rho_metrics["spo/D"] = D.mean().item() + return rho, prompt2protodata, prompt2log_probs, prompt2D, rho_metrics + elif self.config.trainer.spo.rho.type == "kl": + # Extract past dataprotos of micro_prompts + past_dataprotos = [] + first_sampled_number = 0 + for pid_, p_ in enumerate(micro_prompts): + if p_ in prompt2protodata.keys(): + proto = prompt2protodata[p_] + else: + first_sampled_number += 1 + proto = spo_log_prob_batch_backup.select_idxs([pid_]) + # Remove per-sample meta_info fields to avoid conflicts during concat + proto.meta_info.pop("global_token_num", None) + past_dataprotos.append(proto) + past_dataprotos = DataProto.concat(past_dataprotos) + response_mask = compute_response_mask(past_dataprotos) + first_sampled_ratio = first_sampled_number / len(micro_prompts) + rho_metrics["spo/first_sampled_ratio"] = first_sampled_ratio + + cur_log_probs = self.actor_rollout_wg.compute_log_prob(past_dataprotos) + cur_log_probs = cur_log_probs.batch["old_log_probs"] + old_log_probs = [] + for pid_, p_ in enumerate(micro_prompts): + if p_ in prompt2log_probs.keys(): + old_log_probs.append(prompt2log_probs[p_]) + else: + old_log_probs.append(cur_log_probs[pid_].unsqueeze(0)) + old_log_probs = torch.cat(old_log_probs, dim=0) # (M, seq_len) + + kl = (old_log_probs - cur_log_probs).abs() + D = masked_mean(kl, response_mask, axis=-1) # (M,) + rho_metrics["spo/D"] = D.mean().item() + D_half = torch.as_tensor(0.06, dtype=D.dtype, device=D.device) + rho = torch.pow(2.0, -D / D_half) + rho_metrics["spo/rho"] = rho.mean().item() + rho_clipped = torch.clamp(rho, min=self.config.trainer.spo.rho.clip_lower, max=0.96) + rho_metrics["spo/rho_clipped"] = rho_clipped.mean().item() + rho_metrics["spo/rho_clip_ratio"] = (rho_clipped != rho).type(torch.float).mean().item() + + # Update prompt2protodata and prompt2log_probs + new_log_probs = self.actor_rollout_wg.compute_log_prob(spo_log_prob_batch_backup) + for pid_, p_ in enumerate(micro_prompts): + prompt2protodata[p_] = spo_log_prob_batch_backup.select_idxs([pid_]) + prompt2log_probs[p_] = new_log_probs.batch["old_log_probs"][pid_].unsqueeze(0) + prompt2D[p_] = D[pid_].item() + + return rho_clipped, prompt2protodata, prompt2log_probs, prompt2D, rho_metrics + else: + raise ValueError(f"Unknown rho type: {self.config.trainer.spo.rho.type}") + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip.wrap_generate_sequences() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + self.max_steps_duration = 0 + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + if self.config.trainer.spo.enable: + prompt2scores = json.load(open(self.config.trainer.spo.offline_values)) + Neff = self.config.trainer.spo.offline_N + prompt2scores = {k: [int(_ > 0) for _ in v] for k, v in prompt2scores.items()} + print(f"[DEBUG] Select {Neff} samples for each prompt to calculate offline values.") + full_prompts = list(prompt2scores.keys()) + if Neff == 0: + prompt2alpha = {k: 0.5 for k in full_prompts} + prompt2beta = {k: 0.5 for k in full_prompts} + else: + for k, v in prompt2scores.items(): + if len(v) > Neff: + prompt2scores[k] = random.sample(v, Neff) + N_init = 1 / (1 - self.config.trainer.spo.rho.clip_lower) + print(f"[DEBUG] N_init: {N_init}") + prompt2alpha = {k: N_init * (sum(prompt2scores[k]) + 0.5) / (Neff + 1) for k in full_prompts} + prompt2beta = {k: N_init * (Neff - sum(prompt2scores[k]) + 0.5) / (Neff + 1) for k in full_prompts} + prompt2protodata = {} + prompt2log_probs = {} + prompt2D = {} + prompt2sampled_number = defaultdict(int) + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + if self.config.trainer.spo.enable: + EXPLORATION_EPSILON = 0.05 + + prompt2phat = { + k: float(prompt2alpha[k]) / float(prompt2alpha[k] + prompt2beta[k]) for k in full_prompts + } + prompt2weight = { + k: ((prompt2phat[k] * (1.0 - prompt2phat[k])) ** 0.5) + EXPLORATION_EPSILON + for k in full_prompts + } + + items = [] + weights = [] + for i, p in enumerate(batch_dict["raw_prompt"]): + p_str = p[0]["content"].strip() + w = float(prompt2weight.get(p_str, 0.0)) + items.append(i) + weights.append(w) + + M = len(items) + if M > 0: + weights_np = np.asarray(weights, dtype=np.float64) + wsum = float(weights_np.sum()) + + if wsum > 0.0: + probs = weights_np / wsum + else: + probs = np.full(M, 1.0 / M, dtype=np.float64) + + probs = probs / probs.sum() + + target_bs = int(self.config.data.train_batch_size) + replace = target_bs > M + + selected_pos = np.random.choice(M, size=target_bs, replace=replace, p=probs) + keep_idx = [items[j] for j in selected_pos.tolist()] + + if keep_idx: + sampled_batch_dict = {} + for k, v in batch_dict.items(): + try: + sampled_batch_dict[k] = v[keep_idx] + continue + except Exception: + pass + + if isinstance(v, list | tuple): + sampled_batch_dict[k] = type(v)(v[i] for i in keep_idx) + else: + sampled_batch_dict[k] = v + + batch_dict = sampled_batch_dict + print(f"[DEBUG] Final size of keep_idx: {len(keep_idx)}") + + metrics = {} + timing_raw = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # add uid to batch + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + + gen_batch = self._get_gen_batch(batch) + + # pass global_steps to trace + gen_batch.meta_info["global_steps"] = self.global_steps + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + is_last_step = self.global_steps >= self.total_training_steps + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, color="red"): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output) + else: + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + if self.reward_fn is None: + raise ValueError("A reward_fn is required for REMAX advantage estimation.") + + with marked_timer("gen_max", timing_raw, color="purple"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + if not self.async_rollout_mode: + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + else: + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + batch = batch.union(gen_baseline_output) + # compute reward model score on batch + rm_scores = None + if self.use_rm and "rm_scores" not in batch.batch.keys(): + rm_scores = self.rm_wg.compute_rm_score(batch) + batch = batch.union(rm_scores) + reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + keys_to_pop = set(gen_baseline_output.batch.keys()) + if rm_scores is not None: + keys_to_pop.update(rm_scores.batch.keys()) + batch.pop(batch_keys=list(keys_to_pop)) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del rm_scores, gen_baseline_batch, gen_baseline_output + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + if "response_mask" not in batch.batch.keys(): + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm and "rm_scores" not in batch.batch.keys(): + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + if self.config.reward_model.launch_reward_fn_async: + future_reward = compute_reward_async.remote( + data=batch, config=self.config, tokenizer=self.tokenizer + ) + else: + reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + + micro_prompts = batch.non_tensor_batch.get("raw_prompt", None) + micro_prompts = [_[0]["content"].strip() for _ in micro_prompts] + if self.config.trainer.spo.enable: + alpha = [prompt2alpha[_] for _ in micro_prompts] + beta = [prompt2beta[_] for _ in micro_prompts] + sum_reward_tensor = reward_tensor.sum(dim=-1) + + spo_metrics = {} + r = sum_reward_tensor + alpha = torch.tensor(alpha, dtype=torch.float).to(r) + beta = torch.tensor(beta, dtype=torch.float).to(r) + spo_metrics["spo/reward"] = r.mean().detach().item() + spo_metrics["spo/alpha"] = alpha.mean().detach().item() + spo_metrics["spo/beta"] = beta.mean().detach().item() + Neff = alpha + beta + spo_metrics["spo/Neff"] = Neff.mean().detach().item() + p_hats = alpha / Neff + spo_metrics["spo/p_hats"] = p_hats.mean().detach().item() + + # Recalculate advantages + advantages = r - p_hats + spo_metrics["spo/adv_before_norm"] = advantages.mean().detach().item() + + response_mask = compute_response_mask(batch) + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + quantiles = torch.tensor([0.1, 0.3, 0.5, 0.7, 0.9], device=advantages.device) + q_vals = torch.quantile(advantages, quantiles) + spo_metrics["spo/adv_after_norm/p10"] = q_vals[0].item() + spo_metrics["spo/adv_after_norm/p30"] = q_vals[1].item() + spo_metrics["spo/adv_after_norm/p50"] = q_vals[2].item() + spo_metrics["spo/adv_after_norm/p70"] = q_vals[3].item() + spo_metrics["spo/adv_after_norm/p90"] = q_vals[4].item() + advantages = advantages.unsqueeze(-1) * response_mask + batch.batch["advantages"] = advantages + batch.batch["returns"] = advantages + + # recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + if self.config.trainer.spo.enable: + spo_log_prob_batch_backup = batch.select(deepcopy=True) + + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + if self.config.reward_model.launch_reward_fn_async: + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # Compute rollout importance sampling weights centrally (once per batch) + # This corrects for mismatch between rollout policy and training policy + # Also computes mismatch metrics (KL, PPL, etc.) + batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch) + # IS and mismatch metrics already have mismatch/ prefix + metrics.update(is_metrics) + + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + if "advantages" not in batch.batch: + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + if self.config.trainer.spo.enable: + rho, prompt2protodata, prompt2log_probs, prompt2D, rho_metrics = self._get_spo_rho( + prompt2protodata, prompt2log_probs, prompt2D, micro_prompts, spo_log_prob_batch_backup + ) + spo_metrics.update(rho_metrics) + + # if you want exact Beta intervals, maintain alpha/beta as well: + alpha = rho * alpha + r + beta = rho * beta + (1 - r) + + cur_sampled_numbers = [] + for i in range(len(alpha)): + prompt2alpha[micro_prompts[i]] = alpha[i].item() + prompt2beta[micro_prompts[i]] = beta[i].item() + prompt2sampled_number[micro_prompts[i]] += 1 + cur_sampled_numbers.append(prompt2sampled_number[micro_prompts[i]]) + + cur_sampled_numbers = np.array(cur_sampled_numbers, dtype=np.int32) + spo_metrics["spo/cur_sampled_number/min"] = cur_sampled_numbers.min() + spo_metrics["spo/cur_sampled_number/max"] = cur_sampled_numbers.max() + spo_metrics["spo/cur_sampled_number/mean"] = cur_sampled_numbers.mean() + + metrics.update(spo_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + # Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation + + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if ( + hasattr(self.config.actor_rollout_ref.actor, "profiler") + and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory" + ): + self.actor_rollout_wg.dump_memory_snapshot( + tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}" + ) + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + # this is experimental and may be changed/removed in the future + # in favor of a general-purpose data buffer pool + if hasattr(self.train_dataset, "on_batch_end"): + # The dataset may be changed after each training batch + self.train_dataset.on_batch_end(batch=batch) diff --git a/ICL/DAPO/verl-recipe/spo/spo_retool.py b/ICL/DAPO/verl-recipe/spo/spo_retool.py new file mode 100644 index 0000000000000000000000000000000000000000..7f481c26c84f39ae17bd932518490170ed4aae02 --- /dev/null +++ b/ICL/DAPO/verl-recipe/spo/spo_retool.py @@ -0,0 +1,150 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Modifications Copyright 2025 SPO authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import datasets + +from verl.tools.base_tool import OpenAIFunctionToolSchema +from verl.tools.sandbox_fusion_tools import SandboxFusionTool +from verl.utils.dataset import RLHFDataset +from verl.utils.reward_score import math_dapo +from verl.utils.rollout_trace import rollout_trace_op + +logger = logging.getLogger(__name__) + + +class CustomSandboxFusionTool(SandboxFusionTool): + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + super().__init__(config, tool_schema) + + @rollout_trace_op + async def execute(self, instance_id: str, code: str, **kwargs) -> tuple[str, float, dict]: + # NOTE: some script may not explicitly print result, we need to add a print statement to the end of the script + lines = code.split("\n") + for i, line in reversed(list(enumerate(lines))): + if line == "": + continue + if not lines[i].startswith("print"): + lines[i] = f"print({line})" + break + code = "\n".join(lines) + + timeout = self.default_timeout + language = self.default_language + if not isinstance(code, str): + code = str(code) + + result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language) + # sandbox has no score or metrics, use Nones + return result, None, None + + +answer_format = """\nThe answer format must be: \\boxed{'The final answer goes here.'}""" + + +class CustomRLHFDataset(RLHFDataset): + """Custom dataset class to process Maxwell-Jia/AIME_2024, yentinglin/aime_2025 datasets.""" + + def _read_files_and_tokenize(self): + dataframes = [] + for parquet_file in self.data_files: + # read parquet files and cache + if ".parquet" in parquet_file: + dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] + elif "open-r1/DAPO-Math-17k-Processed" in parquet_file: + dataframe = datasets.load_dataset(parquet_file, "all")["train"] + elif "ByteDance-Seed/BeyondAIME" in parquet_file: + dataframe = datasets.load_dataset(parquet_file)["test"] + elif "Polaris-Dataset-Hard" in parquet_file: + dataframe = datasets.load_from_disk(parquet_file) + else: + dataframe = datasets.load_dataset(parquet_file)["train"] + data_source = "/".join(parquet_file.split("/")[-2:]) + if data_source in [ + "Maxwell-Jia/AIME_2024", + "yentinglin/aime_2025", + "ByteDance-Seed/BeyondAIME", + "MathArena/brumo_2025", + "MathArena/hmmt_feb_2025", + ]: + dataframe = dataframe.map( + self.map_fn, fn_kwargs={"data_source": data_source}, remove_columns=dataframe.column_names + ) + elif "Polaris-Dataset-Hard" in data_source: + dataframe = dataframe.map( + self.map_fn, + fn_kwargs={"data_source": "dataset/Polaris-Dataset-Hard"}, + remove_columns=dataframe.column_names, + ) + else: + dataframe = dataframe.map(self.map_fn2, num_proc=16) + dataframes.append(dataframe) + self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) + + print(f"dataset len: {len(self.dataframe)}") + + def map_fn(self, row: dict, *, data_source: str = None): + if data_source == "Maxwell-Jia/AIME_2024": + problem, answer = row["Problem"], row["Answer"] + elif data_source in [ + "yentinglin/aime_2025", + "ByteDance-Seed/BeyondAIME", + "MathArena/brumo_2025", + "MathArena/hmmt_feb_2025", + ]: + problem, answer = row["problem"], row["answer"] + elif data_source == "dataset/Polaris-Dataset-Hard": + problem, answer = row["problem"], row["answer"] + + prompt = problem + answer_format + data = { + "data_source": data_source.split("/")[1].lower(), # aime_2024, aime_2025, polaris-dataset-hard + "prompt": [{"role": "user", "content": prompt}], + "ability": "MATH", + "reward_model": {"ground_truth": str(answer)}, + "agent_name": "spo_tool_agent", + } + return data + + def map_fn2(self, row: dict): + content = row["prompt"] + row["prompt"] = [{"role": "user", "content": content + answer_format}] + row["agent_name"] = "spo_tool_agent" + return row + + +def compute_score(data_source, solution_str, ground_truth, extra_info, **kwargs): + # Check format: if more than one "" tag, score should be zero + if solution_str.count("") != 1: + return {"score": 0, "acc": False, "pred": ""} + + # Check if there are or blocks after + think_end_pos = solution_str.find("") + if think_end_pos != -1: + after_think = solution_str[think_end_pos + len("") :] + if "" in after_think or "" in after_think: + return {"score": 0, "acc": False, "pred": ""} + + # use \\boxed{...} answer + result = math_dapo.compute_score(solution_str, ground_truth, strict_box_verify=True) + + # Modify to 0, +1 reward + if result["score"] < 0: + result["score"] = 0 + + if result["pred"] is None: + result["pred"] = "" + + return result diff --git a/ICL/DAPO/verl-recipe/spo/spo_tool_config.yaml b/ICL/DAPO/verl-recipe/spo/spo_tool_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4d5b2dac00536ac6c62534731931eb63ef462918 --- /dev/null +++ b/ICL/DAPO/verl-recipe/spo/spo_tool_config.yaml @@ -0,0 +1,24 @@ +tools: + - class_name: "recipe.spo.spo_retool.CustomSandboxFusionTool" + config: + sandbox_fusion_url: "http://localhost:8080/run_code" + num_workers: 128 + enable_global_rate_limit: true + rate_limit: 128 + default_timeout: 30 + default_language: "python" + memory_limit_mb: 1024 + type: native + + tool_schema: + type: "function" + function: + name: "code_interpreter" + description: "A tool for executing code." + parameters: + type: "object" + properties: + code: + type: "string" + description: "The code to execute." + required: ["code"] diff --git a/ICL/DAPO/verl-recipe/spo/train.sh b/ICL/DAPO/verl-recipe/spo/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..a9436670236b03c609ad542abfbcacd57cf32eb8 --- /dev/null +++ b/ICL/DAPO/verl-recipe/spo/train.sh @@ -0,0 +1,146 @@ +set -x + +export VLLM_USE_V1=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export VLLM_ALLREDUCE_USE_SYMM_MEM=0 + +# ================= data/model/tool ================= +OUTPUT_DIR=${OUTPUT_DIR:-"."} +TRAIN_DATA_DIR=${TRAIN_DATA_DIR:-""} +EXP_NAME=${EXP_NAME:-""} +MODEL_PATH=${MODEL_PATH:-""} +RESPONSE_LENGTH=${RESPONSE_LENGTH:-8192} +N_TRAIN=${N_TRAIN:-8} +N_VAL=${N_VAL:-16} +METHOD=${METHOD:-"GRPO"} +DEBUG=${DEBUG:-"False"} +OFFLINE_VALUES=${OFFLINE_VALUES:-""} +VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-"False"} + +train_files="['${TRAIN_DATA_DIR}/subset_0.parquet', '$TRAIN_DATA_DIR/subset_1.parquet', '$TRAIN_DATA_DIR/subset_2.parquet', '$TRAIN_DATA_DIR/subset_3.parquet', '$TRAIN_DATA_DIR/subset_4.parquet']" +val_files="['Maxwell-Jia/AIME_2024', 'yentinglin/aime_2025']" + +# tool +tool_config_path=recipe/spo/spo_tool_config.yaml + +# agent loop +agent_loop_config_path=recipe/spo/config/spo_agent.yaml +default_agent_loop=spo_tool_agent + +# wandb +project_name=spo +experiment_name=$EXP_NAME +default_local_dir=$OUTPUT_DIR/$project_name/$experiment_name/checkpoints +validation_data_dir=$OUTPUT_DIR/$project_name/$experiment_name/validation_data +rollout_data_dir=$OUTPUT_DIR/$project_name/$experiment_name/rollout_data + +# ================= algorithm ================= +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_turns=8 +max_prompt_length=2048 +max_response_length=$RESPONSE_LENGTH +actor_lr=1e-6 + +n_resp_per_prompt=$N_TRAIN +n_resp_per_prompt_val=$N_VAL +if [ "$METHOD" = "GRPO" ]; then + train_batch_size=128 + ppo_mini_batch_size=16 + val_batch_size=96 + gen_batch_size=$train_batch_size + spo_enable=False +elif [ "$METHOD" = "SPO" ]; then + train_batch_size=1024 + ppo_mini_batch_size=128 + val_batch_size=96 + n_resp_per_prompt=1 + gen_batch_size=14000 # For DAPO en subsets + spo_enable=True +else + echo "Error: METHOD must be either 'GRPO' or 'SPO' when DEBUG is not True" + exit 1 +fi + +# ================= perfomance ================= +infer_tp=4 # vllm +train_sp=8 # train +offload=True + +actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 1 )) +log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 4 )) + +TENSORBOARD_DIR=$OUTPUT_DIR/${project_name}/${experiment_name}/tensorboard \ +python3 -m recipe.spo.spo_main_ppo \ + algorithm.adv_estimator=$adv_estimator \ + algorithm.use_kl_in_reward=$use_kl_in_reward \ + algorithm.kl_ctrl.kl_coef=$kl_coef \ + data.train_files="$train_files" \ + data.val_files="$val_files" \ + data.return_raw_chat=True \ + data.train_batch_size=$train_batch_size \ + data.val_batch_size=$val_batch_size \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.custom_cls.path=recipe/spo/spo_retool.py \ + data.custom_cls.name=CustomRLHFDataset \ + +data.gen_batch_size=$gen_batch_size \ + custom_reward_function.path=recipe/spo/spo_retool.py \ + custom_reward_function.name=compute_score \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \ + actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \ + actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \ + actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.optim.lr=$actor_lr \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \ + actor_rollout_ref.actor.fsdp_config.param_offload=$offload \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=async \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \ + actor_rollout_ref.rollout.multi_turn.enable=True \ + actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \ + actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \ + actor_rollout_ref.rollout.multi_turn.tool_config_path=$tool_config_path \ + actor_rollout_ref.rollout.agent.agent_loop_config_path=$agent_loop_config_path \ + actor_rollout_ref.rollout.agent.default_agent_loop=$default_agent_loop \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.n=$n_resp_per_prompt \ + actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \ + actor_rollout_ref.rollout.val_kwargs.top_k=20 \ + actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name=$project_name \ + trainer.experiment_name=$experiment_name \ + trainer.n_gpus_per_node=8 \ + trainer.val_before_train=$VAL_BEFORE_TRAIN \ + trainer.log_val_generations=20 \ + trainer.nnodes=1 \ + trainer.save_freq=20 \ + trainer.default_local_dir=$default_local_dir \ + trainer.validation_data_dir=$validation_data_dir \ + trainer.test_freq=10 \ + trainer.total_epochs=500 \ + trainer.spo.enable=$spo_enable \ + trainer.spo.offline_values=$OFFLINE_VALUES \ + trainer.debug=$DEBUG \ + trainer.rollout_data_dir=$rollout_data_dir diff --git a/ICL/DAPO/verl-recipe/sppo/README.md b/ICL/DAPO/verl-recipe/sppo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f87efa853b87857d7fd19de4e4159275619edec3 --- /dev/null +++ b/ICL/DAPO/verl-recipe/sppo/README.md @@ -0,0 +1,50 @@ +# SPPO: Self-Play Preference Optimization for Language Model Alignment + +This repository hosts the community implementation for the paper [Self-Play Preference Optimization for Language Model Alignment](https://arxiv.org/abs/2405.00675). SPPO can significantly enhance the performance of an LLM without strong external signals such as responses or preferences from GPT-4. It can outperform the model trained with iterative direct preference optimization (DPO), among other methods. SPPO is theoretically grounded, ensuring that the LLM can converge to the von Neumann winner (i.e., Nash equilibrium) under general, potentially intransitive preference, and empirically validated through extensive evaluations on multiple datasets. + +Paper Authors: [Yue Wu](https://yuewu.us/)\*, [Zhiqing Sun](https://www.cs.cmu.edu/~zhiqings/)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Yiming Yang](https://www.cs.cmu.edu/~yiming/), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) + +verl Implementation Authors: [Yuhao Yang](https://github.com/yhyang201), [Chenyang Zhao](https://github.com/zhaochenyang20) + +[[Webpage](https://uclaml.github.io/SPPO/)] [[Huggingface](https://huggingface.co/papers/2405.00675)] [[Paper](https://arxiv.org/abs/2405.00675)][[Original Implementation](https://github.com/uclaml/SPPO)] + +## Reproduce the Experiment + +We evaluate the performance of SPPO on the MATH dataset. Starting from an initial score of 46.6 with Qwen2.5-7B-Instruct, we achieve a score of 65.6 after 20 epochs of training, placing our model approximately in the top 20 on the [MATH leaderboard](https://paperswithcode.com/sota/math-word-problem-solving-on-math). It's important to note that verl's internal evaluation metrics may not perfectly align with the official evaluation methodology for Qwen2.5-7B-Instruct. Therefore, for consistency and fair comparison, we report only the results based on verl's evaluation framework. + +``` +git clone git@github.com:volcengine/verl.git +cd verl +python3 -m uv pip install -e ".[sglang]" + +export WANDB_API_KEY= + +python3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math +huggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir $HOME/models/Qwen2.5-7B-Instruct + +export CUDA_VISIBLE_DEVICES=0,1,2,3 +bash recipe/sppo/run_qwen2.5-7b_rm.sh +``` + +Note that the installation would occasionally fail to install flash-attn. If this happens, you can install it manually by running: + +```bash +python3 -m uv pip install wheel +python3 -m uv pip install packaging +python3 -m uv pip install flash-attn --no-build-isolation --no-deps +``` + +## Acknowledgement + +We sincerely thank the contribution and guidance from: + +- [Yue Wu](https://yuewu.us/) +- [Chendong Wang](https://cdwang96.github.io/) +- [Yifan Zhang](https://github.com/yifanzhang-pro) +- [Yongan Xiang](https://github.com/BearBiscuit05) +- [Junrong Lin](https://github.com/ocss884) +- [Yuxuan Tong](https://github.com/tongyx361) +- [Guangming Shen](https://github.com/PeterSH6) +- [Biao He](https://www.linkedin.com/in/biao-he/) +- [Qingquan Song](https://qingquansong.github.io/) +- [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) diff --git a/ICL/DAPO/verl-recipe/sppo/__init__.py b/ICL/DAPO/verl-recipe/sppo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc88468e3aa17ae3dd07e0492b253c60c0d71d03 --- /dev/null +++ b/ICL/DAPO/verl-recipe/sppo/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ICL/DAPO/verl-recipe/sppo/config.py b/ICL/DAPO/verl-recipe/sppo/config.py new file mode 100644 index 0000000000000000000000000000000000000000..6894e1d7cf234db441a500013e85d5aeb6c3cb6b --- /dev/null +++ b/ICL/DAPO/verl-recipe/sppo/config.py @@ -0,0 +1,22 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from verl.workers.config import FSDPActorConfig + + +@dataclass +class SPPOActorConfig(FSDPActorConfig): + sppo_eta: float = 1.0 diff --git a/ICL/DAPO/verl-recipe/sppo/dp_actor.py b/ICL/DAPO/verl-recipe/sppo/dp_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..e1b5a2eb5f9d91a35af33a24280e5befdd1cca2a --- /dev/null +++ b/ICL/DAPO/verl-recipe/sppo/dp_actor.py @@ -0,0 +1,195 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import torch + +import verl.utils.torch_functional as verl_F +from verl import DataProto +from verl.trainer.ppo.core_algos import agg_loss, kl_penalty +from verl.utils.device import get_device_id +from verl.utils.profiler import GPUMemoryLogger +from verl.utils.py_functional import append_to_dict +from verl.utils.seqlen_balancing import rearrange_micro_batches +from verl.workers.actor.dp_actor import DataParallelPPOActor + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def compute_sppo_loss( + old_log_prob: torch.Tensor, # (bs, seq_len) + log_prob: torch.Tensor, # (bs, seq_len) + rewards: torch.Tensor, # (bs,) + response_mask: torch.Tensor, # (bs, seq_len) + eta: float = 1.0, + loss_agg_mode: str = "token-mean", +): + """ + SPPO Loss computation. + """ + # Compute log-ratios over masked tokens + log_prob_sum = (log_prob * response_mask).sum(dim=1) # (bs,) + old_log_prob_sum = (old_log_prob * response_mask).sum(dim=1) # (bs,) + log_ratios = log_prob_sum - old_log_prob_sum # (bs,) + + scaled_rewards = eta * (rewards) + loss_vec = (log_ratios - scaled_rewards) ** 2 # (bs,) + + if loss_agg_mode == "token-mean": + sample_mask = response_mask.any(dim=1).float() # (bs,) + loss = verl_F.masked_mean(loss_vec, sample_mask) + + return loss, log_ratios, scaled_rewards + + +class DataParallelSPPOActor(DataParallelPPOActor): + @GPUMemoryLogger(role="dp actor", logger=logger) + def update_policy(self, data: DataProto): + # make sure we are in training mode + self.actor_module.train() + + temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error + multi_turn = data.meta_info.get("multi_turn", False) + + select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "seq_level_rewards"] + if multi_turn: + select_keys.append("loss_mask") + if self.config.use_kl_loss: + select_keys.append("ref_log_prob") + batch = data.select(batch_keys=select_keys).batch + has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() + + # Split to make minibatch iterator for updating the actor + # See PPO paper for details. https://arxiv.org/abs/1707.06347 + if has_multi_modal_inputs: + num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size + non_tensor_select_keys = ["multi_modal_inputs"] + dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) + else: + dataloader = batch.split(self.config.ppo_mini_batch_size) + + metrics = {} + for epoch in range(self.config.ppo_epochs): + for batch_idx, data in enumerate(dataloader): + # split batch into micro_batches + mini_batch = data + if has_multi_modal_inputs: + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) + num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu + micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) + elif self.config.use_dynamic_bsz: + max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size + micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) + else: + self.gradient_accumulation = ( + self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu + ) + # split batch into micro_batches + micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + + self.actor_optimizer.zero_grad() + + for data in micro_batches: + # Support all hardwares + if isinstance(data, DataProto): + data = {**data.batch.to(get_device_id()), **data.non_tensor_batch} + else: + data = data.to(get_device_id()) # actor device is cpu when using offload + responses = data["responses"] + response_length = responses.size(1) + attention_mask = data["attention_mask"] + if multi_turn: + response_mask = data["loss_mask"][:, -response_length:] + else: + response_mask = attention_mask[:, -response_length:] + + old_log_prob = data["old_log_probs"] + rewards = data["seq_level_rewards"] + + entropy_coeff = self.config.entropy_coeff + loss_agg_mode = self.config.loss_agg_mode + eta = self.config.get("sppo_eta", 1.0) + + # all return: (bsz, response_length) + calculate_entropy = False + if entropy_coeff != 0: + calculate_entropy = True + entropy, log_prob = self._forward_micro_batch( + micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy + ) + + pg_loss, log_ratios, preference = compute_sppo_loss( + old_log_prob=old_log_prob, + log_prob=log_prob, + rewards=rewards, + response_mask=response_mask, + eta=eta, + loss_agg_mode=loss_agg_mode, + ) + + if entropy_coeff != 0: + entropy_loss = agg_loss( + loss_mat=entropy, + loss_mask=response_mask, + loss_agg_mode=loss_agg_mode, + loss_scale_factor=self.config.loss_scale_factor, + ) + + # compute policy loss + policy_loss = pg_loss - entropy_loss * entropy_coeff + else: + policy_loss = pg_loss + + if self.config.use_kl_loss: + ref_log_prob = data["ref_log_prob"] + # compute kl loss + kld = kl_penalty( + logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type + ) + kl_loss = agg_loss( + loss_mat=kld, + loss_mask=response_mask, + loss_agg_mode=self.config.loss_agg_mode, + loss_scale_factor=self.config.loss_scale_factor, + ) + + policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef + metrics["actor/kl_loss"] = kl_loss.detach().item() + metrics["actor/kl_coef"] = self.config.kl_loss_coef + + if self.config.use_dynamic_bsz: + # relative to the dynamic bsz + loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size) + else: + loss = policy_loss / self.gradient_accumulation + loss.backward() + + data = { + "actor/loss": loss.detach().item(), + "actor/log_ratio_mean": log_ratios.mean().detach().item(), + "actor/preference_mean": preference.mean().detach().item(), + } + append_to_dict(metrics, data) + + grad_norm = self._optimizer_step() + data = {"actor/grad_norm": grad_norm.detach().item()} + append_to_dict(metrics, data) + self.actor_optimizer.zero_grad() + return metrics diff --git a/ICL/DAPO/verl-recipe/sppo/main_sppo.py b/ICL/DAPO/verl-recipe/sppo/main_sppo.py new file mode 100644 index 0000000000000000000000000000000000000000..7f5a9e2c9ad63316364eef146299e2ed1c12d419 --- /dev/null +++ b/ICL/DAPO/verl-recipe/sppo/main_sppo.py @@ -0,0 +1,166 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +import os + +import hydra +import ray +from omegaconf import OmegaConf + +from verl.trainer.ppo.reward import load_reward_manager +from verl.trainer.ppo.utils import need_reference_policy +from verl.utils.config import validate_config + +from .sppo_ray_trainer import RaySPPOTrainer + + +@hydra.main(config_path="config", config_name="sppo_trainer", version_base=None) +def main(config): + run_ppo(config) + + +def run_ppo(config) -> None: + # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices + # isolation, will solve in the future + os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if not ray.is_initialized(): + # this is for local ray cluster + default_runtime_env = { + "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} + } + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + runner = TaskRunner.remote() + ray.get(runner.run.remote(config)) + + +@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head +class TaskRunner: + def run(self, config): + # print initial config + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # define worker classes + if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: + assert config.critic.strategy in {"fsdp", "fsdp2"} + from verl.single_controller.ray import RayWorkerGroup + + from .sppo_worker import SPPOActorRolloutRefWorker # , CriticWorker + + actor_rollout_cls = SPPOActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.single_controller.ray import RayWorkerGroup + from verl.workers.megatron_workers import ActorRolloutRefWorker + + actor_rollout_cls = ActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + # sppo does not use critic + role_worker_mapping = { + Role.ActorRollout: ray.remote(actor_rollout_cls), + } + + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + } + + # we should adopt a multi-source reward function here + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # - finally, we combine all the rewards together + # - The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + mapping[Role.RewardModel] = global_pool_id + + # use reference model + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(SPPOActorRolloutRefWorker) + mapping[Role.RefPolicy] = global_pool_id + + # validate config + validate_config( + config=config, + use_reference_policy=need_reference_policy(role_worker_mapping), + use_critic=False, + ) + + # download the checkpoint from hdfs + local_path = copy_to_local(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none + + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1) + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + trainer = RaySPPOTrainer( + config=config, + tokenizer=tokenizer, + processor=processor, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn, + ) + trainer.init_workers() + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/ICL/DAPO/verl-recipe/sppo/run_qwen2.5-7b_rm.sh b/ICL/DAPO/verl-recipe/sppo/run_qwen2.5-7b_rm.sh new file mode 100644 index 0000000000000000000000000000000000000000..cc614d02511f5d3c97a90d97f7dc5d8420ff1bdc --- /dev/null +++ b/ICL/DAPO/verl-recipe/sppo/run_qwen2.5-7b_rm.sh @@ -0,0 +1,56 @@ +# Discliamer: the model used in the script is only for academic purpose. +set -x + +# Data preparation scripts are available in ``examples/data_preprocess``. +# Example usage: +# +# python3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math +# python3 examples/data_preprocess/gsm8k.py --local_save_dir ~/data/gsm8k + +gsm8k_train_path=$HOME/data/math/train.parquet +gsm8k_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path']" +test_files="['$gsm8k_test_path']" + +# prepare model ckpt +huggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir $HOME/models/Qwen2.5-7B-Instruct & +# huggingface-cli download sfairXC/FsfairX-LLaMA3-RM-v0.1 --local-dir $HOME/models/FsfairX-LLaMA3-RM-v0.1 & +wait + +python3 -m recipe.sppo.main_sppo \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=512 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path="$HOME/models/Qwen2.5-7B-Instruct" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=sglang \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='sppo-sglang' \ + trainer.val_before_train=True \ + trainer.experiment_name='Qwen2-7B-Instruct_hybrid_rm' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=1 \ + trainer.total_epochs=1000 $@ + # Note that we set lr_warmup_steps = 15 in config/sppo_trainer.yaml + # The experiment will converge to 0.656 on MATH dataset after 20 epochs \ No newline at end of file diff --git a/ICL/DAPO/verl-recipe/sppo/sppo_ray_trainer.py b/ICL/DAPO/verl-recipe/sppo/sppo_ray_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2f9aeb77c2def2d05b230df0eb0cc6135707709b --- /dev/null +++ b/ICL/DAPO/verl-recipe/sppo/sppo_ray_trainer.py @@ -0,0 +1,363 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FSDP PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import uuid +from copy import deepcopy +from pprint import pprint +from typing import Optional + +import numpy as np +import ray +import torch +from torch.utils.data import Dataset, Sampler +from tqdm import tqdm + +from verl import DataProto +from verl.single_controller.ray import RayWorkerGroup +from verl.trainer.ppo import core_algos +from verl.trainer.ppo.core_algos import agg_loss +from verl.trainer.ppo.ray_trainer import ( + AdvantageEstimator, + RayPPOTrainer, + ResourcePoolManager, + apply_kl_penalty, + compute_response_mask, +) +from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model +from verl.utils.metric import reduce_metrics +from verl.utils.profiler.performance import simple_timer +from verl.utils.tracking import ValidationGenerationsLogger + + +def softmean(x: torch.Tensor, beta: float, dim: int = -1, keepdim: bool = False) -> torch.Tensor: + """ + Compute SoftMean_β(x) = (1/β) * log( (1/n) * Σ exp(β * x_i) ) + Falls back to arithmetic mean when β=0. + """ + if beta == 0.0: + return x.mean(dim=dim, keepdim=keepdim) + + # cast beta to tensor on same device/dtype + beta_t = x.new_tensor(beta) + # numerically-stable logsumexp(β x) + lse = torch.logsumexp(x * beta_t, dim=dim, keepdim=keepdim) + n = x.size(dim) + log_n = x.new_tensor(n).log() + + return (lse - log_n) / beta_t + + +def compute_advantage(data: DataProto, beta=1.0): + rewards = data.batch["token_level_rewards"].sum(axis=-1) # (bs, ) + s_mean = softmean(rewards, beta, keepdim=True) # (bs, ) + rewards = rewards - s_mean # (bs, ) + data.batch["seq_level_rewards"] = rewards # (bs, ) + return data + + +class RaySPPOTrainer(RayPPOTrainer): + """ + Note that this trainer runs on the driver process on a single CPU/GPU node. + """ + + # TODO: support each role have individual ray_worker_group_cls, + # i.e., support different backend of different role + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + collate_fn=None, + train_sampler: Optional[Sampler] = None, + device_name=None, + ): + self.tokenizer = tokenizer + self.processor = processor + self.config = config + self.reward_fn = reward_fn + self.val_reward_fn = val_reward_fn + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert self.hybrid_engine, "Currently, only support hybrid engine" + + if self.hybrid_engine: + assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = need_reference_policy(role_worker_mapping) + self.use_rm = need_reward_model(role_worker_mapping) + self.use_critic = False + self.ray_worker_group_cls = ray_worker_group_cls + self.validation_generations_logger = ValidationGenerationsLogger() + self.device_name = device_name if device_name else self.config.trainer.device + + # define in-reward KL control + # kl loss control currently not supported + if config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) + + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the + worker group through RPC to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # pop those keys for generation + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_data" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("multi_modal_data") + if "raw_prompt" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) + + is_last_step = self.global_steps >= self.total_training_steps + + with simple_timer("step", timing_raw): + # generate a batch + with simple_timer("gen", timing_raw): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output) + else: + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + with simple_timer("gen_max", timing_raw): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + + batch = batch.union(gen_baseline_output) + # compute reward model score on batch + rm_scores = None + if self.use_rm and "rm_scores" not in batch.batch.keys(): + rm_scores = self.rm_wg.compute_rm_score(batch) + batch = batch.union(rm_scores) + reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + keys_to_pop = set(gen_baseline_output.batch.keys()) + if rm_scores is not None: + keys_to_pop.update(rm_scores.batch.keys()) + batch.pop(batch_keys=list(keys_to_pop)) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del rm_scores, gen_baseline_batch, gen_baseline_output + + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object + ) + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + with simple_timer("reward", timing_raw): + # compute reward model score + if self.use_rm and "rm_scores" not in batch.batch.keys(): + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + if self.config.reward_model.launch_reward_fn_async: + future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer) + else: + reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + + # recompute old_log_probs + with simple_timer("old_log_prob", timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + actor_config = self.config.actor_rollout_ref.actor + entropy_agg = agg_loss( + loss_mat=entropys, + loss_mask=response_masks, + loss_agg_mode=actor_config.loss_agg_mode, + loss_scale_factor=actor_config.loss_scale_factor, + ) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if self.use_reference_policy: + # compute reference log_prob + with simple_timer("ref", timing_raw): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with simple_timer("values", timing_raw): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with simple_timer("adv", timing_raw): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + if self.config.reward_model.launch_reward_fn_async: + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + batch.batch["seq_level_rewards"] = batch.batch["token_level_scores"] + + beta = self.config.algorithm.sppo_eta + batch = compute_advantage(batch, beta=beta) + + # update critic + if self.use_critic: + with simple_timer("update_critic", timing_raw): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with simple_timer("update_actor", timing_raw): + batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with simple_timer("testing", timing_raw): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 + ): + with simple_timer("save_checkpoint", timing_raw): + self._save_checkpoint() + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + progress_bar.update(1) + self.global_steps += 1 diff --git a/ICL/DAPO/verl-recipe/sppo/sppo_worker.py b/ICL/DAPO/verl-recipe/sppo/sppo_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..3353159b1820055e17b3605820fea6006054fc96 --- /dev/null +++ b/ICL/DAPO/verl-recipe/sppo/sppo_worker.py @@ -0,0 +1,122 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from omegaconf import OmegaConf, open_dict + +from verl.single_controller.base.decorator import Dispatch, register +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager +from verl.utils.flops_counter import FlopsCounter +from verl.utils.fsdp_utils import offload_fsdp_model_to_cpu, offload_fsdp_optimizer +from verl.utils.import_utils import import_external_libs +from verl.utils.profiler import log_gpu_memory_usage +from verl.workers.fsdp_workers import ActorRolloutRefWorker + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) + + +class SPPOActorRolloutRefWorker(ActorRolloutRefWorker): + """ + This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy + or a hybrid engine based on the config.rollout + """ + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def init_model(self): + from .dp_actor import DataParallelSPPOActor + + # This is used to import external_lib into the huggingface systems + import_external_libs(self.config.model.get("external_lib", None)) + + override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) + use_remove_padding = self.config.model.get("use_remove_padding", False) + use_fused_kernels = self.config.model.get("use_fused_kernels", False) + + if self._is_actor or self._is_rollout: + # we need the model for actor and rollout + if self._is_actor: + optim_config = self.config.actor.optim + fsdp_config = self.config.actor.fsdp_config + else: + optim_config = None + fsdp_config = OmegaConf.create() + self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = ( + self._build_model_optimizer( + model_path=self.config.model.path, + fsdp_config=fsdp_config, + optim_config=optim_config, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="actor", + ) + ) + + # get the original unwrapped module + self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module + + if self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + log_gpu_memory_usage("After offload actor model during init", logger=logger) + + if self._is_offload_optimizer: + offload_fsdp_optimizer(optimizer=self.actor_optimizer) + log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) + # load from checkpoint + if self._is_actor: + OmegaConf.set_struct(self.config.actor, True) + with open_dict(self.config.actor): + self.config.actor.use_remove_padding = use_remove_padding + self.config.actor.use_fused_kernels = use_fused_kernels + self.actor = DataParallelSPPOActor( + config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer + ) + + if self._is_rollout: + self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) + + if self._is_ref: + self.ref_module_fsdp = self._build_model_optimizer( + model_path=self.config.model.path, + fsdp_config=self.config.ref.fsdp_config, + optim_config=None, + override_model_config=override_model_config, + use_remove_padding=use_remove_padding, + use_fused_kernels=use_fused_kernels, + trust_remote_code=self.config.model.get("trust_remote_code", False), + use_liger=self.config.model.get("use_liger", False), + role="ref", + )[0] + OmegaConf.set_struct(self.config.ref, True) + with open_dict(self.config.ref): + self.config.ref.use_remove_padding = use_remove_padding + self.config.ref.use_fused_kernels = use_fused_kernels + self.ref_policy = DataParallelSPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) + + if self._is_actor: + self.flops_counter = FlopsCounter(self.actor_model_config) + self.checkpoint_manager = FSDPCheckpointManager( + model=self.actor_module_fsdp, + optimizer=self.actor.actor_optimizer, + lr_scheduler=self.actor_lr_scheduler, + processing_class=self.processor if self.processor is not None else self.tokenizer, + checkpoint_config=self.config.actor.checkpoint, + ) diff --git a/ICL/LV/code/SFT/__pycache__/build_icl_eval_sharegpt.cpython-313.pyc b/ICL/LV/code/SFT/__pycache__/build_icl_eval_sharegpt.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03bcd25b958f6aa1b6f4569a717c1f79f576215d Binary files /dev/null and b/ICL/LV/code/SFT/__pycache__/build_icl_eval_sharegpt.cpython-313.pyc differ diff --git a/ICL/LV/code/SFT/__pycache__/config.cpython-310.pyc b/ICL/LV/code/SFT/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f85eafa43bda2d97f32a65acde651a9ab293034d Binary files /dev/null and b/ICL/LV/code/SFT/__pycache__/config.cpython-310.pyc differ diff --git a/ICL/LV/code/SFT/__pycache__/config.cpython-311.pyc b/ICL/LV/code/SFT/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddc1a760d25bc7d8d12ca2843e3fb082bc7b618d Binary files /dev/null and b/ICL/LV/code/SFT/__pycache__/config.cpython-311.pyc differ diff --git a/ICL/LV/code/__pycache__/run_attn_map_shot_sweep_qwen3vl.cpython-313.pyc b/ICL/LV/code/__pycache__/run_attn_map_shot_sweep_qwen3vl.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..100e70a8a29ab9ba284f93b1f460001841e6279f Binary files /dev/null and b/ICL/LV/code/__pycache__/run_attn_map_shot_sweep_qwen3vl.cpython-313.pyc differ diff --git a/ICL/LV/code/adapters/__init__.py b/ICL/LV/code/adapters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ICL/LV/code/adapters/__pycache__/__init__.cpython-313.pyc b/ICL/LV/code/adapters/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c639bbfabc6382722d31a9751bd2a3a07d11335 Binary files /dev/null and b/ICL/LV/code/adapters/__pycache__/__init__.cpython-313.pyc differ diff --git a/ICL/LV/code/adapters/__pycache__/qwen3vl_adapter.cpython-313.pyc b/ICL/LV/code/adapters/__pycache__/qwen3vl_adapter.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cfaf18e02e527f05186c11318325a1abff20cf2 Binary files /dev/null and b/ICL/LV/code/adapters/__pycache__/qwen3vl_adapter.cpython-313.pyc differ diff --git a/ICL/LV/code/adapters/_runners/__init__.py b/ICL/LV/code/adapters/_runners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ICL/LV/code/adapters/_runners/__pycache__/__init__.cpython-313.pyc b/ICL/LV/code/adapters/_runners/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06c7dc83db3c10b5c1939e030ee19ce3be516a3b Binary files /dev/null and b/ICL/LV/code/adapters/_runners/__pycache__/__init__.cpython-313.pyc differ diff --git a/ICL/LV/code/adapters/_runners/__pycache__/gemma3_infer.cpython-313.pyc b/ICL/LV/code/adapters/_runners/__pycache__/gemma3_infer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50f985ca43eba11d413b0d80613d056922d501be Binary files /dev/null and b/ICL/LV/code/adapters/_runners/__pycache__/gemma3_infer.cpython-313.pyc differ diff --git a/ICL/LV/code/adapters/_runners/__pycache__/qwen3_vl_infer.cpython-313.pyc b/ICL/LV/code/adapters/_runners/__pycache__/qwen3_vl_infer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49012a6631bcf77ccd7e439aadbff175a9ac9fb5 Binary files /dev/null and b/ICL/LV/code/adapters/_runners/__pycache__/qwen3_vl_infer.cpython-313.pyc differ diff --git a/ICL/LV/code/adapters/_runners/gemma3_infer.py b/ICL/LV/code/adapters/_runners/gemma3_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..100e9d50578ae0d844fba0a02b1d0fe6a34c8de0 --- /dev/null +++ b/ICL/LV/code/adapters/_runners/gemma3_infer.py @@ -0,0 +1,348 @@ +""" +Gemma 3 inference wrapper used by the evaluation scripts. + +This mirrors the structure we used for IDEFICS2/Qwen wrappers and adds a +compatibility layer to run prompts built with the Qwen-style list-format +([{'image': ...}, {'text': ...}, ...]). + +Key points from Gemma 3 (transformers >= 4.50.0): +- Vision inputs are normalized to 896x896 internally by the processor and + encoded into fixed "image tokens"; you provide images via chat messages. +- Use AutoProcessor.apply_chat_template with tokenize=True to get a ready + batch (no separate image argument required for Gemma 3 when using messages). + +Typical usage: + runner = Gemma3Runner('/z_data/pretrained/syxin/gemma-3-4b-it/') + text = runner.generate('/path/to/img.jpg', 'Describe this image.') + +For few-shot experiments that already build Qwen-like segments, use: + text = runner.generate_from_qwen_segs(segs) +""" + +from __future__ import annotations + +import os +from typing import Dict, List, Optional + +import torch +from PIL import Image +from transformers import AutoProcessor, Gemma3ForConditionalGeneration + + +class Gemma3Runner: + def __init__( + self, + model_path: str, + device: str = 'cuda', + dtype: Optional[str] = 'bf16', + device_map: Optional[str] = None, # e.g., 'auto' for multi-GPU + ) -> None: + self.model_path = model_path + + # Resolve dtype + torch_dtype = None + if dtype: + d = dtype.lower() + if d in ('bf16', 'bfloat16'): + torch_dtype = torch.bfloat16 + elif d in ('fp16', 'float16'): + torch_dtype = torch.float16 + elif d in ('fp32', 'float32'): + torch_dtype = torch.float32 + + # Processor first; prefer fast image processor to avoid warnings + self.processor = AutoProcessor.from_pretrained(model_path, use_fast=True, padding_side="left") + # Some processors expose tokenizer separately; ensure left padding for generation + try: + tok = getattr(self.processor, "tokenizer", None) + if tok is not None and hasattr(tok, "padding_side"): + tok.padding_side = "left" + except Exception: + pass + + # Load model; either place fully on a single device or let accelerate shard + load_kwargs = {} + if torch_dtype is not None: + # transformers deprecates torch_dtype in favor of dtype + load_kwargs['dtype'] = torch_dtype + if device_map == 'auto' and (device.startswith('cuda') and torch.cuda.is_available()): + load_kwargs['device_map'] = 'auto' + + self.model = Gemma3ForConditionalGeneration.from_pretrained( + model_path, + **load_kwargs, + ).eval() + + # If not sharded, move to the requested device + if load_kwargs.get('device_map') is None: + try: + self.model.to(device) + except Exception: + if device.startswith('cuda') and torch.cuda.is_available(): + self.model.to('cuda:0') + + # Cache: preferred generation dtype for inputs when we create tensors + self._infer_dtype = torch_dtype or (torch.bfloat16 if device.startswith('cuda') else None) + + @staticmethod + def _load_image(path: str) -> Image.Image: + img = Image.open(path) + if img.mode != 'RGB': + img = img.convert('RGB') + return img + + def _apply_and_generate( + self, + messages: List[Dict], + max_new_tokens: int, + temperature: float, + top_p: float, + ) -> str: + # Build chat text via template and batch images explicitly (supports multi-image per sample). + msgs = list(messages) + prompt = None + try: + prompt = self.processor.apply_chat_template( + msgs, + add_generation_prompt=True, + tokenize=False, + ) + except Exception: + # Fallback: if the chat template rejects roles (e.g., no 'system' allowed), + # collapse the leading system into the first user text block. + sys_txt = '' + if msgs and isinstance(msgs[0], dict) and msgs[0].get('role') == 'system': + parts = msgs[0].get('content') or [] + for p in parts: + if isinstance(p, dict) and p.get('type') == 'text' and isinstance(p.get('text'), str): + sys_txt = (sys_txt + '\n\n' + p['text'].strip()).strip() if sys_txt else p['text'].strip() + msgs = msgs[1:] + # Prepend system text to first user turn as a text block to preserve guidance + if sys_txt and msgs: + for i, m in enumerate(msgs): + if m.get('role') == 'user': + content = m.get('content') or [] + content = [{'type': 'text', 'text': sys_txt}] + content + msgs[i] = {'role': 'user', 'content': content} + break + prompt = self.processor.apply_chat_template( + msgs, + add_generation_prompt=True, + tokenize=False, + ) + # Collect images in order; accept PIL, tensors, or file/URL strings. + images = [] + try: + for m in msgs: + for p in (m.get('content') or []): + if isinstance(p, dict) and p.get('type') == 'image': + v = p.get('image', None) + if v is None: + v = p.get('url') or p.get('path') + if isinstance(v, str): + if v.startswith('http://') or v.startswith('https://'): + images.append(v) + else: + try: + images.append(self._load_image(v)) + except Exception: + pass + elif v is not None: + images.append(v) + except Exception: + images = [] + pan_env = (os.getenv('GEMMA3_PAN_AND_SCAN') or '').strip().lower() + pan_flag = pan_env in ('1', 'true', 'yes', 'on') + proc_kwargs = {'do_pan_and_scan': True} if pan_flag else {} + inputs = self.processor( + text=prompt, + images=(images or None), + return_tensors='pt', + return_dict=True, + **proc_kwargs, + ) + # Move to model device (and preferred dtype) when model isn't sharded + try: + dev = next(self.model.parameters()).device + except StopIteration: + dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if getattr(self.model, 'device_map', None) is None and hasattr(inputs, 'to'): + # inputs is a BatchEncoding (Mapping), not a module; use .to on mapping + inputs = inputs.to(dev, dtype=self._infer_dtype) # type: ignore[attr-defined] + + input_len = inputs['input_ids'].shape[-1] + do_sample = temperature > 0.0 + with torch.inference_mode(): + out = self.model.generate( + **inputs, + do_sample=do_sample, + temperature=(temperature if do_sample else None), + top_p=top_p, + max_new_tokens=max_new_tokens, + ) + new_tokens = out[:, input_len:] + text = self.processor.batch_decode(new_tokens, skip_special_tokens=True)[0] + return text.strip() + + def generate( + self, + image_path: str, + prompt: str, + temperature: float = 0.0, + top_p: float = 1.0, + max_new_tokens: int = 32, + ) -> str: + """Single image + prompt. Builds one user turn with [image, text] content.""" + img = self._load_image(image_path) + messages = [ + { + 'role': 'user', + 'content': [ + {'type': 'image', 'image': img}, + {'type': 'text', 'text': prompt}, + ], + }, + ] + return self._apply_and_generate(messages, max_new_tokens, temperature, top_p) + + def generate_from_qwen_segs( + self, + segs: List[Dict[str, str]], + temperature: float = 0.0, + top_p: float = 1.0, + max_new_tokens: int = 32, + ) -> str: + """Compatibility with Qwen-style list segments, mapped to real chat turns. + + Desired behavior (matches our IDEFICS2 wrapper and paper setup): + - Instruction becomes the first system message. + - Each demo becomes: user([images] + question) → assistant(answer). + - Final query becomes: user([images] + question), with no assistant turn. + + This implementation supports the three modal orders built by + modal_order.py, including cases where images appear before or after + the [REQUEST]/[RESPONSE] text, and the split text-image-text order + where [REQUEST] and [RESPONSE] are separate text segments. + """ + # Helper to parse a text segment into optional request/response parts + def _parse_rr(txt: str) -> Dict[str, Optional[str]]: + s = (txt or '').strip() + has_req = '[REQUEST]' in s + has_resp = '[RESPONSE]' in s + req_txt: Optional[str] = None + resp_txt: Optional[str] = None + if not has_req and not has_resp: + return {'has_req': False, 'has_resp': False, 'req': None, 'resp': None} + # Normalize by splitting on the tags when present + try: + after_req = s.split('[REQUEST]', 1)[1] if has_req else s + except Exception: + after_req = s + if has_resp: + parts = after_req.split('[RESPONSE]', 1) + if has_req: + req_txt = parts[0].strip() + # response can be empty (query stub) + resp_txt = (parts[1] if len(parts) > 1 else '').strip() + else: + # only request present + if has_req: + req_txt = after_req.strip() + return {'has_req': has_req, 'has_resp': has_resp, 'req': req_txt, 'resp': resp_txt} + + # Accumulators + instruction: str = '' + # Current block under construction (one demo or the final query) + # We keep exact user content order by storing interleaved items. + cur_items: List[Dict] = [] # [{'type': 'image', 'image': PIL}, {'type': 'text', 'text': req}, ...] + cur_req: Optional[str] = None + cur_resp: Optional[str] = None + blocks: List[Dict] = [] # each: {'user_items': [...], 'resp': Optional[str]} + + def _push_block(force: bool = False) -> None: + nonlocal blocks, cur_items, cur_req, cur_resp + if cur_req is None: + return + # Only push when we have a request. Response may be empty/None for the query. + if not force and cur_req is None: + return + blocks.append({'user_items': list(cur_items), 'resp': cur_resp}) + cur_items.clear(); cur_req = None; cur_resp = None + + # Iterate through segments and assemble logical blocks + for seg in segs: + if 'image' in seg: + try: + img = self._load_image(seg['image']) + except Exception: + continue + cur_items.append({'type': 'image', 'image': img}) + continue + if 'text' in seg: + parsed = _parse_rr(seg.get('text') or '') + has_req = bool(parsed['has_req']) + has_resp = bool(parsed['has_resp']) + req_txt = parsed['req'] + resp_txt = parsed['resp'] + + if not has_req and not has_resp: + # Treat as instruction (accumulate only before any blocks) + if blocks or cur_req is not None or cur_items: + # If stray instruction appears mid-stream, append to existing instruction + instruction = (instruction + '\n\n' + (seg.get('text') or '').strip()).strip() if instruction else (seg.get('text') or '').strip() + else: + instruction = (instruction + '\n\n' + (seg.get('text') or '').strip()).strip() if instruction else (seg.get('text') or '').strip() + continue + + # If a new [REQUEST] starts and we already hold a complete block, push it first + if has_req and cur_req is not None and cur_resp is not None: + _push_block() + # If a new request starts but previous block had no response + # (e.g., stray query before demos), still push it to keep alignment. + elif has_req and cur_req is not None and cur_resp is None: + _push_block() + + # Update current block fields + if has_req: + cur_req = (req_txt or '').strip() + # Insert the request text at the exact position it appears + cur_items.append({'type': 'text', 'text': cur_req}) + if has_resp: + cur_resp = (resp_txt or '').strip() + + # For the common case where both appear in one segment, we still + # delay pushing until the next [REQUEST] or end-of-input so that + # trailing images (text-image order) attach to this block. + continue + + # Push the last block (query or final demo) + _push_block(force=True) + + # Build chat messages: demos (u→a) → final user + # To avoid strict system constraints in Gemma3 chat_template, we fold + # the instruction into the first user turn instead of emitting 'system'. + if instruction.strip() and blocks: + first_items = blocks[0].get('user_items') or [] + blocks[0]['user_items'] = [{'type': 'text', 'text': instruction.strip()}] + list(first_items) + messages: List[Dict] = [] + for i, blk in enumerate(blocks): + u_content = blk.get('user_items') or [] + resp = blk.get('resp') + # user turn preserves intra-order (images/text) exactly as in segs + messages.append({'role': 'user', 'content': u_content}) + # Ensure strict role alternation for Gemma3's chat template: every demo must have an assistant turn. + if i != len(blocks) - 1: + resp_text = (resp.strip() if isinstance(resp, str) else '') + # Insert a minimal placeholder if the demo response is empty/missing + if not resp_text: + resp_text = ' ' + messages.append({'role': 'assistant', 'content': [{'type': 'text', 'text': resp_text}]}) + + # Safety: if nothing was built, fall back to a single user prompt + if not messages: + messages = [{ + 'role': 'user', + 'content': [{'type': 'text', 'text': instruction or 'Please answer the question.'}], + }] + + return self._apply_and_generate(messages, max_new_tokens, temperature, top_p) diff --git a/ICL/LV/code/adapters/_runners/idefics2_infer.py b/ICL/LV/code/adapters/_runners/idefics2_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..3e692f43330700f2059b1145bfba587713354a9b --- /dev/null +++ b/ICL/LV/code/adapters/_runners/idefics2_infer.py @@ -0,0 +1,481 @@ +import os +import warnings +from typing import List, Optional, Tuple + +import torch +from PIL import Image +from transformers import AutoProcessor +try: + from transformers import AutoModelForImageTextToText as _AutoI2T +except Exception: # fallback for older transformers + try: + from transformers import AutoModelForVision2Seq as _AutoI2T + except Exception: + _AutoI2T = None + +# Drop-in replacement of the original Qwen-VL runner, but backed by Idefics2. +# We preserve the same class name and public methods so other scripts remain unchanged. +class Idefics2Runner: + def __init__( + self, + model_path: str, + device: str = 'cuda', + dtype: Optional[str] = None, + compile: bool = False, + force_single_device: bool = True, + ) -> None: + self.model_path = model_path + # Tidy: reduce noise for known chat_template warning; template content is used via apply_chat_template + try: + warnings.filterwarnings( + 'once', + message="Chat templates should be in a 'chat_template.jinja' file", + module='transformers', + category=UserWarning, + ) + except Exception: + pass + # Allow env override of placement policy + try: + env_force = os.environ.get('IDEFICS2_FORCE_SINGLE_DEVICE') + if env_force is not None: + force_single_device = str(env_force).strip().lower() not in ('0','false','off','no') + except Exception: + pass + + # Default dtype: prefer BF16 on CUDA to reduce VRAM; allow override via arg or env. + # Supported: bf16|fp16|fp32. Set IDEFICS2_DTYPE=fp16/fp32 to override. + if dtype is None: + try: + env_dtype = (os.environ.get('IDEFICS2_DTYPE') or '').strip() + dtype = env_dtype or None + except Exception: + dtype = None + if dtype is None and str(device).startswith('cuda'): + dtype = 'bf16' + if dtype is not None and str(dtype).strip().lower() in ('none', 'auto', ''): + dtype = None + + # Map dtype string to torch dtype + torch_dtype = None + if dtype: + d = dtype.lower() + if d in ('bf16', 'bfloat16'): + torch_dtype = torch.bfloat16 + elif d in ('fp16', 'float16'): + torch_dtype = torch.float16 + elif d in ('fp32', 'float32'): + torch_dtype = torch.float32 + + # Minimal "tokenizer" shim exposing from_list_format so existing callers work unchanged. + class _ShimTokenizer: + @staticmethod + def from_list_format(segs: List[dict]): + # For Idefics2 we pass segments through and parse them inside chat() + return segs + + self.tokenizer = _ShimTokenizer() + + # Load Idefics2 processor and model + use_device_map_auto = (str(device).startswith('cuda') and not force_single_device) + self.use_device_map_auto = use_device_map_auto + self.processor = AutoProcessor.from_pretrained(model_path) + if _AutoI2T is None: + raise RuntimeError('Neither AutoModelForImageTextToText nor AutoModelForVision2Seq is available in this transformers version.') + load_kwargs = { + 'device_map': ('auto' if use_device_map_auto else None), + } + if torch_dtype is not None: + # transformers versions vary: some prefer torch_dtype, others prefer dtype. + load_kwargs['torch_dtype'] = torch_dtype + try: + self.model = _AutoI2T.from_pretrained(model_path, **load_kwargs).eval() + except TypeError: + if torch_dtype is None: + raise + load_kwargs.pop('torch_dtype', None) + load_kwargs['dtype'] = torch_dtype + self.model = _AutoI2T.from_pretrained(model_path, **load_kwargs).eval() + + # Hotfix: some versions of HF Idefics2 use a CPU view() on patch_attention_mask + # inside the vision embeddings, which clashes with CUDA tensors for position_ids. + # To avoid the CPU/GPU indexing mismatch, we patch the embeddings.forward to + # ignore any incoming patch_attention_mask (equivalent to treating all patches + # as valid). This mirrors the behavior we want when we explicitly drop + # pixel/patch attention masks from processor outputs below. + _patched_vis_emb = False + try: + import types as _types + core = getattr(self.model, 'model', None) + vm = getattr(core, 'vision_model', None) if core is not None else None + emb = getattr(vm, 'embeddings', None) if vm is not None else None + if emb is not None and hasattr(emb.__class__, 'forward'): + _orig_forward_unbound = emb.__class__.forward + + def _patched_vis_emb_forward(self_module, pixel_values=None, patch_attention_mask=None, *args, **kwargs): + """Wrapper around HF vision embeddings to avoid CPU/GPU mask-indexing crashes. + Strategy: + 1) Try the original forward as-is. + 2) If it raises a device-mismatch/TypeError related to mask iteration/indexing, + fallback to run embeddings on CPU (move module + inputs to CPU), then move + the result back to the original device. + """ + # Remember original device of inputs and module + dev = None + try: + if isinstance(pixel_values, torch.Tensor): + dev = pixel_values.device + except Exception: + pass + # Optional: force embeddings on CPU via env to avoid per-call exceptions + if os.environ.get('IDEFICS2_EMB_CPU'): + try: + self_module.to('cpu') + px_cpu = pixel_values.cpu() if isinstance(pixel_values, torch.Tensor) else pixel_values + pam_cpu = None + if patch_attention_mask is None: + pam_cpu = None + elif isinstance(patch_attention_mask, torch.Tensor): + pam_cpu = patch_attention_mask.cpu() + elif isinstance(patch_attention_mask, (list, tuple)): + tmp = [] + for itm in patch_attention_mask: + if isinstance(itm, torch.Tensor): + tmp.append(itm.cpu()) + else: + tmp.append(itm) + pam_cpu = type(patch_attention_mask)(tmp) + else: + pam_cpu = patch_attention_mask + out = _orig_forward_unbound(self_module, pixel_values=px_cpu, patch_attention_mask=pam_cpu, *args, **kwargs) + if dev is not None: + try: + self_module.to(dev) + except Exception: + pass + if isinstance(out, torch.Tensor): + try: + out = out.to(dev) + except Exception: + pass + return out + except Exception: + # Fall through to normal path + pass + try: + return _orig_forward_unbound(self_module, pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, *args, **kwargs) + except Exception as e: + msg = str(e) + # Only trigger fallback for the known classes of failures + if ('Expected all tensors to be on the same device' in msg) or ('NoneType' in msg and 'iterable' in msg): + try: + # Move module to CPU + self_module.to('cpu') + # Prepare CPU inputs + px_cpu = pixel_values.cpu() if isinstance(pixel_values, torch.Tensor) else pixel_values + pam_cpu = None + if patch_attention_mask is None: + pam_cpu = None + elif isinstance(patch_attention_mask, torch.Tensor): + pam_cpu = patch_attention_mask.cpu() + elif isinstance(patch_attention_mask, (list, tuple)): + tmp = [] + for itm in patch_attention_mask: + if isinstance(itm, torch.Tensor): + tmp.append(itm.cpu()) + else: + tmp.append(itm) + pam_cpu = type(patch_attention_mask)(tmp) + else: + pam_cpu = patch_attention_mask + # Run on CPU + out = _orig_forward_unbound(self_module, pixel_values=px_cpu, patch_attention_mask=pam_cpu, *args, **kwargs) + # Move module back to original device if known + try: + if dev is not None: + self_module.to(dev) + except Exception: + pass + # Move output back to original device + try: + if isinstance(out, torch.Tensor) and dev is not None: + out = out.to(dev) + except Exception: + pass + return out + except Exception: + # Restore module device on best-effort + try: + if dev is not None: + self_module.to(dev) + except Exception: + pass + raise + raise + + emb.forward = _types.MethodType(_patched_vis_emb_forward, emb) + _patched_vis_emb = True + except Exception: + # Best-effort patch; continue if anything fails + pass + if os.environ.get('DEBUG_IDEFICS2'): + try: + td, vd = None, None + core = getattr(self.model, 'model', None) + if core is not None: + tm = getattr(core, 'text_model', None) + if tm is not None: + emb = tm.get_input_embeddings() + if hasattr(emb, 'weight'): + td = str(emb.weight.device) + vm = getattr(core, 'vision_model', None) + if vm is not None: + for p in vm.parameters(): + vd = str(p.device); break + print(f'[Idefics2Runner] patched_vis_emb={_patched_vis_emb} text_device={td} vision_device={vd} use_device_map_auto={self.use_device_map_auto}') + except Exception: + pass + + # Work around a known HF Idefics2 issue where torch.bucketize receives tensors on + # different devices (e.g., fractional_coords on CUDA and boundaries on CPU). We + # patch torch.bucketize to move 'boundaries' to the same device as the first arg + # when both are tensors and devices differ. + try: + _orig_bucketize = torch.bucketize + def _safe_bucketize(x, boundaries, *a, **kw): + try: + if isinstance(x, torch.Tensor) and isinstance(boundaries, torch.Tensor): + if x.device != boundaries.device: + boundaries = boundaries.to(x.device) + except Exception: + pass + return _orig_bucketize(x, boundaries, *a, **kw) + torch.bucketize = _safe_bucketize # type: ignore[attr-defined] + except Exception: + pass + + # If not using auto device map, move the whole model + if not use_device_map_auto: + dev = device or 'cuda' + try: + self.model.to(dev) + except Exception: + if dev.startswith('cuda') and torch.cuda.is_available(): + self.model.to('cuda:0') + + # Note: do not change global default device to avoid side-effects in other modules + + # Optional compile + if compile and hasattr(torch, 'compile'): + try: + self.model = torch.compile(self.model) + except Exception: + pass + + # Infer per-modality devices when using device_map='auto' + def _infer_modal_devices(): + text_device = None + vision_device = None + try: + core = getattr(self.model, 'model', None) + if core is not None: + tm = getattr(core, 'text_model', None) + if tm is not None: + emb = tm.get_input_embeddings() + if hasattr(emb, 'weight'): + text_device = emb.weight.device + vm = getattr(core, 'vision_model', None) + if vm is not None: + for p in vm.parameters(): + vision_device = p.device + break + except Exception: + pass + return text_device, vision_device + + # Attach a Qwen-like chat() to the model so upstream code can call model.chat(...) + def _chat(_tok_shim, query, history=None, temperature: float = 0.0, top_p: float = 1.0, max_new_tokens: int = 32) -> Tuple[str, list]: + """Qwen-style chat shim that converts list-format segments into + Idefics2 multi-turn messages. + + Critical behavior: demonstrations are represented as proper + user/assistant turns so that few-shot signals are preserved and + modal-order effects remain visible. The final query is appended as + the last user turn and `add_generation_prompt=True` is used to have + the model generate the answer. + """ + import re as _re + + # Helper: add text to a content list, skipping empty strings + def _append_text(lst: List[dict], s: str): + s = str(s) + if s.strip(): + lst.append({'type': 'text', 'text': s}) + + # Parse the Qwen-style list-format into Idefics2 messages + images + images: List[Image.Image] = [] + + # We accumulate blocks separated by [REQUEST] markers. Each block + # becomes one user turn (with any number/order of images and text), + # followed by an assistant turn if we observed a [RESPONSE] with + # non-empty answer. + blocks: List[Tuple[List[dict], str]] = [] # (user_content, assistant_text) + cur_user: List[dict] = [] + cur_ans: List[str] = [] + seen_request = False + mode = 'pre' # 'pre'|'request'|'response' + + def _flush_block(final: bool = False): + nonlocal cur_user, cur_ans + # Nothing to flush + if not cur_user and not any(t.strip() for t in cur_ans): + return + ans_text = ''.join(cur_ans).strip() + blocks.append((cur_user if cur_user else [{'type': 'text', 'text': ''}], ans_text)) + # Start next block unless we're finalizing + if not final: + cur_user = [] + cur_ans = [] + + for item in (query or []): + if 'image' in item and item['image']: + # Image always belongs to the current block's user content + try: + img = Image.open(item['image']).convert('RGB') + images.append(img) + except Exception: + # Keep placeholder even if image is broken to preserve ordering + pass + cur_user.append({'type': 'image'}) + continue + if 'text' not in item: + continue + s = str(item['text']) + # Split by markers and route pieces into user (request) or assistant (response) + parts = [p for p in _re.split(r'(\[REQUEST\]|\[RESPONSE\])', s) if p is not None and p != ''] + for p in parts: + if p == '[REQUEST]': + # Starting a new block; if this is the first [REQUEST] and we already + # accumulated preamble content (instruction text/images), flush it as a + # standalone user turn so it won't be merged into the first demo/query. + if seen_request: + _flush_block(final=False) + else: + if cur_user: # preamble existed + _flush_block(final=False) + seen_request = True + mode = 'request' + continue + if p == '[RESPONSE]': + mode = 'response' + continue + # Plain text + if not seen_request: + # Preamble before the first [REQUEST] goes to the first user turn + _append_text(cur_user, p) + elif mode == 'response': + cur_ans.append(p) + else: # 'request' or unexpected + _append_text(cur_user, p) + + # Flush the last block (final query or demo) + _flush_block(final=True) + + # Convert blocks into messages: demos get user+assistant; final block (last) + # is the query with only a user turn. + messages: List[dict] = [] + for i, (ucont, ans) in enumerate(blocks): + is_last = (i == len(blocks) - 1) + messages.append({'role': 'user', 'content': ucont}) + if not is_last and ans: + messages.append({'role': 'assistant', 'content': [{'type': 'text', 'text': ans}]}) + + if not messages: + # Fallback to a single empty user turn to avoid crashing HF template + messages = [{'role': 'user', 'content': [{'type': 'text', 'text': ''}]}] + + prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = self.processor(text=prompt, images=images or None, return_tensors='pt') + # Keep attention masks; patched vision embeddings will handle device pitfalls or fallback to CPU + # Place inputs carefully to avoid CPU/GPU mismatch + if self.use_device_map_auto: + tdev, vdev = _infer_modal_devices() + placed = {} + for k, v in inputs.items(): + if not hasattr(v, 'to'): + placed[k] = v; continue + try: + if k in ('input_ids', 'attention_mask', 'position_ids') and tdev is not None: + placed[k] = v.to(tdev) + elif k.startswith('pixel_') and vdev is not None: + placed[k] = v.to(vdev) + else: + placed[k] = v + except Exception: + placed[k] = v + inputs = placed + else: + inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + + gen_kwargs = dict( + do_sample=(temperature is not None and float(temperature) > 0.0), + temperature=(float(temperature) if temperature else 1.0), + top_p=float(top_p) if top_p is not None else 1.0, + max_new_tokens=int(max_new_tokens) if max_new_tokens is not None else 32, + pad_token_id=getattr(getattr(self.processor, 'tokenizer', None), 'eos_token_id', None), + eos_token_id=getattr(getattr(self.processor, 'tokenizer', None), 'eos_token_id', None), + ) + try: + out = self.model.generate(**inputs, **{k: v for k, v in gen_kwargs.items() if v is not None}) + except RuntimeError as re: + msg = str(re) + # Retry once by forcing text inputs onto the text embedding device if detected + if 'index is on cpu' in msg or 'Expected all tensors to be on the same device' in msg: + if self.use_device_map_auto: + tdev, vdev = _infer_modal_devices() + rescue = {} + for k, v in inputs.items(): + if not hasattr(v, 'to'): + rescue[k] = v; continue + try: + if k in ('input_ids', 'attention_mask', 'position_ids') and tdev is not None: + rescue[k] = v.to(tdev) + elif k.startswith('pixel_') and vdev is not None: + rescue[k] = v.to(vdev) + else: + rescue[k] = v + except Exception: + rescue[k] = v + out = self.model.generate(**rescue, **{k: v for k, v in gen_kwargs.items() if v is not None}) + else: + raise + else: + raise + # Decode only newly generated tokens + try: + start = inputs['input_ids'].shape[1] + new_tokens = out[:, start:] + text = self.processor.batch_decode(new_tokens, skip_special_tokens=True)[0].strip() + except Exception: + text = self.processor.batch_decode(out, skip_special_tokens=True)[0].strip() + return text, (history if isinstance(history, list) else []) + + # Bind the shim chat onto the model instance + setattr(self.model, 'chat', _chat) + + def generate( + self, + image_path: str, + prompt: str, + temperature: float = 0.0, + top_p: float = 1.0, + max_new_tokens: int = 32, + ) -> str: + # Build a single-turn, single-image query and call the shim chat() + query = [ + {'image': image_path}, + {'text': prompt}, + ] + response, _ = self.model.chat(self.tokenizer, query=query, history=None, + temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens) + return response diff --git a/ICL/LV/code/adapters/_runners/qwen3_vl_infer.py b/ICL/LV/code/adapters/_runners/qwen3_vl_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..f3e2530f0d8b3192d8a0c8bf43b81651742ba06e --- /dev/null +++ b/ICL/LV/code/adapters/_runners/qwen3_vl_infer.py @@ -0,0 +1,274 @@ +""" +Qwen3-VL inference wrapper used by the evaluation scripts. + +This adapts the original Qwen-VL runner to the official Qwen3-VL +Transformers interface (AutoModelForImageTextToText + AutoProcessor). + +Usage: + runner = Qwen3VLRunner('/workspace/Qwen3-VL-8B-Instruct/') + text = runner.generate('/path/to/image.jpg', 'Describe this image.') + +Notes: + - Requires transformers >= 4.57.0 + - Optional speedups: flash-attn 2 (set attn_implementation='flash_attention_2') + - For video inputs you may additionally need qwen-vl-utils and decord +""" + +from __future__ import annotations + +import os +from typing import Optional, Dict, Any, List + +import torch +from transformers import AutoModelForImageTextToText, AutoProcessor + + +class Qwen3VLRunner: + def __init__( + self, + model_path: str, + device: Optional[str] = None, + dtype: Optional[str] = None, + attn_implementation: Optional[str] = None, + device_map: Optional[str] = None, + ) -> None: + """Initialize Qwen3-VL model and processor. + + - model_path: local HF-style directory of Qwen3-VL checkpoint. + - device: 'cuda', 'cuda:0', 'cpu' or None for automatic. + - dtype: one of {'auto','bf16','fp16','fp32'}; None -> 'auto'. + - attn_implementation: e.g., 'flash_attention_2' if flash-attn is installed. + - device_map: None or 'auto'. If None, we place the whole model to the explicit device. + """ + self.model_path = model_path + + # Resolve dtype + torch_dtype = None + if dtype: + d = dtype.lower() + if d in ("bf16", "bfloat16"): + torch_dtype = torch.bfloat16 + elif d in ("fp16", "float16"): + torch_dtype = torch.float16 + elif d in ("fp32", "float32"): + torch_dtype = torch.float32 + + # Default device: prefer CUDA if available + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + kwargs: Dict[str, Any] = dict( + dtype="auto" if torch_dtype is None else torch_dtype, + ) + if attn_implementation: + kwargs["attn_implementation"] = attn_implementation + if device.startswith("cuda") and device_map == "auto": + kwargs["device_map"] = "auto" + + # Load model/processor with the official Qwen3-VL API + self.model = AutoModelForImageTextToText.from_pretrained( + model_path, + **kwargs, + ) + self.processor = AutoProcessor.from_pretrained(model_path) + + # Place the whole model on a specific device when not using accelerate's device_map + if kwargs.get("device_map") is None: + try: + self.model.to(device) + except Exception: + # Last resort: to the first CUDA device + if device.startswith("cuda") and torch.cuda.is_available(): + self.model.to("cuda:0") + + def generate( + self, + image_path: str, + prompt: str, + temperature: float = 0.0, + top_p: float = 1.0, + max_new_tokens: int = 32, + ) -> str: + """Single-image inference compatible with existing evaluators.""" + # Build Qwen3-VL chat messages: [{'type':'image',...}, {'type':'text',...}] + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": prompt}, + ], + } + ] + + # Build chat text via template and pack images explicitly to ensure vision inputs are used + prompt = self.processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + images = [] + try: + # Extract images from message content + for m in messages: + for c in (m.get('content') or []): + if isinstance(c, dict) and c.get('type') == 'image' and c.get('image'): + images.append(c['image']) + except Exception: + pass + inputs = self.processor(text=prompt, images=(images or None), return_tensors="pt") + inputs = inputs.to(self.model.device) + + do_sample = temperature > 0.0 + gen_ids = self.model.generate( + **inputs, + do_sample=do_sample, + temperature=(temperature if do_sample else None), + top_p=top_p, + max_new_tokens=max_new_tokens, + ) + # Trim the input part from outputs to keep only generated text + trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, gen_ids)] + text: List[str] = self.processor.batch_decode( + trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + return text[0] if text else "" + + def generate_from_segments( + self, + segments: List[Dict[str, str]], + *, + temperature: float = 0.0, + top_p: float = 1.0, + max_new_tokens: int = 32, + ) -> str: + """Map Qwen-style segments to real chat turns with ICL enabled. + + Behavior: + - Instruction text (without [REQUEST]/[RESPONSE]) -> first system message. + - Each demo -> user(content preserves intra-order of images/text, using only the [REQUEST] text) then assistant([RESPONSE]). + - Final query -> user only (images + [REQUEST] text), no assistant turn. + """ + # Helper to parse request/response from a text segment + def _parse_rr(txt: str): + s = (txt or '').strip() + has_req = '[REQUEST]' in s + has_resp = '[RESPONSE]' in s + req_txt = None + resp_txt = None + if not has_req and not has_resp: + return {'has_req': False, 'has_resp': False, 'req': None, 'resp': None} + after_req = s.split('[REQUEST]', 1)[1] if has_req else s + if has_resp: + parts = after_req.split('[RESPONSE]', 1) + if has_req: + req_txt = parts[0].strip() + resp_txt = (parts[1] if len(parts) > 1 else '').strip() + else: + if has_req: + req_txt = after_req.strip() + return {'has_req': has_req, 'has_resp': has_resp, 'req': req_txt, 'resp': resp_txt} + + instruction = '' + cur_items: List[Dict[str, str]] = [] # user-side content in exact seg order + cur_req: Optional[str] = None + cur_resp: Optional[str] = None + blocks: List[Dict[str, Any]] = [] # {'user_items': [...], 'resp': Optional[str]} + + def _push_block(force: bool = False): + nonlocal blocks, cur_items, cur_req, cur_resp + if cur_req is None: + return + blocks.append({'user_items': list(cur_items), 'resp': cur_resp}) + cur_items.clear(); cur_req = None; cur_resp = None + + for seg in segments: + if not isinstance(seg, dict): + continue + if 'image' in seg and isinstance(seg['image'], str) and seg['image']: + cur_items.append({"type": "image", "image": seg['image']}) + continue + if 'text' in seg and isinstance(seg['text'], str): + parsed = _parse_rr(seg['text']) + has_req, has_resp = bool(parsed['has_req']), bool(parsed['has_resp']) + req_txt, resp_txt = parsed['req'], parsed['resp'] + + if not has_req and not has_resp: + t = seg['text'].strip() + if t: + instruction = (instruction + '\n\n' + t).strip() if instruction else t + continue + + # New request begins while an unfinished block exists -> push previous block + if has_req and cur_req is not None: + _push_block() + + if has_req: + cur_req = (req_txt or '').strip() + # Insert request text at the current position to preserve intra-order + cur_items.append({"type": "text", "text": cur_req}) + if has_resp: + cur_resp = (resp_txt or '').strip() + continue + + _push_block(force=True) + + # Build messages: system -> demos (u→a) -> final user + messages: List[Dict[str, Any]] = [] + if instruction.strip(): + messages.append({ + 'role': 'system', + 'content': [{'type': 'text', 'text': instruction.strip()}], + }) + for i, blk in enumerate(blocks): + u_content = blk.get('user_items') or [] + resp = blk.get('resp') + messages.append({'role': 'user', 'content': u_content}) + # Only demos (not the final query) include an assistant turn + if i != len(blocks) - 1: + resp_text = (resp.strip() if isinstance(resp, str) else '') + if not resp_text: + resp_text = ' ' + messages.append({'role': 'assistant', 'content': [{'type': 'text', 'text': resp_text}]}) + + if not messages: + messages = [{ + 'role': 'user', + 'content': [{'type': 'text', 'text': instruction or 'Please answer the question.'}], + }] + + # Tokenize with chat template and pack images explicitly + prompt = self.processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + # Collect images in exact order + images = [] + try: + for m in messages: + for c in (m.get('content') or []): + if isinstance(c, dict) and c.get('type') == 'image' and c.get('image'): + images.append(c['image']) + except Exception: + pass + inputs = self.processor(text=prompt, images=(images or None), return_tensors="pt") + inputs = inputs.to(self.model.device) + do_sample = temperature > 0.0 + gen_ids = self.model.generate( + **inputs, + do_sample=do_sample, + temperature=(temperature if do_sample else None), + top_p=top_p, + max_new_tokens=max_new_tokens, + ) + trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, gen_ids)] + text: List[str] = self.processor.batch_decode( + trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + return text[0] if text else "" diff --git a/ICL/LV/code/adapters/_runners/qwen_vl_infer.py b/ICL/LV/code/adapters/_runners/qwen_vl_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..bacb51cb105e88dfd8bb1b1878a2ccda6674c88f --- /dev/null +++ b/ICL/LV/code/adapters/_runners/qwen_vl_infer.py @@ -0,0 +1,129 @@ +import os +import re +from typing import Dict, List, Tuple, Optional + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig + +# Minimal wrapper for Qwen-VL(-Chat) inference. +class QwenVLRunner: + def __init__(self, + model_path: str, + device: str = 'cuda', + dtype: Optional[str] = None, + compile: bool = False, + force_single_device: bool = True, + ) -> None: + self.model_path = model_path + # Auto mixed precision selection + torch_dtype = None + if dtype: + if dtype.lower() in ('bf16', 'bfloat16'): + torch_dtype = torch.bfloat16 + elif dtype.lower() in ('fp16', 'float16'): + torch_dtype = torch.float16 + elif dtype.lower() in ('fp32', 'float32'): + torch_dtype = torch.float32 + # Load tokenizer/model with trust_remote_code per Qwen-VL official usage + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + # Allow env override to keep the old behavior if desired + env_force = os.environ.get('QWENVL_FORCE_SINGLE_DEVICE') + if env_force is not None: + try: + force_single_device = str(env_force).strip().lower() not in ('0','false','off','no') + except Exception: + pass + use_device_map_auto = (device.startswith('cuda') and not force_single_device) + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + device_map=('auto' if use_device_map_auto else None), + torch_dtype=torch_dtype, + trust_remote_code=True, + ).eval() + # If not using accelerate's auto device map, place the whole model on the target device + if not use_device_map_auto: + # Accept forms: 'cuda', 'cuda:0' + dev = device if device else 'cuda' + try: + self.model.to(dev) + except Exception: + # Last resort: to the first CUDA device + if dev.startswith('cuda') and torch.cuda.is_available(): + self.model.to('cuda:0') + try: + self.model.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True) + except Exception: + pass + if compile and hasattr(torch, 'compile'): + try: + self.model = torch.compile(self.model) + except Exception: + pass + # Optional debug of placement + if os.environ.get('DEBUG_DEVICES'): + try: + devs = {n: (p.device if hasattr(p, 'device') else None) for n, p in self.model.named_parameters()} + first = next(iter(devs.values())) + print(f'[QwenVLRunner] model device example: {first}; force_single_device={force_single_device}') + except Exception: + pass + + @staticmethod + def _strip_grounding_markup(text: str) -> str: + # Remove ...... patterns used by Qwen-VL for grounding + return re.sub(r'(.*?)(?:.*?)*(?:.*?)*', r'\1', text).strip() + + def generate(self, + image_path: str, + prompt: str, + temperature: float = 0.0, + top_p: float = 1.0, + max_new_tokens: int = 32, + ) -> str: + # Qwen-VL expects a list-format input: [{'image': path}, {'text': prompt}] + query = self.tokenizer.from_list_format([ + {'image': image_path}, + {'text': prompt}, + ]) + try: + response, _ = self.model.chat( + self.tokenizer, + query=query, + history=None, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + ) + except Exception: + # Retry with empty history (some Qwen-VL builds expect a list instead of None) + try: + response, _ = self.model.chat( + self.tokenizer, + query=query, + history=[], + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + ) + except Exception: + # Last-resort fallback to plain generate() API (may drop images) + inputs = self.tokenizer(query, return_tensors='pt') + if hasattr(inputs, 'to'): + inputs = inputs.to(self.model.device) + gen_kwargs = dict( + do_sample=(temperature > 0.0), + temperature=temperature or 1.0, + top_p=top_p, + max_new_tokens=max_new_tokens, + ) + # Ensure pad_token_id is set to avoid attention_mask inference issues + try: + if getattr(self.model.generation_config, 'pad_token_id', None) is None: + self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id + except Exception: + pass + out_ids = self.model.generate(**inputs, **gen_kwargs) + response = self.tokenizer.decode(out_ids[0], skip_special_tokens=True) + # Heuristic: keep the tail after the prompt + response = response.split(prompt)[-1].strip() + return self._strip_grounding_markup(response) diff --git a/ICL/LV/code/adapters/idefics2_adapter.py b/ICL/LV/code/adapters/idefics2_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..232e49801f674106d71d41e1074cbab1a64a082c --- /dev/null +++ b/ICL/LV/code/adapters/idefics2_adapter.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import List, Dict, Tuple + +# Reuse the runner we already vetted; import locally so this adapter is self-contained +from adapters._runners.idefics2_infer import Idefics2Runner + + +class Adapter: + def __init__(self, model_path: str): + self.runner = Idefics2Runner(model_path) + + def generate_from_segments(self, segs: List[Dict[str, str]], *, + temperature: float, top_p: float, max_new_tokens: int) -> str: + q = self.runner.tokenizer.from_list_format(segs) + try: + resp, _ = self.runner.model.chat( + self.runner.tokenizer, query=q, history=None, + temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, + ) + return resp + except Exception: + # Do NOT silently degrade, as it destroys modal-order effects + # Fail fast so the issue can be fixed instead of producing biased scores + raise + + +def create(model_path: str) -> Adapter: + return Adapter(model_path) diff --git a/ICL/LV/code/adapters/qwen_vl_adapter.py b/ICL/LV/code/adapters/qwen_vl_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..79d39b17b376eb915f606cc2b9e408557a2c89c0 --- /dev/null +++ b/ICL/LV/code/adapters/qwen_vl_adapter.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import List, Dict + +# Thin wrapper that relies on the QWENVL repo's runner if available +try: + from adapters._runners.qwen_vl_infer import QwenVLRunner +except Exception: # allow running outside that repo; user can add it to PYTHONPATH + QwenVLRunner = None # type: ignore + + +class Adapter: + def __init__(self, model_path: str): + if QwenVLRunner is None: + raise RuntimeError('QwenVLRunner unavailable. Ensure QWENVL-code is on PYTHONPATH or install its runner.') + self.runner = QwenVLRunner(model_path) + + def generate_from_segments(self, segs: List[Dict[str, str]], *, + temperature: float, top_p: float, max_new_tokens: int) -> str: + # Optional debug: dump segments for ICL analysis + import os, json + dump_dir = os.environ.get('DUMP_SEGS_DIR', '').strip() + dump_ok = bool(dump_dir) + dump_path = None + if dump_ok: + try: + from pathlib import Path + p = Path(dump_dir); p.mkdir(parents=True, exist_ok=True) + # Try to find the last image path as query image for naming + imgs = [s.get('image') for s in segs if isinstance(s, dict) and 'image' in s] + base = Path(imgs[-1]).name if imgs else 'segs' + dump_path = p / f'{base}.segs.json' + dump_path.write_text(json.dumps({'segs': segs}, ensure_ascii=False, indent=2), encoding='utf-8') + except Exception: + dump_ok = False + q = self.runner.tokenizer.from_list_format(segs) + try: + resp, _ = self.runner.model.chat( + self.runner.tokenizer, query=q, history=None, + temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, + ) + return resp + except Exception: + # Do NOT silently degrade (would drop demo images and mask order effects) + raise + + +def create(model_path: str) -> Adapter: + return Adapter(model_path) diff --git a/ICL/LV/code/attn map/__pycache__/select_samples_for_attn_map.cpython-313.pyc b/ICL/LV/code/attn map/__pycache__/select_samples_for_attn_map.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47a1563c8d0cc8e745f3ce83ca2313fd4d8f9990 Binary files /dev/null and b/ICL/LV/code/attn map/__pycache__/select_samples_for_attn_map.cpython-313.pyc differ diff --git a/ICL/LV/code/attn map/attn map/attn map/__pycache__/run_attn_map_shot_sweep_qwen3vl.cpython-313.pyc b/ICL/LV/code/attn map/attn map/attn map/__pycache__/run_attn_map_shot_sweep_qwen3vl.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c24d2bb08a8eaecfbf10c24e3dbdf8f90561f41a Binary files /dev/null and b/ICL/LV/code/attn map/attn map/attn map/__pycache__/run_attn_map_shot_sweep_qwen3vl.cpython-313.pyc differ diff --git a/ICL/LV/code/attn map/attn map/attn map/avg_output_token_attention.py b/ICL/LV/code/attn map/attn map/attn map/avg_output_token_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd9e5ff0c18fa1193f37b4653d03245c195dab2 --- /dev/null +++ b/ICL/LV/code/attn map/attn map/attn map/avg_output_token_attention.py @@ -0,0 +1,326 @@ +""" +Generate a single heatmap that averages attention from all output tokens to visual tokens. + +This is adapted from the Qwen2.5-VL attention notebook: it still averages heads/layers, +but instead of visualizing per-token, it collapses all generated tokens (until EOS) into +one map. +""" +from __future__ import annotations + +import argparse +import base64 +import os +from io import BytesIO + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import requests +import seaborn as sns +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer + +try: + from transformers.models.qwen2_vl.image_processing_qwen2_vl_fast import smart_resize +except Exception: + smart_resize = None + + +def convert_pil_image_to_base64(image: Image.Image) -> str: + buffer = BytesIO() + image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode() + + +def get_qwen2_5vl_prompt_msg(image: Image.Image, instruction: str, screen_width: int, screen_height: int): + return [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "min_pixels": 3136, + "max_pixels": 12845056, + "image_url": { + "url": "data:image/png;base64," + convert_pil_image_to_base64(image) + }, + }, + {"type": "text", "text": instruction}, + ], + } + ] + + +def load_image(image_path_or_url: str) -> Image.Image: + if image_path_or_url.startswith("http://") or image_path_or_url.startswith("https://"): + response = requests.get(image_path_or_url) + response.raise_for_status() + return Image.open(BytesIO(response.content)).convert("RGB") + return Image.open(image_path_or_url).convert("RGB") + + +def aggregate_llm_attention(attn: list[torch.Tensor]) -> torch.Tensor: + """Average attention vector across layers/heads for one generated token.""" + averaged = [] + for layer in attn: + layer_attns = layer.squeeze(0) + attns_per_head = layer_attns.mean(dim=0) + vec = torch.concat( + ( + torch.tensor([0.0]), + attns_per_head[-1][1:].cpu(), + torch.tensor([0.0]), + ) + ) + averaged.append(vec / vec.sum()) + return torch.stack(averaged).mean(dim=0) + + +def heterogenous_stack(vecs: list[torch.Tensor]) -> torch.Tensor: + """Pad vectors with zeros then stack them to the same length.""" + max_length = max(v.shape[0] for v in vecs) + return torch.stack([torch.concat((v, torch.zeros(max_length - v.shape[0]))) for v in vecs]) + + +def compute_token_grid(num_tokens: int, width: int, height: int) -> tuple[int, int]: + """ + Infer a (rows, cols) grid close to image aspect ratio that matches the token count, + padding if needed. + """ + if num_tokens <= 0: + raise ValueError("num_tokens must be positive") + ratio = width / height if height else 1.0 + cols = max(1, round((num_tokens * ratio) ** 0.5)) + rows = max(1, int(np.ceil(num_tokens / cols))) + if rows * cols < num_tokens: + cols = int(np.ceil(num_tokens / rows)) + return rows, cols + + +def show_mask_on_image(img: np.ndarray, mask: np.ndarray, alpha: float = 0.3) -> tuple[np.ndarray, np.ndarray]: + # Normalize image to 0-1 + img = np.float32(img) / 255.0 + blue_to_red_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) + blue_to_red_map[..., 0] = (255 * mask).astype(np.uint8) + blue_to_red_map[..., 2] = (255 * (1 - mask)).astype(np.uint8) + heatmap = np.float32(blue_to_red_map) / 255.0 + overlay = cv2.addWeighted(img, 1 - alpha, heatmap, alpha, 0) + return np.uint8(255 * overlay), np.uint8(255 * heatmap) + + +def build_llm_attention_matrix(outputs: dict) -> torch.Tensor: + """Rebuild the LLM attention matrix, zeroing BOS attention like the notebook.""" + attentions = outputs.get("attentions") + if attentions is None: + raise ValueError("No attentions returned. Ensure model.set_attn_implementation('eager') and output_attentions=True.") + + aggregated_prompt_attention = [] + for layer in attentions[0]: # step 0 (prompt) + layer_attns = layer.squeeze(0) # (num_heads, seq_len_step, seq_len_prev) + attns_per_head = layer_attns.mean(dim=0) + cur = attns_per_head.cpu().clone() + cur[1:, 0] = 0.0 # zero except for the first generated token + cur[1:] = cur[1:] / cur[1:].sum(-1, keepdim=True) + aggregated_prompt_attention.append(cur) + aggregated_prompt_attention = torch.stack(aggregated_prompt_attention).mean(dim=0) + + return heterogenous_stack( + [torch.tensor([1.0])] + + list(aggregated_prompt_attention) + + list(map(aggregate_llm_attention, outputs["attentions"])) + ) + + +def average_output_to_image_attention( + llm_attn_matrix: torch.Tensor, + vision_token_start: int, + vision_token_end: int, + output_token_start: int, + output_token_end: int, +) -> torch.Tensor: + """Average attention from all output tokens onto visual tokens.""" + attn_rows = [] + for row in llm_attn_matrix[output_token_start:output_token_end]: + vec = row[vision_token_start:vision_token_end] + denom = vec.sum().clamp(min=1e-8) + attn_rows.append(vec / denom) + if not attn_rows: + raise ValueError("No output tokens found to average.") + return torch.stack(attn_rows).mean(dim=0) + + +def render_average_heatmap( + image: Image.Image, + averaged_attn: torch.Tensor, + token_img_shape: tuple[int, int], + save_path: str, + overlay_with_image: bool = True, + heatmap_mode: str = "cv2", +) -> str: + """Resize attention to the image size and save the heatmap.""" + rows, cols = token_img_shape + total = rows * cols + if averaged_attn.numel() < total: + pad = torch.zeros(total - averaged_attn.numel()) + averaged_attn = torch.cat([averaged_attn, pad]) + attn_weights = averaged_attn[:total].reshape(1, 1, rows, cols) + attn_weights = attn_weights / attn_weights.max().clamp(min=1e-8) + attn_over_image = ( + F.interpolate( + attn_weights, + size=image.size[::-1], # (H, W) + mode="nearest", + ) + .squeeze() + .cpu() + .numpy() + ) + + if heatmap_mode == "sns": + plt.figure(figsize=(10, 10)) + ax = plt.gca() + img_np = np.array(image) + img_np = np.clip(img_np.astype(np.float32) * 0.9, 0, 255).astype(np.uint8) # lightly darken base image + ax.imshow(img_np, alpha=0.9) + cmap = sns.color_palette("RdYlBu_r", n_colors=30, desat=0.9)[4:30] + sns.heatmap( + attn_over_image, + cmap=cmap, + alpha=0.7, + vmin=0, + vmax=1, + cbar=False, + ax=ax, + xticklabels=False, + yticklabels=False, + ) + plt.axis("off") + plt.savefig(save_path, bbox_inches="tight", pad_inches=0) + plt.close() + return save_path + + np_img = np.array(image)[:, :, ::-1] # RGB -> BGR for OpenCV + img_with_attn, heatmap = show_mask_on_image(np_img, attn_over_image) + final = img_with_attn if overlay_with_image else heatmap + cv2.imwrite(save_path, final) + return save_path + + +def main(): + parser = argparse.ArgumentParser(description="Average output-token attention over visual tokens.") + parser.add_argument( + "--model", + default="/workspace/Qwen3-VL-8B-Instruct", + help="HF model id or local path.", + ) + parser.add_argument( + "--image", + default="/z_data/syxin/code/runs/shot_sweep_allmetrics_qwen3-vl/shot0/captioning_bertscore/_image_cache/captioning_coco/COCO_val2014_000000000661.jpg.jpg", + help="Image path or URL.", + ) + parser.add_argument( + "--prompt", + default="解释一下这张图", + help="Prompt text.", + ) + parser.add_argument("--save", default="avg_output_token_attention.png", help="Where to save the heatmap.") + parser.add_argument("--no-overlay", action="store_true", help="Save only the heatmap, without the base image.") + parser.add_argument( + "--heatmap-mode", + choices=["cv2", "sns"], + default="sns", + help="Renderer for heatmap overlay.", + ) + args = parser.parse_args() + + device = "cuda" if torch.cuda.is_available() else "cpu" + torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32 + + processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + model = ( + AutoModelForVision2Seq.from_pretrained( + args.model, + device_map="cuda" if device == "cuda" else None, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + .eval() + ) + if device != "cuda": + model = model.to(device) + if hasattr(model, "set_attn_implementation"): + try: + model.set_attn_implementation("eager") + except Exception as exc: # pragma: no cover - defensive + print(f"Warning: failed to set attn implementation to eager: {exc}") + elif hasattr(model.config, "attn_implementation"): + model.config.attn_implementation = "eager" + + image = load_image(args.image) + patch_size = getattr(processor.image_processor, "patch_size", 28) + merge_size = getattr(processor.image_processor, "merge_size", 1) + patch_factor = patch_size[0] * merge_size if isinstance(patch_size, (list, tuple)) else patch_size * merge_size + + min_pixels = getattr(processor.image_processor, "min_pixels", None) + max_pixels = getattr(processor.image_processor, "max_pixels", None) + + if smart_resize is not None and max_pixels is not None: + resized_h, resized_w = smart_resize( + image.height, + image.width, + factor=patch_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + else: + resized_h, resized_w = image.height, image.width + messages = get_qwen2_5vl_prompt_msg(image, args.prompt, resized_w, resized_h) + text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=[text_input], images=[image], padding=True, return_tensors="pt").to(device) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=512, + return_dict_in_generate=True, + output_attentions=True, + ) + decoded_text = processor.batch_decode( + outputs["sequences"], skip_special_tokens=False, clean_up_tokenization_spaces=False + )[0] + print(f"Generated text:\n{decoded_text}\n") + + llm_attn_matrix = build_llm_attention_matrix(outputs) + len_prompt_tokens = len(inputs.input_ids[0]) + len_all_tokens = len(outputs["sequences"][0]) + output_token_start = len_prompt_tokens + output_token_end = len_all_tokens + vision_token_start = len( + tokenizer(decoded_text.split("<|vision_start|>")[0], return_tensors="pt")["input_ids"][0] + ) + 1 + vision_token_end = len(tokenizer(decoded_text.split("<|vision_end|>")[0], return_tensors="pt")["input_ids"][0]) + + averaged_attn = average_output_to_image_attention( + llm_attn_matrix, vision_token_start, vision_token_end, output_token_start, output_token_end + ) + num_vis_tokens = vision_token_end - vision_token_start + rows, cols = compute_token_grid(num_vis_tokens, resized_w, resized_h) + token_shape = (rows, cols) + + save_path = os.path.abspath(args.save) + final_path = render_average_heatmap( + image, + averaged_attn, + token_shape, + save_path, + overlay_with_image=not args.no_overlay, + heatmap_mode=args.heatmap_mode, + ) + print(f"Saved averaged attention heatmap to: {final_path}") + + +if __name__ == "__main__": + main() diff --git a/ICL/LV/code/attn map/attn map/attn map/run_attn_map_shot_sweep_qwen3vl.py b/ICL/LV/code/attn map/attn map/attn map/run_attn_map_shot_sweep_qwen3vl.py new file mode 100644 index 0000000000000000000000000000000000000000..31ff2eaebce64f5e6e3e624e97b0ec753d8c4f72 --- /dev/null +++ b/ICL/LV/code/attn map/attn map/attn map/run_attn_map_shot_sweep_qwen3vl.py @@ -0,0 +1,619 @@ +#!/usr/bin/env python3 +""" +Generate Qwen3-VL attention overlays for a tiny 0-7 shot sweep (default 10 samples). + +This mirrors the evaluation prompt format (orders, [REQUEST]/[RESPONSE]) and +stores outputs per-sample: + / + _image_cache/ # dataset image cache (scoped to this run) + _feat_cache/ # demo pool features (limited by --demo-pool-limit) + samples//// + image/ + attn_maps/shot{k}.png + shot{k}/prompt.txt, messages.json, output.txt, demos.json + +Defaults follow run_all_metrics_shot_sweep_0_7_qwen3vl.sh. +""" +from __future__ import annotations + +import argparse +import json +import os +import gc +import random +import re +import shutil +import sys +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import cv2 +import numpy as np +import torch +from PIL import Image +from transformers import AutoModel, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer + +def _maybe_add_repo_root() -> None: + """Search upwards for a 'core' folder and add its parent to sys.path.""" + here = Path(__file__).resolve() + for parent in [here.parent] + list(here.parents): + core_dir = parent / "core" + if core_dir.exists() and core_dir.is_dir(): + p = str(parent) + if p not in sys.path: + sys.path.insert(0, p) + return + + +_maybe_add_repo_root() + + +from core.datasets.m3it_reader import iter_m3it_samples, load_instructions +from core.eval._modal_order import build_image_text, build_text_image, build_text_image_text +from core.eval.order_eval_core import ( + _build_prompts, + _detect_tasks, + _encode_pairs, + _extract_inputs, + _extract_uid, + _image_path_sig, + _img_sig_from_rec, + _prefer_demo_splits, + load_pool_items, +) +from token_attention_utils import save_token_attention_artifacts + + +def sanitize_name(name: str) -> str: + return re.sub(r"[^A-Za-z0-9_.-]+", "_", name)[:128] or "sample" + + +def quota(total: int, n: int) -> List[int]: + base = total // max(1, n) + rem = total % max(1, n) + return [base + (1 if i < rem else 0) for i in range(max(1, n))] + + +def segments_to_messages(segments: Sequence[Dict[str, str]]) -> List[Dict[str, Any]]: + """Match Qwen3VLRunner.generate_from_segments message construction.""" + + def _parse_rr(txt: str): + s = (txt or "").strip() + has_req = "[REQUEST]" in s + has_resp = "[RESPONSE]" in s + req_txt = None + resp_txt = None + if not has_req and not has_resp: + return {"has_req": False, "has_resp": False, "req": None, "resp": None} + after_req = s.split("[REQUEST]", 1)[1] if has_req else s + if has_resp: + parts = after_req.split("[RESPONSE]", 1) + if has_req: + req_txt = parts[0].strip() + resp_txt = (parts[1] if len(parts) > 1 else "").strip() + else: + if has_req: + req_txt = after_req.strip() + return {"has_req": has_req, "has_resp": has_resp, "req": req_txt, "resp": resp_txt} + + instruction = "" + cur_items: List[Dict[str, str]] = [] + cur_req: Optional[str] = None + cur_resp: Optional[str] = None + blocks: List[Dict[str, Any]] = [] + + def _push_block(): + nonlocal blocks, cur_items, cur_req, cur_resp + if cur_req is None: + return + blocks.append({"user_items": list(cur_items), "resp": cur_resp}) + cur_items.clear() + cur_req = None + cur_resp = None + + for seg in segments: + if not isinstance(seg, dict): + continue + if "image" in seg and isinstance(seg["image"], str) and seg["image"]: + cur_items.append({"type": "image", "image": seg["image"]}) + continue + if "text" in seg and isinstance(seg["text"], str): + parsed = _parse_rr(seg["text"]) + has_req, has_resp = bool(parsed["has_req"]), bool(parsed["has_resp"]) + req_txt, resp_txt = parsed["req"], parsed["resp"] + + if not has_req and not has_resp: + t = seg["text"].strip() + if t: + instruction = (instruction + "\n\n" + t).strip() if instruction else t + continue + + if has_req and cur_req is not None: + _push_block() + + if has_req: + cur_req = (req_txt or "").strip() + cur_items.append({"type": "text", "text": cur_req}) + if has_resp: + cur_resp = (resp_txt or "").strip() + continue + + _push_block() + + messages: List[Dict[str, Any]] = [] + if instruction.strip(): + messages.append({"role": "system", "content": [{"type": "text", "text": instruction.strip()}]}) + for i, blk in enumerate(blocks): + u_content = blk.get("user_items") or [] + resp = blk.get("resp") + messages.append({"role": "user", "content": u_content}) + if i != len(blocks) - 1: + resp_text = (resp.strip() if isinstance(resp, str) else "") or " " + messages.append({"role": "assistant", "content": [{"type": "text", "text": resp_text}]}) + if not messages: + messages = [{"role": "user", "content": [{"type": "text", "text": "Please answer the question."}]}] + return messages + + +def collect_images_from_messages(messages: Sequence[Dict[str, Any]]) -> List[Any]: + images = [] + for m in messages: + for c in (m.get("content") or []): + if isinstance(c, dict) and c.get("type") == "image" and c.get("image") is not None: + images.append(c["image"]) + return images + + +def find_vision_token_ranges(tokenizer: AutoTokenizer, input_ids: torch.Tensor) -> List[Tuple[int, int]]: + ids = input_ids[0].tolist() + start_id = tokenizer.convert_tokens_to_ids("<|vision_start|>") + end_id = tokenizer.convert_tokens_to_ids("<|vision_end|>") + ranges: List[Tuple[int, int]] = [] + i = 0 + while i < len(ids): + if ids[i] == start_id: + try: + j = ids.index(end_id, i + 1) + except ValueError: + break + ranges.append((i + 1, j)) + i = j + 1 + else: + i += 1 + return ranges + + +def reshape_attention_map( + attn_vec: np.ndarray, image_grid_thw: Optional[torch.Tensor], target_index: int = -1, spatial_merge_size: int = 2 +) -> np.ndarray: + num_tokens = len(attn_vec) + if image_grid_thw is not None: + ig = image_grid_thw + if isinstance(ig, torch.Tensor): + ig = ig.cpu() + if hasattr(ig, "__len__") and len(ig) > 0: + idx = target_index if target_index >= 0 else len(ig) - 1 + idx = max(0, min(idx, len(ig) - 1)) + t, h, w = ig[idx].tolist() + merged_h = h // spatial_merge_size + merged_w = w // spatial_merge_size + expected = t * merged_h * merged_w + if num_tokens == expected: + attn_map = attn_vec.reshape(t, merged_h, merged_w) + attn_map = attn_map.mean(axis=0) if t > 1 else attn_map[0] + return attn_map + side = int(np.sqrt(num_tokens)) + if side * side == num_tokens: + return attn_vec.reshape(side, side) + for h in range(side, 0, -1): + if num_tokens % h == 0: + w = num_tokens // h + return attn_vec.reshape(h, w) + return attn_vec.reshape(1, -1) + + +def overlay_heatmap(image: Image.Image, attn_map: np.ndarray, save_path: Path) -> Path: + base = np.array(image) + attn_norm = attn_map - attn_map.min() + attn_norm = attn_norm / attn_norm.max() if attn_norm.max() > 0 else attn_norm + attn_resized = cv2.resize(attn_norm, (base.shape[1], base.shape[0]), interpolation=cv2.INTER_CUBIC) + heatmap = cv2.applyColorMap(np.uint8(255 * attn_resized), cv2.COLORMAP_JET) + overlay = cv2.addWeighted(base[:, :, ::-1], 0.5, heatmap, 0.5, 0) + save_path.parent.mkdir(parents=True, exist_ok=True) + cv2.imwrite(str(save_path), overlay) + return save_path + + +def generate_with_attention( + model, + processor, + tokenizer, + messages: List[Dict[str, Any]], + query_image_path: str, + max_new_tokens: int, + save_path: Path, +) -> Dict[str, Any]: + prompt_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + images = collect_images_from_messages(messages) + inputs = processor(text=prompt_text, images=images or None, padding=True, return_tensors="pt").to(model.device) + + input_ids = inputs["input_ids"] + attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)) + pixel_values = inputs.get("pixel_values") + image_grid_thw = inputs.get("image_grid_thw") + prompt_len = input_ids.shape[1] + + vision_ranges = find_vision_token_ranges(tokenizer, input_ids) + if not vision_ranges: + raise ValueError("Could not locate vision token range in input_ids.") + target_range = vision_ranges[-1] + + generated_ids = input_ids.clone() + attn_per_token: List[torch.Tensor] = [] + generated_tokens: List[str] = [] + + eos_token_id = tokenizer.eos_token_id + eos_set = set(eos_token_id) if isinstance(eos_token_id, list) else {eos_token_id} + + with torch.no_grad(): + for _ in range(max_new_tokens): + outputs = model( + input_ids=generated_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_attentions=True, + return_dict=True, + ) + attns = outputs.attentions + if attns is None: + raise ValueError("Model did not return attentions. Ensure eager attention is enabled.") + + last_token_attn = [] + start, end = target_range + for layer_attn in attns: + attn_slice = layer_attn[0, :, -1, start:end] + last_token_attn.append(attn_slice.float().cpu()) + avg_attn = torch.stack(last_token_attn, dim=0).mean(dim=(0, 1)) + attn_per_token.append(avg_attn) + + next_token_logits = outputs.logits[:, -1, :] + next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True) + generated_tokens.append(tokenizer.decode(next_token_id[0])) + + if next_token_id.item() in eos_set: + generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) + break + + generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) + attention_mask = torch.cat( + [attention_mask, torch.ones((1, 1), device=attention_mask.device, dtype=attention_mask.dtype)], dim=-1 + ) + + generated_text = tokenizer.decode( + generated_ids[0, prompt_len:], skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + if not attn_per_token: + raise ValueError("No attentions collected; generation ended immediately.") + + all_attn = torch.stack(attn_per_token, dim=0).mean(dim=0).numpy() + attn_map = reshape_attention_map(all_attn, image_grid_thw, target_index=len(vision_ranges) - 1) + + token_attn_sums = [float(a.sum().item()) for a in attn_per_token] + image = Image.open(query_image_path).convert("RGB") + final_path = overlay_heatmap(image, attn_map, save_path) + return { + "prompt_text": prompt_text, + "generated_text": generated_text, + "attn_path": str(final_path), + "attn_map_shape": list(attn_map.shape), + "generated_tokens": generated_tokens, + "vision_ranges": vision_ranges, + "token_attn_sums": token_attn_sums, + } + + +def select_demos( + k_shots: int, + smp, + recs: List[Dict], + imgs: List[str], + pool_emb: Optional[torch.Tensor], + retr_proc, + retr_model, + device: torch.device, + seed: int, +) -> Tuple[List[Dict], List[str]]: + if k_shots <= 0 or not recs or pool_emb is None or not isinstance(pool_emb, torch.Tensor) or pool_emb.numel() == 0: + return [], [] + q_emb = _encode_pairs(retr_proc, retr_model, [smp.image_path], [smp.text or ""], device, batch_size=1) + sim = (q_emb @ pool_emb.T).squeeze(0) + q_uid = _extract_uid(smp.raw, "") if isinstance(smp.raw, dict) else "" + q_sig = _img_sig_from_rec(smp.raw) if isinstance(smp.raw, dict) else "" + q_psig = _image_path_sig(smp.image_path) + q_txt = (smp.text or "").strip().lower() + mask = torch.ones(sim.shape[0], dtype=torch.bool, device=sim.device) + for i, r in enumerate(recs): + if q_uid and r.get("uid") == q_uid: + mask[i] = False + continue + if q_sig and r.get("img_sig") and r["img_sig"] == q_sig: + mask[i] = False + continue + if q_txt and (r.get("text_in") or "").strip().lower() == q_txt: + mask[i] = False + continue + mask &= (sim < 0.999).to(dtype=mask.dtype, device=mask.device).bool() + sim[~mask] = -1e4 + pre_k = min(max(k_shots * 50, 500), sim.numel()) + cand = [i for i in torch.topk(sim, k=pre_k).indices.tolist() if mask[i].item()] + if q_psig: + cand = [i for i in cand if _image_path_sig(imgs[i]) != q_psig] + idxs2 = cand[:k_shots] + if len(idxs2) < min(k_shots, sim.numel()): + rest = [i for i in range(sim.numel()) if mask[i].item() and i not in idxs2] + if q_psig: + rest = [i for i in rest if _image_path_sig(imgs[i]) != q_psig] + rng2 = random.Random(seed) + rng2.shuffle(rest) + idxs2.extend(rest[: max(0, k_shots - len(idxs2))]) + demos = [recs[i] for i in idxs2] + demo_imgs = [imgs[i] for i in idxs2] + return demos, demo_imgs + + + + +def cleanup_cuda(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + try: + torch.cuda.ipc_collect() + except Exception: + pass + + +def main(): + ap = argparse.ArgumentParser(description="Qwen3-VL 0-7 shot attention map sweep (small subset).") + ap.add_argument("--adapter", default="qwen3-vl") + ap.add_argument("--model-path", default="/workspace/Qwen3-VL-8B-Instruct") + ap.add_argument("--dataset-root", default="/workspace/M3IT") + ap.add_argument("--retriever-model", default="/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc") + ap.add_argument("--retriever-device", default="cpu", choices=["cpu", "cuda"], help="Place retriever on CPU to save VRAM.") + ap.add_argument("--output-base", default="runs/attn_maps_qwen3-vl") + ap.add_argument("--orders", default="image-text") + ap.add_argument("--categories", default="captioning,vqa,classification,reasoning") + ap.add_argument("--total-samples", type=int, default=10) + ap.add_argument("--max-shots", type=int, default=7) + ap.add_argument("--shots", default="", help="Comma-separated k-shots to run (e.g., 0,1,3). If empty, run 0..max-shots.") + ap.add_argument("--split", default="val") + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--temperature", type=float, default=0.6) + ap.add_argument("--top-p", type=float, default=1.0) + ap.add_argument("--max-new-tokens", type=int, default=128) + ap.add_argument("--demo-pool-limit", type=int, default=256, help="Limit demos per task when building pools.") + ap.add_argument("--eval-pool-limit", type=int, default=400, help="Limit how many eval samples to scan per task.") + args = ap.parse_args() + + torch.manual_seed(args.seed) + random.seed(args.seed) + + root_dir = Path.cwd() + output_base = Path(args.output_base) + cache_dir = output_base / "_image_cache" + feat_cache = output_base / "_feat_cache" + cache_dir.mkdir(parents=True, exist_ok=True) + feat_cache.mkdir(parents=True, exist_ok=True) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"[INFO] device={device}, model={args.model_path}") + processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + model = AutoModelForVision2Seq.from_pretrained( + args.model_path, + device_map="cuda" if device.type == "cuda" else None, + torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32, + trust_remote_code=True, + ).eval() + if device.type != "cuda": + model = model.to(device) + if hasattr(model, "set_attn_implementation"): + try: + model.set_attn_implementation("eager") + except Exception as exc: + print(f"[WARN] failed to set attn implementation to eager: {exc}") + elif hasattr(model.config, "attn_implementation"): + model.config.attn_implementation = "eager" + + retr_proc = AutoProcessor.from_pretrained(args.retriever_model, trust_remote_code=True) + retr_dev = torch.device("cuda" if args.retriever_device == "cuda" and torch.cuda.is_available() else "cpu") + retr_model = AutoModel.from_pretrained(args.retriever_model, trust_remote_code=True).to(retr_dev).eval() + + orders = [o.strip().lower() for o in args.orders.split(",") if o.strip()] + if not orders: + orders = ["image-text"] + order_map = {"image-text": build_image_text, "text-image": build_text_image, "text-image-text": build_text_image_text} + categories = [c.strip().lower() for c in args.categories.split(",") if c.strip()] + + dataset_root = Path(args.dataset_root) + default_tasks = { + "captioning": _detect_tasks(dataset_root, "captioning"), + "vqa": _detect_tasks(dataset_root, "vqa"), + "classification": _detect_tasks(dataset_root, "classification"), + "reasoning": _detect_tasks(dataset_root, "reasoning"), + } + default_tasks.setdefault("vqa", ["vqa/vqav2"]) + + rng = random.Random(args.seed) + + demo_pools: Dict[str, Tuple[List[dict], List[str], Optional[torch.Tensor]]] = {} + if args.max_shots > 0: + for cat in categories: + for subdir in default_tasks.get(cat, []): + demo_prefer = _prefer_demo_splits(args.split) + try: + recs, imgs = load_pool_items( + dataset_root, subdir, cache_dir, prefer=demo_prefer, max_items=args.demo_pool_limit, category=cat + ) + except Exception as exc: + print(f"[WARN] skip demos for {subdir}: {exc}") + demo_pools[subdir] = ([], [], None) + continue + feat_path = feat_cache / f"multimodal_{subdir.replace('/','_')}_pool_{args.demo_pool_limit}.pt" + if feat_path.exists(): + pool_emb = torch.load(feat_path, map_location="cpu") + else: + texts = [r["text_in"] for r in recs] + pool_emb = _encode_pairs(retr_proc, retr_model, imgs, texts, device) + torch.save(pool_emb, feat_path) + demo_pools[subdir] = (recs, imgs, pool_emb) + + eval_plans: Dict[str, List[int]] = {} + eval_samples: Dict[str, List[Any]] = {} + per_cat = quota(args.total_samples, len(categories) or 1) + for cat, q in zip(categories, per_cat): + tasks = default_tasks.get(cat, []) + if not tasks: + continue + per_task = quota(q, len(tasks)) + for subdir, q_task in zip(tasks, per_task): + try: + pool = [ + s + for s in iter_m3it_samples( + args.dataset_root, subdir, split=args.split, cache_dir=str(cache_dir), max_samples=args.eval_pool_limit + ) + ] + except Exception: + continue + if not pool: + continue + k = min(q_task, len(pool)) + idxs = rng.sample(range(len(pool)), k=k) + eval_plans[subdir] = idxs + eval_samples[subdir] = pool + + prompts = _build_prompts() + samples_root = output_base / "samples" + records: List[Dict[str, Any]] = [] + + shot_plan: List[int] + if args.shots.strip(): + shot_plan = [] + for part in args.shots.split(","): + part = part.strip() + if not part: + continue + try: + v = int(part) + except ValueError: + continue + if v < 0: + continue + if args.max_shots is not None: + if v > args.max_shots: + continue + shot_plan.append(v) + if not shot_plan: + shot_plan = list(range(args.max_shots + 1)) + else: + shot_plan = list(range(args.max_shots + 1)) + + for order in orders: + builder = order_map.get(order) + if builder is None: + print(f"[WARN] unknown order {order}, skip") + continue + for cat in categories: + tasks = default_tasks.get(cat, []) + for subdir in tasks: + idxs = eval_plans.get(subdir, []) + if not idxs: + continue + ds_insts = load_instructions(dataset_root, subdir) + ds_text = "" + if isinstance(ds_insts, list) and ds_insts: + ds_text = "\n".join([s for s in ds_insts if isinstance(s, str) and s.strip()]) + inst = prompts.get(cat, "") + if ds_text: + inst = (inst + ("\n" + ds_text)).strip() if inst else ds_text + recs, imgs, pool_emb = demo_pools.get(subdir, ([], [], None)) + for idx in idxs: + smp = eval_samples[subdir][idx] + q_text = smp.text or "" + if cat in ("classification", "reasoning"): + inp = _extract_inputs(smp.raw) + if inp: + q_text = (q_text.rstrip() + "\n" + inp).strip() + if not q_text.strip(): + base_req = (inst or "").strip() or prompts.get(cat, "").strip() + q_text = base_req + + uid = sanitize_name(f"{_extract_uid(smp.raw, f'{idx:05d}')}") + sample_dir = samples_root / cat / subdir.replace("/", "_") / uid + image_dir = sample_dir / "image" + image_dir.mkdir(parents=True, exist_ok=True) + try: + img_name = Path(smp.image_path).name + shutil.copyfile(smp.image_path, image_dir / img_name) + except Exception: + pass + attn_dir = sample_dir / "attn_maps" + attn_dir.mkdir(parents=True, exist_ok=True) + + for kshot in shot_plan: + demos, demo_imgs = select_demos( + kshot, smp, recs, imgs, pool_emb, retr_proc, retr_model, retr_dev, args.seed + ) + segs = builder(inst, demos, demo_imgs, smp.image_path, q_text) + messages = segments_to_messages(segs) + shot_dir = sample_dir / f"shot{kshot}" + shot_dir.mkdir(parents=True, exist_ok=True) + try: + res = generate_with_attention( + model, + processor, + tokenizer, + messages, + smp.image_path, + args.max_new_tokens, + attn_dir / f"shot{kshot}.png", + ) + except Exception as exc: + (shot_dir / "error.txt").write_text(str(exc), encoding="utf-8") + cleanup_cuda() + continue + + (shot_dir / "prompt.txt").write_text(res["prompt_text"], encoding="utf-8") + (shot_dir / "messages.json").write_text(json.dumps(messages, ensure_ascii=False, indent=2), encoding="utf-8") + (shot_dir / "output.txt").write_text(res["generated_text"], encoding="utf-8") + demo_dump = [ + {"text_in": d.get("text_in", ""), "text_out": d.get("text_out", ""), "image": demo_imgs[i]} + for i, d in enumerate(demos) + ] + (shot_dir / "demos.json").write_text(json.dumps(demo_dump, ensure_ascii=False, indent=2), encoding="utf-8") + records.append( + { + "order": order, + "category": cat, + "task": subdir, + "shot": kshot, + "sample_dir": str(sample_dir), + "attn_path": res["attn_path"], + } + ) + try: + save_token_attention_artifacts(res["token_attn_sums"], res["generated_tokens"], shot_dir) + except Exception as exc: + print(f"[WARN] failed to save token attention artifacts for {shot_dir}: {exc}") + cleanup_cuda() + + summary_path = output_base / "summary.json" + summary_path.write_text(json.dumps(records, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"[DONE] Saved attention maps and prompts under {output_base}") + print(f"[INFO] Summary: {summary_path}") + + +if __name__ == "__main__": + main() diff --git a/ICL/LV/code/core/__pycache__/__init__.cpython-313.pyc b/ICL/LV/code/core/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a801fb735725f6e8507e4753c515fe17577f231 Binary files /dev/null and b/ICL/LV/code/core/__pycache__/__init__.cpython-313.pyc differ diff --git a/ICL/LV/code/core/datasets/__init__.py b/ICL/LV/code/core/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ICL/LV/code/core/datasets/__pycache__/__init__.cpython-313.pyc b/ICL/LV/code/core/datasets/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd79888ca87943aac65ac43c8b5fc6e828815f7a Binary files /dev/null and b/ICL/LV/code/core/datasets/__pycache__/__init__.cpython-313.pyc differ diff --git a/ICL/LV/code/core/datasets/__pycache__/m3it_reader.cpython-311.pyc b/ICL/LV/code/core/datasets/__pycache__/m3it_reader.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d44cbd982584065687b1639b059eeddd65de84c2 Binary files /dev/null and b/ICL/LV/code/core/datasets/__pycache__/m3it_reader.cpython-311.pyc differ diff --git a/ICL/LV/code/core/datasets/__pycache__/m3it_reader.cpython-313.pyc b/ICL/LV/code/core/datasets/__pycache__/m3it_reader.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0165dc28f4f5bf48f53241d44407fc3da6d61e19 Binary files /dev/null and b/ICL/LV/code/core/datasets/__pycache__/m3it_reader.cpython-313.pyc differ diff --git a/ICL/LV/code/core/datasets/m3it_reader.py b/ICL/LV/code/core/datasets/m3it_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e4b15a168815a9cb1b2e731d3c6c4a0410c6b7 --- /dev/null +++ b/ICL/LV/code/core/datasets/m3it_reader.py @@ -0,0 +1,485 @@ +import base64 +import io +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterator, List, Optional, Set + +from PIL import Image +import os + + +@dataclass +class Sample: + image_path: str + text: str + answers: Optional[List[str]] + raw: Dict + + +def _ensure_dir(p: Path) -> None: + p.mkdir(parents=True, exist_ok=True) + + +def _b64_to_image_path(b64: str, cache_dir: Path, key: str) -> str: + _ensure_dir(cache_dir) + out = cache_dir / f"{key}.jpg" + if not out.exists(): + img = Image.open(io.BytesIO(base64.b64decode(b64))) + if img.mode != "RGB": + img = img.convert("RGB") + img.save(out, format="JPEG", quality=90) + return str(out) + + +def read_jsonl(path: Path) -> Iterator[Dict]: + """Read JSONL by default; if file is a JSON array/object, iterate accordingly.""" + if path.suffix.lower() == '.json': + data = json.loads(path.read_text(encoding='utf-8')) + if isinstance(data, list): + for obj in data: + if isinstance(obj, dict): + yield obj + elif isinstance(data, dict): + # some datasets wrap records under a key + for v in data.values(): + if isinstance(v, list): + for obj in v: + if isinstance(obj, dict): + yield obj + return + with path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + yield json.loads(line) + + +# ---- Schema alignment (authoritative mapping via dataset_inspect.schemas.json) ---- +_SCHEMA_CACHE: Optional[Dict] = None +_SCHEMA_PATH_DEFAULT = '/mnt/e/xiaobin/dataset_inspect.schemas.json' + + +def _load_schema() -> Optional[Dict]: + global _SCHEMA_CACHE + if _SCHEMA_CACHE is not None: + return _SCHEMA_CACHE + p = os.environ.get('DATASET_SCHEMA_JSON') or _SCHEMA_PATH_DEFAULT + try: + sp = Path(p) + if sp.exists(): + _SCHEMA_CACHE = json.loads(sp.read_text(encoding='utf-8')) + else: + _SCHEMA_CACHE = None + except Exception: + _SCHEMA_CACHE = None + return _SCHEMA_CACHE + + +def _schema_entry(subdir: str, split: str) -> Optional[Dict]: + obj = _load_schema() + if not isinstance(obj, dict): + return None + # Prefer exact split name; accept common aliases json/jsonl; accept 'validation' for 'val' + keys = [] + split_names = [split] + if split == 'val': + split_names = ['val', 'validation'] + for s in split_names: + for ext in ('.jsonl', '.json'): + keys.append(f'data/{subdir}/{s}{ext}') + for k in keys: + if k in obj and isinstance(obj[k], dict): + return obj[k] + return None + + +def _schema_required_fields(subdir: str, split: str) -> Set[str]: + e = _schema_entry(subdir, split) + if isinstance(e, dict): + req = e.get('required') + if isinstance(req, list) and all(isinstance(x, str) for x in req): + return set(req) + return set() + + +def _schema_answer_key_candidates(subdir: str, split: str) -> List[str]: + """Return a prioritized list of dotted keys for answers based on schema.""" + e = _schema_entry(subdir, split) + if not isinstance(e, dict): + return [] + props = e.get('properties') or {} + order = [] + # Prefer explicit label/target/output in top-level properties + for k in ( + 'label', 'target', 'targets', + 'output', 'outputs', + 'answer', 'answers', + # vqav2-like fields + 'paraphrased_answer', 'original_answer', + 'caption', 'captions', 'caption_text'): + if k in props: + order.append(k) + # meta.* common for COCO style references + meta = props.get('meta') if isinstance(props.get('meta'), dict) else None + mprops = meta.get('properties') if isinstance(meta, dict) else None + if isinstance(mprops, dict): + for k in ('targets', 'sent', 'caption', 'captions'): + if k in mprops: + order.append(f'meta.{k}') + # Deduplicate, preserve order + seen = set(); out = [] + for k in order: + if k not in seen: + seen.add(k); out.append(k) + return out + + +def _get_dotted(d: Dict, path: str): + cur = d + for p in path.split('.'): + if not isinstance(cur, dict) or p not in cur: + return None + cur = cur[p] + return cur + + +_IMAGE_KEYS = ( + "image", + "image_str", + "img_str", + "image_base64", + "image_bytes", + "img", + "image_base64_str", + # some M3IT subdirs (e.g., fm-iqa) use this key name + "base64", +) +_IMAGE_PATH_KEYS = ("image_path", "img_path", "image_file", "image_filepath", "file", "path") +# Treat these as primary per-sample input prompts when present +_TEXT_IN_KEYS = ( + "input", + "inputs", + "question", + "query", + "prompt", + "user_prompt", + "instruction", + # some subdirs use a generic "text" key for the user prompt + "text", +) +# Output/reference keys: broaden to cover caption datasets as well +_TEXT_OUT_KEYS = ( + "output", "outputs", "answer", "answers", "target", "targets", "label", + # VQA variants in M3IT (e.g., vqav2) may use these names + "paraphrased_answer", "original_answer", + # Common caption fields + "caption", "captions", "caption_text", +) + + +def _candidate_instruction_files(root: Path, subdir: str) -> List[Path]: + names = ["instructions.json", "instruction.json"] + bases: List[Path] = [] + # Typical layout: /data/ + bases.append(root / "data") + # Caller may pass the data folder directly: / + bases.append(root) + # Handle nested or duplicated root segments + bases.append(root.parent) + bases.append(root.parent / "data") + + cands: List[Path] = [] + for b in bases: + for n in names: + cands.append(b / subdir / n) + # De-duplicate while preserving order + seen = set() + uniq: List[Path] = [] + for p in cands: + if p in seen: + continue + seen.add(p) + uniq.append(p) + return uniq + + +def load_instructions(dataset_root: Path, subdir: str) -> List[str]: + # 1) Try direct instructions.json near the dataset + for p in _candidate_instruction_files(dataset_root, subdir): + if p.exists(): + try: + return json.loads(p.read_text(encoding="utf-8")) + except Exception: + return [] + + # 2) Fallback: probe a consolidated schema dump (dataset_inspect.schemas.json) + # which may contain examples for data//instructions.json + # Typical location in this repo layout: /../dataset_inspect.schemas.json + # Also probe dataset_root parents to be robust across machines. + cand_schema: List[Path] = [] + try: + # Current working directory → parent + cwdp = Path.cwd() + cand_schema.append(cwdp / 'dataset_inspect.schemas.json') + cand_schema.append(cwdp.parent / 'dataset_inspect.schemas.json') + except Exception: + pass + # Dataset root neighborhood + try: + cand_schema.append(dataset_root / 'dataset_inspect.schemas.json') + cand_schema.append(dataset_root.parent / 'dataset_inspect.schemas.json') + cand_schema.append(dataset_root.parent.parent / 'dataset_inspect.schemas.json') + except Exception: + pass + # Optional explicit env + env_path = os.environ.get('DATASET_SCHEMA_JSON') + if env_path: + cand_schema.insert(0, Path(env_path)) + + key = f"data/{subdir}/instructions.json" + seen: set[Path] = set() + for sp in cand_schema: + try: + if sp in seen: + continue + seen.add(sp) + if not sp.exists(): + continue + obj = json.loads(sp.read_text(encoding='utf-8')) + if isinstance(obj, dict) and key in obj and isinstance(obj[key], dict): + ex = obj[key].get('examples') + if isinstance(ex, list) and all(isinstance(x, str) for x in ex): + # Use unique, non-empty strings; keep short to avoid prompt bloat + out = [] + for s in ex: + s = (s or '').strip() + if s and s not in out: + out.append(s) + return out[:16] + except Exception: + continue + return [] + + +def _candidate_split_files(root: Path, subdir: str, split: str) -> List[Path]: + # Try common dataset layouts and split names. + split_names = [ + # prefer requested split + f"{split}.jsonl", f"{split}.json", + # common alternates + "val.jsonl", "val.json", "validation.jsonl", "validation.json", "valid.jsonl", "valid.json", + # test fallback + "test.jsonl", "test.json", + # some datasets provide v2 or dev variants + f"{split}_v2.jsonl", f"{split}_v2.json", + "val_v2.jsonl", "val_v2.json", "validation_v2.jsonl", "validation_v2.json", "test_v2.jsonl", "test_v2.json", + "dev.jsonl", "dev.json", + ] + bases: List[Path] = [] + # Typical: /data/ + bases.append(root / "data") + # Sometimes caller already passes the data folder: / + bases.append(root) + # Handle nested or duplicated root names, e.g., .../M3IT/M3IT + bases.append(root.parent) + bases.append(root.parent / "data") + + cands: List[Path] = [] + for b in bases: + cands.extend([b / subdir / s for s in split_names]) + # De-duplicate while keeping order + seen = set() + uniq: List[Path] = [] + for p in cands: + if p in seen: + continue + seen.add(p) + uniq.append(p) + return uniq + + +def _resolve_image_path(root: Path, subdir: str, p: str) -> Optional[str]: + """Resolve an image path that may be absolute or relative to common roots. + Returns an existing path if found, else None. + """ + if not isinstance(p, str) or not p.strip(): + return None + cand = [] + sp = Path(p) + cand.append(sp) + cand.append(root / p) + cand.append(root / "data" / p) + cand.append(root / "data" / subdir / p) + cand.append(root.parent / p) + cand.append(root.parent / "data" / p) + cand.append(root.parent / "data" / subdir / p) + # No automatic Windows path mapping; rely solely on server paths. + for c in cand: + try: + if c.exists(): + return str(c) + except Exception: + pass + return None + + +def iter_m3it_samples( + dataset_root: str, + subdir: str, + split: str, + cache_dir: str, + max_samples: Optional[int] = None, +) -> Iterator[Sample]: + root = Path(dataset_root) + cache = Path(cache_dir) + + # Build candidate split paths across common layouts + cand_files = _candidate_split_files(root, subdir, split) + file = None + for cf in cand_files: + if cf.exists(): + file = cf + break + if file is None: + raise FileNotFoundError(f"No split file for {subdir} split={split}") + + # Schema required enforcement (strict alignment) + required = _schema_required_fields(subdir, split) + + for idx, rec in enumerate(read_jsonl(file)): + # If schema declares required fields, enforce presence strictly + if required: + missing = [k for k in required if k not in rec] + if missing: + # Skip records not meeting required contract + continue + # base64 image or file path + img_b64 = None + img_path_resolved = None + # 1) direct/base64-like large strings + def _join_if_list(v): + if isinstance(v, list) and v and all(isinstance(x, str) for x in v): + return ''.join(v) + return v + for k in _IMAGE_KEYS: + if k in rec: + val = _join_if_list(rec[k]) + if isinstance(val, str): + if len(val) > 100: + img_b64 = val; break + else: + rp = _resolve_image_path(root, subdir, val) + if rp: + img_path_resolved = rp; break + if "meta" in rec and isinstance(rec["meta"], dict) and k in rec["meta"]: + val = _join_if_list(rec["meta"][k]) + if isinstance(val, str): + if len(val) > 100: + img_b64 = val; break + else: + rp = _resolve_image_path(root, subdir, val) + if rp: + img_path_resolved = rp; break + # 2) explicit path keys + if img_b64 is None and img_path_resolved is None: + for k in _IMAGE_PATH_KEYS: + if k in rec and isinstance(rec[k], str): + rp = _resolve_image_path(root, subdir, rec[k]) + if rp: + img_path_resolved = rp; break + if "meta" in rec and isinstance(rec["meta"], dict) and k in rec["meta"] and isinstance(rec["meta"][k], str): + rp = _resolve_image_path(root, subdir, rec["meta"][k]) + if rp: + img_path_resolved = rp; break + if img_b64 is None and img_path_resolved is None: + # no usable image; skip + continue + + # input text + text_in = "" + for k in _TEXT_IN_KEYS: + if k in rec and isinstance(rec[k], str): + text_in = rec[k] + break + if "meta" in rec and isinstance(rec["meta"], dict) and k in rec["meta"] and isinstance(rec["meta"][k], str): + text_in = rec["meta"][k] + break + + # references(容错:支持 str/int/float/bool 及其列表;其次尝试更多常见键) + def _to_str_list(v) -> Optional[List[str]]: + if v is None: + return None + if isinstance(v, (str, int, float, bool)): + return [str(v)] + if isinstance(v, list): + out = [] + for x in v: + if isinstance(x, (str, int, float, bool)): + out.append(str(x)) + return out if out else None + return None + answers: Optional[List[str]] = None + # Try schema-guided keys first (strict alignment) + cand = _schema_answer_key_candidates(subdir, split) + for dk in cand: + val = _get_dotted(rec, dk) + if val is None: + continue + tmp = _to_str_list(val) + if tmp: + answers = tmp; break + # Fallback: generic keys (still deterministic) + if answers is None: + for k in _TEXT_OUT_KEYS: + if k in rec: + tmp = _to_str_list(rec[k]) + if tmp: + answers = tmp; break + mv = rec.get("meta") if isinstance(rec.get("meta"), dict) else None + if isinstance(mv, dict) and k in mv: + tmp = _to_str_list(mv[k]) + if tmp: + answers = tmp; break + if answers is None: + for k in ("final_answer","ans","answer_text","gt_answer","label_text","short_answer"): + if k in rec: + tmp = _to_str_list(rec[k]) + if tmp: + answers = tmp; break + mv = rec.get("meta") if isinstance(rec.get("meta"), dict) else None + if isinstance(mv, dict) and k in mv: + tmp = _to_str_list(mv[k]) + if tmp: + answers = tmp; break + # Fallback: ITM-style true_idx + candidates -> textual answer + if answers is None: + try: + cand = rec.get('candidates') + ti = rec.get('true_idx') + if isinstance(cand, list) and isinstance(ti, int) and 0 <= ti < len(cand): + val = cand[ti] + answers = _to_str_list(val) + except Exception: + pass + + # materialize image path (decode base64 if needed) + if img_path_resolved is None: + mv = rec.get("meta") if isinstance(rec.get("meta"), dict) else {} + img_key = ( + (mv.get("img_id") if isinstance(mv.get("img_id"), (str,int)) else None) + or rec.get("Flickr30k_image_id") + or rec.get("image_id") + or rec.get("pair_id") + or rec.get("id") + or f"{idx:08d}" + ) + img_path = _b64_to_image_path(img_b64, cache / subdir.replace("/", "_"), str(img_key)) + else: + img_path = img_path_resolved + + yield Sample(image_path=img_path, text=text_in, answers=answers, raw=rec) + + if max_samples is not None and idx + 1 >= max_samples: + break diff --git a/ICL/LV/code/core/eval/__init__.py b/ICL/LV/code/core/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ICL/LV/code/core/eval/__pycache__/__init__.cpython-313.pyc b/ICL/LV/code/core/eval/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b6cf8cedc54f4fe64889a53c7ff1310a282a331 Binary files /dev/null and b/ICL/LV/code/core/eval/__pycache__/__init__.cpython-313.pyc differ diff --git a/ICL/LV/code/core/eval/__pycache__/_modal_order.cpython-313.pyc b/ICL/LV/code/core/eval/__pycache__/_modal_order.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea04acce688218facf32b43014ff28861e6026a4 Binary files /dev/null and b/ICL/LV/code/core/eval/__pycache__/_modal_order.cpython-313.pyc differ diff --git a/ICL/LV/code/core/eval/__pycache__/eval_order_reasoning_accuracy.cpython-313.pyc b/ICL/LV/code/core/eval/__pycache__/eval_order_reasoning_accuracy.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..077f3e4e6007ccaeb08c2afbc12290d05ce4c3ee Binary files /dev/null and b/ICL/LV/code/core/eval/__pycache__/eval_order_reasoning_accuracy.cpython-313.pyc differ diff --git a/ICL/LV/code/core/eval/__pycache__/order_eval_core.cpython-313.pyc b/ICL/LV/code/core/eval/__pycache__/order_eval_core.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d441be5b8ef2aa3dbc3718f7061a71e6573462c Binary files /dev/null and b/ICL/LV/code/core/eval/__pycache__/order_eval_core.cpython-313.pyc differ diff --git a/ICL/LV/code/core/eval/_modal_order.py b/ICL/LV/code/core/eval/_modal_order.py new file mode 100644 index 0000000000000000000000000000000000000000..ae6be77f8071c920afe73df4591376e0904ccb6a --- /dev/null +++ b/ICL/LV/code/core/eval/_modal_order.py @@ -0,0 +1,94 @@ +""" +Minimal builders for configurable in-context modal order. + +Each builder returns Qwen-style list-format segments: + - image: {'image': } + - text: {'text': '...'} where text may contain [REQUEST]/[RESPONSE] + +We keep the query order fixed as image -> [REQUEST] question -> [RESPONSE] stub +and vary only the per-demo presentation order. +""" + +from typing import Dict, List, Sequence + + +def _demo_segs(order: str, image_path: str, q_text: str, a_text: str) -> List[Dict[str, str]]: + segs: List[Dict[str, str]] = [] + q = (q_text or '').strip() + a = (a_text or '').strip() + for c in order: + if c == 'I': + segs.append({'image': image_path}) + elif c == 'Q': + segs.append({'text': f'[REQUEST]\n{q}'}) + elif c == 'A': + segs.append({'text': f'[RESPONSE]\n{a}'}) + return segs + + +def build_image_text( + instruction: str, + demos: Sequence[Dict], # text_in, text_out + demo_images: Sequence[str], + query_image_path: str, + query_text: str, +) -> List[Dict[str, str]]: + segs: List[Dict[str, str]] = [] + inst = (instruction or '').strip() + if inst: + segs.append({'text': inst}) + for d, img in zip(demos, demo_images): + q = (d.get('text_in', '') or '').strip() + a = (d.get('text_out', '') or '').strip() + segs.append({'image': img}) + segs.append({'text': f'[REQUEST]\n{q}\n[RESPONSE]\n{a}'}) + segs.append({'image': query_image_path}) + qt = (query_text or '').strip() + segs.append({'text': f'[REQUEST]\n{qt}\n[RESPONSE]'}) + return segs + + +def build_text_image( + instruction: str, + demos: Sequence[Dict], # text_in, text_out + demo_images: Sequence[str], + query_image_path: str, + query_text: str, +) -> List[Dict[str, str]]: + segs: List[Dict[str, str]] = [] + inst = (instruction or '').strip() + if inst: + segs.append({'text': inst}) + for d, img in zip(demos, demo_images): + q = (d.get('text_in', '') or '').strip() + a = (d.get('text_out', '') or '').strip() + segs.append({'text': f'[REQUEST]\n{q}\n[RESPONSE]\n{a}'}) + segs.append({'image': img}) + segs.append({'image': query_image_path}) + qt = (query_text or '').strip() + segs.append({'text': f'[REQUEST]\n{qt}\n[RESPONSE]'}) + return segs + + +def build_text_image_text( + instruction: str, + demos: Sequence[Dict], # text_in, text_out + demo_images: Sequence[str], + query_image_path: str, + query_text: str, +) -> List[Dict[str, str]]: + segs: List[Dict[str, str]] = [] + inst = (instruction or '').strip() + if inst: + segs.append({'text': inst}) + for d, img in zip(demos, demo_images): + q = (d.get('text_in', '') or '').strip() + a = (d.get('text_out', '') or '').strip() + segs.append({'text': f'[REQUEST]\n{q}'}) + segs.append({'image': img}) + segs.append({'text': f'[RESPONSE]\n{a}'}) + segs.append({'image': query_image_path}) + qt = (query_text or '').strip() + segs.append({'text': f'[REQUEST]\n{qt}\n[RESPONSE]'}) + return segs + diff --git a/ICL/LV/code/core/eval/collect_all_scores.py b/ICL/LV/code/core/eval/collect_all_scores.py new file mode 100644 index 0000000000000000000000000000000000000000..1f4711791cbb4a9c7c292e839c29e2aa5f04f8be --- /dev/null +++ b/ICL/LV/code/core/eval/collect_all_scores.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Collect per-run metric scores (actual numbers) and compute top/bottom runs by overall score. + +Inputs: +- Either pass one or more run bases via --bases (each base should contain per-metric subfolders + with summary.json), e.g., runs/order_qwen-vl, runs/order_gemma3, ... +- Or omit --bases to auto-scan runs/order_* under CWD. + +Outputs (JSON): +- all_scores.json: a list of runs with per-metric by-order scores and per-metric averages, and overall_avg. +- top_bottom.json: top5 and bottom5 by overall_avg (ties broken by run_id). + +Notes: +- The per-metric average follows summarize_overall.py: mean over orders, ignoring nulls. +- overall_avg is the mean over metric averages, ignoring nulls. +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Dict, List, Optional + + +DEFAULT_METRICS = [ + "vqa_tokenf1", + "vqa_bertscore", + "captioning_bertscore", + "captioning_cider", + "classification_accuracy", + "classification_f1", + "reasoning_accuracy", + "reasoning_ras", +] + + +def _read_json(path: Path) -> Optional[dict]: + try: + with path.open("r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return None + + +def _mean(values: List[Optional[float]]) -> Optional[float]: + xs = [float(v) for v in values if isinstance(v, (int, float))] + return (sum(xs) / len(xs)) if xs else None + + +def collect_for_base(base: Path, metrics: List[str], *, ras_mul: float = 1.0, ras_auto_scale: bool = False) -> Optional[dict]: + if not base.is_dir(): + return None + out_metrics: Dict[str, dict] = {} + for m in metrics: + summ = _read_json(base / m / "summary.json") + if isinstance(summ, dict) and summ: + # by_order: as-is + by_order: Dict[str, Optional[float]] = {} + for k, v in summ.items(): + by_order[str(k)] = (None if v is None else float(v)) + # Optional scaling for RAS + if m == "reasoning_ras" and ras_mul != 1.0: + vals = [vv for vv in by_order.values() if isinstance(vv, (int, float))] + need_scale = True + if ras_auto_scale and vals: + try: + mx = max(float(x) for x in vals) + need_scale = (mx <= 1.05) + except Exception: + need_scale = True + if need_scale: + for kk, vv in list(by_order.items()): + if isinstance(vv, (int, float)): + by_order[kk] = float(vv) * float(ras_mul) + avg = _mean(list(by_order.values())) + else: + by_order = {} + avg = None + out_metrics[m] = { + "by_order": by_order, + "avg": (None if avg is None else float(avg)), + } + overall = _mean([v.get("avg") for v in out_metrics.values()]) + + meta = {} + ov = _read_json(base / "overall.json") + if isinstance(ov, dict): + if isinstance(ov.get("meta"), dict): + meta = ov["meta"] + + return { + "run_id": str(base), + "metrics": out_metrics, + "overall_avg": (None if overall is None else float(overall)), + "meta": meta or None, + } + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--bases", nargs="*", default=[], help="Run base dirs (e.g., runs/order_qwen-vl). Default: auto-scan runs/order_*") + ap.add_argument("--metrics", nargs="*", default=DEFAULT_METRICS, help="Metric folder names to include") + ap.add_argument("--output-dir", default="runs/summaries", help="Where to write JSON outputs") + ap.add_argument("--topk", type=int, default=5, help="How many top/bottom runs to keep") + ap.add_argument("--ras-mul", type=float, default=1.0, help="Multiply reasoning_ras scores by this factor (applied per-order and avg)") + ap.add_argument("--ras-auto-scale", action="store_true", help="Apply ras-mul only when RAS appears to be in 0..1 range") + args = ap.parse_args() + + cwd = Path.cwd() + if args.bases: + bases = [Path(b) if Path(b).is_absolute() else (cwd / b) for b in args.bases] + else: + bases = sorted((cwd / "runs").glob("order_*")) + + results: List[dict] = [] + for b in bases: + rec = collect_for_base(b, args.metrics, ras_mul=float(args.ras_mul), ras_auto_scale=bool(args.ras_auto_scale)) + if rec is not None: + results.append(rec) + + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + all_path = out_dir / "all_scores.json" + all_path.write_text(json.dumps(results, ensure_ascii=False, indent=2), encoding="utf-8") + + # Rank by overall_avg (exclude None); break ties by run_id for determinism + def _key(d): + ov = d.get("overall_avg") + return (float(ov) if isinstance(ov, (int, float)) else float("nan")) + + ranked = [r for r in results if isinstance(r.get("overall_avg"), (int, float))] + ranked.sort(key=lambda d: (-d["overall_avg"], str(d.get("run_id", "")))) + topk = ranked[: max(0, args.topk)] + ranked.sort(key=lambda d: (d["overall_avg"], str(d.get("run_id", "")))) + bottomk = ranked[: max(0, args.topk)] + + tb = { + "top": topk, + "bottom": bottomk, + "k": int(args.topk), + } + tb_path = out_dir / "top_bottom.json" + tb_path.write_text(json.dumps(tb, ensure_ascii=False, indent=2), encoding="utf-8") + + print(f"[collect_all_scores] wrote {all_path}") + print(f"[collect_all_scores] wrote {tb_path}") + + +if __name__ == "__main__": + main() diff --git a/ICL/LV/code/core/eval/dump_per_sample_scores.py b/ICL/LV/code/core/eval/dump_per_sample_scores.py new file mode 100644 index 0000000000000000000000000000000000000000..31755e21033e5a0cf51df6607bfb3ef5ae15003e --- /dev/null +++ b/ICL/LV/code/core/eval/dump_per_sample_scores.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 +""" +Dump per-sample scores for each metric under a run's output-base. + +For each metric and modal order, read cached predictions from + //_cache/__.jsonl +and write per-sample details to + /per_sample//.jsonl + +Notes: +- RAS: if roscoe is available, compute per-sample via core.metrics.roscoe_shim.evaluate_list; + otherwise fallback to BERTScore per-sample; otherwise token-F1 per-sample. +- BERTScore: requires bert-score; if unavailable, fallback to token-F1 per-sample. +- CIDEr: requires pycocoevalcap; if unavailable, writes NA for scores. +- Classification Macro-F1: not well-defined per-sample; we log pred/gold letters. + +Usage: + python -m core.eval.dump_per_sample_scores --output-base runs/order_qwen-vl +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import List, Optional + + +ORDERS = ("image-text", "text-image", "text-image-text") + + +def _read_cache_lines(base: Path, metric: str, category: str, order: str) -> List[dict]: + p = base / metric / "_cache" / f"{category}__{order}.jsonl" + if not p.exists(): + return [] + out = [] + for line in p.read_text(encoding="utf-8").splitlines(): + if not line.strip(): + continue + try: + out.append(json.loads(line)) + except Exception: + pass + return out + + +def _ensure_dir(p: Path) -> None: + p.parent.mkdir(parents=True, exist_ok=True) + + +def dump_vqa_tokenf1(base: Path, out_root: Path) -> None: + from core.metrics.metrics import _f1_score + metric = "vqa_tokenf1"; cat = "vqa" + for order in ORDERS: + lines = _read_cache_lines(base, metric, cat, order) + out_path = out_root / metric / f"{order}.jsonl" + _ensure_dir(out_path) + out_lines = [] + for i, obj in enumerate(lines): + pred = obj.get('pred') or '' + ans = obj.get('answers') or [] + ref = (ans[0] if ans else '') + # Scale to 0..100 to stay consistent with summary and other per-sample metrics + score = (100.0 * float(_f1_score(pred, [ref]))) if ref else None + out_lines.append(json.dumps({ + 'i': i, + 'order': order, + 'metric': 'token_f1', + 'pred': pred, + 'ref': ref, + 'score': score, + 'meta': obj.get('meta') or None, + }, ensure_ascii=False)) + out_path.write_text("\n".join(out_lines), encoding='utf-8') + + +def _bertscore_per_sample( + preds: List[str], + refs, # accepts List[str] or List[List[str]] + model_type: str, + rescale_with_baseline: bool, + batch_size: int, + num_layers: Optional[int], + lang: Optional[str], + strict: bool = False, +): + """Return per-sample BERTScore F1 (0..100) list. + If strict=True, raise on any failure instead of returning None. + + Note: bert_score.score expects refs as List[str] (one ref per cand) or + List[List[str]] (multiple refs). We normalize to List[str] by taking the + first non-empty reference per sample when nested is provided. + """ + try: + from bert_score import score as bert_score + except Exception as e: + if strict: + raise + return None + kwargs = dict(model_type=model_type, verbose=False, rescale_with_baseline=rescale_with_baseline, batch_size=batch_size) + if lang and rescale_with_baseline: + kwargs['lang'] = lang + if num_layers is not None: + kwargs['num_layers'] = int(num_layers) + # Normalize refs shape to List[str] for robustness + try: + if refs and isinstance(refs[0], (list, tuple)): + refs_norm = [(r[0] if (isinstance(r, (list, tuple)) and r) else '') for r in refs] + else: + refs_norm = [str(r) for r in refs] + except Exception: + refs_norm = [str(r) for r in refs] if isinstance(refs, list) else [] + try: + _, _, F1 = bert_score(preds, refs_norm, **kwargs) + return [100.0 * float(x) for x in F1] + except AssertionError: + # auto-disable baseline and retry + try: + kwargs['rescale_with_baseline'] = False + _, _, F1 = bert_score(preds, refs_norm, **kwargs) + return [100.0 * float(x) for x in F1] + except Exception: + if strict: + raise + return None + except Exception: + if strict: + raise + return None + + +def dump_vqa_bertscore(base: Path, out_root: Path, model_type: str, lang: Optional[str], *, strict: bool, batch_size: int, num_layers: Optional[int]) -> None: + from core.metrics.metrics import _f1_score + metric = "vqa_bertscore"; cat = "vqa" + for order in ORDERS: + lines = _read_cache_lines(base, metric, cat, order) + out_path = out_root / metric / f"{order}.jsonl" + _ensure_dir(out_path) + preds, refs = [], [] + metas = [] + for obj in lines: + preds.append(obj.get('pred') or '') + ans = obj.get('answers') or [] + refs.append(ans[0] if ans else '') + metas.append(obj.get('meta') or None) + # Pass refs as a flat list (one reference per sample); _bertscore_per_sample + # will normalize shapes when needed. + scores = _bertscore_per_sample(preds, refs, model_type=model_type, rescale_with_baseline=False, batch_size=batch_size, num_layers=num_layers, lang=lang, strict=strict) + out_lines = [] + for i, (p, r, m) in enumerate(zip(preds, refs, metas)): + s = None + impl = 'bertscore_f1' + if scores is None: + s = (100.0 * float(_f1_score(p, [r]))) if r else None + impl = 'token_f1 (fallback)' + else: + # Treat empty reference as missing score to avoid misleading 0.0 + s = (None if (not r) else float(scores[i])) + out_lines.append(json.dumps({'i': i, 'order': order, 'metric': impl, 'pred': p, 'ref': r, 'score': s, 'meta': m}, ensure_ascii=False)) + out_path.write_text("\n".join(out_lines), encoding='utf-8') + + +def dump_caption_bertscore(base: Path, out_root: Path, model_type: str, lang: Optional[str], *, strict: bool, batch_size: int, num_layers: Optional[int]) -> None: + # Same as VQA bertscore but under captioning category + from core.metrics.metrics import _f1_score + metric = "captioning_bertscore"; cat = "captioning" + for order in ORDERS: + lines = _read_cache_lines(base, metric, cat, order) + out_path = out_root / metric / f"{order}.jsonl" + _ensure_dir(out_path) + preds, refs, metas = [], [], [] + for obj in lines: + preds.append(obj.get('pred') or '') + ans = obj.get('answers') or [] + refs.append(ans[0] if ans else '') + metas.append(obj.get('meta') or None) + scores = _bertscore_per_sample(preds, refs, model_type=model_type, rescale_with_baseline=False, batch_size=batch_size, num_layers=num_layers, lang=lang, strict=strict) + out_lines = [] + for i, (p, r, m) in enumerate(zip(preds, refs, metas)): + if scores is None: + s = (100.0 * float(_f1_score(p, [r]))) if r else None + impl = 'token_f1 (fallback)' + else: + s = (None if (not r) else float(scores[i])) + impl = 'bertscore_f1' + out_lines.append(json.dumps({'i': i, 'order': order, 'metric': impl, 'pred': p, 'ref': r, 'score': s, 'meta': m}, ensure_ascii=False)) + out_path.write_text("\n".join(out_lines), encoding='utf-8') + + +def dump_caption_cider(base: Path, out_root: Path) -> None: + metric = "captioning_cider"; cat = "captioning" + try: + from pycocoevalcap.cider.cider import Cider # type: ignore + except Exception: + Cider = None # type: ignore + for order in ORDERS: + lines = _read_cache_lines(base, metric, cat, order) + out_path = out_root / metric / f"{order}.jsonl" + _ensure_dir(out_path) + preds, refs, metas = [], [], [] + for obj in lines: + preds.append(obj.get('pred') or '') + ans = obj.get('answers') or [] + refs.append(ans[0] if ans else '') + metas.append(obj.get('meta') or None) + if Cider is None or not preds: + out_lines = [json.dumps({'i': i, 'order': order, 'metric': 'cider', 'pred': p, 'ref': r, 'score': None, 'meta': m}, ensure_ascii=False) for i, (p, r, m) in enumerate(zip(preds, refs, metas))] + out_path.write_text("\n".join(out_lines), encoding='utf-8') + continue + gts = {i: [refs[i]] for i in range(len(refs))} + res = {i: [preds[i]] for i in range(len(preds))} + scorer = Cider() + score, scores = scorer.compute_score(gts, res) + out_lines = [] + for i, (p, r, m) in enumerate(zip(preds, refs, metas)): + s = 100.0 * float(scores[i]) if scores is not None else None + out_lines.append(json.dumps({'i': i, 'order': order, 'metric': 'cider', 'pred': p, 'ref': r, 'score': s, 'meta': m}, ensure_ascii=False)) + out_path.write_text("\n".join(out_lines), encoding='utf-8') + + +def dump_classification_accuracy(base: Path, out_root: Path) -> None: + from core.metrics.metrics import parse_choice_letter + import re as _re + def _norm(s: str) -> str: + s = (s or '').lower().strip() + s = _re.sub(r"[\.,!?;:\-\(\)\[\]\{\}\'\"/\\,。?!:;()]", " ", s) + s = _re.sub(r"\s+", " ", s).strip() + return s + def _parse_opts(s: str) -> dict: + s = s or '' + opts = {} + for m in _re.finditer(r"[\((]([A-Za-z])[\))]\s*([^\n]+)", s): + opts[m.group(1).lower()] = _norm(m.group(2)) + for m in _re.finditer(r"(?m)^\s*([A-Za-z])[\).:]\s*([^\n]+)", s): + opts.setdefault(m.group(1).lower(), _norm(m.group(2))) + return opts + metric = 'classification_accuracy'; cat = 'classification' + for order in ORDERS: + lines = _read_cache_lines(base, metric, cat, order) + out_path = out_root / metric / f"{order}.jsonl" + _ensure_dir(out_path) + out_lines = [] + for i, obj in enumerate(lines): + pred_raw = obj.get('pred') or '' + ans = obj.get('answers') or [] + ref_raw = (ans[0] if ans else '') + meta = obj.get('meta') or {} + pr = _norm(pred_raw) + gr = _norm(ref_raw) + opts = _parse_opts(meta.get('inputs') or '') + gold_letter = None + gl = meta.get('gold_choice'); gold_letter = (gl.strip().lower() if isinstance(gl, str) and gl.strip() else None) + if not gold_letter and gr and opts: + for k, v in opts.items(): + if v and (v == gr or v in gr or gr in v): + gold_letter = k; break + pred_letter = parse_choice_letter(pred_raw) or '' + correct = None + if gold_letter: + correct = bool(pred_letter and pred_letter == gold_letter) + elif gr: + correct = bool(pr == gr) + out_lines.append(json.dumps({'i': i, 'order': order, 'metric': 'accuracy', 'pred': pred_raw, 'ref': ref_raw, 'correct': correct, 'pred_letter': pred_letter or None, 'gold_letter': gold_letter, 'meta': meta or None}, ensure_ascii=False)) + out_path.write_text("\n".join(out_lines), encoding='utf-8') + + +def dump_classification_f1(base: Path, out_root: Path) -> None: + from core.metrics.metrics import parse_choice_letter + metric = 'classification_f1'; cat = 'classification' + for order in ORDERS: + lines = _read_cache_lines(base, metric, cat, order) + out_path = out_root / metric / f"{order}.jsonl" + _ensure_dir(out_path) + out_lines = [] + for i, obj in enumerate(lines): + pred_raw = obj.get('pred') or '' + ans = obj.get('answers') or [] + ref_raw = (ans[0] if ans else '') + meta = obj.get('meta') or {} + gl = meta.get('gold_choice'); gold_letter = (gl.strip().lower() if isinstance(gl, str) and gl.strip() else None) + if not gold_letter and isinstance(ref_raw, str): + gold_letter = parse_choice_letter(ref_raw) + pred_letter = parse_choice_letter(pred_raw) + correct = bool(pred_letter and gold_letter and pred_letter == gold_letter) + out_lines.append(json.dumps({'i': i, 'order': order, 'metric': 'macro_f1', 'pred': pred_raw, 'ref': ref_raw, 'pred_letter': pred_letter, 'gold_letter': gold_letter, 'correct': correct, 'meta': meta or None}, ensure_ascii=False)) + out_path.write_text("\n".join(out_lines), encoding='utf-8') + + +def dump_reasoning_accuracy(base: Path, out_root: Path) -> None: + from core.metrics.metrics import exact_match, parse_choice_letter + import re as _re + def _norm(s: str) -> str: + s = (s or '').lower().strip() + s = _re.sub(r"[\.,!?;:\-\(\)\[\]\{\}\'\"/\\,。?!:;()]", " ", s) + s = _re.sub(r"\s+", " ", s).strip() + return s + def _map_yesno(pn: str) -> str | None: + if any(k in pn for k in ('yes','yeah','yep','true','是')): + return 'yes' + if any(k in pn for k in ('no',"don't",'not','false','否')): + return 'no' + return None + def _parse_opts(s: str) -> dict: + import re as _re2 + s = s or '' + opts = {} + for m in _re2.finditer(r"[\((]([A-Za-z])[\))]\s*([^\n]+)", s): + opts[m.group(1).lower()] = _norm(m.group(2)) + for m in _re2.finditer(r"(?m)^\s*([A-Za-z])[\).:]\s*([^\n]+)", s): + opts.setdefault(m.group(1).lower(), _norm(m.group(2))) + return opts + metric = 'reasoning_accuracy'; cat = 'reasoning' + for order in ORDERS: + lines = _read_cache_lines(base, metric, cat, order) + out_path = out_root / metric / f"{order}.jsonl" + _ensure_dir(out_path) + out_lines = [] + for i, obj in enumerate(lines): + pred_raw = obj.get('pred') or '' + ans = obj.get('answers') or [] + ref_raw = (ans[0] if ans else '') + meta = obj.get('meta') or {} + pred_n = _norm(pred_raw) + ref_n = _norm(ref_raw) + opts = _parse_opts(meta.get('inputs') or '') + gold_letter = None + gl = meta.get('gold_choice') + if isinstance(gl, str) and gl: + gold_letter = gl.strip().lower() + if not gold_letter and ref_n and opts: + for k, v in opts.items(): + if v == ref_n or (v and ref_n and (v in ref_n or ref_n in v)): + gold_letter = k; break + if not gold_letter and len(ref_n) == 1 and ref_n in 'abcdefghijklmnopqrstuvwxyz': + gold_letter = ref_n + pred_letter = parse_choice_letter(pred_raw) or '' + pred_label = '' + if not pred_letter and opts: + for k, v in opts.items(): + if v and pred_n and (v == pred_n or v in pred_n or pred_n in v): + pred_label = v; pred_letter = k; break + correct = None + if gold_letter: + correct = bool(pred_letter and pred_letter == gold_letter) + elif ref_n: + correct = bool(pred_n == ref_n or _map_yesno(pred_n) == ref_n) + out_lines.append(json.dumps({'i': i, 'order': order, 'metric': 'accuracy', 'pred': pred_raw, 'ref': ref_raw, 'pred_letter': pred_letter or None, 'gold_letter': gold_letter, 'correct': correct, 'meta': meta or None}, ensure_ascii=False)) + out_path.write_text("\n".join(out_lines), encoding='utf-8') + + +def dump_reasoning_ras(base: Path, out_root: Path, bertscore_model: str, lang: Optional[str], roscoe_model_path: Optional[str], scale_100: bool = True, auto_scale: bool = True, *, strict_bertscore: bool = False, bertscore_batch_size: int = 32, bertscore_num_layers: Optional[int] = None) -> None: + from core.metrics.metrics import _f1_score + metric = 'reasoning_ras'; cat = 'reasoning' + for order in ORDERS: + lines = _read_cache_lines(base, metric, cat, order) + out_path = out_root / metric / f"{order}.jsonl" + _ensure_dir(out_path) + preds, refs, metas = [], [], [] + for obj in lines: + preds.append(obj.get('pred') or '') + ref_text = obj.get('ref_text') or '' + if not ref_text: + ans = obj.get('answers') or [] + ref_text = ans[0] if ans else '' + refs.append(ref_text) + metas.append(obj.get('meta') or None) + + used = 'ras' + per = None + # Try roscoe per-sample first + try: + from core.metrics.roscoe_shim import evaluate_list as roscoe_eval_list + per = roscoe_eval_list(preds, refs, model_path=roscoe_model_path) + if per is not None: + # optional scale + mx = max(per) if per else 0.0 + need_scale = (mx <= 1.05) if auto_scale else True + if scale_100 and need_scale: + per = [float(x) * 100.0 for x in per] + except Exception: + per = None + + # Fallback to bertscore + if per is None: + per = _bertscore_per_sample(preds, [[r] for r in refs], model_type=bertscore_model, rescale_with_baseline=False, batch_size=bertscore_batch_size, num_layers=bertscore_num_layers, lang=lang, strict=strict_bertscore) + used = 'bertscore_f1' + # Last resort token-f1 + if per is None: + used = 'token_f1 (fallback)' + per = [100.0 * float(_f1_score(p, [r])) if r else None for p, r in zip(preds, refs)] + + out_lines = [] + for i, (p, r, m) in enumerate(zip(preds, refs, metas)): + s = None if per is None else (None if per[i] is None else float(per[i])) + out_lines.append(json.dumps({'i': i, 'order': order, 'metric': used, 'pred': p, 'ref': r, 'score': s, 'meta': m}, ensure_ascii=False)) + out_path.write_text("\n".join(out_lines), encoding='utf-8') + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument('--output-base', required=True) + ap.add_argument('--bertscore-model', default='roberta-large') + ap.add_argument('--bertscore-lang', default='') + ap.add_argument('--bertscore-batch-size', type=int, default=32) + ap.add_argument('--bertscore-num-layers', type=int, default=-1) + ap.add_argument('--strict-bertscore', action='store_true') + ap.add_argument('--roscoe-model-path', default='') + ap.add_argument('--ras-scale-100', action='store_true', help='Multiply RAS by 100 (auto-scale to 0..100 when values look like 0..1)') + args = ap.parse_args() + + base = Path(args.output_base) + out_root = base / 'per_sample' + dump_vqa_tokenf1(base, out_root) + nl = None if args.bertscore_num_layers < 0 else int(args.bertscore_num_layers) + dump_vqa_bertscore(base, out_root, model_type=args.bertscore_model, lang=(args.bertscore_lang or None), strict=bool(args.strict_bertscore), batch_size=int(args.bertscore_batch_size), num_layers=nl) + dump_caption_bertscore(base, out_root, model_type=args.bertscore_model, lang=(args.bertscore_lang or None), strict=bool(args.strict_bertscore), batch_size=int(args.bertscore_batch_size), num_layers=nl) + dump_caption_cider(base, out_root) + dump_classification_accuracy(base, out_root) + dump_classification_f1(base, out_root) + dump_reasoning_accuracy(base, out_root) + dump_reasoning_ras(base, out_root, bertscore_model=args.bertscore_model, lang=(args.bertscore_lang or None), roscoe_model_path=(args.roscoe_model_path or None), scale_100=bool(args.ras_scale_100), auto_scale=True, strict_bertscore=bool(args.strict_bertscore), bertscore_batch_size=int(args.bertscore_batch_size), bertscore_num_layers=nl) + print(f"[dump_per_sample_scores] wrote per-sample details under {out_root}") + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/eval/eval_multimodal_retriever_vqa.py b/ICL/LV/code/core/eval/eval_multimodal_retriever_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..10d0784637e68e39bfac145e4f3028d968f89c84 --- /dev/null +++ b/ICL/LV/code/core/eval/eval_multimodal_retriever_vqa.py @@ -0,0 +1,428 @@ +#!/usr/bin/env python3 +import argparse +import json +import hashlib +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn.functional as F +from transformers import AutoProcessor, AutoModel +from PIL import Image +import numpy as np +import random + +from core.datasets.m3it_reader import ( + iter_m3it_samples, + load_instructions, + read_jsonl, +) +from core.datasets.m3it_reader import _candidate_split_files as _cand_splits +from core.datasets.m3it_reader import _IMAGE_KEYS, _TEXT_IN_KEYS, _TEXT_OUT_KEYS, _b64_to_image_path +from core.metrics.metrics import token_f1, bertscore_f1 +from core.prompting.openai_segments import openai_to_list_format + + +VQA_SUBTASKS = [ + 'vqa/vqav2', 'vqa/docvqa', 'vqa/ocr-vqa', 'vqa/st-vqa', 'vqa/text-vqa', 'vqa/gqa', 'vqa/okvqa', 'vqa/a-okvqa', +] + + +def _extract_uid(rec: Dict, fallback: str) -> str: + for k in ('id', 'image_id'): + v = rec.get(k) + if isinstance(v, (str, int)): + return str(v) + mv = rec.get('meta', {}) if isinstance(rec.get('meta'), dict) else {} + for k in ('img_id', 'id', 'image_id'): + v = mv.get(k) + if isinstance(v, (str, int)): + return str(v) + return fallback + + +def _prefer_demo_splits(eval_split: str) -> Tuple[str, ...]: + base = ('val', 'validation', 'train', 'dev') + es = (eval_split or '').lower() + exclude = {es} + if es in {'val', 'validation'}: + exclude |= {'val', 'validation'} + return tuple(s for s in base if s not in exclude) + + +def _img_sig_from_rec(rec: Dict) -> str: + try: + for k in _IMAGE_KEYS: + v = rec.get(k) + if isinstance(v, str) and len(v) > 100: + return hashlib.sha1(v.encode('utf-8')).hexdigest() + mv = rec.get('meta') if isinstance(rec.get('meta'), dict) else None + if mv and isinstance(mv.get(k), str) and len(mv[k]) > 100: + return hashlib.sha1(mv[k].encode('utf-8')).hexdigest() + except Exception: + pass + return '' + + +_IMG_PSIG_CACHE: Dict[Tuple[str, int, int], str] = {} + + +def _image_path_sig(path: str, size: int = 32) -> str: + try: + st = Path(path).stat() + key = (path, int(st.st_mtime), int(st.st_size)) + if key in _IMG_PSIG_CACHE: + return _IMG_PSIG_CACHE[key] + im = Image.open(path).convert('RGB').resize((size, size), resample=Image.BICUBIC) + arr = np.asarray(im, dtype=np.uint8) + h = hashlib.sha1(arr.tobytes()).hexdigest() + _IMG_PSIG_CACHE[key] = h + return h + except Exception: + return '' + + +def load_pool_items(dataset_root: Path, subdir: str, cache_dir: Path, max_items: Optional[int] = None, + prefer: Optional[Sequence[str]] = None) -> Tuple[List[Dict], List[str]]: + path = None + for split in (tuple(prefer) if prefer is not None else ('val', 'validation', 'train', 'dev')): + for p in _cand_splits(dataset_root, subdir, split): + if p.exists(): + path = p; break + if path is not None: + break + if path is None: + raise FileNotFoundError(f'No demo split file found for {subdir}') + out_recs: List[Dict] = [] + img_paths: List[str] = [] + demo_cache = cache_dir / f'_demo_{subdir.replace("/", "_")}' + demo_cache.mkdir(parents=True, exist_ok=True) + idx = 0 + for rec in read_jsonl(path): + # image b64 + img_b64 = None + for k in _IMAGE_KEYS: + v = rec.get(k) + if isinstance(v, str) and len(v) > 100: + img_b64 = v; break + mv = rec.get('meta') if isinstance(rec.get('meta'), dict) else None + if mv and isinstance(mv.get(k), str) and len(mv[k]) > 100: + img_b64 = mv[k]; break + if not img_b64: + continue + # text in/out + text_in = '' + for k in _TEXT_IN_KEYS: + v = rec.get(k) or (rec.get('meta') or {}).get(k) + if isinstance(v, str): + text_in = v; break + answers = None + for k in _TEXT_OUT_KEYS: + v = rec.get(k) or (rec.get('meta') or {}).get(k) + if isinstance(v, str): + answers = [v]; break + if isinstance(v, list) and all(isinstance(x, str) for x in v): + answers = list(v); break + uid = _extract_uid(rec, f'{idx:08d}') + img_path = _b64_to_image_path(img_b64, demo_cache, uid) + img_sig = _img_sig_from_rec(rec) + out_recs.append({'uid': uid, 'text_in': text_in or '', 'text_out': (answers or [''])[0], 'img_sig': img_sig}) + img_paths.append(img_path) + idx += 1 + if max_items is not None and len(out_recs) >= max_items: + break + return out_recs, img_paths + + +def load_adapter(name: str, model_path: str): + name = (name or '').lower() + if name in ('idefics2','idefics','i2'): + from adapters import idefics2_adapter as A + elif name in ('qwen-vl','qwenvl','qwen'): + from adapters import qwen_vl_adapter as A + elif name in ('qwen3-vl','qwen3vl','qwen3'): + from adapters import qwen3vl_adapter as A + elif name in ('gemma3','gemma-3','gemma'): + from adapters import gemma3_adapter as A + else: + raise ValueError(f'Unknown adapter: {name}') + return A.create(model_path) + + +def _infer_bt_target_crop(model) -> int: + return 336 + + +def _safe_image_args_for_processor(processor, model) -> Dict: + # For some BridgeTower variants, image size is fixed; allow passing a conservative size + # Here we keep defaults and let encode_pairs handle fallback if needed + return {} + + +def _manual_bt_image_preprocess(processor, pil_imgs: List[Image.Image], target: int) -> Dict[str, torch.Tensor]: + means = getattr(getattr(processor, 'image_processor', processor), 'image_mean', [0.48145466, 0.4578275, 0.40821073]) + stds = getattr(getattr(processor, 'image_processor', processor), 'image_std', [0.26862954, 0.26130258, 0.27577711]) + means = np.array(means, dtype=np.float32).reshape(3, 1, 1) + stds = np.array(stds, dtype=np.float32).reshape(3, 1, 1) + + batch = [] + for im in pil_imgs: + im = im.convert('RGB') + w, h = im.size + if w <= 0 or h <= 0: + w = h = max(1, int(target)) + im = im.resize((w, h), resample=Image.BICUBIC) + se = min(w, h) + scale = float(target) / float(max(1, se)) + new_w = max(1, int(round(w * scale))) + new_h = max(1, int(round(h * scale))) + im = im.resize((new_w, new_h), resample=Image.BICUBIC) + left = max(0, (im.width - target) // 2) + top = max(0, (im.height - target) // 2) + right = left + target + bottom = top + target + im = im.crop((left, top, right, bottom)) + arr = np.asarray(im).astype(np.float32) / 255.0 + arr = arr.transpose(2, 0, 1) + arr = (arr - means) / stds + batch.append(arr) + x = np.stack(batch, axis=0) + return {'pixel_values': torch.from_numpy(x)} + + +def encode_pairs(processor, model, image_paths: List[str], texts: List[str], device: torch.device, batch_size: int = 8) -> torch.Tensor: + embs: List[torch.Tensor] = [] + model.eval() + with torch.no_grad(): + for i in range(0, len(texts), batch_size): + chunk_t = texts[i:i+batch_size] + chunk_i = image_paths[i:i+batch_size] + pil_imgs = [] + for p in chunk_i: + img = Image.open(p) + if img.mode != 'RGB': + img = img.convert('RGB') + pil_imgs.append(img) + # Build inputs by calling tokenizer and image_processor separately + img_args = _safe_image_args_for_processor(processor, model) + try: + tok = processor.tokenizer( + text=chunk_t, + padding=True, + truncation=True, + max_length=256, + return_tensors='pt', + ) + img = processor.image_processor(images=pil_imgs, return_tensors='pt', **img_args) + inputs = {**tok, **img} + except ValueError as e: + if 'height and width must be > 0' in str(e): + target = _infer_bt_target_crop(model) + img = _manual_bt_image_preprocess(processor, pil_imgs, target) + inputs = {**tok, **img} + else: + raise + inputs = {k: v.to(device) for k, v in inputs.items()} + try: + out = model(**inputs) + except RuntimeError as re: + msg = str(re) + if 'must match the size of tensor' in msg or 'The size of tensor a' in msg: + target = _infer_bt_target_crop(model) + img = _manual_bt_image_preprocess(processor, pil_imgs, target) + inputs.update({k: v.to(device) for k, v in img.items()}) + out = model(**inputs) + else: + raise + if hasattr(out, 'pooler_output') and out.pooler_output is not None: + pooled = out.pooler_output + elif isinstance(out, dict) and 'pooler_output' in out: + pooled = out['pooler_output'] + else: + last = out.last_hidden_state if hasattr(out, 'last_hidden_state') else (list(out.values())[0] if isinstance(out, dict) else out[0]) + pooled = last.mean(dim=1) + pooled = F.normalize(pooled, p=2, dim=-1) + embs.append(pooled.detach().cpu()) + return torch.cat(embs, dim=0) if embs else torch.empty(0, 1024) + + +def load_adapter(name: str, model_path: str): + name = (name or '').lower() + if name in ('idefics2','idefics','i2'): + from adapters import idefics2_adapter as A + elif name in ('qwen-vl','qwenvl','qwen'): + from adapters import qwen_vl_adapter as A + elif name in ('qwen3-vl','qwen3vl','qwen3'): + from adapters import qwen3vl_adapter as A + elif name in ('gemma3','gemma-3','gemma'): + from adapters import gemma3_adapter as A + else: + raise ValueError(f'Unknown adapter: {name}') + return A.create(model_path) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--adapter', required=True, help='idefics2 | qwen-vl | qwen3-vl | gemma3') + ap.add_argument('--model-path', required=True, help='generator model path') + ap.add_argument('--dataset-root', required=True) + ap.add_argument('--retriever-model-path', required=True, help='Path to BridgeTower-like model') + ap.add_argument('--output-dir', default='runs/m3it_multimodal_retriever_vqa') + ap.add_argument('--total-samples', type=int, default=500) + ap.add_argument('--k-shots', type=int, default=3) + ap.add_argument('--temperature', type=float, default=0.2) + ap.add_argument('--top-p', type=float, default=1.0) + ap.add_argument('--max-new-tokens', type=int, default=32) + ap.add_argument('--split', type=str, default='test') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--use-paper-instruction', action='store_true') + ap.add_argument('--no-instruction', action='store_true') + ap.add_argument('--auto-detect', action='store_true') + ap.add_argument('--bertscore-model', type=str, default='roberta-large') + ap.add_argument('--no-bertscore-baseline', action='store_true') + ap.add_argument('--bertscore-batch-size', type=int, default=32) + ap.add_argument('--bertscore-lang', type=str, default='', help="Language code for BERTScore baseline rescaling (e.g., 'en', 'zh')") + args = ap.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True) + cache_dir = out_dir / '_image_cache' + feat_cache = out_dir / '_feat_cache'; feat_cache.mkdir(parents=True, exist_ok=True) + + # Init retriever encoder + processor = AutoProcessor.from_pretrained(args.retriever_model_path, trust_remote_code=True) + mm_model = AutoModel.from_pretrained(args.retriever_model_path, trust_remote_code=True).to(device).eval() + + # Init generator adapter + adapter = load_adapter(args.adapter, args.model_path) + + tasks = VQA_SUBTASKS + if args.auto_detect: + root = Path(args.dataset_root) + present = [t for t in VQA_SUBTASKS if any(p.exists() for p in _cand_splits(root, t, args.split))] + if present: + tasks = present + + def _quota(total: int, n: int) -> List[int]: + base = total // n; rem = total % n + return [base + (1 if i < rem else 0) for i in range(n)] + per_task = _quota(args.total_samples, len(tasks)) + + all_preds: List[str] = [] + all_refs: List[List[str]] = [] + details: List[Dict] = [] + + PAPER_VQA_INTRO = 'Examine the image and answer the question.' + rng = random.Random(args.seed) + + for subdir, quota in zip(tasks, per_task): + demo_prefer = _prefer_demo_splits(args.split) + recs, imgs = load_pool_items(Path(args.dataset_root), subdir, cache_dir, prefer=demo_prefer) + texts = [r['text_in'] for r in recs] + feat_path = feat_cache / f'multimodal_{subdir.replace("/","_")}_pool.pt' + if feat_path.exists(): + pool_emb = torch.load(feat_path, map_location='cpu') + else: + pool_emb = encode_pairs(processor, mm_model, imgs, texts, device) + torch.save(pool_emb, feat_path) + + eval_pool = [s for s in iter_m3it_samples(args.dataset_root, subdir, split=args.split, cache_dir=str(cache_dir))] + if not eval_pool: + print(f'Skipping {subdir}: empty eval pool') + continue + select_n = min(quota, len(eval_pool)) + sel_indices = rng.sample(range(len(eval_pool)), k=select_n) + taken = 0 + for _idx in sel_indices: + smp = eval_pool[_idx] + q_emb = encode_pairs(processor, mm_model, [smp.image_path], [smp.text or ''], device) + sim = (q_emb @ pool_emb.T).squeeze(0) + q_uid = _extract_uid(smp.raw, '') if isinstance(smp.raw, dict) else '' + q_sig = _img_sig_from_rec(smp.raw) if isinstance(smp.raw, dict) else '' + q_psig = _image_path_sig(smp.image_path) + q_txt = (smp.text or '').strip().lower() + mask = torch.ones(sim.shape[0], dtype=torch.bool, device=sim.device) + for i, r in enumerate(recs): + if q_uid and r.get('uid') == q_uid: + mask[i] = False; continue + if q_sig and r.get('img_sig') and r['img_sig'] == q_sig: + mask[i] = False; continue + if q_txt and (r.get('text_in') or '').strip().lower() == q_txt: + mask[i] = False; continue + mask &= (sim < 0.999).to(dtype=mask.dtype, device=mask.device).bool() + sim[~mask] = -1e4 + pre_k = min(max(args.k_shots * 50, 500), sim.numel()) + cand = [i for i in torch.topk(sim, k=pre_k).indices.tolist() if mask[i].item()] + if q_psig: + cand = [i for i in cand if _image_path_sig(imgs[i]) != q_psig] + idxs = cand[:args.k_shots] + if len(idxs) < min(args.k_shots, sim.numel()): + rest = [i for i in range(sim.numel()) if mask[i].item() and i not in idxs] + if q_psig: + rest = [i for i in rest if _image_path_sig(imgs[i]) != q_psig] + rng.shuffle(rest) + idxs.extend(rest[:max(0, args.k_shots - len(idxs))]) + demos = [recs[i] for i in idxs] + demo_imgs = [imgs[i] for i in idxs] + + insts = [] if args.no_instruction else load_instructions(Path(args.dataset_root), subdir) + ds_inst = '' + if isinstance(insts, list) and insts: + ds_inst = '\n'.join([s for s in insts if isinstance(s, str) and s.strip()]) + base_inst = 'Examine the image and answer the question with a short answer.' + inst = (PAPER_VQA_INTRO + (('\n' + ds_inst) if ds_inst else '')) if args.use_paper_instruction else (ds_inst or base_inst) + + oa_items: List[Dict] = [] + if inst: + oa_items.append({'type': 'text', 'text': inst}) + for d, img in zip(demos, demo_imgs): + oa_items.append({'type': 'image_url', 'image_url': img}) + oa_items.append({'type': 'text', 'text': f"[REQUEST]\n{(d.get('text_in') or '').strip()}\n[RESPONSE]\n{(d.get('text_out') or '').strip()}"}) + oa_items.append({'type': 'image_url', 'image_url': smp.image_path}) + oa_items.append({'type': 'text', 'text': f"[REQUEST]\n{(smp.text or '').strip()}\n[RESPONSE]"}) + segs = openai_to_list_format(oa_items, cache_dir=cache_dir / '_oa_cache') + + response = adapter.generate_from_segments( + segs, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + ) + + all_preds.append(response) + all_refs.append(smp.answers or []) + details.append({'task': subdir, 'image_path': smp.image_path, 'text': smp.text, 'answers': smp.answers, 'demo_uids': [d['uid'] for d in demos]}) + taken += 1 + print(f'{subdir}: {taken} eval samples | demo-pool={len(recs)} | k={args.k_shots}') + + tf1 = token_f1(all_preds, all_refs) + try: + bsf1 = bertscore_f1( + all_preds, + all_refs, + model_type=args.bertscore_model, + rescale_with_baseline=not args.no_bertscore_baseline, + batch_size=args.bertscore_batch_size, + lang=(args.bertscore_lang or None), + ) + except Exception: + bsf1 = None + + out = { + 'setting': 'few-shot-multimodal-retriever', + 'k_shots': args.k_shots, + 'adapter': args.adapter, + 'model_path': args.model_path, + 'total': len(all_preds), + 'metrics': {'token_f1': tf1, 'bertscore_f1': bsf1}, + 'predictions': [ + {'pred': p, 'answers': r, 'meta': m} + for p, r, m in zip(all_preds, all_refs, details) + ] + } + (out_dir / f'vqa_multimodal_retriever_{args.k_shots}shot.json').write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f'Multimodal Retriever (k={args.k_shots}) Token-F1={tf1:.2f} BERTScore-F1=' + (f'{bsf1:.2f}' if bsf1 is not None else 'NA')) + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/eval/eval_order_caption_cider.py b/ICL/LV/code/core/eval/eval_order_caption_cider.py new file mode 100644 index 0000000000000000000000000000000000000000..81169b82ef60581a89f5f1590617b4be5cf88ee5 --- /dev/null +++ b/ICL/LV/code/core/eval/eval_order_caption_cider.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +"""Captioning CIDEr per modal order. +If pycocoevalcap is unavailable, returns NA. +""" + +import argparse +import json +from pathlib import Path + +from core.eval.order_eval_core import run_predictions + + +def compute_cider(preds, refs): + try: + from pycocoevalcap.cider.cider import Cider # type: ignore + except Exception: + return None + refs_str = [r[0] if r else '' for r in refs] + gts = {i: [refs_str[i]] for i in range(len(refs_str))} + res = {i: [preds[i]] for i in range(len(preds))} + scorer = Cider() + score, _ = scorer.compute_score(gts, res) + return 100.0 * float(score) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--adapter', required=True) + ap.add_argument('--model-path', required=True) + ap.add_argument('--dataset-root', required=True) + ap.add_argument('--retriever-model-path', required=True) + ap.add_argument('--output-dir', default='runs/order_metrics') + ap.add_argument('--orders', type=str, default='image-text,text-image,text-image-text') + ap.add_argument('--total-samples', type=int, default=500) + ap.add_argument('--k-shots', type=int, default=3) + ap.add_argument('--temperature', type=float, default=0.6) + ap.add_argument('--top-p', type=float, default=1.0) + ap.add_argument('--max-new-tokens', type=int, default=128) + ap.add_argument('--split', type=str, default='val') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--reuse-cache', action='store_true') + args = ap.parse_args() + + orders = [o.strip().lower() for o in args.orders.split(',') if o.strip()] + base_out = Path(args.output_dir) / 'captioning_cider' + + preds = run_predictions( + adapter=args.adapter, + model_path=args.model_path, + dataset_root=args.dataset_root, + retriever_model_path=args.retriever_model_path, + output_dir=str(base_out), + orders=orders, + categories=['captioning'], + total_samples=args.total_samples, + k_shots=args.k_shots, + split=args.split, + seed=args.seed, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + auto_detect=True, + reuse_cache=args.reuse_cache, + ) + + summary = {} + for order in orders: + p, r = preds[order]['captioning'] + keep_idx = [] + for i, ri in enumerate(r): + first = (ri[0] if (isinstance(ri, (list, tuple)) and ri) else ri) + first = first if isinstance(first, str) else '' + if first.strip(): + keep_idx.append(i) + if keep_idx: + p = [p[i] for i in keep_idx] + r = [[(r[i][0] if r[i] else '')] for i in keep_idx] + else: + p, r = [], [] + + score = compute_cider(p, r) if p else None + out_dir = base_out / order; out_dir.mkdir(parents=True, exist_ok=True) + out = {'order': order, 'metric': 'captioning_cider', 'score': (None if score is None else float(score))} + (out_dir / 'result.json').write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f'[CIDEr] order={order} score=' + (f'{score:.2f}' if score is not None else 'NA')) + summary[order] = out['score'] + + (base_out / 'summary.json').write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding='utf-8') + print('[CIDEr] all orders done:', summary) + + +if __name__ == '__main__': + main() + diff --git a/ICL/LV/code/core/eval/eval_order_classification_accuracy.py b/ICL/LV/code/core/eval/eval_order_classification_accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..6095eb72efe3dec52b70aa15e7707978698acb8a --- /dev/null +++ b/ICL/LV/code/core/eval/eval_order_classification_accuracy.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +"""Classification Accuracy per modal order. +Relies on cached meta.gold_choice when available; otherwise falls back to +parsing the first reference with choice-letter heuristics. +""" + +import argparse +import json +from pathlib import Path + +from core.eval.order_eval_core import run_predictions +from core.metrics.metrics import parse_choice_letter + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--adapter', required=True) + ap.add_argument('--model-path', required=True) + ap.add_argument('--dataset-root', required=True) + ap.add_argument('--retriever-model-path', required=True) + ap.add_argument('--output-dir', default='runs/order_metrics') + ap.add_argument('--orders', type=str, default='image-text,text-image,text-image-text') + ap.add_argument('--total-samples', type=int, default=500) + ap.add_argument('--k-shots', type=int, default=3) + ap.add_argument('--temperature', type=float, default=0.6) + ap.add_argument('--top-p', type=float, default=1.0) + ap.add_argument('--max-new-tokens', type=int, default=128) + ap.add_argument('--split', type=str, default='val') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--reuse-cache', action='store_true') + args = ap.parse_args() + + orders = [o.strip().lower() for o in args.orders.split(',') if o.strip()] + base_out = Path(args.output_dir) / 'classification_accuracy' + + preds = run_predictions( + adapter=args.adapter, + model_path=args.model_path, + dataset_root=args.dataset_root, + retriever_model_path=args.retriever_model_path, + output_dir=str(base_out), + orders=orders, + categories=['classification'], + total_samples=args.total_samples, + k_shots=args.k_shots, + split=args.split, + seed=args.seed, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + auto_detect=True, + reuse_cache=args.reuse_cache, + ) + + # Helpers for robust classification scoring (options/yes-no/SNLI) + import re as _re + def _norm(s: str) -> str: + s = (s or '').lower().strip() + s = _re.sub(r"[\.,!?;:\-\(\)\[\]\{\}\'\"/\\,。?!:;()]", " ", s) + s = _re.sub(r"\s+", " ", s).strip() + return s + def _parse_opts(s: str) -> dict: + s = s or '' + opts = {} + for m in _re.finditer(r"[\((]([A-Za-z])[\))]\s*([^\n]+)", s): + opts[m.group(1).lower()] = _norm(m.group(2)) + for m in _re.finditer(r"(?m)^\s*([A-Za-z])[\).:]\s*([^\n]+)", s): + opts.setdefault(m.group(1).lower(), _norm(m.group(2))) + return opts + def _map_yesno(pn: str) -> str | None: + if any(k in pn for k in ('yes','yeah','yep','affirm','correct','right','true','是')): + return 'yes' + if any(k in pn for k in ('no',"don't",'not','incorrect','wrong','false','否')): + return 'no' + return None + def _map_snli(pn: str) -> str | None: + if any(k in pn for k in ('entail','supported','support','yes','true')): + return 'entailment' + if any(k in pn for k in ('contradict','not ',"can't",'cannot','no','false')): + return 'contradiction' + if any(k in pn for k in ('neutral','unknown','uncertain','maybe','cannot be determined','undetermined')): + return 'neutral' + return None + + summary = {} + for order in orders: + p, r = preds[order]['classification'] + cache_file = base_out / '_cache' / f'classification__{order}.jsonl' + lines = cache_file.read_text(encoding='utf-8').splitlines() if cache_file.exists() else [] + + total = 0 + correct = 0 + if lines: + for i, line in enumerate(lines): + if not line.strip(): + continue + obj = json.loads(line) + meta = obj.get('meta') or {} + task = (meta.get('task') or '').lower() + pred_raw = p[i] if i < len(p) else (obj.get('pred') or '') + ans = obj.get('answers') or [] + ref_raw = (ans[0] if ans else '') + pr = _norm(pred_raw) + gr = _norm(ref_raw) + opts = _parse_opts(meta.get('inputs') or '') + gold_letter = None + gl = meta.get('gold_choice') + if isinstance(gl, str) and gl.strip(): + gold_letter = gl.strip().lower() + if not gold_letter and gr and opts: + for k, v in opts.items(): + if v and (v == gr or v in gr or gr in v): + gold_letter = k; break + pred_letter = parse_choice_letter(pred_raw) or '' + if gold_letter: + total += 1 + if pred_letter and pred_letter == gold_letter: + correct += 1 + elif pred_letter and pred_letter in opts and gr: + if opts[pred_letter] == gr: + correct += 1 + continue + if 'snli' in task: + pm = _map_snli(pr) + if pm and gr: + total += 1 + if pm == gr: + correct += 1 + continue + yn = _map_yesno(pr) + if yn and gr in ('yes','no'): + total += 1 + if yn == gr: + correct += 1 + continue + if gr: + total += 1 + if pr == gr: + correct += 1 + + if total == 0: + gold_letters = [] + for ri in r: + first = (ri[0] if ri else '') + gl = parse_choice_letter(first) if isinstance(first, str) else None + gold_letters.append(gl) + for pred, gl in zip(p, gold_letters): + pl = parse_choice_letter(pred) + if gl and pl: + total += 1 + if gl.lower().strip() == pl.lower().strip(): + correct += 1 + + score = 100.0 * correct / max(1, total) + out_dir = base_out / order; out_dir.mkdir(parents=True, exist_ok=True) + out = {'order': order, 'metric': 'classification_accuracy', 'score': float(score)} + (out_dir / 'result.json').write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f'[Classification Acc] order={order} score={score:.2f}') + summary[order] = out['score'] + + (base_out / 'summary.json').write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding='utf-8') + print('[Classification Acc] all orders done:', summary) + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/eval/eval_order_classification_f1.py b/ICL/LV/code/core/eval/eval_order_classification_f1.py new file mode 100644 index 0000000000000000000000000000000000000000..051e49d499bdaa2fa5560b15894a3104dcded2e3 --- /dev/null +++ b/ICL/LV/code/core/eval/eval_order_classification_f1.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +"""Classification Macro-F1 per modal order. +Uses choice-letter mapping when available. +""" + +import argparse +import json +from collections import Counter, defaultdict +from pathlib import Path + +from core.eval.order_eval_core import run_predictions +from core.metrics.metrics import parse_choice_letter + + +def macro_f1(pred_letters, gold_letters) -> float: + labels = sorted({g for g in gold_letters if g}) + if not labels: + return 0.0 + per_label = [] + for lbl in labels: + tp = sum(1 for p, g in zip(pred_letters, gold_letters) if p == lbl and g == lbl) + fp = sum(1 for p, g in zip(pred_letters, gold_letters) if p == lbl and g != lbl) + fn = sum(1 for p, g in zip(pred_letters, gold_letters) if p != lbl and g == lbl) + if tp == 0 and fp == 0 and fn == 0: + f1 = 0.0 + else: + prec = tp / max(1, tp + fp) + rec = tp / max(1, tp + fn) + f1 = 0.0 if (prec + rec) == 0 else 2 * prec * rec / (prec + rec) + per_label.append(f1) + return 100.0 * (sum(per_label) / len(per_label)) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--adapter', required=True) + ap.add_argument('--model-path', required=True) + ap.add_argument('--dataset-root', required=True) + ap.add_argument('--retriever-model-path', required=True) + ap.add_argument('--output-dir', default='runs/order_metrics') + ap.add_argument('--orders', type=str, default='image-text,text-image,text-image-text') + ap.add_argument('--total-samples', type=int, default=500) + ap.add_argument('--k-shots', type=int, default=3) + ap.add_argument('--temperature', type=float, default=0.6) + ap.add_argument('--top-p', type=float, default=1.0) + ap.add_argument('--max-new-tokens', type=int, default=128) + ap.add_argument('--split', type=str, default='val') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--reuse-cache', action='store_true') + args = ap.parse_args() + + orders = [o.strip().lower() for o in args.orders.split(',') if o.strip()] + base_out = Path(args.output_dir) / 'classification_f1' + + preds = run_predictions( + adapter=args.adapter, + model_path=args.model_path, + dataset_root=args.dataset_root, + retriever_model_path=args.retriever_model_path, + output_dir=str(base_out), + orders=orders, + categories=['classification'], + total_samples=args.total_samples, + k_shots=args.k_shots, + split=args.split, + seed=args.seed, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + auto_detect=True, + reuse_cache=args.reuse_cache, + ) + + import re as _re + def _norm(s: str) -> str: + s = (s or '').lower().strip() + s = _re.sub(r"[\.,!?;:\-\(\)\[\]\{\}\'\"/\\,。?!:;()]", " ", s) + s = _re.sub(r"\s+", " ", s).strip() + return s + def _parse_opts(s: str) -> dict: + s = s or '' + opts = {} + for m in _re.finditer(r"[\((]([A-Za-z])[\))]\s*([^\n]+)", s): + opts[m.group(1).lower()] = _norm(m.group(2)) + for m in _re.finditer(r"(?m)^\s*([A-Za-z])[\).:]\s*([^\n]+)", s): + opts.setdefault(m.group(1).lower(), _norm(m.group(2))) + return opts + + summary = {} + for order in orders: + p, r = preds[order]['classification'] + cache_file = base_out / '_cache' / f'classification__{order}.jsonl' + gold_letters = [] + # Try derive gold letters from cache (gold_choice or matching options against ref text) + if cache_file.exists(): + lines = cache_file.read_text(encoding='utf-8').splitlines() + for i, line in enumerate(lines): + if not line.strip(): + continue + obj = json.loads(line) + meta = obj.get('meta') or {} + gl = meta.get('gold_choice') + if isinstance(gl, str) and gl.strip(): + gold_letters.append(gl.strip().lower()) + else: + opts = _parse_opts(meta.get('inputs') or '') + ans = obj.get('answers') or [] + ref = _norm((ans[0] if ans else '')) + if ref and opts: + found = None + for k, v in opts.items(): + if v and (v == ref or v in ref or ref in v): + found = k; break + gold_letters.append(found) + else: + gold_letters.append(None) + # Fallback: derive from refs when cache didn't include valid gold letters + if not any(isinstance(gl, str) and gl for gl in gold_letters): + gold_letters = [] + for ri in r: + first = (ri[0] if ri else '') + gl = parse_choice_letter(first) if isinstance(first, str) else None + gold_letters.append(gl) + pred_letters = [parse_choice_letter(x) for x in p] + # Only keep indices with both gold and pred letters + keep = [i for i, (pl, gl) in enumerate(zip(pred_letters, gold_letters)) if pl and gl] + if keep: + pred_letters = [pred_letters[i] for i in keep] + gold_letters = [gold_letters[i] for i in keep] + score = macro_f1(pred_letters, gold_letters) if keep else 0.0 + out_dir = base_out / order; out_dir.mkdir(parents=True, exist_ok=True) + out = {'order': order, 'metric': 'classification_macro_f1', 'score': float(score)} + (out_dir / 'result.json').write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f'[Classification Macro-F1] order={order} score={score:.2f}') + summary[order] = out['score'] + + (base_out / 'summary.json').write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding='utf-8') + print('[Classification Macro-F1] all orders done:', summary) + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/eval/eval_order_reasoning_accuracy.py b/ICL/LV/code/core/eval/eval_order_reasoning_accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..4663817e27d5df2463a40dd2db269da6d68b5d64 --- /dev/null +++ b/ICL/LV/code/core/eval/eval_order_reasoning_accuracy.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +"""Reasoning Accuracy per modal order. +Heuristic: when a gold choice letter is available, compute accuracy on letters; +otherwise fall back to exact match against provided textual answers. +""" + +import argparse +import json +from pathlib import Path + +from core.eval.order_eval_core import run_predictions +from core.metrics.metrics import exact_match, parse_choice_letter + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--adapter', required=True) + ap.add_argument('--model-path', required=True) + ap.add_argument('--dataset-root', required=True) + ap.add_argument('--retriever-model-path', required=True) + ap.add_argument('--output-dir', default='runs/order_metrics') + ap.add_argument('--orders', type=str, default='image-text,text-image,text-image-text') + ap.add_argument('--total-samples', type=int, default=500) + ap.add_argument('--k-shots', type=int, default=3) + ap.add_argument('--temperature', type=float, default=0.6) + ap.add_argument('--top-p', type=float, default=1.0) + ap.add_argument('--max-new-tokens', type=int, default=128) + ap.add_argument('--split', type=str, default='val') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--reuse-cache', action='store_true') + args = ap.parse_args() + + orders = [o.strip().lower() for o in args.orders.split(',') if o.strip()] + base_out = Path(args.output_dir) / 'reasoning_accuracy' + + preds = run_predictions( + adapter=args.adapter, + model_path=args.model_path, + dataset_root=args.dataset_root, + retriever_model_path=args.retriever_model_path, + output_dir=str(base_out), + orders=orders, + categories=['reasoning'], + total_samples=args.total_samples, + k_shots=args.k_shots, + split=args.split, + seed=args.seed, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + auto_detect=True, + reuse_cache=args.reuse_cache, + ) + + # More robust scoring aligned with QWEN3VL-code + import re as _re + def _norm(s: str) -> str: + s = (s or '').lower().strip() + s = _re.sub(r"[\.,!?;:\-\(\)\[\]\{\}\'\"/\\,。?!:;()]", " ", s) + s = _re.sub(r"\s+", " ", s).strip() + return s + def _map_yesno(pn: str) -> str | None: + if any(k in pn for k in ('yes','yeah','yep','true','是')): + return 'yes' + if any(k in pn for k in ('no',"don't",'not','false','否')): + return 'no' + return None + def _extract_first_number(s: str) -> str | None: + m = _re.search(r"[-+]?\d+", s or '') + return m.group(0) if m else None + # Note: numeric words handling removed for strict parity with original logic + def _parse_opts(s: str) -> dict: + s = s or '' + opts = {} + for m in _re.finditer(r"[\((]([A-Za-z])[\))]\s*([^\n]+)", s): + opts[m.group(1).lower()] = _norm(m.group(2)) + for m in _re.finditer(r"(?m)^\s*([A-Za-z])[\).:]\s*([^\n]+)", s): + opts.setdefault(m.group(1).lower(), _norm(m.group(2))) + return opts + # Note: task-aware typed extraction and canonicalization removed, restoring original strict logic + + summary = {} + for order in orders: + p, r = preds[order]['reasoning'] + cache_file = base_out / '_cache' / f'reasoning__{order}.jsonl' + total = 0 + correct = 0 + if cache_file.exists(): + for i, line in enumerate(cache_file.read_text(encoding='utf-8').splitlines()): + if not line.strip(): + continue + obj = json.loads(line) + pred_raw = p[i] if i < len(p) else (obj.get('pred') or '') + pred_n = _norm(pred_raw) + meta = obj.get('meta') or {} + ans = obj.get('answers') or [] + ref_raw = ans[0] if ans else '' + ref_n = _norm(ref_raw) + opts = _parse_opts(meta.get('inputs') or '') + + # gold letter from cache + gold_letter = None + gl = meta.get('gold_choice') + if isinstance(gl, str) and gl: + gold_letter = gl.strip().lower() + if not gold_letter and ref_n and opts: + for k, v in opts.items(): + if v == ref_n or (v and ref_n and (v in ref_n or ref_n in v)): + gold_letter = k; break + if not gold_letter and len(ref_n) == 1 and ref_n in 'abcdefghijklmnopqrstuvwxyz': + gold_letter = ref_n + + pred_letter = parse_choice_letter(pred_raw) or '' + pred_label = '' + if not pred_letter and opts: + for k, v in opts.items(): + if v and pred_n and (v == pred_n or v in pred_n or pred_n in v): + pred_label = v; pred_letter = k; break + + if gold_letter: + total += 1 + if pred_letter and pred_letter == gold_letter: + correct += 1 + elif pred_label and opts.get(gold_letter,'') and pred_label == opts.get(gold_letter,''): + correct += 1 + elif ref_n: + total += 1 + if pred_n == ref_n: + correct += 1 + else: + if _re.fullmatch(r"[-+]?\d+", ref_n): + num = _extract_first_number(pred_n) + if num is not None and num == ref_n: + correct += 1 + else: + yn = _map_yesno(pred_n) + if yn and yn == ref_n: + correct += 1 + # Fallback when cache yielded nothing + if total == 0: + total2 = 0 + correct2 = 0 + for pred, rr in zip(p, r): + ref = (rr[0] if rr else '') + if not (ref and ref.strip()): + continue + total2 += 1 + pl = parse_choice_letter(pred) or '' + rl = parse_choice_letter(ref) or '' + if pl and rl: + if pl == rl: + correct2 += 1 + else: + if _norm(pred) == _norm(ref): + correct2 += 1 + score = 100.0 * correct2 / max(1, total2) + else: + score = 100.0 * correct / max(1, total) + out_dir = base_out / order; out_dir.mkdir(parents=True, exist_ok=True) + out = {'order': order, 'metric': 'reasoning_accuracy', 'score': float(score)} + (out_dir / 'result.json').write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f'[Reasoning Acc] order={order} score={score:.2f}') + summary[order] = out['score'] + + (base_out / 'summary.json').write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding='utf-8') + print('[Reasoning Acc] all orders done:', summary) + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/eval/eval_order_reasoning_ras.py b/ICL/LV/code/core/eval/eval_order_reasoning_ras.py new file mode 100644 index 0000000000000000000000000000000000000000..f0a78075b7196a5f401cd6d3fa208e33e3bd5a70 --- /dev/null +++ b/ICL/LV/code/core/eval/eval_order_reasoning_ras.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +"""Reasoning RAS per modal order. +Backends: +- roscoe (optional, via a local module/function path) +- bertscore (fallback) +- token-f1 (last resort) +""" + +import argparse +import json +from pathlib import Path + +from core.eval.order_eval_core import run_predictions +from core.metrics.metrics import token_f1, bertscore_f1 + + +def ras_roscoe(preds, refs, model_path=None, module_name=None, func_name=None, module_path=None): + try: + import importlib, sys, os + # Optional: add a filesystem path to sys.path so we can import roscoe without installation + if module_path: + raw = str(module_path) + try: + # Build candidate paths: raw, raw/src + cand = [] + p = raw + if os.path.isfile(p): + p = os.path.dirname(p) + if p and os.path.isdir(p): + cand.append(p) + sub = os.path.join(p, 'src') + if os.path.isdir(sub): + cand.append(sub) + for c in cand: + if c not in sys.path: + sys.path.insert(0, c) + # Also export ROSCOE_PY_PATH so shim's file-fallback can discover files under this root + os.environ['ROSCOE_PY_PATH'] = p + except Exception as _e: + print('[RAS-roscoe] warn: failed to process roscoe-path:', _e) + mn = module_name or 'roscoe' + fn = func_name or 'evaluate' + try: + mod = importlib.import_module(mn) + except Exception as _ie: + print(f"[RAS-roscoe] import failed for module='{mn}'. sys.path head=", sys.path[:3], 'err=', _ie) + return None + try: + fnc = getattr(mod, fn) + except Exception as _ge: + print(f"[RAS-roscoe] function '{fn}' not found in module '{mn}':", _ge) + return None + # Try call with model_path kw; if signature doesn't accept it, retry without + try: + out = fnc(preds, refs, model_path=model_path) + except TypeError: + try: + out = fnc(preds, refs) + except Exception as _ce: + print('[RAS-roscoe] evaluation call failed:', _ce) + return None + if isinstance(out, dict): + for k in ('ras', 'score', 'mean'): + if k in out: + return float(out[k]) + if isinstance(out, (int, float)): + return float(out) + return None + except Exception: + return None + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--adapter', required=True) + ap.add_argument('--model-path', required=True) + ap.add_argument('--dataset-root', required=True) + ap.add_argument('--retriever-model-path', required=True) + ap.add_argument('--output-dir', default='runs/order_metrics') + ap.add_argument('--orders', type=str, default='image-text,text-image,text-image-text') + ap.add_argument('--total-samples', type=int, default=500) + ap.add_argument('--k-shots', type=int, default=3) + ap.add_argument('--temperature', type=float, default=0.6) + ap.add_argument('--top-p', type=float, default=1.0) + ap.add_argument('--max-new-tokens', type=int, default=128) + ap.add_argument('--split', type=str, default='val') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--reuse-cache', action='store_true') + # Backend options + ap.add_argument('--ras-backend', type=str, default='auto', choices=['auto','bertscore','roscoe']) + ap.add_argument('--bertscore-model', type=str, default='roberta-large') + ap.add_argument('--no-bertscore-baseline', action='store_true') + ap.add_argument('--bertscore-batch-size', type=int, default=32) + ap.add_argument('--bertscore-num-layers', type=int, default=-1) + ap.add_argument('--bertscore-lang', type=str, default='', help="Language code for BERTScore baseline rescaling (e.g., 'en', 'zh')") + ap.add_argument('--roscoe-model-path', type=str, default='') + # Default to our in-repo shim so users don't have to remember to pass it every time + ap.add_argument('--roscoe-module', type=str, default='core.metrics.roscoe_shim') + ap.add_argument('--roscoe-func', type=str, default='evaluate') + ap.add_argument('--roscoe-path', type=str, default='', help='Filesystem path to roscoe Python package root; added to sys.path before import (or set env ROSCOE_PY_PATH)') + ap.add_argument('--ras-strict', action='store_true', help='If set, do NOT fallback when roscoe backend fails; raise instead') + ap.add_argument('--no-bertscore-fallback', action='store_true', help='Do NOT fallback to token-F1 when BERTScore fails/unavailable; raise instead') + args = ap.parse_args() + + orders = [o.strip().lower() for o in args.orders.split(',') if o.strip()] + base_out = Path(args.output_dir) / 'reasoning_ras' + + preds = run_predictions( + adapter=args.adapter, + model_path=args.model_path, + dataset_root=args.dataset_root, + retriever_model_path=args.retriever_model_path, + output_dir=str(base_out), + orders=orders, + categories=['reasoning'], + total_samples=args.total_samples, + k_shots=args.k_shots, + split=args.split, + seed=args.seed, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + auto_detect=True, + reuse_cache=args.reuse_cache, + ) + + summary = {} + for order in orders: + p, r = preds[order]['reasoning'] + # Prefer ref_text cached by core; fallback to first textual answer + cache_file = base_out / '_cache' / f'reasoning__{order}.jsonl' + raw_refs = [] + if cache_file.exists(): + for line in cache_file.read_text(encoding='utf-8').splitlines(): + if not line.strip(): + continue + obj = json.loads(line) + t = obj.get('ref_text') or '' + if not t: + ans = obj.get('answers') or [] + t = ans[0] if ans else '' + raw_refs.append(t) + else: + raw_refs = [x[0] if x else '' for x in r] + keep = [i for i, t in enumerate(raw_refs) if isinstance(t, str) and t.strip()] + p_eval = [p[i] for i in keep] if keep else [] + ref_texts = [raw_refs[i] for i in keep] if keep else [] + + used_metric = 'reasoning_ras' + score = None + if args.ras_backend in ('roscoe', 'auto') and p_eval: + import os + mp = args.roscoe_model_path or os.environ.get('ROSCOE_MODEL_PATH') + mn = args.roscoe_module or None + fnm = args.roscoe_func or None + rpy = args.roscoe_path or os.environ.get('ROSCOE_PY_PATH') + score = ras_roscoe(p_eval, ref_texts, model_path=mp, module_name=mn, func_name=fnm, module_path=rpy) + if score is None and args.ras_backend == 'roscoe': + if args.ras_strict: + raise RuntimeError('ROSCOE backend requested but unavailable/failed; strict mode forbids fallback') + print('[Reasoning RAS] roscoe backend unavailable; falling back to bert-score backend (disable with --ras-strict)') + if (args.ras_backend in ('bertscore', 'auto')) and (score is None) and p_eval: + nl = None if args.bertscore_num_layers < 0 else int(args.bertscore_num_layers) + try: + import os + if args.no_bertscore_fallback: + os.environ['BERTSCORE_STRICT'] = '1' + score = bertscore_f1( + p_eval, + [[x] for x in ref_texts], + model_type=args.bertscore_model, + rescale_with_baseline=not args.no_bertscore_baseline, + batch_size=args.bertscore_batch_size, + num_layers=nl, + lang=(args.bertscore_lang or None), + ) + except Exception as _be: + print('[Reasoning RAS] bert-score failed:', _be) + score = None + if score is None: + if args.no_bertscore_fallback: + raise RuntimeError('BERTScore returned None and --no-bertscore-fallback forbids fallback') + # Always ensure we have a numeric fallback + score = token_f1(p_eval, [[x] for x in ref_texts]) + used_metric = 'reasoning_token_f1 (fallback)' + + # Scale to 0..100 if score looks like 0..1 + s_out = None + if p_eval and isinstance(score, (int, float)): + try: + s = float(score) + if s <= 1.05: + s *= 100.0 + s_out = float(s) + except Exception: + s_out = None + out_dir = base_out / order; out_dir.mkdir(parents=True, exist_ok=True) + out = {'order': order, 'metric': used_metric, 'score': (None if (not p_eval) else s_out)} + (out_dir / 'result.json').write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + if not p_eval: + print(f'[Reasoning RAS] order={order} score=NA (no valid refs)') + summary[order] = None + else: + # Defensive: format score only if numeric + if used_metric.startswith('reasoning_token_f1'): + print(f'[Reasoning RAS] bert-score unavailable; used token-F1 fallback. order={order} score={float(score):.2f}') + else: + s_val = 'NA' if (s_out is None) else f'{float(s_out):.2f}' + print(f'[Reasoning RAS] order={order} score={s_val}') + summary[order] = out['score'] + + (base_out / 'summary.json').write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding='utf-8') + print('[Reasoning RAS] all orders done:', summary) + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/eval/eval_order_vqa_bertscore.py b/ICL/LV/code/core/eval/eval_order_vqa_bertscore.py new file mode 100644 index 0000000000000000000000000000000000000000..36c16e5ce64a12dec14504cce1d1180f5da488a2 --- /dev/null +++ b/ICL/LV/code/core/eval/eval_order_vqa_bertscore.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +"""VQA BERTScore-F1 per modal order.""" + +import argparse +import json +from pathlib import Path + +from core.eval.order_eval_core import run_predictions +from core.metrics.metrics import bertscore_f1, token_f1 + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--adapter', required=True) + ap.add_argument('--model-path', required=True) + ap.add_argument('--dataset-root', required=True) + ap.add_argument('--retriever-model-path', required=True) + ap.add_argument('--output-dir', default='runs/order_metrics') + ap.add_argument('--orders', type=str, default='image-text,text-image,text-image-text') + ap.add_argument('--total-samples', type=int, default=500) + ap.add_argument('--k-shots', type=int, default=3) + ap.add_argument('--temperature', type=float, default=0.6) + ap.add_argument('--top-p', type=float, default=1.0) + ap.add_argument('--max-new-tokens', type=int, default=128) + ap.add_argument('--split', type=str, default='val') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--reuse-cache', action='store_true') + ap.add_argument('--bertscore-model', type=str, default='roberta-large') + ap.add_argument('--no-bertscore-baseline', action='store_true') + ap.add_argument('--bertscore-batch-size', type=int, default=32) + ap.add_argument('--bertscore-num-layers', type=int, default=-1) + ap.add_argument('--bertscore-lang', type=str, default='', help="Language code for BERTScore baseline rescaling (e.g., 'en', 'zh')") + ap.add_argument('--strict-bertscore', action='store_true', help='Do not fallback to token-F1; raise when BERTScore fails/unavailable') + args = ap.parse_args() + + orders = [o.strip().lower() for o in args.orders.split(',') if o.strip()] + base_out = Path(args.output_dir) / 'vqa_bertscore' + + preds = run_predictions( + adapter=args.adapter, + model_path=args.model_path, + dataset_root=args.dataset_root, + retriever_model_path=args.retriever_model_path, + output_dir=str(base_out), + orders=orders, + categories=['vqa'], + total_samples=args.total_samples, + k_shots=args.k_shots, + split=args.split, + seed=args.seed, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + auto_detect=True, + reuse_cache=args.reuse_cache, + ) + + summary = {} + for order in orders: + p, r = preds[order]['vqa'] + keep_idx = [] + for i, ri in enumerate(r): + first = (ri[0] if (isinstance(ri, (list, tuple)) and ri) else ri) + first = first if isinstance(first, str) else '' + if first.strip(): + keep_idx.append(i) + if keep_idx: + p = [p[i] for i in keep_idx] + r = [[(r[i][0] if r[i] else '')] for i in keep_idx] + else: + p, r = [], [] + + # Enable strict behavior via env for the metrics helper + import os + if args.strict_bertscore: + os.environ['BERTSCORE_STRICT'] = '1' + nl = None if args.bertscore_num_layers < 0 else int(args.bertscore_num_layers) + if p: + score = bertscore_f1( + p, + r, + model_type=args.bertscore_model, + rescale_with_baseline=not args.no_bertscore_baseline, + batch_size=args.bertscore_batch_size, + num_layers=nl, + lang=(args.bertscore_lang or None), + ) + else: + score = None + used_metric = 'vqa_bertscore_f1' + if score is None and p: + if args.strict_bertscore: + raise RuntimeError('BERTScore returned None and --strict-bertscore forbids fallback') + from core.metrics.metrics import token_f1 as _tf1 + score = _tf1(p, r) + used_metric = 'vqa_token_f1 (fallback)' + + out_dir = base_out / order; out_dir.mkdir(parents=True, exist_ok=True) + out = {'order': order, 'metric': used_metric, 'score': (None if score is None else float(score))} + (out_dir / 'result.json').write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + if score is None: + print(f'[VQA BERTScore-F1] order={order} score=NA (no valid refs)') + else: + if used_metric.startswith('vqa_token_f1'): + print(f'[VQA BERTScore-F1] bert-score unavailable; used token-F1 fallback. order={order} score={score:.2f}') + else: + print(f'[VQA BERTScore-F1] order={order} score={score:.2f}') + summary[order] = out['score'] + + (base_out / 'summary.json').write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding='utf-8') + print('[VQA BERTScore-F1] all orders done:', summary) + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/eval/eval_order_vqa_tokenf1.py b/ICL/LV/code/core/eval/eval_order_vqa_tokenf1.py new file mode 100644 index 0000000000000000000000000000000000000000..fda8082919b86dd2504e47d8b7e0d1703e4a7df9 --- /dev/null +++ b/ICL/LV/code/core/eval/eval_order_vqa_tokenf1.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +"""VQA Token-F1 per modal order (uses cached predictions from order_eval_core).""" + +import argparse +import json +from pathlib import Path + +from core.eval.order_eval_core import run_predictions +from core.metrics.metrics import token_f1 + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--adapter', required=True) + ap.add_argument('--model-path', required=True) + ap.add_argument('--dataset-root', required=True) + ap.add_argument('--retriever-model-path', required=True) + ap.add_argument('--output-dir', default='runs/order_metrics') + ap.add_argument('--orders', type=str, default='image-text,text-image,text-image-text') + ap.add_argument('--total-samples', type=int, default=500) + ap.add_argument('--k-shots', type=int, default=3) + ap.add_argument('--temperature', type=float, default=0.6) + ap.add_argument('--top-p', type=float, default=1.0) + ap.add_argument('--max-new-tokens', type=int, default=128) + ap.add_argument('--split', type=str, default='val') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--reuse-cache', action='store_true') + args = ap.parse_args() + + orders = [o.strip().lower() for o in args.orders.split(',') if o.strip()] + base_out = Path(args.output_dir) / 'vqa_tokenf1' + + preds = run_predictions( + adapter=args.adapter, + model_path=args.model_path, + dataset_root=args.dataset_root, + retriever_model_path=args.retriever_model_path, + output_dir=str(base_out), + orders=orders, + categories=['vqa'], + total_samples=args.total_samples, + k_shots=args.k_shots, + split=args.split, + seed=args.seed, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + auto_detect=True, + reuse_cache=args.reuse_cache, + ) + + summary = {} + for order in orders: + p, r = preds[order]['vqa'] + # filter empty refs (use only first ref) + keep_idx = [] + for i, ri in enumerate(r): + first = (ri[0] if (isinstance(ri, (list, tuple)) and ri) else ri) + first = first if isinstance(first, str) else '' + if first.strip(): + keep_idx.append(i) + if keep_idx: + p = [p[i] for i in keep_idx] + r = [[(r[i][0] if r[i] else '')] for i in keep_idx] + else: + p, r = [], [] + + score = None if not p else token_f1(p, r) + out_dir = base_out / order; out_dir.mkdir(parents=True, exist_ok=True) + out = {'order': order, 'metric': 'vqa_token_f1', 'score': (None if score is None else float(score))} + (out_dir / 'result.json').write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + if score is None: + print(f'[VQA Token-F1] order={order} score=NA (no valid refs)') + else: + print(f'[VQA Token-F1] order={order} score={score:.2f}') + summary[order] = out['score'] + + (base_out / 'summary.json').write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding='utf-8') + print('[VQA Token-F1] all orders done:', summary) + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/eval/eval_textual_retriever_vqa.py b/ICL/LV/code/core/eval/eval_textual_retriever_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..f20b9259a989a9313bec0d1e83fc132331ae80e8 --- /dev/null +++ b/ICL/LV/code/core/eval/eval_textual_retriever_vqa.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python3 +import argparse +import json +import hashlib +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer, AutoModel +from PIL import Image +import numpy as np +import random + +from core.datasets.m3it_reader import ( + iter_m3it_samples, + load_instructions, + read_jsonl, +) +from core.datasets.m3it_reader import _candidate_split_files as _cand_splits +from core.datasets.m3it_reader import _IMAGE_KEYS, _TEXT_IN_KEYS, _TEXT_OUT_KEYS, _b64_to_image_path +from core.metrics.metrics import token_f1, bertscore_f1 +from core.prompting.openai_segments import openai_to_list_format + + +VQA_SUBTASKS = [ + 'vqa/vqav2', 'vqa/docvqa', 'vqa/ocr-vqa', 'vqa/st-vqa', 'vqa/text-vqa', 'vqa/gqa', 'vqa/okvqa', 'vqa/a-okvqa', +] + + +def _extract_uid(rec: Dict, fallback: str) -> str: + for k in ('id', 'image_id'): + v = rec.get(k) + if isinstance(v, (str, int)): + return str(v) + mv = rec.get('meta', {}) if isinstance(rec.get('meta'), dict) else {} + for k in ('img_id', 'id', 'image_id'): + v = mv.get(k) + if isinstance(v, (str, int)): + return str(v) + return fallback + + +def _prefer_demo_splits(eval_split: str) -> Tuple[str, ...]: + base = ('val', 'validation', 'train', 'dev') + es = (eval_split or '').lower() + exclude = {es} + if es in {'val', 'validation'}: + exclude |= {'val', 'validation'} + return tuple(s for s in base if s not in exclude) + + +def _img_sig_from_rec(rec: Dict) -> str: + try: + for k in _IMAGE_KEYS: + v = rec.get(k) + if isinstance(v, str) and len(v) > 100: + return hashlib.sha1(v.encode('utf-8')).hexdigest() + mv = rec.get('meta') if isinstance(rec.get('meta'), dict) else None + if mv and isinstance(mv.get(k), str) and len(mv[k]) > 100: + return hashlib.sha1(mv[k].encode('utf-8')).hexdigest() + except Exception: + pass + return '' + + +_IMG_PSIG_CACHE: Dict[Tuple[str, int, int], str] = {} + + +def _image_path_sig(path: str, size: int = 32) -> str: + try: + st = Path(path).stat() + key = (path, int(st.st_mtime), int(st.st_size)) + if key in _IMG_PSIG_CACHE: + return _IMG_PSIG_CACHE[key] + im = Image.open(path).convert('RGB').resize((size, size), resample=Image.BICUBIC) + arr = np.asarray(im, dtype=np.uint8) + h = hashlib.sha1(arr.tobytes()).hexdigest() + _IMG_PSIG_CACHE[key] = h + return h + except Exception: + return '' + + +def load_pool_items(dataset_root: Path, subdir: str, cache_dir: Path, max_items: Optional[int] = None, + prefer: Optional[Sequence[str]] = None) -> Tuple[List[Dict], List[str]]: + path = None + for split in (tuple(prefer) if prefer is not None else ('val', 'validation', 'train', 'dev')): + for p in _cand_splits(dataset_root, subdir, split): + if p.exists(): + path = p; break + if path is not None: + break + if path is None: + raise FileNotFoundError(f'No demo split file found for {subdir}') + out_recs: List[Dict] = [] + img_paths: List[str] = [] + demo_cache = cache_dir / f'_demo_{subdir.replace("/", "_")}' + demo_cache.mkdir(parents=True, exist_ok=True) + idx = 0 + for rec in read_jsonl(path): + # image b64 + img_b64 = None + for k in _IMAGE_KEYS: + v = rec.get(k) + if isinstance(v, str) and len(v) > 100: + img_b64 = v; break + mv = rec.get('meta') if isinstance(rec.get('meta'), dict) else None + if mv and isinstance(mv.get(k), str) and len(mv[k]) > 100: + img_b64 = mv[k]; break + if not img_b64: + continue + # text in/out + text_in = '' + for k in _TEXT_IN_KEYS: + v = rec.get(k) or (rec.get('meta') or {}).get(k) + if isinstance(v, str): + text_in = v; break + answers = None + for k in _TEXT_OUT_KEYS: + v = rec.get(k) or (rec.get('meta') or {}).get(k) + if isinstance(v, str): + answers = [v]; break + if isinstance(v, list) and all(isinstance(x, str) for x in v): + answers = list(v); break + uid = _extract_uid(rec, f'{idx:08d}') + img_path = _b64_to_image_path(img_b64, demo_cache, uid) + img_sig = _img_sig_from_rec(rec) + out_recs.append({'uid': uid, 'text_in': text_in or '', 'text_out': (answers or [''])[0], 'img_sig': img_sig}) + img_paths.append(img_path) + idx += 1 + if max_items is not None and len(out_recs) >= max_items: + break + return out_recs, img_paths + + +def load_adapter(name: str, model_path: str): + name = (name or '').lower() + if name in ('idefics2','idefics','i2'): + from adapters import idefics2_adapter as A + elif name in ('qwen-vl','qwenvl','qwen'): + from adapters import qwen_vl_adapter as A + elif name in ('qwen3-vl','qwen3vl','qwen3'): + from adapters import qwen3vl_adapter as A + elif name in ('gemma3','gemma-3','gemma'): + from adapters import gemma3_adapter as A + else: + raise ValueError(f'Unknown adapter: {name}') + return A.create(model_path) + + +def encode_texts(tokenizer, model, texts: List[str], device: torch.device, batch_size: int = 64) -> torch.Tensor: + embs: List[torch.Tensor] = [] + model.eval() + with torch.no_grad(): + for i in range(0, len(texts), batch_size): + chunk = texts[i:i+batch_size] + toks = tokenizer(chunk, padding=True, truncation=True, max_length=256, return_tensors='pt').to(device) + out = model(**toks) + if hasattr(out, 'last_hidden_state'): + cls = out.last_hidden_state[:, 0, :] + else: + cls = out[0][:, 0, :] + cls = F.normalize(cls, p=2, dim=-1) + embs.append(cls.detach().cpu()) + return torch.cat(embs, dim=0) if embs else torch.empty(0, 768) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--adapter', required=True, help='idefics2 | qwen-vl | qwen3-vl | gemma3') + ap.add_argument('--model-path', required=True) + ap.add_argument('--dataset-root', required=True) + ap.add_argument('--retriever-model-path', required=True, help='Path to a text encoder (e.g., roberta-large)') + ap.add_argument('--output-dir', default='runs/m3it_textual_retriever_vqa') + ap.add_argument('--total-samples', type=int, default=500) + ap.add_argument('--k-shots', type=int, default=3) + ap.add_argument('--temperature', type=float, default=0.2) + ap.add_argument('--top-p', type=float, default=1.0) + ap.add_argument('--max-new-tokens', type=int, default=32) + ap.add_argument('--split', type=str, default='test') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--use-paper-instruction', action='store_true') + ap.add_argument('--no-instruction', action='store_true') + ap.add_argument('--auto-detect', action='store_true') + ap.add_argument('--bertscore-model', type=str, default='roberta-large') + ap.add_argument('--no-bertscore-baseline', action='store_true') + ap.add_argument('--bertscore-batch-size', type=int, default=32) + ap.add_argument('--bertscore-lang', type=str, default='', help="Language code for BERTScore baseline rescaling (e.g., 'en', 'zh')") + ap.add_argument('--instruction-image', type=str, default=None) + ap.add_argument('--dump-first', type=int, default=0) + args = ap.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True) + cache_dir = out_dir / '_image_cache' + feat_cache = out_dir / '_feat_cache'; feat_cache.mkdir(parents=True, exist_ok=True) + + # Init retriever encoder + txt_tok = AutoTokenizer.from_pretrained(args.retriever_model_path, trust_remote_code=True) + txt_model = AutoModel.from_pretrained(args.retriever_model_path, trust_remote_code=True).to(device).eval() + + # Init generator + adapter = load_adapter(args.adapter, args.model_path) + + # Select tasks present + tasks = VQA_SUBTASKS + if args.auto_detect: + root = Path(args.dataset_root) + present = [] + for t in VQA_SUBTASKS: + if any(p.exists() for p in _cand_splits(root, t, args.split)): + present.append(t) + if present: + tasks = present + + def _quota(total: int, n: int) -> List[int]: + base = total // n; rem = total % n + return [base + (1 if i < rem else 0) for i in range(n)] + per_task = _quota(args.total_samples, len(tasks)) + + all_preds: List[str] = [] + all_refs: List[List[str]] = [] + details: List[Dict] = [] + + PAPER_VQA_INTRO = 'Examine the image and answer the question.' + rng = random.Random(args.seed) + + for subdir, quota in zip(tasks, per_task): + # Build demo pool and its embeddings; never draw demos from the eval split + demo_prefer = _prefer_demo_splits(args.split) + recs, imgs = load_pool_items(Path(args.dataset_root), subdir, cache_dir, prefer=demo_prefer) + texts = [r['text_in'] for r in recs] + feat_path = feat_cache / f'textual_{subdir.replace("/","_")}_pool.pt' + if feat_path.exists(): + pool_emb = torch.load(feat_path, map_location='cpu') + else: + pool_emb = encode_texts(txt_tok, txt_model, texts, device) + torch.save(pool_emb, feat_path) + + eval_pool = [s for s in iter_m3it_samples(args.dataset_root, subdir, split=args.split, cache_dir=str(cache_dir))] + if not eval_pool: + print(f'Skipping {subdir}: empty eval pool') + continue + select_n = min(quota, len(eval_pool)) + sel_indices = rng.sample(range(len(eval_pool)), k=select_n) + taken = 0 + for _idx in sel_indices: + smp = eval_pool[_idx] + q_emb = encode_texts(txt_tok, txt_model, [smp.text or ''], device) + sim = (q_emb @ pool_emb.T).squeeze(0) + q_uid = _extract_uid(smp.raw, '') if isinstance(smp.raw, dict) else '' + q_sig = _img_sig_from_rec(smp.raw) if isinstance(smp.raw, dict) else '' + q_psig = _image_path_sig(smp.image_path) + q_txt = (smp.text or '').strip().lower() + mask = torch.ones(sim.shape[0], dtype=torch.bool, device=sim.device) + for i, r in enumerate(recs): + if q_uid and r.get('uid') == q_uid: + mask[i] = False; continue + if q_sig and r.get('img_sig') and r['img_sig'] == q_sig: + mask[i] = False; continue + if q_txt and (r.get('text_in') or '').strip().lower() == q_txt: + mask[i] = False; continue + mask &= (sim < 0.999).to(dtype=mask.dtype, device=mask.device).bool() + sim[~mask] = -1e4 + pre_k = min(max(args.k_shots * 50, 500), sim.numel()) + cand = [i for i in torch.topk(sim, k=pre_k).indices.tolist() if mask[i].item()] + if q_psig: + cand = [i for i in cand if _image_path_sig(imgs[i]) != q_psig] + idxs = cand[:args.k_shots] + if len(idxs) < min(args.k_shots, sim.numel()): + rest = [i for i in range(sim.numel()) if mask[i].item() and i not in idxs] + if q_psig: + rest = [i for i in rest if _image_path_sig(imgs[i]) != q_psig] + rng.shuffle(rest) + idxs.extend(rest[:max(0, args.k_shots - len(idxs))]) + demos = [recs[i] for i in idxs] + demo_imgs = [imgs[i] for i in idxs] + + insts = [] if args.no_instruction else load_instructions(Path(args.dataset_root), subdir) + ds_inst = '' + if isinstance(insts, list) and insts: + ds_inst = '\n'.join([s for s in insts if isinstance(s, str) and s.strip()]) + base_inst = 'Examine the image and answer the question with a short answer.' + inst = (PAPER_VQA_INTRO + (('\n' + ds_inst) if ds_inst else '')) if args.use_paper_instruction else (ds_inst or base_inst) + + oa_items: List[Dict] = [] + if inst: + oa_items.append({'type': 'text', 'text': inst}) + if args.instruction_image: + oa_items.append({'type': 'image_url', 'image_url': args.instruction_image}) + for d, img in zip(demos, demo_imgs): + oa_items.append({'type': 'image_url', 'image_url': img}) + oa_items.append({'type': 'text', 'text': f"[REQUEST]\n{(d.get('text_in') or '').strip()}\n[RESPONSE]\n{(d.get('text_out') or '').strip()}"}) + oa_items.append({'type': 'image_url', 'image_url': smp.image_path}) + oa_items.append({'type': 'text', 'text': f"[REQUEST]\n{(smp.text or '').strip()}\n[RESPONSE]"}) + segs = openai_to_list_format(oa_items, cache_dir=cache_dir / '_oa_cache') + + # Generate + response = adapter.generate_from_segments( + segs, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + ) + + all_preds.append(response) + all_refs.append(smp.answers or []) + details.append({'task': subdir, 'image_path': smp.image_path, 'text': smp.text, 'answers': smp.answers, 'demo_uids': [d['uid'] for d in demos]}) + taken += 1 + print(f'{subdir}: {taken} eval samples | demo-pool={len(recs)} | k={args.k_shots}') + + tf1 = token_f1(all_preds, all_refs) + try: + bsf1 = bertscore_f1( + all_preds, + all_refs, + model_type=args.bertscore_model, + rescale_with_baseline=not args.no_bertscore_baseline, + batch_size=args.bertscore_batch_size, + lang=(args.bertscore_lang or None), + ) + except Exception: + bsf1 = None + + out = { + 'setting': 'few-shot-textual-retriever', + 'k_shots': args.k_shots, + 'adapter': args.adapter, + 'model_path': args.model_path, + 'total': len(all_preds), + 'metrics': {'token_f1': tf1, 'bertscore_f1': bsf1}, + 'predictions': [ + {'pred': p, 'answers': r, 'meta': m} + for p, r, m in zip(all_preds, all_refs, details) + ] + } + (out_dir / f'vqa_textual_retriever_{args.k_shots}shot.json').write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f'Text Retriever (k={args.k_shots}) Token-F1={tf1:.2f} BERTScore-F1=' + (f'{bsf1:.2f}' if bsf1 is not None else 'NA')) + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/eval/eval_visual_retriever_vqa.py b/ICL/LV/code/core/eval/eval_visual_retriever_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..ae66d5477ba23f0a091c805643c3b4063ecbe07b --- /dev/null +++ b/ICL/LV/code/core/eval/eval_visual_retriever_vqa.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +import argparse +import json +import hashlib +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple + +import torch +from transformers import CLIPProcessor, CLIPModel +from PIL import Image +import numpy as np +import random + +from core.datasets.m3it_reader import ( + iter_m3it_samples, + load_instructions, + read_jsonl, +) +from core.datasets.m3it_reader import _candidate_split_files as _cand_splits +from core.datasets.m3it_reader import _IMAGE_KEYS, _TEXT_IN_KEYS, _TEXT_OUT_KEYS, _b64_to_image_path +from core.metrics.metrics import token_f1, bertscore_f1 +from core.prompting.openai_segments import openai_to_list_format + + +VQA_SUBTASKS = [ + 'vqa/vqav2', 'vqa/docvqa', 'vqa/ocr-vqa', 'vqa/st-vqa', 'vqa/text-vqa', 'vqa/gqa', 'vqa/okvqa', 'vqa/a-okvqa', +] + + +def _extract_uid(rec: Dict, fallback: str) -> str: + for k in ('id', 'image_id'): + v = rec.get(k) + if isinstance(v, (str, int)): + return str(v) + mv = rec.get('meta', {}) if isinstance(rec.get('meta'), dict) else {} + for k in ('img_id', 'id', 'image_id'): + v = mv.get(k) + if isinstance(v, (str, int)): + return str(v) + return fallback + + +def _prefer_demo_splits(eval_split: str) -> Tuple[str, ...]: + base = ('val', 'validation', 'train', 'dev') + es = (eval_split or '').lower() + exclude = {es} + if es in {'val', 'validation'}: + exclude |= {'val', 'validation'} + return tuple(s for s in base if s not in exclude) + + +def _img_sig_from_rec(rec: Dict) -> str: + try: + for k in _IMAGE_KEYS: + v = rec.get(k) + if isinstance(v, str) and len(v) > 100: + return hashlib.sha1(v.encode('utf-8')).hexdigest() + mv = rec.get('meta') if isinstance(rec.get('meta'), dict) else None + if mv and isinstance(mv.get(k), str) and len(mv[k]) > 100: + return hashlib.sha1(mv[k].encode('utf-8')).hexdigest() + except Exception: + pass + return '' + + +_IMG_PSIG_CACHE: Dict[Tuple[str, int, int], str] = {} + + +def _image_path_sig(path: str, size: int = 32) -> str: + try: + st = Path(path).stat() + key = (path, int(st.st_mtime), int(st.st_size)) + if key in _IMG_PSIG_CACHE: + return _IMG_PSIG_CACHE[key] + im = Image.open(path).convert('RGB').resize((size, size), resample=Image.BICUBIC) + arr = np.asarray(im, dtype=np.uint8) + h = hashlib.sha1(arr.tobytes()).hexdigest() + _IMG_PSIG_CACHE[key] = h + return h + except Exception: + return '' + + +def load_pool_items(dataset_root: Path, subdir: str, cache_dir: Path, max_items: Optional[int] = None, + prefer: Optional[Sequence[str]] = None) -> Tuple[List[Dict], List[str]]: + path = None + for split in (tuple(prefer) if prefer is not None else ('val', 'validation', 'train', 'dev')): + for p in _cand_splits(dataset_root, subdir, split): + if p.exists(): + path = p; break + if path is not None: + break + if path is None: + raise FileNotFoundError(f'No demo split file found for {subdir}') + out_recs: List[Dict] = [] + img_paths: List[str] = [] + demo_cache = cache_dir / f'_demo_{subdir.replace("/", "_")}' + demo_cache.mkdir(parents=True, exist_ok=True) + idx = 0 + for rec in read_jsonl(path): + # image b64 + img_b64 = None + for k in _IMAGE_KEYS: + v = rec.get(k) + if isinstance(v, str) and len(v) > 100: + img_b64 = v; break + mv = rec.get('meta') if isinstance(rec.get('meta'), dict) else None + if mv and isinstance(mv.get(k), str) and len(mv[k]) > 100: + img_b64 = mv[k]; break + if not img_b64: + continue + # text in/out + text_in = '' + for k in _TEXT_IN_KEYS: + v = rec.get(k) or (rec.get('meta') or {}).get(k) + if isinstance(v, str): + text_in = v; break + answers = None + for k in _TEXT_OUT_KEYS: + v = rec.get(k) or (rec.get('meta') or {}).get(k) + if isinstance(v, str): + answers = [v]; break + if isinstance(v, list) and all(isinstance(x, str) for x in v): + answers = list(v); break + uid = _extract_uid(rec, f'{idx:08d}') + img_path = _b64_to_image_path(img_b64, demo_cache, uid) + img_sig = _img_sig_from_rec(rec) + out_recs.append({'uid': uid, 'text_in': text_in or '', 'text_out': (answers or [''])[0], 'img_sig': img_sig}) + img_paths.append(img_path) + idx += 1 + if max_items is not None and len(out_recs) >= max_items: + break + return out_recs, img_paths + + +def load_adapter(name: str, model_path: str): + name = (name or '').lower() + if name in ('idefics2','idefics','i2'): + from adapters import idefics2_adapter as A + elif name in ('qwen-vl','qwenvl','qwen'): + from adapters import qwen_vl_adapter as A + elif name in ('qwen3-vl','qwen3vl','qwen3'): + from adapters import qwen3vl_adapter as A + elif name in ('gemma3','gemma-3','gemma'): + from adapters import gemma3_adapter as A + else: + raise ValueError(f'Unknown adapter: {name}') + return A.create(model_path) + + +def encode_images(clip_proc, clip_model, image_paths: List[str], device: torch.device, batch_size: int = 32) -> torch.Tensor: + embs: List[torch.Tensor] = [] # type: ignore[name-defined] + clip_model.eval() + with torch.no_grad(): + for i in range(0, len(image_paths), batch_size): + chunk = image_paths[i:i+batch_size] + pil_imgs = [] + for p in chunk: + img = Image.open(p) + if img.mode != 'RGB': + img = img.convert('RGB') + pil_imgs.append(img) + inputs = clip_proc(images=pil_imgs, return_tensors='pt') + inputs = {k: v.to(device) for k, v in inputs.items()} + feats = clip_model.get_image_features(**inputs) + feats = torch.nn.functional.normalize(feats, p=2, dim=-1) + embs.append(feats.detach().cpu()) + return torch.cat(embs, dim=0) if embs else torch.empty(0, 768) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--adapter', required=True, help='idefics2 | qwen-vl | qwen3-vl | gemma3') + ap.add_argument('--model-path', required=True) + ap.add_argument('--dataset-root', required=True) + ap.add_argument('--retriever-model-path', required=True, help='Path to CLIP model (e.g., ViT-L-14)') + ap.add_argument('--output-dir', default='runs/m3it_visual_retriever_vqa') + ap.add_argument('--total-samples', type=int, default=500) + ap.add_argument('--k-shots', type=int, default=3) + ap.add_argument('--temperature', type=float, default=0.2) + ap.add_argument('--top-p', type=float, default=1.0) + ap.add_argument('--max-new-tokens', type=int, default=32) + ap.add_argument('--split', type=str, default='test') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--use-paper-instruction', action='store_true') + ap.add_argument('--no-instruction', action='store_true') + ap.add_argument('--auto-detect', action='store_true') + ap.add_argument('--bertscore-model', type=str, default='roberta-large') + ap.add_argument('--no-bertscore-baseline', action='store_true') + ap.add_argument('--bertscore-batch-size', type=int, default=32) + ap.add_argument('--bertscore-lang', type=str, default='', help="Language code for BERTScore baseline rescaling (e.g., 'en', 'zh')") + ap.add_argument('--instruction-image', type=str, default=None) + args = ap.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True) + cache_dir = out_dir / '_image_cache' + feat_cache = out_dir / '_feat_cache'; feat_cache.mkdir(parents=True, exist_ok=True) + + # Init retriever encoder + clip_model = CLIPModel.from_pretrained(args.retriever_model_path, trust_remote_code=True).to(device).eval() + clip_proc = CLIPProcessor.from_pretrained(args.retriever_model_path, trust_remote_code=True) + + # Init generator + adapter = load_adapter(args.adapter, args.model_path) + + tasks = VQA_SUBTASKS + if args.auto_detect: + root = Path(args.dataset_root) + present = [t for t in VQA_SUBTASKS if any(p.exists() for p in _cand_splits(root, t, args.split))] + if present: + tasks = present + + def _quota(total: int, n: int) -> List[int]: + base = total // n; rem = total % n + return [base + (1 if i < rem else 0) for i in range(n)] + per_task = _quota(args.total_samples, len(tasks)) + + all_preds: List[str] = [] + all_refs: List[List[str]] = [] + details: List[Dict] = [] + + PAPER_VQA_INTRO = 'Examine the image and answer the question.' + rng = random.Random(args.seed) + + for subdir, quota in zip(tasks, per_task): + demo_prefer = _prefer_demo_splits(args.split) + recs, imgs = load_pool_items(Path(args.dataset_root), subdir, cache_dir, prefer=demo_prefer) + feat_path = feat_cache / f'visual_{subdir.replace("/","_")}_pool.pt' + if feat_path.exists(): + pool_emb = torch.load(feat_path, map_location='cpu') + else: + pool_emb = encode_images(clip_proc, clip_model, imgs, device) + torch.save(pool_emb, feat_path) + + eval_pool = [s for s in iter_m3it_samples(args.dataset_root, subdir, split=args.split, cache_dir=str(cache_dir))] + if not eval_pool: + print(f'Skipping {subdir}: empty eval pool') + continue + select_n = min(quota, len(eval_pool)) + sel_indices = rng.sample(range(len(eval_pool)), k=select_n) + taken = 0 + for _idx in sel_indices: + smp = eval_pool[_idx] + q_emb = encode_images(clip_proc, clip_model, [smp.image_path], device) + sim = (q_emb @ pool_emb.T).squeeze(0) + q_uid = _extract_uid(smp.raw, '') if isinstance(smp.raw, dict) else '' + q_sig = _img_sig_from_rec(smp.raw) if isinstance(smp.raw, dict) else '' + q_psig = _image_path_sig(smp.image_path) + q_txt = (smp.text or '').strip().lower() + mask = torch.ones(sim.shape[0], dtype=torch.bool, device=sim.device) + for i, r in enumerate(recs): + if q_uid and r.get('uid') == q_uid: + mask[i] = False; continue + if q_sig and r.get('img_sig') and r['img_sig'] == q_sig: + mask[i] = False; continue + if q_txt and (r.get('text_in') or '').strip().lower() == q_txt: + mask[i] = False; continue + mask &= (sim < 0.999).to(dtype=mask.dtype, device=mask.device).bool() + sim[~mask] = -1e4 + pre_k = min(max(args.k_shots * 50, 500), sim.numel()) + cand = [i for i in torch.topk(sim, k=pre_k).indices.tolist() if mask[i].item()] + if q_psig: + cand = [i for i in cand if _image_path_sig(imgs[i]) != q_psig] + idxs = cand[:args.k_shots] + if len(idxs) < min(args.k_shots, sim.numel()): + rest = [i for i in range(sim.numel()) if mask[i].item() and i not in idxs] + if q_psig: + rest = [i for i in rest if _image_path_sig(imgs[i]) != q_psig] + rng.shuffle(rest) + idxs.extend(rest[:max(0, args.k_shots - len(idxs))]) + demos = [recs[i] for i in idxs] + demo_imgs = [imgs[i] for i in idxs] + + insts = [] if args.no_instruction else load_instructions(Path(args.dataset_root), subdir) + ds_inst = '' + if isinstance(insts, list) and insts: + ds_inst = '\n'.join([s for s in insts if isinstance(s, str) and s.strip()]) + base_inst = 'Examine the image and answer the question with a short answer.' + inst = (PAPER_VQA_INTRO + (('\n' + ds_inst) if ds_inst else '')) if args.use_paper_instruction else (ds_inst or base_inst) + + oa_items: List[Dict] = [] + if inst: + oa_items.append({'type': 'text', 'text': inst}) + if args.instruction_image: + oa_items.append({'type': 'image_url', 'image_url': args.instruction_image}) + for d, img in zip(demos, demo_imgs): + oa_items.append({'type': 'image_url', 'image_url': img}) + oa_items.append({'type': 'text', 'text': f"[REQUEST]\n{(d.get('text_in') or '').strip()}\n[RESPONSE]\n{(d.get('text_out') or '').strip()}"}) + oa_items.append({'type': 'image_url', 'image_url': smp.image_path}) + oa_items.append({'type': 'text', 'text': f"[REQUEST]\n{(smp.text or '').strip()}\n[RESPONSE]"}) + segs = openai_to_list_format(oa_items, cache_dir=cache_dir / '_oa_cache') + + response = adapter.generate_from_segments( + segs, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + ) + + all_preds.append(response) + all_refs.append(smp.answers or []) + details.append({'task': subdir, 'image_path': smp.image_path, 'text': smp.text, 'answers': smp.answers, 'demo_uids': [d['uid'] for d in demos]}) + taken += 1 + print(f'{subdir}: {taken} eval samples | demo-pool={len(recs)} | k={args.k_shots}') + + tf1 = token_f1(all_preds, all_refs) + try: + bsf1 = bertscore_f1( + all_preds, + all_refs, + model_type=args.bertscore_model, + rescale_with_baseline=not args.no_bertscore_baseline, + batch_size=args.bertscore_batch_size, + lang=(args.bertscore_lang or None), + ) + except Exception: + bsf1 = None + + out = { + 'setting': 'few-shot-visual-retriever', + 'k_shots': args.k_shots, + 'adapter': args.adapter, + 'model_path': args.model_path, + 'total': len(all_preds), + 'metrics': {'token_f1': tf1, 'bertscore_f1': bsf1}, + 'predictions': [ + {'pred': p, 'answers': r, 'meta': m} + for p, r, m in zip(all_preds, all_refs, details) + ] + } + (out_dir / f'vqa_visual_retriever_{args.k_shots}shot.json').write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f'Visual Retriever (k={args.k_shots}) Token-F1={tf1:.2f} BERTScore-F1=' + (f'{bsf1:.2f}' if bsf1 is not None else 'NA')) + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/eval/order_eval_core.py b/ICL/LV/code/core/eval/order_eval_core.py new file mode 100644 index 0000000000000000000000000000000000000000..56460140361e8fad7307d7595d0f1e2b4b303c5b --- /dev/null +++ b/ICL/LV/code/core/eval/order_eval_core.py @@ -0,0 +1,822 @@ +""" +Adapter-agnostic core for modal-order evaluation. + +Produces predictions and references per order/category and caches +per-sample details under /_cache/__.jsonl +so metric scripts can consume them independently. +""" + +from __future__ import annotations + +import json +import os +import random +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +import torch + +from core.datasets.m3it_reader import ( + iter_m3it_samples, + load_instructions, + read_jsonl, +) +from core.datasets.m3it_reader import _candidate_split_files as _cand_splits +from core.datasets.m3it_reader import _IMAGE_KEYS, _TEXT_IN_KEYS, _TEXT_OUT_KEYS, _b64_to_image_path +from core.datasets.m3it_reader import _schema_answer_key_candidates as _schema_ans_keys +from core.datasets.m3it_reader import _get_dotted as _get_dot +from core.eval._modal_order import build_image_text, build_text_image, build_text_image_text +from transformers import AutoProcessor, AutoModel + + +def _build_prompts() -> Dict[str, str]: + return { + # Keep these short and model-agnostic; dataset-specific instructions.json + # (when present) will be appended by run_predictions(). + 'captioning': 'Describe the image briefly.', + 'vqa': 'Look at the image and answer the question concisely.', + 'classification': ( + 'Look at the image and answer the classification. ' + 'If [REQUEST] lists options, answer with only the option letter (A/B/C/D/...). ' + 'Otherwise, answer concisely with the class name or ID.' + ), + 'reasoning': ( + 'Look at the image and the text. Reason step by step and conclude with a clear answer. ' + 'If options are provided in [REQUEST], end with only the option letter.' + ), + } + + +def _prefer_demo_splits(eval_split: str) -> Tuple[str, ...]: + base = ('val', 'validation', 'train', 'dev') + es = (eval_split or '').lower() + # Exclude the eval split from the demo pool preference + exclude = {es} + if es in {'val', 'validation'}: + exclude |= {'val', 'validation'} + return tuple(s for s in base if s not in exclude) + + +def _extract_uid(rec: Dict, fallback: str) -> str: + for k in ('id', 'image_id'): + v = rec.get(k) + if isinstance(v, (str, int)): + return str(v) + mv = rec.get('meta', {}) if isinstance(rec.get('meta'), dict) else {} + for k in ('img_id', 'id', 'image_id', 'Flickr30k_image_id', 'pair_id'): + v = mv.get(k) + if isinstance(v, (str, int)): + return str(v) + return fallback + + +def _img_sig_from_rec(rec: Dict) -> str: + import hashlib + try: + for k in _IMAGE_KEYS: + v = rec.get(k) + if isinstance(v, str) and len(v) > 100: + return hashlib.sha1(v.encode('utf-8')).hexdigest() + mv = rec.get('meta') if isinstance(rec.get('meta'), dict) else None + if mv and isinstance(mv.get(k), str) and len(mv[k]) > 100: + return hashlib.sha1(mv[k].encode('utf-8')).hexdigest() + except Exception: + pass + return '' + + +_IMG_PSIG_CACHE: Dict[Tuple[str, int, int], str] = {} + + +def _image_path_sig(path: str, size: int = 32) -> str: + import hashlib + from PIL import Image + try: + st = os.stat(path) + key = (path, int(st.st_mtime), int(st.st_size)) + if key in _IMG_PSIG_CACHE: + return _IMG_PSIG_CACHE[key] + im = Image.open(path).convert('RGB').resize((size, size), resample=Image.BICUBIC) + import numpy as np + arr = np.asarray(im, dtype=np.uint8) + h = hashlib.sha1(arr.tobytes()).hexdigest() + _IMG_PSIG_CACHE[key] = h + return h + except Exception: + return '' + + +def load_pool_items(dataset_root: Path, subdir: str, cache_dir: Path, max_items: Optional[int] = None, + prefer: Optional[Sequence[str]] = None, + category: Optional[str] = None) -> Tuple[List[Dict], List[str]]: + path = None + for split in (tuple(prefer) if prefer is not None else ('val', 'validation', 'train', 'dev')): + for p in _cand_splits(dataset_root, subdir, split): + if p.exists(): + path = p; break + if path is not None: + break + if path is None: + raise FileNotFoundError(f'No demo split file found for {subdir}') + out_recs: List[Dict] = [] + img_paths: List[str] = [] + demo_cache = cache_dir / f'_demo_{subdir.replace("/", "_")}' + demo_cache.mkdir(parents=True, exist_ok=True) + idx = 0 + for rec in read_jsonl(path): + # image b64 + img_b64 = None + def _join_if_list(v): + if isinstance(v, list) and v and all(isinstance(x, str) for x in v): + return ''.join(v) + return v + for k in _IMAGE_KEYS: + v = _join_if_list(rec.get(k)) + if isinstance(v, str) and len(v) > 100: + img_b64 = v; break + mv = rec.get('meta') if isinstance(rec.get('meta'), dict) else None + if mv: + vv = _join_if_list(mv.get(k)) + if isinstance(vv, str) and len(vv) > 100: + img_b64 = vv; break + if not img_b64: + continue + # text in/out + text_in = '' + for k in _TEXT_IN_KEYS: + v = rec.get(k) or (rec.get('meta') or {}).get(k) + if isinstance(v, str): + text_in = v; break + if (not text_in.strip()) and (category in ('classification', 'reasoning')): + try: + syn = _extract_inputs(rec) + if isinstance(syn, str) and syn.strip(): + text_in = syn + except Exception: + pass + answers = None + # Prefer schema-specified answer keys for strict alignment + for dk in _schema_ans_keys(subdir, split='val'): + vv = _get_dot(rec, dk) + if isinstance(vv, str): + answers = [vv]; break + if isinstance(vv, list) and all(isinstance(x, str) for x in vv): + answers = list(vv); break + if answers is None: + for k in _TEXT_OUT_KEYS: + v = rec.get(k) or (rec.get('meta') or {}).get(k) + if isinstance(v, str): + answers = [v]; break + if isinstance(v, list) and all(isinstance(x, str) for x in v): + answers = list(v); break + uid = _extract_uid(rec, f'{idx:08d}') + img_path = _b64_to_image_path(img_b64, demo_cache, uid) + img_sig = _img_sig_from_rec(rec) + out_recs.append({'uid': uid, 'text_in': text_in or '', 'text_out': (answers or [''])[0], 'img_sig': img_sig}) + img_paths.append(img_path) + idx += 1 + if max_items is not None and len(out_recs) >= max_items: + break + return out_recs, img_paths + + +def _detect_tasks(root: Path, prefix: str) -> List[str]: + base = root / 'data' / prefix + out: List[str] = [] + if not base.exists(): + return out + for sub in base.iterdir(): + if not sub.is_dir(): + continue + for sp in ('test.jsonl', 'val.jsonl', 'validation.jsonl', 'test.json', 'val.json', 'validation.json'): + if (sub / sp).exists(): + out.append(f'{prefix}/{sub.name}') + break + return out + + +def _load_adapter(name: str, model_path: str): + name = (name or '').lower() + if name in ('idefics2', 'idefics', 'i2'): + from adapters import idefics2_adapter as A + elif name in ('qwen-vl', 'qwenvl', 'qwen'): + from adapters import qwen_vl_adapter as A + elif name in ('qwen3-vl', 'qwen3vl', 'qwen3'): + from adapters import qwen3vl_adapter as A + elif name in ('gemma3', 'gemma-3', 'gemma'): + from adapters import gemma3_adapter as A + else: + raise ValueError(f'Unknown adapter: {name}') + return A.create(model_path) + + + +def _encode_pairs(processor, model, image_paths: List[str], texts: List[str], device: torch.device, batch_size: int = 8) -> torch.Tensor: + """Generic multimodal pair encoding using AutoProcessor/AutoModel. + - Aligns image resize/crop to BridgeTower's expected grid; falls back to manual preproc on mismatch. + """ + import numpy as np + from PIL import Image + import torch.nn.functional as F + + def _infer_bt_target_crop(m) -> int: + """Infer the expected square crop for BridgeTower's ViT. + Defaults to 336 if probing fails. + """ + default = 336 + try: + vis = getattr(m, 'vision_model', None) + visual = getattr(vis, 'visual', vis) + emb_mod = getattr(getattr(visual, 'embeddings', visual), 'position_embedding', None) + if emb_mod is None: + return default + pos_len = int(getattr(emb_mod, 'num_embeddings', 0) or getattr(getattr(emb_mod, 'weight', None), 'shape', [0])[0]) + if pos_len <= 1: + return default + grid_tokens = max(1, pos_len - 1) + grid_n = int(grid_tokens ** 0.5) + if grid_n * grid_n != grid_tokens: + return default + patch_sz = None + try: + patch_proj = getattr(getattr(getattr(visual, 'embeddings', visual), 'patch_embedding', None), 'projection', None) + if patch_proj is not None and hasattr(patch_proj, 'weight') and getattr(patch_proj.weight, 'ndim', 0) >= 4: + patch_sz = int(patch_proj.weight.shape[-1]) + except Exception: + patch_sz = None + if not isinstance(patch_sz, int) or patch_sz <= 0: + cfg_patch = getattr(getattr(m, 'config', None), 'vision_config', None) + cfg_patch = getattr(cfg_patch, 'patch_size', None) if cfg_patch is not None else None + if isinstance(cfg_patch, int) and cfg_patch > 0: + patch_sz = int(cfg_patch) + else: + patch_sz = 14 + return int(grid_n * int(patch_sz)) + except Exception: + return default + + def _safe_image_args_for_processor(proc, m=None): + def _clamp_pos(v, default): + if v is None: + return default + try: + v = int(v) + except Exception: + return default + return max(1, v) + ip = getattr(proc, 'image_processor', None) + target = _infer_bt_target_crop(m) if m is not None else 336 + raw_size = getattr(ip, 'size', None) if ip is not None else None + raw_crop = getattr(ip, 'crop_size', None) if ip is not None else None + crop_h = crop_w = None + if isinstance(raw_crop, dict): + crop_h = raw_crop.get('shortest_edge') or raw_crop.get('height') or raw_crop.get('width') + crop_w = crop_h + elif isinstance(raw_crop, (int, float)): + crop_h = crop_w = int(raw_crop) + crop_h = _clamp_pos(crop_h, target) + crop_w = _clamp_pos(crop_w, crop_h) + se = max(crop_h, crop_w) + size_se = None + if isinstance(raw_size, dict) and 'shortest_edge' in raw_size: + size_se = _clamp_pos(raw_size.get('shortest_edge'), se) + size = {'shortest_edge': (size_se if size_se is not None else se)} + # BridgeTower expects resize/crop sizes as 'shortest_edge'; + # also pin channels_last to avoid 1x1x3 ambiguity when H=W=1. + return { + 'do_resize': True, + 'size': {'shortest_edge': se}, + 'do_center_crop': True, + 'crop_size': {'shortest_edge': se}, + 'input_data_format': 'channels_last', + } + + def _manual_image_preprocess(proc, pil_imgs: List[Image.Image], target: int) -> Dict[str, torch.Tensor]: + means = getattr(getattr(processor, 'image_processor', processor), 'image_mean', [0.48145466, 0.4578275, 0.40821073]) + stds = getattr(getattr(processor, 'image_processor', processor), 'image_std', [0.26862954, 0.26130258, 0.27577711]) + means = np.array(means, dtype=np.float32).reshape(3, 1, 1) + stds = np.array(stds, dtype=np.float32).reshape(3, 1, 1) + batch = [] + for im in pil_imgs: + im = im.convert('RGB') + w, h = im.size + if w <= 0 or h <= 0: + w = h = max(1, int(target)) + im = im.resize((w, h), resample=Image.BICUBIC) + se = min(w, h) + scale = float(target) / float(max(1, se)) + new_w = max(1, int(round(w * scale))) + new_h = max(1, int(round(h * scale))) + im = im.resize((new_w, new_h), resample=Image.BICUBIC) + left = max(0, (im.width - target) // 2) + top = max(0, (im.height - target) // 2) + right = left + target + bottom = top + target + im = im.crop((left, top, right, bottom)) + arr = np.asarray(im).astype(np.float32) / 255.0 + arr = arr.transpose(2, 0, 1) + arr = (arr - means) / stds + batch.append(arr) + x = np.stack(batch, axis=0) + return {'pixel_values': torch.from_numpy(x)} + + embs: List[torch.Tensor] = [] + model.eval() + with torch.no_grad(): + for i in range(0, len(texts), batch_size): + chunk_t = texts[i:i+batch_size] + chunk_i = image_paths[i:i+batch_size] + pil_imgs = [] + for p in chunk_i: + img = Image.open(p) + if img.mode != 'RGB': + img = img.convert('RGB') + pil_imgs.append(img) + try: + tok = processor.tokenizer(text=chunk_t, padding=True, truncation=True, max_length=256, return_tensors='pt') + img_args = _safe_image_args_for_processor(processor, model) + img = processor.image_processor(images=pil_imgs, return_tensors='pt', **img_args) + inputs = {**tok, **img} + except Exception as e: + # Fallback path for edge cases: tiny images (1x1), ambiguous channel dim, or other preproc errors + em = str(e).lower() + if ('height and width must be > 0' in em) or ('channel dimension is ambiguous' in em) or ('mean must have' in em): + target = _infer_bt_target_crop(model) + img = _manual_image_preprocess(processor, pil_imgs, target) + tok = processor.tokenizer(text=chunk_t, padding=True, truncation=True, max_length=256, return_tensors='pt') + inputs = {**tok, **img} + else: + raise + inputs = {k: v.to(device) for k, v in inputs.items()} + try: + out = model(**inputs) + except RuntimeError as re: + msg = str(re) + if ('must match the size of tensor' in msg) or ('The size of tensor a' in msg) or ('position_embedding' in msg) or ('position ids' in msg): + target = _infer_bt_target_crop(model) + img = _manual_image_preprocess(processor, pil_imgs, target) + new_inputs = {**tok, **img} + new_inputs = {k: v.to(device) for k, v in new_inputs.items()} + out = model(**new_inputs) + else: + raise + if hasattr(out, 'pooler_output') and out.pooler_output is not None: + pooled = out.pooler_output + elif isinstance(out, dict) and 'pooler_output' in out: + pooled = out['pooler_output'] + else: + last = out.last_hidden_state if hasattr(out, 'last_hidden_state') else (list(out.values())[0] if isinstance(out, dict) else out[0]) + pooled = last.mean(dim=1) + pooled = F.normalize(pooled, p=2, dim=-1) + embs.append(pooled.detach().cpu()) + return torch.cat(embs, dim=0) if embs else torch.empty(0, 1024) + + +def run_predictions( + adapter: str, + model_path: str, + dataset_root: str, + retriever_model_path: str, + output_dir: str, + orders: Sequence[str], # 'image-text' | 'text-image' | 'text-image-text' + categories: Sequence[str], # subset of: captioning,vqa,classification,reasoning + total_samples: int = 800, + k_shots: int = 3, + split: str = 'test', + seed: int = 0, + temperature: float = 0.6, + top_p: float = 1.0, + max_new_tokens: int = 32, + auto_detect: bool = True, + reuse_cache: bool = True, +) -> Dict[str, Dict[str, Tuple[List[str], List[List[str]]]]]: + """Return predictions and references as a nested dict: + {order: {category: (preds, refs)}} and write cache files under + /_cache/__.jsonl. + """ + + out_dir = Path(output_dir); out_dir.mkdir(parents=True, exist_ok=True) + cache_dir = out_dir / '_image_cache' + feat_cache = out_dir / '_feat_cache'; feat_cache.mkdir(parents=True, exist_ok=True) + pred_cache = out_dir / '_cache'; pred_cache.mkdir(parents=True, exist_ok=True) + + # Normalize orders/categories up front for potential cache reuse + orders = [o.strip().lower() for o in orders if str(o).strip()] + cats = [c.strip().lower() for c in categories if str(c).strip()] + + # Fast path: if reuse_cache and all required cache files exist, load and return + if reuse_cache: + all_exist = True + for order_name in orders: + for cat in cats: + if not (pred_cache / f'{cat}__{order_name}.jsonl').exists(): + all_exist = False + break + if not all_exist: + break + if all_exist: + results: Dict[str, Dict[str, Tuple[List[str], List[List[str]]]]] = {} + for order_name in orders: + order_out: Dict[str, Tuple[List[str], List[List[str]]]] = {} + for cat in cats: + cf = pred_cache / f'{cat}__{order_name}.jsonl' + preds: List[str] = [] + refs: List[List[str]] = [] + for line in cf.read_text(encoding='utf-8').splitlines(): + if not line.strip(): + continue + obj = json.loads(line) + preds.append(obj.get('pred', '')) + refs.append(obj.get('answers') or []) + order_out[cat] = (preds, refs) + results[order_name] = order_out + return results + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # Normalize retriever_model_path to support Windows-style paths on Linux + def _normalize_model_path(p: str) -> str: + if not isinstance(p, str): + return p + q = p.strip() + # If directory exists as-is, keep + try: + if q and os.path.isdir(q): + return q + except Exception: + pass + # Best-effort Windows -> WSL path (e.g., E:/foo or E:\\foo -> /mnt/e/foo) + if len(q) >= 2 and q[1] == ':' and q[0].isalpha(): + drive = q[0].lower() + rest = q[2:].replace('\\', '/').lstrip('/') + cand = f"/mnt/{drive}/{rest}" + try: + if os.path.isdir(cand): + return cand + except Exception: + pass + return q + + retriever_model_path = _normalize_model_path(retriever_model_path) + processor = AutoProcessor.from_pretrained(retriever_model_path, trust_remote_code=True) + mm_model = AutoModel.from_pretrained(retriever_model_path, trust_remote_code=True).to(device).eval() + + runner = _load_adapter(adapter, model_path) + + order_map: Dict[str, Callable] = { + 'image-text': build_image_text, + 'text-image': build_text_image, + 'text-image-text': build_text_image_text, + } + prompts = _build_prompts() + + root = Path(dataset_root) + default_tasks = { + 'captioning': _detect_tasks(root, 'captioning'), + 'vqa': _detect_tasks(root, 'vqa'), + 'classification': _detect_tasks(root, 'classification'), + 'reasoning': _detect_tasks(root, 'reasoning'), + } + # Basic fallbacks if detection returns empty + default_tasks.setdefault('vqa', ['vqa/vqav2']) + + def _quota(total: int, n: int) -> List[int]: + base = total // n; rem = total % n + return [base + (1 if i < rem else 0) for i in range(n)] + + rng = random.Random(seed) + + demo_limit_env = os.getenv('DEMO_POOL_LIMIT') + demo_limit = int(demo_limit_env) if (demo_limit_env or '').strip().isdigit() else None + + # Build demo pools + demo_pools: Dict[str, Tuple[List[dict], List[str], torch.Tensor]] = {} + if k_shots > 0: + for cat in cats: + for subdir in default_tasks.get(cat, []): + demo_prefer = _prefer_demo_splits(split) + try: + recs, imgs = load_pool_items(root, subdir, cache_dir, prefer=demo_prefer, max_items=demo_limit, category=cat) + except Exception: + demo_pools[subdir] = ([], [], None) # type: ignore + continue + feat_path = feat_cache / f'multimodal_{subdir.replace("/","_")}_pool.pt' + if feat_path.exists(): + pool_emb = torch.load(feat_path, map_location='cpu') + else: + texts = [r['text_in'] for r in recs] + pool_emb = _encode_pairs(processor, mm_model, imgs, texts, device) + torch.save(pool_emb, feat_path) + demo_pools[subdir] = (recs, imgs, pool_emb) # type: ignore + + # Build eval plan + eval_plans: Dict[str, List[int]] = {} + eval_samples: Dict[str, List[Any]] = {} + per_cat = _quota(total_samples, len(cats) or 1) + for cat, quota in zip(cats, per_cat): + tasks = default_tasks.get(cat, []) + if not tasks: + continue + per_task = _quota(quota, len(tasks)) + for subdir, q in zip(tasks, per_task): + try: + pool = [s for s in iter_m3it_samples(dataset_root, subdir, split=split, cache_dir=str(cache_dir), max_samples=None)] + except Exception: + continue + if not pool: + continue + k = min(q, len(pool)) + idxs = rng.sample(range(len(pool)), k=k) + eval_plans[subdir] = idxs + eval_samples[subdir] = pool + + results: Dict[str, Dict[str, Tuple[List[str], List[List[str]]]]] = {} + + for order_name in orders: + builder = order_map.get(order_name) + if builder is None: + continue + order_out: Dict[str, Tuple[List[str], List[List[str]]]] = {} + for cat in cats: + preds: List[str] = [] + refs: List[List[str]] = [] + lines: List[str] = [] + tasks = default_tasks.get(cat, []) + for subdir in tasks: + # Compose instruction: built-in prompt + dataset-specific instructions.json if present + ds_insts = load_instructions(root, subdir) + ds_text = '' + if isinstance(ds_insts, list) and ds_insts: + ds_text = '\n'.join([s for s in ds_insts if isinstance(s, str) and s.strip()]) + inst = prompts.get(cat, '') + if ds_text: + inst = (inst + ('\n' + ds_text)).strip() if inst else ds_text + idxs = eval_plans.get(subdir, []) + recs, imgs, pool_emb = demo_pools.get(subdir, ([], [], None)) + for _idx in idxs: + smp = eval_samples[subdir][_idx] + # Select in-context demos per query + if k_shots > 0 and recs and isinstance(pool_emb, torch.Tensor) and pool_emb.numel() > 0: + q_emb = _encode_pairs(processor, mm_model, [smp.image_path], [smp.text or ''], device) + sim = (q_emb @ pool_emb.T).squeeze(0) + q_uid = _extract_uid(smp.raw, '') if isinstance(smp.raw, dict) else '' + q_sig = _img_sig_from_rec(smp.raw) if isinstance(smp.raw, dict) else '' + q_psig = _image_path_sig(smp.image_path) + q_txt = (smp.text or '').strip().lower() + mask = torch.ones(sim.shape[0], dtype=torch.bool, device=sim.device) + for i, r in enumerate(recs): + if q_uid and r.get('uid') == q_uid: + mask[i] = False; continue + if q_sig and r.get('img_sig') and r['img_sig'] == q_sig: + mask[i] = False; continue + if q_txt and (r.get('text_in') or '').strip().lower() == q_txt: + mask[i] = False; continue + mask &= (sim < 0.999).to(dtype=mask.dtype, device=mask.device).bool() + sim[~mask] = -1e4 + pre_k = min(max(k_shots * 50, 500), sim.numel()) + cand = [i for i in torch.topk(sim, k=pre_k).indices.tolist() if mask[i].item()] + if q_psig: + cand = [i for i in cand if _image_path_sig(imgs[i]) != q_psig] + idxs2 = cand[:k_shots] + if len(idxs2) < min(k_shots, sim.numel()): + rest = [i for i in range(sim.numel()) if mask[i].item() and i not in idxs2] + if q_psig: + rest = [i for i in rest if _image_path_sig(imgs[i]) != q_psig] + rng2 = random.Random(seed) + rng2.shuffle(rest) + idxs2.extend(rest[:max(0, k_shots - len(idxs2))]) + demos = [recs[i] for i in idxs2] + demo_imgs = [imgs[i] for i in idxs2] + else: + demos, demo_imgs = [], [] + + # Build query text; ensure the final user turn carries a concrete + # [REQUEST] even when the dataset provides an empty prompt (e.g., captioning). + q_text = smp.text or '' + if cat in ('classification', 'reasoning'): + # For classification/reasoning add parsed options/inputs under [REQUEST] + inp = _extract_inputs(smp.raw) + if inp: + q_text = (q_text.rstrip() + '\n' + inp).strip() + # When no per-sample text is present (common in captioning), + # reuse the dataset-specific instruction as the actual request. + if not q_text.strip(): + base_req = (inst or '').strip() + # Fallback to a minimal request if both are empty + if not base_req: + base_req = _build_prompts().get(cat, '').strip() + q_text = base_req + + segs = builder(inst, demos, demo_imgs, smp.image_path, q_text) + resp = runner.generate_from_segments( + segs, + temperature=temperature, + top_p=top_p, + max_new_tokens=max_new_tokens, + ) + + preds.append(resp) + refs.append(smp.answers or []) + lines.append(json.dumps({ + 'pred': resp, + 'answers': smp.answers or [], + 'ref_text': _extract_ref_text(smp.raw), + 'meta': { + 'task': subdir, + 'image_path': smp.image_path, + 'text': q_text, + 'inputs': _extract_inputs(smp.raw), + 'gold_choice': _extract_gold_choice(smp.raw), + 'order': order_name, + } + }, ensure_ascii=False)) + cache_file = pred_cache / f'{cat}__{order_name}.jsonl' + if lines: + cache_file.write_text('\n'.join(lines), encoding='utf-8') + order_out[cat] = (preds, refs) + results[order_name] = order_out + return results + + +def _extract_ref_text(sample_raw: Dict) -> str: + if not isinstance(sample_raw, dict): + return '' + keys = [ + 'rationale', 'rationales', 'explanation', 'explanations', 'justification', + 'reason', 'reasons', 'solution', 'solutions', + 'chain_of_thought', 'cot', 'rationale_text', 'explain', 'explanation_text', + 'final_rationale' + ] + for k in keys: + v = sample_raw.get(k) + if isinstance(v, str) and v.strip(): + return v.strip() + if isinstance(v, list): + for x in v: + if isinstance(x, str) and x.strip(): + return x.strip() + mv = sample_raw.get('meta') if isinstance(sample_raw.get('meta'), dict) else None + if isinstance(mv, dict): + for k in keys: + v = mv.get(k) + if isinstance(v, str) and v.strip(): + return v.strip() + if isinstance(v, list): + for x in v: + if isinstance(x, str) and x.strip(): + return x.strip() + for k in ('output', 'outputs'): + v = sample_raw.get(k) + if isinstance(v, list): + v = v[0] if v else '' + if isinstance(v, str) and v: + s = v.strip() + import re as _re + m = _re.search(r"\bBECAUSE\s*:\s*(.*)$", s, flags=_re.IGNORECASE | _re.DOTALL) + if m: + return m.group(1).strip() + if isinstance(mv, dict): + for k in ('output', 'outputs'): + v = mv.get(k) + if isinstance(v, list): + v = v[0] if v else '' + if isinstance(v, str) and v: + s = v.strip() + import re as _re + m = _re.search(r"\bBECAUSE\s*:\s*(.*)$", s, flags=_re.IGNORECASE | _re.DOTALL) + if m: + return m.group(1).strip() + return '' + + +def _extract_inputs(sample_raw: Dict) -> str: + import re as _re + if not isinstance(sample_raw, dict): + return '' + # Structured pairs common in SNLI-VE/VE-like datasets + parts: List[str] = [] + def _pick(k: str) -> Optional[str]: + v = sample_raw.get(k) + if isinstance(v, str) and v.strip(): + return v.strip() + mv = sample_raw.get('meta') if isinstance(sample_raw.get('meta'), dict) else None + if isinstance(mv, dict): + vv = mv.get(k) + if isinstance(vv, str) and vv.strip(): + return vv.strip() + return None + prem = _pick('premise') or _pick('caption') or _pick('sentence1') + hyp = _pick('hypothesis') or _pick('statement') or _pick('sentence2') + if prem: + parts.append(f'Premise: {prem}') + if hyp: + parts.append(f'Hypothesis: {hyp}') + # user_prompt/instruction can be appended if present + up = _pick('user_prompt') or _pick('instruction') + if up: + parts.append(up) + # If we collected structured text, keep it as the base before other inputs/options + base_struct = '\n'.join(parts).strip() + v = sample_raw.get('inputs') + if isinstance(v, str) and v.strip(): + s = v.strip() + return (base_struct + ('\n' + s if base_struct else s)) + vv = sample_raw.get('input') + if isinstance(vv, str) and vv.strip(): + s = vv.strip() + if _re.search(r"\([A-Za-z]\)|(?m)^\s*[A-Za-z][\).:]+\s+", s): + return (base_struct + ('\n' + s if base_struct else s)) + for key in ('choices', 'options'): + lst = sample_raw.get(key) + if isinstance(lst, list) and lst: + letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + lines = [] + for i, item in enumerate(lst): + if isinstance(item, str): + lines.append(f'({letters[i]}) {item}') + elif isinstance(item, dict): + t = item.get('text') or item.get('label') or item.get('value') or '' + lines.append(f'({letters[i]}) {t}') + s = '\n'.join(lines) + return (base_struct + ('\n' + s if base_struct else s)) + # candidates (ITM-style) + lst = sample_raw.get('candidates') + if isinstance(lst, list) and lst: + letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + lines = [] + for i, item in enumerate(lst): + if isinstance(item, str): + lines.append(f'({letters[i]}) {item}') + elif isinstance(item, dict): + t = item.get('text') or item.get('label') or item.get('value') or '' + lines.append(f'({letters[i]}) {t}') + s = '\n'.join(lines) + return (base_struct + ('\n' + s if base_struct else s)) + mv = sample_raw.get('meta') if isinstance(sample_raw.get('meta'), dict) else None + if isinstance(mv, dict): + vv = mv.get('inputs') + if isinstance(vv, str) and vv.strip(): + s = vv.strip() + return (base_struct + ('\n' + s if base_struct else s)) + vv2 = mv.get('input') + if isinstance(vv2, str) and vv2.strip(): + s = vv2.strip() + if _re.search(r"\([A-Za-z]\)|(?m)^\s*[A-Za-z][\).:]+\s+", s): + return (base_struct + ('\n' + s if base_struct else s)) + for key in ('choices', 'options'): + lst = mv.get(key) + if isinstance(lst, list) and lst: + letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + lines = [] + for i, item in enumerate(lst): + if isinstance(item, str): + lines.append(f'({letters[i]}) {item}') + elif isinstance(item, dict): + t = item.get('text') or item.get('label') or item.get('value') or '' + lines.append(f'({letters[i]}) {t}') + s = '\n'.join(lines) + return (base_struct + ('\n' + s if base_struct else s)) + return base_struct + + +def _extract_gold_choice(sample_raw: Dict) -> Optional[str]: + try: + mv = sample_raw.get('meta') if isinstance(sample_raw.get('meta'), dict) else None + if isinstance(mv, dict) and 'correct_choice_idx' in mv: + ci = mv.get('correct_choice_idx') + if isinstance(ci, int) and ci >= 0: + letters = 'abcdefghijklmnopqrstuvwxyz' + if ci < len(letters): + return letters[ci] + def _choices_from(obj: Dict): + if not isinstance(obj, dict): + return None + for key in ('choices', 'options'): + lst = obj.get(key) + if isinstance(lst, list) and lst: + out = [] + for it in lst: + if isinstance(it, str): + out.append(it) + elif isinstance(it, dict): + out.append(it.get('text') or it.get('label') or it.get('value') or '') + return out + return None + ch = _choices_from(sample_raw) or _choices_from(mv) + ans = None + for k in ('answer', 'label', 'target', 'output', 'outputs'): + if isinstance(sample_raw.get(k), str): + ans = sample_raw.get(k); break + if isinstance(mv, dict) and isinstance(mv.get(k), str): + ans = mv.get(k); break + if isinstance(ans, list): + ans = (ans[0] if ans else None) + if isinstance(ans, str) and ch: + a = ans.strip().lower() + for i, t in enumerate(ch): + if isinstance(t, str) and t.strip() and t.strip().lower() == a: + letters = 'abcdefghijklmnopqrstuvwxyz' + if i < len(letters): + return letters[i] + except Exception: + pass + return None diff --git a/ICL/LV/code/core/eval/random_k_shot_vqa.py b/ICL/LV/code/core/eval/random_k_shot_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..4826323ed4259331f26603d28e40805e7e631594 --- /dev/null +++ b/ICL/LV/code/core/eval/random_k_shot_vqa.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +import argparse +import json +import random +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Sequence + +from core.datasets.m3it_reader import iter_m3it_samples, load_instructions +from core.metrics.metrics import token_f1, bertscore_f1 +from core.prompting.openai_segments import openai_to_list_format + + +VQA_SUBTASKS = [ + 'vqa/vqav2','vqa/docvqa','vqa/ocr-vqa','vqa/st-vqa','vqa/text-vqa','vqa/gqa','vqa/okvqa','vqa/a-okvqa', +] + +PAPER_VQA_INTRO = 'Examine the image and answer the question. Please only provide the short answer without explanation.' + + +def distribute_quota(total: int, n: int) -> List[int]: + base = total // n; rem = total % n + return [base + (1 if i < rem else 0) for i in range(n)] + + +@dataclass +class DemoItem: + image_path: str + text_in: str + text_out: str + uid: str + + +def _extract_uid(raw: Dict, fallback: str) -> str: + for k in ('id','image_id'): + v = raw.get(k); + if isinstance(v,(str,int)): return str(v) + mv = raw.get('meta',{}) if isinstance(raw.get('meta'),dict) else {} + for k in ('img_id','id','image_id'): + v = mv.get(k) + if isinstance(v,(str,int)): return str(v) + return fallback + + +def load_demo_pool(dataset_root: Path, subdir: str, cache_dir: Path) -> List[DemoItem]: + from core.datasets.m3it_reader import read_jsonl, _candidate_split_files, _IMAGE_KEYS, _TEXT_IN_KEYS, _TEXT_OUT_KEYS, _b64_to_image_path + path = None + for s in ('val','validation','train','dev','test'): + for p in _candidate_split_files(dataset_root, subdir, s): + if p.exists(): path = p; break + if path is not None: break + if path is None: raise FileNotFoundError('No demo split') + out: List[DemoItem] = [] + cache = cache_dir / f'_demo_{subdir.replace("/","_")}' + cache.mkdir(parents=True, exist_ok=True) + idx = 0 + for rec in read_jsonl(path): + img_b64 = None + for k in _IMAGE_KEYS: + v = rec.get(k) + if isinstance(v,str) and len(v)>100: img_b64=v; break + mv = rec.get('meta') if isinstance(rec.get('meta'),dict) else None + if mv and isinstance(mv.get(k),str) and len(mv[k])>100: img_b64=mv[k]; break + if not img_b64: continue + # text in/out + text_in = '' + for k in _TEXT_IN_KEYS: + v = rec.get(k) or (rec.get('meta') or {}).get(k) + if isinstance(v,str): text_in=v; break + answers = None + for k in _TEXT_OUT_KEYS: + v = rec.get(k) or (rec.get('meta') or {}).get(k) + if isinstance(v,str): answers=[v]; break + if isinstance(v,list) and all(isinstance(x,str) for x in v): answers=list(v); break + img_path = _b64_to_image_path(img_b64, cache, str(rec.get('image_id') or rec.get('id') or idx)) + uid = _extract_uid(rec, f'{idx:08d}') + out.append(DemoItem(img_path, text_in or '', (answers or [''])[0], uid)) + idx += 1 + return out + + +def load_adapter(name: str, model_path: str): + name = (name or '').lower() + if name in ('idefics2','idefics','i2'): + from adapters import idefics2_adapter as A + elif name in ('qwen-vl','qwenvl','qwen'): + from adapters import qwen_vl_adapter as A + elif name in ('qwen3-vl','qwen3vl','qwen3'): + from adapters import qwen3vl_adapter as A + elif name in ('gemma3','gemma-3','gemma'): + from adapters import gemma3_adapter as A + else: + raise ValueError(f'Unknown adapter: {name}') + return A.create(model_path) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--adapter', required=True) + ap.add_argument('--model-path', required=True) + ap.add_argument('--dataset-root', required=True) + ap.add_argument('--output-dir', default='runs/unified_random_k_shot_vqa') + ap.add_argument('--total-samples', type=int, default=500) + ap.add_argument('--k-shots', type=int, default=1) + ap.add_argument('--temperature', type=float, default=0.2) + ap.add_argument('--top-p', type=float, default=1.0) + ap.add_argument('--max-new-tokens', type=int, default=32) + ap.add_argument('--split', type=str, default='val') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--use-paper-instruction', action='store_true') + ap.add_argument('--instruction-image', type=str, default=None) + ap.add_argument('--auto-detect', action='store_true') + ap.add_argument('--dump-first', type=int, default=0) + ap.add_argument('--bertscore-lang', type=str, default='', help="Language code for BERTScore baseline rescaling (e.g., 'en', 'zh')") + args = ap.parse_args() + + out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True) + cache_dir = out_dir / '_image_cache' + + if args.temperature <= 0.0: + raise ValueError('temperature must be > 0 for few-shot random per setting') + + adapter = load_adapter(args.adapter, args.model_path) + + tasks = VQA_SUBTASKS + if args.auto_detect: + from core.datasets.m3it_reader import _candidate_split_files as _cand + root = Path(args.dataset_root) + present = [t for t in VQA_SUBTASKS if any(p.exists() for p in _cand(root, t, args.split))] + if present: tasks = present + + per_task = distribute_quota(args.total_samples, len(tasks)) + + all_preds: List[str] = []; all_refs: List[List[str]] = []; details: List[Dict] = [] + + rng = random.Random(args.seed) + for subdir, quota in zip(tasks, per_task): + # instruction text + insts = load_instructions(Path(args.dataset_root), subdir) + ds_inst = '' + if isinstance(insts, list) and insts: + ds_inst = '\n'.join([s for s in insts if isinstance(s, str) and s.strip()]) + inst = (PAPER_VQA_INTRO + (('\n' + ds_inst) if ds_inst else '')) if args.use_paper_instruction else (ds_inst or 'Examine the image and answer the question with a short answer.') + + try: + demo_pool = load_demo_pool(Path(args.dataset_root), subdir, cache_dir) + except FileNotFoundError: + print(f'Skipping {subdir}: no demo split') + continue + if not demo_pool: + print(f'Skipping {subdir}: empty demo pool') + continue + eval_pool = [s for s in iter_m3it_samples(args.dataset_root, subdir, split=args.split, cache_dir=str(cache_dir))] + if not eval_pool: + print(f'Skipping {subdir}: empty eval pool') + continue + select_n = min(quota, len(eval_pool)) + sel_indices = rng.sample(range(len(eval_pool)), k=select_n) + + taken = 0 + for _idx in sel_indices: + smp = eval_pool[_idx] + # sample demos (exclude same uid when available) + q_uid = _extract_uid(smp.raw, '') if isinstance(smp.raw, dict) else '' + cand = [d for d in demo_pool if d.uid != q_uid] + if args.k_shots >= len(cand): + rng.shuffle(cand); demos = cand + else: + demos = rng.sample(cand, k=args.k_shots) + + oa = [] + if inst: oa.append({'type':'text','text':inst}) + if args.instruction_image: oa.append({'type':'image_url','image_url':args.instruction_image}) + for d in demos: + oa.append({'type':'image_url','image_url':d.image_path}) + oa.append({'type':'text','text':f"[REQUEST]\n{d.text_in.strip()}\n[RESPONSE]\n{d.text_out.strip()}"}) + oa.append({'type':'image_url','image_url':smp.image_path}) + oa.append({'type':'text','text':f"[REQUEST]\n{(smp.text or '').strip()}\n[RESPONSE]"}) + segs = openai_to_list_format(oa, cache_dir=cache_dir / '_oa_cache') + + pred = adapter.generate_from_segments(segs, temperature=args.temperature, top_p=args.top_p, max_new_tokens=args.max_new_tokens) + all_preds.append(pred) + all_refs.append(smp.answers or []) + details.append({'task': subdir, 'k_shots': args.k_shots, 'image_path': smp.image_path, 'text': smp.text, 'answers': smp.answers, 'demo_count': len(demos)}) + taken += 1 + + if args.dump_first and len(details) <= args.dump_first: + (out_dir / f'debug_{subdir.replace("/","_")}_{len(details)-1:04d}_openai.json').write_text(json.dumps(oa, ensure_ascii=False, indent=2), encoding='utf-8') + (out_dir / f'debug_{subdir.replace("/","_")}_{len(details)-1:04d}_list.json').write_text(json.dumps(segs, ensure_ascii=False, indent=2), encoding='utf-8') + + print(f'{subdir}: {taken} eval samples | demo-pool={len(demo_pool)} | k={args.k_shots}') + + tf1 = token_f1(all_preds, all_refs) + try: + bsf1 = bertscore_f1(all_preds, all_refs, lang=(args.bertscore_lang or None)) + except Exception: + bsf1 = None + out = {'setting':'few-shot-random','adapter':args.adapter,'k_shots':args.k_shots,'model_path':args.model_path,'total':len(all_preds), + 'metrics':{'token_f1':tf1,'bertscore_f1':bsf1}, + 'predictions':[{'pred':p,'answers':r,'meta':m} for p,r,m in zip(all_preds, all_refs, details)]} + (out_dir / f'vqa_random_{args.k_shots}shot.json').write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f'Few-shot (Random, k={args.k_shots}) Token-F1={tf1:.2f}', 'BERTScore-F1=' + (f'{bsf1:.2f}' if bsf1 is not None else 'NA')) + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/eval/repair_cache_answers.py b/ICL/LV/code/core/eval/repair_cache_answers.py new file mode 100644 index 0000000000000000000000000000000000000000..6f7dd1860a81ad883b6120605cec5c9cb7cdabe4 --- /dev/null +++ b/ICL/LV/code/core/eval/repair_cache_answers.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +""" +Repair cached prediction files by backfilling missing references (answers) from the dataset, +without rerunning model inference. + +We match each cached sample to a dataset sample by image signature and copy over its references. + +Usage: + python -m core.eval.repair_cache_answers \ + --output-base runs/order_idefics2 \ + --dataset-root /z_data/datasets/M3IT \ + --categories captioning \ + --orders image-text,text-image,text-image-text + +Notes: + - Only updates cache lines whose 'answers' are missing/empty. + - Writes back in-place. + - Safe to run multiple times. +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from core.datasets.m3it_reader import read_jsonl, _IMAGE_KEYS, _IMAGE_PATH_KEYS +from core.eval.order_eval_core import _image_path_sig + + +def _img_sig_from_rec(rec: dict, dataset_root: Path, subdir: str) -> str: + # Try base64-like strings + try: + import hashlib + for k in _IMAGE_KEYS: + v = rec.get(k) + if isinstance(v, str) and len(v) > 100: + return hashlib.sha1(v.encode("utf-8")).hexdigest() + mv = rec.get("meta") if isinstance(rec.get("meta"), dict) else None + if mv and isinstance(mv.get(k), str) and len(mv[k]) > 100: + return hashlib.sha1(mv[k].encode("utf-8")).hexdigest() + except Exception: + pass + # Try path keys (resolve common roots) + from core.datasets.m3it_reader import _resolve_image_path + for k in _IMAGE_PATH_KEYS: + p = rec.get(k) + if isinstance(p, str) and p.strip(): + rp = _resolve_image_path(dataset_root, subdir, p) + if rp: + return _image_path_sig(rp) + mv = rec.get("meta") if isinstance(rec.get("meta"), dict) else None + if isinstance(mv, dict): + p = mv.get(k) + if isinstance(p, str) and p.strip(): + rp = _resolve_image_path(dataset_root, subdir, p) + if rp: + return _image_path_sig(rp) + return "" + + +def _build_sig2answers(dataset_root: Path, subdir: str) -> Dict[str, List[str]]: + from core.datasets.m3it_reader import _candidate_split_files + files = _candidate_split_files(dataset_root, subdir, "val") + file = None + for f in files: + if f.exists(): + file = f + break + if file is None: + return {} + sig2ans: Dict[str, List[str]] = {} + for rec in read_jsonl(file): + sig = _img_sig_from_rec(rec, dataset_root, subdir) + if not sig: + continue + # Collect references across common keys + def _to_str_list(v) -> Optional[List[str]]: + if v is None: + return None + if isinstance(v, (str, int, float, bool)): + return [str(v)] + if isinstance(v, list): + out = [] + for x in v: + if isinstance(x, (str, int, float, bool)): + out.append(str(x)) + return out if out else None + return None + answers: Optional[List[str]] = None + keys = ( + "output", "outputs", "answer", "answers", "target", "label", + "caption", "captions", "caption_text", + ) + for k in keys: + if k in rec: + tmp = _to_str_list(rec[k]) + if tmp: + answers = tmp; break + mv = rec.get("meta") if isinstance(rec.get("meta"), dict) else None + if isinstance(mv, dict) and k in mv: + tmp = _to_str_list(mv[k]) + if tmp: + answers = tmp; break + if answers: + sig2ans.setdefault(sig, answers) + return sig2ans + + +def _iter_cache_files(base: Path, categories: List[str], orders: List[str]) -> List[Tuple[str, Path]]: + out: List[Tuple[str, Path]] = [] + # Known metric folders per category + per_cat = { + "captioning": ["captioning_bertscore", "captioning_cider"], + "vqa": ["vqa_bertscore", "vqa_tokenf1"], + } + for cat in categories: + for folder in per_cat.get(cat, []): + for o in orders: + p = base / folder / "_cache" / f"{cat}__{o}.jsonl" + if p.exists(): + out.append((cat, p)) + return out + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--output-base", required=True) + ap.add_argument("--dataset-root", required=True) + ap.add_argument("--orders", default="image-text,text-image,text-image-text") + ap.add_argument("--categories", default="captioning", help="comma-separated: captioning,vqa") + args = ap.parse_args() + + base = Path(args.output_base) + dsroot = Path(args.dataset_root) + orders = [o.strip() for o in args.orders.split(",") if o.strip()] + cats = [c.strip() for c in args.categories.split(",") if c.strip()] + + # Collect unique subtasks from caches to build sig->answers maps once per subdir + cache_files = _iter_cache_files(base, cats, orders) + subdirs: Dict[str, Dict[str, List[str]]] = {} + for _, cf in cache_files: + for line in cf.read_text(encoding="utf-8").splitlines(): + if not line.strip(): + continue + try: + obj = json.loads(line) + except Exception: + continue + meta = obj.get("meta") or {} + task = meta.get("task") or "" + if isinstance(task, str) and task: + subdirs[task] = {} + + # Build maps + for sub in list(subdirs.keys()): + subdirs[sub] = _build_sig2answers(dsroot, sub) + + # Patch caches + total = 0 + fixed = 0 + for cat, cf in cache_files: + lines = cf.read_text(encoding="utf-8").splitlines() + out_lines: List[str] = [] + updated = 0 + for line in lines: + if not line.strip(): + continue + try: + obj = json.loads(line) + except Exception: + out_lines.append(line) + continue + total += 1 + cur_ans = obj.get("answers") or [] + need = (not isinstance(cur_ans, list)) or (len(cur_ans) == 0) + if not need: + out_lines.append(json.dumps(obj, ensure_ascii=False)) + continue + meta = obj.get("meta") or {} + task = meta.get("task") or "" + ipath = meta.get("image_path") or "" + if not task or not ipath: + out_lines.append(json.dumps(obj, ensure_ascii=False)) + continue + sig = _image_path_sig(ipath) + ans = subdirs.get(task, {}).get(sig) + if ans: + obj["answers"] = ans + updated += 1 + out_lines.append(json.dumps(obj, ensure_ascii=False)) + if updated: + cf.write_text("\n".join(out_lines), encoding="utf-8") + fixed += updated + print(f"[repair] updated {updated} lines in {cf}") + print(f"[repair] scanned {total} cache lines; filled answers for {fixed} lines") + + +if __name__ == "__main__": + main() + diff --git a/ICL/LV/code/core/eval/select_extremes.py b/ICL/LV/code/core/eval/select_extremes.py new file mode 100644 index 0000000000000000000000000000000000000000..44a9d68e5503692a1056716140c2e676d4177314 --- /dev/null +++ b/ICL/LV/code/core/eval/select_extremes.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +""" +Select per-sample extremes from per_sample//.jsonl dumps. + +Outputs a JSON with: +- diff_topk: top-K samples with largest |score(order_a) - score(order_b)| +- overall_topk: top-K by mean score across --overall-orders (ignoring missing) +- overall_bottomk: bottom-K by mean score across --overall-orders + +Notes: +- Works best with metrics that expose a numeric 'score' per sample (e.g., vqa_*f1, + caption_*f1/cider, reasoning_ras). For classification_accuracy/f1 which dump 'correct' + booleans, this script maps True->100, False->0. +- Sample identity is aligned by the 'i' field in per-sample files and assumed consistent + across orders for the same metric. +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Dict, List, Optional, Tuple + + +def _read_per_sample(base: Path, metric: str, order: str) -> Dict[int, dict]: + path = base / 'per_sample' / metric / f'{order}.jsonl' + out: Dict[int, dict] = {} + if not path.exists(): + return out + for line in path.read_text(encoding='utf-8').splitlines(): + if not line.strip(): + continue + try: + obj = json.loads(line) + except Exception: + continue + i = obj.get('i') + if not isinstance(i, int): + continue + out[i] = obj + return out + + +def _score_val(obj: dict) -> Optional[float]: + v = obj.get('score') + if isinstance(v, (int, float)): + return float(v) + # classification_accuracy dump uses 'correct' + if isinstance(obj.get('correct'), bool): + return 100.0 if obj['correct'] else 0.0 + return None + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument('--output-base', required=True, help='Run output base, e.g., runs/order_qwen-vl') + ap.add_argument('--metric', default='reasoning_ras') + ap.add_argument('--orders-a', default='image-text') + ap.add_argument('--orders-b', default='text-image') + ap.add_argument('--overall-orders', default='image-text,text-image,text-image-text') + ap.add_argument('--k', type=int, default=5) + ap.add_argument('--write', default='') + args = ap.parse_args() + + base = Path(args.output_base) + metric = args.metric + ord_a = args.orders_a.strip() + ord_b = args.orders_b.strip() + over_orders = [o.strip() for o in args.overall_orders.split(',') if o.strip()] + + A = _read_per_sample(base, metric, ord_a) + B = _read_per_sample(base, metric, ord_b) + + # diff top-k across intersection + diffs: List[Tuple[float, int, float, float]] = [] # (absdiff, i, a, b) + for i in sorted(set(A.keys()) & set(B.keys())): + sa = _score_val(A[i]) + sb = _score_val(B[i]) + if sa is None or sb is None: + continue + diffs.append((abs(sa - sb), i, sa, sb)) + diffs.sort(key=lambda x: (-x[0], x[1])) + diff_topk = [] + for d, i, sa, sb in diffs[: max(0, args.k)]: + meta = A.get(i, {}).get('meta') or B.get(i, {}).get('meta') + diff_topk.append({ + 'i': i, + 'score_a': sa, + 'score_b': sb, + 'abs_diff': d, + 'order_a': ord_a, + 'order_b': ord_b, + 'meta': meta, + }) + + # overall top/bottom: mean across provided orders + by_order: Dict[str, Dict[int, dict]] = {o: _read_per_sample(base, metric, o) for o in over_orders} + # Gather all indices that appear in at least one order + all_idx = set() + for o, M in by_order.items(): + all_idx |= set(M.keys()) + overall_list: List[Tuple[float, int, Dict[str, Optional[float]]]] = [] # (mean, i, per_order) + for i in sorted(all_idx): + per: Dict[str, Optional[float]] = {} + vals: List[float] = [] + for o, M in by_order.items(): + v = _score_val(M.get(i, {})) if i in M else None + per[o] = (None if v is None else float(v)) + if isinstance(v, (int, float)): + vals.append(float(v)) + if not vals: + continue + mean = sum(vals) / len(vals) + overall_list.append((mean, i, per)) + overall_list.sort(key=lambda x: (-x[0], x[1])) + overall_topk = [] + for mean, i, per in overall_list[: max(0, args.k)]: + # Prefer meta from first available order + meta = None + for o in over_orders: + meta = by_order.get(o, {}).get(i, {}).get('meta') + if meta: + break + overall_topk.append({'i': i, 'mean': float(mean), 'per_order': per, 'meta': meta}) + overall_list.sort(key=lambda x: (x[0], x[1])) + overall_bottomk = [] + for mean, i, per in overall_list[: max(0, args.k)]: + meta = None + for o in over_orders: + meta = by_order.get(o, {}).get(i, {}).get('meta') + if meta: + break + overall_bottomk.append({'i': i, 'mean': float(mean), 'per_order': per, 'meta': meta}) + + out = { + 'run_id': str(base), + 'metric': metric, + 'orders_a': ord_a, + 'orders_b': ord_b, + 'overall_orders': over_orders, + 'k': int(args.k), + 'diff_topk': diff_topk, + 'overall_topk': overall_topk, + 'overall_bottomk': overall_bottomk, + } + + if args.write: + out_path = Path(args.write) + else: + out_path = base / 'per_sample' / f'select_extremes_{metric}.json' + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f"[select_extremes] wrote {out_path}") + + +if __name__ == '__main__': + main() + diff --git a/ICL/LV/code/core/eval/select_overall_extremes.py b/ICL/LV/code/core/eval/select_overall_extremes.py new file mode 100644 index 0000000000000000000000000000000000000000..24a8e84dbee618a61340d42d38e9d9484a8cedc3 --- /dev/null +++ b/ICL/LV/code/core/eval/select_overall_extremes.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +""" +Select extremes across ALL metrics using per-sample dumps. + +Assumptions/Notes: +- Per-sample files live under /per_sample//.jsonl +- Numeric per-sample values are taken from 'score' when present (0..100); + for classification/reasoning accuracy/F1 dumps we map 'correct': True->100, False->0. +- 'Total score' at sample-level across all 8 metrics is not strictly definable + (different categories use different datasets). Here we treat each per-metric + sample independently, and compute extremes across the union of all metrics. + +Outputs (JSON): +- diff_topk: top-K entries with largest |score(image-text) - score(text-image)| + Each entry is keyed by (metric, i) and includes both scores + meta. +- order_topk: per-order top-K entries by score across all metrics. +- order_bottomk: per-order bottom-K entries by score across all metrics. + +Usage: + python -m core.eval.select_overall_extremes \ + --output-base runs/order_qwen-vl \ + --k 5 \ + --write runs/order_qwen-vl/per_sample/extremes_overall.json +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Dict, List, Optional, Tuple + + +DEFAULT_METRICS = [ + 'vqa_tokenf1', + 'vqa_bertscore', + 'captioning_bertscore', + 'captioning_cider', + 'classification_accuracy', + 'classification_f1', + 'reasoning_accuracy', + 'reasoning_ras', +] +ORDERS = ('image-text', 'text-image', 'text-image-text') + + +def _read_per_sample(base: Path, metric: str, order: str) -> Dict[int, dict]: + p = base / 'per_sample' / metric / f'{order}.jsonl' + out: Dict[int, dict] = {} + if not p.exists(): + return out + for line in p.read_text(encoding='utf-8').splitlines(): + if not line.strip(): + continue + try: + obj = json.loads(line) + except Exception: + continue + i = obj.get('i') + if isinstance(i, int): + out[i] = obj + return out + + +def _score_of(obj: dict) -> Optional[float]: + v = obj.get('score') + if isinstance(v, (int, float)): + return float(v) + c = obj.get('correct') + if isinstance(c, bool): + return 100.0 if c else 0.0 + return None + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument('--output-base', required=True) + ap.add_argument('--metrics', nargs='*', default=DEFAULT_METRICS) + ap.add_argument('--orders', default=','.join(ORDERS)) + ap.add_argument('--k', type=int, default=5) + ap.add_argument('--write', default='') + args = ap.parse_args() + + base = Path(args.output_base) + orders = [o.strip() for o in args.orders.split(',') if o.strip()] + metrics = [m.strip() for m in args.metrics if m.strip()] + + # Load per-sample maps per metric/order + data: Dict[str, Dict[str, Dict[int, dict]]] = {} + for m in metrics: + om: Dict[str, Dict[int, dict]] = {} + for o in orders: + om[o] = _read_per_sample(base, m, o) + data[m] = om + + # Largest diff between image-text and text-image + a, b = 'image-text', 'text-image' + diffs: List[dict] = [] + for m in metrics: + A = data[m].get(a, {}) + B = data[m].get(b, {}) + common = set(A.keys()) & set(B.keys()) + for i in common: + sa = _score_of(A[i]) + sb = _score_of(B[i]) + if sa is None or sb is None: + continue + diffs.append({ + 'metric': m, + 'i': i, + 'score_a': float(sa), + 'score_b': float(sb), + 'abs_diff': abs(float(sa) - float(sb)), + 'order_a': a, + 'order_b': b, + 'meta': (A[i].get('meta') or B[i].get('meta') or None), + }) + diffs.sort(key=lambda d: (-d['abs_diff'], d['metric'], d['i'])) + diff_topk = diffs[: max(0, args.k)] + + # Per-order overall top/bottom across all metrics + order_topk: Dict[str, List[dict]] = {} + order_bottomk: Dict[str, List[dict]] = {} + for o in orders: + entries: List[dict] = [] + for m in metrics: + M = data[m].get(o, {}) + for i, obj in M.items(): + sv = _score_of(obj) + if not isinstance(sv, (int, float)): + continue + entries.append({ + 'metric': m, + 'i': i, + 'order': o, + 'score': float(sv), + 'meta': obj.get('meta') or None, + }) + entries.sort(key=lambda e: (-e['score'], e['metric'], e['i'])) + order_topk[o] = entries[: max(0, args.k)] + entries.sort(key=lambda e: (e['score'], e['metric'], e['i'])) + order_bottomk[o] = entries[: max(0, args.k)] + + out = { + 'run_id': str(base), + 'k': int(args.k), + 'orders': orders, + 'metrics': metrics, + 'diff_topk': diff_topk, + 'order_topk': order_topk, + 'order_bottomk': order_bottomk, + 'note': 'Per-sample union across metrics; scores in [0,100]; classification/reasoning use correct->100/0.', + } + + if args.write: + out_path = Path(args.write) + else: + out_path = base / 'per_sample' / 'extremes_overall.json' + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f"[select_overall_extremes] wrote {out_path}") + + +if __name__ == '__main__': + main() + diff --git a/ICL/LV/code/core/eval/select_taskpair_extremes.py b/ICL/LV/code/core/eval/select_taskpair_extremes.py new file mode 100644 index 0000000000000000000000000000000000000000..1025fd89cd62b43d9651528938976e418fc500db --- /dev/null +++ b/ICL/LV/code/core/eval/select_taskpair_extremes.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +""" +Per-task 2-metric totals (0..200) extremes. + +For each task we pair two metrics evaluated on the SAME samples, join per-sample +by meta.image_path (fallback to task+i), and compute: + total_0_200 = clamp01to100(m1) + clamp01to100(m2) + +Tasks and metric pairs: + - vqa: vqa_tokenf1 + vqa_bertscore + - captioning: captioning_bertscore + captioning_cider + - classification:classification_accuracy + classification_f1 + - reasoning: reasoning_accuracy + reasoning_ras + +We then select for each task: + - order_topk: per-order Top-K totals + - order_bottomk: per-order Bottom-K totals + - diff_topk: keys with largest |total(image-text) - total(text-image)| + +Filtering/normalization: + - Ignore rows whose score is missing (None). + - --ignore-bert-zero: for *bertscore metrics only, treat score==0.0 as missing. + - Auto-scale safeguard: if we detect a metric appears to be 0..1 (max<=1.05), + scale by 100 before clamping to [0,100]. + - Clamp every metric to [0,100] before summation, ensuring total is 0..200. + +Usage: + python -m core.eval.select_taskpair_extremes \ + --output-base runs/order_qwen-vl \ + --k 5 \ + --ignore-bert-zero + +Output: + /per_sample/extremes_taskpair.json +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Dict, List, Optional, Tuple + + +ORDERS = ("image-text", "text-image", "text-image-text") + +TASK_PAIRS = { + 'vqa': ('vqa_tokenf1', 'vqa_bertscore'), + 'captioning': ('captioning_bertscore', 'captioning_cider'), + 'classification': ('classification_accuracy', 'classification_f1'), + 'reasoning': ('reasoning_accuracy', 'reasoning_ras'), +} + + +def _read_rows(base: Path, metric: str, order: str) -> List[dict]: + p = base / 'per_sample' / metric / f'{order}.jsonl' + if not p.exists(): + return [] + out: List[dict] = [] + for line in p.read_text(encoding='utf-8').splitlines(): + if not line.strip(): + continue + try: + out.append(json.loads(line)) + except Exception: + pass + return out + + +def _key_of(obj: dict) -> str: + """Build a stable key to join across metrics within the same task. + Prefer the suffix under '_image_cache/' if present, since different metrics + write caches under their own /_image_cache but share the same suffix. + Fallback to full image_path, then to (task,i). + """ + m = obj.get('meta') or {} + img = m.get('image_path') if isinstance(m, dict) else None + def _norm_path(p: str) -> str: + p = p.replace('\\', '/') + tag = '_image_cache/' + idx = p.rfind(tag) + if idx >= 0: + return p[idx+len(tag):] + # fallback: just basename + try: + from pathlib import Path as _P + return _P(p).name + except Exception: + return p + if isinstance(img, str) and img: + return f'img::{_norm_path(img)}' + task = m.get('task') if isinstance(m, dict) else None + i = obj.get('i') + return f"task::{task or 'NA'}::i={i}" + + +def _score_of(obj: dict) -> Optional[float]: + v = obj.get('score') + if isinstance(v, (int, float)): + return float(v) + c = obj.get('correct') + if isinstance(c, bool): + return 100.0 if c else 0.0 + return None + + +def _metric_is_bertscore(name: str) -> bool: + return name.endswith('bertscore') + + +def _auto_scale_if_needed(vals: List[float]) -> float: + # Return scaling factor to bring 0..1 scale up to 0..100 when detected + try: + mx = max(vals) if vals else 0.0 + return 100.0 if mx <= 1.05 else 1.0 + except Exception: + return 1.0 + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument('--output-base', required=True) + ap.add_argument('--orders', default=','.join(ORDERS)) + ap.add_argument('--k', type=int, default=5) + ap.add_argument('--write', default='') + ap.add_argument('--ignore-bert-zero', action='store_true') + ap.add_argument('--min-per-metric', type=float, default=-1.0, help='If >=0, drop pairs where score1<=thr or score2<=thr (post-scale/clamp). Use 0 to drop any per-metric zero.') + ap.add_argument('--min-total', type=float, default=0.0, help='Filter out joined samples whose total_0_200 <= this value (e.g., 0.0 to drop zero totals).') + args = ap.parse_args() + + base = Path(args.output_base) + orders = [o.strip() for o in args.orders.split(',') if o.strip()] + + results = {} + stats = {} + + for task, (m1, m2) in TASK_PAIRS.items(): + # Build per-order key->(score1, score2, meta) + per_order: Dict[str, Dict[str, dict]] = {o: {} for o in orders} + # Pre-read for auto-scale detection + raw_by_order_metric: Dict[Tuple[str, str], List[float]] = {} + for o in orders: + vals1, vals2 = [], [] + for obj in _read_rows(base, m1, o): + s = _score_of(obj) + if s is not None: + vals1.append(float(s)) + for obj in _read_rows(base, m2, o): + s = _score_of(obj) + if s is not None: + vals2.append(float(s)) + raw_by_order_metric[(o, m1)] = vals1 + raw_by_order_metric[(o, m2)] = vals2 + + # Scaling factors per (order,metric) + scales: Dict[Tuple[str, str], float] = {} + for o in orders: + scales[(o, m1)] = _auto_scale_if_needed(raw_by_order_metric[(o, m1)]) + scales[(o, m2)] = _auto_scale_if_needed(raw_by_order_metric[(o, m2)]) + + # Build joined map + task_counts = {o: {'rows_m1': 0, 'rows_m2': 0, 'joined': 0} for o in orders} + for o in orders: + map1: Dict[str, Tuple[float, dict]] = {} + for obj in _read_rows(base, m1, o): + s = _score_of(obj) + if not isinstance(s, (int, float)): + continue + # Filter BERTScore zeros if requested + if args.ignore_bert_zero and _metric_is_bertscore(m1) and float(s) == 0.0: + continue + s = float(s) * scales[(o, m1)] + # clamp per-metric to [0,100] + if s < 0.0: s = 0.0 + if s > 100.0: s = 100.0 + k = _key_of(obj) + map1[k] = (s, obj.get('meta') or {}) + task_counts[o]['rows_m1'] += 1 + + map2: Dict[str, Tuple[float, dict]] = {} + for obj in _read_rows(base, m2, o): + s = _score_of(obj) + if not isinstance(s, (int, float)): + continue + # Filter BERTScore zeros if requested + if args.ignore_bert_zero and _metric_is_bertscore(m2) and float(s) == 0.0: + continue + s = float(s) * scales[(o, m2)] + if s < 0.0: s = 0.0 + if s > 100.0: s = 100.0 + k = _key_of(obj) + map2[k] = (s, obj.get('meta') or {}) + task_counts[o]['rows_m2'] += 1 + + joined = {} + common = set(map1.keys()) & set(map2.keys()) + for k in common: + s1, mta = map1[k] + s2, mtb = map2[k] + # Optional per-metric thresholding + if float(args.min_per_metric) >= 0.0: + thr = float(args.min_per_metric) + if (s1 <= thr) or (s2 <= thr): + continue + tot = float(s1 + s2) + if tot <= float(args.min_total): + continue + joined[k] = { + 'key': k, + 'order': o, + 'metric1': m1, + 'metric2': m2, + 'score1': float(s1), + 'score2': float(s2), + 'total_0_200': tot, + 'meta': (mta or mtb or None), + } + per_order[o] = joined + task_counts[o]['joined'] = len(joined) + stats[task] = task_counts + + # Rank per-order + order_topk: Dict[str, List[dict]] = {} + order_bottomk: Dict[str, List[dict]] = {} + for o in orders: + entries = list(per_order[o].values()) + entries.sort(key=lambda e: (-e['total_0_200'], e['key'])) + order_topk[o] = entries[: max(0, args.k)] + entries.sort(key=lambda e: (e['total_0_200'], e['key'])) + order_bottomk[o] = entries[: max(0, args.k)] + + # image-text vs text-image diffs + a, b = 'image-text', 'text-image' + common = set(per_order.get(a, {}).keys()) & set(per_order.get(b, {}).keys()) + diffs: List[dict] = [] + for k in sorted(common): + ea = per_order[a][k]['total_0_200'] + eb = per_order[b][k]['total_0_200'] + meta = per_order[a][k].get('meta') or per_order[b][k].get('meta') or None + diffs.append({ + 'key': k, + 'order_a': a, + 'order_b': b, + 'sum_a': float(ea), + 'sum_b': float(eb), + 'abs_diff': abs(float(ea) - float(eb)), + 'meta': meta, + }) + diffs.sort(key=lambda d: (-d['abs_diff'], d['key'])) + diff_topk = diffs[: max(0, args.k)] + + results[task] = { + 'task': task, + 'pair': [m1, m2], + 'diff_topk': diff_topk, + 'order_topk': order_topk, + 'order_bottomk': order_bottomk, + } + + out = { + 'run_id': str(base), + 'orders': orders, + 'k': int(args.k), + 'tasks': {t: {'pair': list(TASK_PAIRS[t])} for t in TASK_PAIRS}, + 'ignore_bert_zero': bool(args.ignore_bert_zero), + 'results': results, + 'stats': stats, + 'note': 'Per-task totals = clamp([0,100]) sum of the two metrics (0..200); auto-scaled 0..1->0..100; BERTScore zeros filtered if requested.', + } + + out_path = Path(args.write) if args.write else (base / 'per_sample' / 'extremes_taskpair.json') + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f"[select_taskpair_extremes] wrote {out_path}") + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/eval/select_total_8sum.py b/ICL/LV/code/core/eval/select_total_8sum.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9cb170cf85f4f2e1ad18f8e39456fa0c8d6933 --- /dev/null +++ b/ICL/LV/code/core/eval/select_total_8sum.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 +""" +Exact 8-metric total per sample (0..800) using intersection across metrics. + +Definition: +- For each order (image-text, text-image, text-image-text), read per-sample dumps + from /per_sample//.jsonl. +- Use meta.image_path as the sample key (fallback to task+i if missing). +- Keep only samples that have valid scores in ALL requested metrics (intersection). +- Score per metric is taken from 'score' when numeric; for rows with 'correct' + boolean, map True->100/False->0. We assume per-sample dumps already scale to 0..100 + (RAS auto ×100 when needed; token-F1 now scaled to 0..100). +- BERTScore filtering: with --ignore-bert-zero, rows from *bertscore metrics with + score==0.0 are treated as missing; with --skip-bert-fallback, rows whose + 'metric' string contains 'token_f1' are treated as missing. + +Outputs: +- extremes_total_8sum.json under /per_sample/ containing: + - diff_topk_total_8sum: largest |SUM(image-text) - SUM(text-image)| keys + - order_topk_total_8sum: per-order top-K by SUM + - order_bottomk_total_8sum: per-order bottom-K by SUM + - stats: sizes of intersections per order, how many rows skipped, etc. + +Usage: + python -m core.eval.select_total_8sum \ + --output-base runs/order_qwen-vl \ + --k 5 \ + --ignore-bert-zero \ + --skip-bert-fallback +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Dict, List, Optional, Tuple + + +DEFAULT_METRICS = [ + 'vqa_tokenf1', + 'vqa_bertscore', + 'captioning_bertscore', + 'captioning_cider', + 'classification_accuracy', + 'classification_f1', + 'reasoning_accuracy', + 'reasoning_ras', +] +ORDERS = ('image-text', 'text-image', 'text-image-text') + + +def _read_rows(base: Path, metric: str, order: str) -> List[dict]: + p = base / 'per_sample' / metric / f'{order}.jsonl' + if not p.exists(): + return [] + out: List[dict] = [] + for line in p.read_text(encoding='utf-8').splitlines(): + if not line.strip(): + continue + try: + out.append(json.loads(line)) + except Exception: + pass + return out + + +def _key_of(obj: dict) -> str: + m = obj.get('meta') or {} + img = None + if isinstance(m, dict): + img = m.get('image_path') + if isinstance(img, str) and img: + return f'img::{img}' + task = m.get('task') if isinstance(m, dict) else None + i = obj.get('i') + return f"task::{task or 'NA'}::i={i}" + + +def _score_of(obj: dict) -> Optional[float]: + v = obj.get('score') + if isinstance(v, (int, float)): + return float(v) + c = obj.get('correct') + if isinstance(c, bool): + return 100.0 if c else 0.0 + return None + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument('--output-base', required=True) + ap.add_argument('--metrics', nargs='*', default=DEFAULT_METRICS) + ap.add_argument('--orders', default=','.join(ORDERS)) + ap.add_argument('--k', type=int, default=5) + ap.add_argument('--write', default='') + ap.add_argument('--ignore-bert-zero', action='store_true') + ap.add_argument('--skip-bert-fallback', action='store_true') + ap.add_argument('--clamp-100', action='store_true', help='Clamp each metric score to [0,100] before summing (enforce 0..800 total).') + args = ap.parse_args() + + base = Path(args.output_base) + orders = [o.strip() for o in args.orders.split(',') if o.strip()] + metrics = [m.strip() for m in args.metrics if m.strip()] + + # Load rows: per order -> per metric -> filtered dict[key] = (score, meta) + filtered: Dict[str, Dict[str, Dict[str, Tuple[float, dict]]]] = {o: {} for o in orders} + skipped_zero = 0 + skipped_fallback = 0 + total_rows = 0 + for o in orders: + for m in metrics: + M: Dict[str, Tuple[float, dict]] = {} + vals_tmp = [] + for obj in _read_rows(base, m, o): + total_rows += 1 + s = _score_of(obj) + if not isinstance(s, (int, float)): + continue + impl = str(obj.get('metric') or '').lower() + if args.skip_bert_fallback and m.endswith('bertscore') and 'token_f1' in impl: + skipped_fallback += 1 + continue + if args.ignore_bert_zero and m.endswith('bertscore') and float(s) == 0.0: + skipped_zero += 1 + continue + vals_tmp.append(s) + # Optional auto-scale for known metrics that might be 0..1 in old dumps + scale = 1.0 + if m == 'vqa_tokenf1': + try: + mx = max(vals_tmp) if vals_tmp else 0.0 + if mx <= 1.05: + scale = 100.0 + except Exception: + scale = 1.0 + # Reiterate to build filtered map with scaling and optional clamp + for obj in _read_rows(base, m, o): + s = _score_of(obj) + if not isinstance(s, (int, float)): + continue + impl = str(obj.get('metric') or '').lower() + if args.skip_bert_fallback and m.endswith('bertscore') and 'token_f1' in impl: + continue + if args.ignore_bert_zero and m.endswith('bertscore') and float(s) == 0.0: + continue + s = float(s) * float(scale) + if args.clamp_100: + if s < 0.0: + s = 0.0 + if s > 100.0: + s = 100.0 + key = _key_of(obj) + M[key] = (float(s), obj.get('meta') or {}) + filtered[o][m] = M + + # Intersection keys per order (keys that exist in ALL metrics for that order) + inter_keys: Dict[str, List[str]] = {} + for o in orders: + sets = [set(filtered[o][m].keys()) for m in metrics] + if not sets: + inter = set() + else: + inter = sets[0] + for st in sets[1:]: + inter &= st + inter_keys[o] = sorted(inter) + + # Build per-order totals and per-metric breakdown for intersection keys + totals_by_order: Dict[str, Dict[str, dict]] = {o: {} for o in orders} + for o in orders: + for k in inter_keys[o]: + per_metric = {m: filtered[o][m][k][0] for m in metrics} + total_sum = sum(per_metric.values()) + # choose a representative meta + meta = None + for m in metrics: + meta = filtered[o][m][k][1] or meta + if meta: + break + totals_by_order[o][k] = { + 'key': k, + 'order': o, + 'total_sum': float(total_sum), + 'per_metric': per_metric, + 'meta': meta, + } + + # Per-order top/bottom + order_topk_total_8sum: Dict[str, List[dict]] = {} + order_bottomk_total_8sum: Dict[str, List[dict]] = {} + for o in orders: + entries = list(totals_by_order[o].values()) + entries.sort(key=lambda e: (-e['total_sum'], e['key'])) + order_topk_total_8sum[o] = entries[: max(0, args.k)] + entries.sort(key=lambda e: (e['total_sum'], e['key'])) + order_bottomk_total_8sum[o] = entries[: max(0, args.k)] + + # image-text vs text-image diffs by SUM (only keys present in both orders' intersections) + a, b = 'image-text', 'text-image' + common = set(inter_keys.get(a, [])) & set(inter_keys.get(b, [])) + diffs: List[dict] = [] + for k in sorted(common): + ea = totals_by_order[a][k]['total_sum'] + eb = totals_by_order[b][k]['total_sum'] + meta = totals_by_order[a][k].get('meta') or totals_by_order[b][k].get('meta') or None + diffs.append({ + 'key': k, + 'order_a': a, + 'order_b': b, + 'sum_a': float(ea), + 'sum_b': float(eb), + 'abs_diff': abs(float(ea) - float(eb)), + 'meta': meta, + }) + diffs.sort(key=lambda d: (-d['abs_diff'], d['key'])) + diff_topk_total_8sum = diffs[: max(0, args.k)] + + out = { + 'run_id': str(base), + 'k': int(args.k), + 'orders': orders, + 'metrics': metrics, + 'ignore_bert_zero': bool(args.ignore_bert_zero), + 'skip_bert_fallback': bool(args.skip_bert_fallback), + 'stats': { + 'total_rows_seen': total_rows, + 'skipped_bert_zero': skipped_zero, + 'skipped_fallback': skipped_fallback, + 'intersection_sizes': {o: len(inter_keys[o]) for o in orders}, + }, + 'diff_topk_total_8sum': diff_topk_total_8sum, + 'order_topk_total_8sum': order_topk_total_8sum, + 'order_bottomk_total_8sum': order_bottomk_total_8sum, + 'note': 'Exact 8-metric total per key (0..800) using intersection across metrics; BERTScore zeros/fallback filtered if requested.', + } + + out_path = Path(args.write) if args.write else (base / 'per_sample' / 'extremes_total_8sum.json') + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f"[select_total_8sum] wrote {out_path}") + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/eval/select_total_extremes.py b/ICL/LV/code/core/eval/select_total_extremes.py new file mode 100644 index 0000000000000000000000000000000000000000..b3635e7165831b540207b7b1eea2ab8630eef7c7 --- /dev/null +++ b/ICL/LV/code/core/eval/select_total_extremes.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +""" +Compute per-sample 'total scores' by aggregating across metrics for each order, +then select extremes by that total. + +Why this file: +- Users often want 'overall total' per sample, but our metrics come from + different datasets. We approximate a per-sample total by grouping rows using + a stable key (prefer meta.image_path; fall back to (task,i) otherwise), + aggregating the metrics available for that key, and normalizing. + +Definitions: +- Per-sample key: meta.image_path if present (most tasks provide it). If not, + we use a fallback key: f"{task}::i={i}". +- Score for a row: use 'score' if numeric; else if 'correct' is a boolean, + map True->100, False->0; else skip the row. +- BERTScore handling: by default we IGNORE BERTScore entries whose score == 0.0 + (these are usually caused by empty refs). Optionally we can also skip + BERTScore rows that are marked as 'token_f1 (fallback)'. + +Aggregation (per key per order): +- Collect scores for any of the requested metrics that are present for that key. +- Compute mean over the present metrics, then scale to an 'equivalent total' + for the full metric set: total_equiv = mean * len(metrics). + This keeps comparability even when some metrics are missing for a key. + +Outputs (JSON): +- order_topk_total / order_bottomk_total: per-order top-K/bottom-K by total_equiv. +- diff_topk_total: top-K keys with largest |total_equiv(image-text) - total_equiv(text-image)|. + +Usage: + python -m core.eval.select_total_extremes \ + --output-base runs/order_qwen-vl \ + --k 5 \ + --write runs/order_qwen-vl/per_sample/extremes_total.json + +Notes: +- This approximation avoids any single metric (e.g., CIDEr) dominating purely + because of scale differences by first averaging over metrics present for a key + and then scaling by the metric count. +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Dict, List, Optional, Tuple + + +DEFAULT_METRICS = [ + 'vqa_tokenf1', + 'vqa_bertscore', + 'captioning_bertscore', + 'captioning_cider', + 'classification_accuracy', + 'classification_f1', + 'reasoning_accuracy', + 'reasoning_ras', +] +ORDERS = ('image-text', 'text-image', 'text-image-text') + + +def _read_per_sample(base: Path, metric: str, order: str) -> List[dict]: + p = base / 'per_sample' / metric / f'{order}.jsonl' + if not p.exists(): + return [] + out: List[dict] = [] + for line in p.read_text(encoding='utf-8').splitlines(): + if not line.strip(): + continue + try: + out.append(json.loads(line)) + except Exception: + pass + return out + + +def _score_of(obj: dict) -> Optional[float]: + v = obj.get('score') + if isinstance(v, (int, float)): + return float(v) + c = obj.get('correct') + if isinstance(c, bool): + return 100.0 if c else 0.0 + return None + + +def _key_of(obj: dict) -> str: + m = obj.get('meta') or {} + # Prefer a real image path; else fallback to (task,i) to avoid cross-dataset collisions + img = m.get('image_path') if isinstance(m, dict) else None + if isinstance(img, str) and img: + return f'img::{img}' + task = m.get('task') if isinstance(m, dict) else None + i = obj.get('i') + return f"task::{task or 'NA'}::i={i}" + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument('--output-base', required=True) + ap.add_argument('--metrics', nargs='*', default=DEFAULT_METRICS) + ap.add_argument('--orders', default=','.join(ORDERS)) + ap.add_argument('--k', type=int, default=5) + ap.add_argument('--write', default='') + ap.add_argument('--ignore-bert-zero', action='store_true', help='Skip BERTScore rows with score==0.0') + ap.add_argument('--skip-bert-fallback', action='store_true', help='Skip rows where per-sample metric implementation contains token_f1 (fallback)') + args = ap.parse_args() + + base = Path(args.output_base) + orders = [o.strip() for o in args.orders.split(',') if o.strip()] + metrics = [m.strip() for m in args.metrics if m.strip()] + + # Load rows per metric/order + rows: Dict[str, Dict[str, List[dict]]] = {} + for m in metrics: + per_order: Dict[str, List[dict]] = {} + for o in orders: + per_order[o] = _read_per_sample(base, m, o) + rows[m] = per_order + + # Build per-order aggregation map: key -> {metric: score} + per_order_scores: Dict[str, Dict[str, Dict[str, float]]] = {o: {} for o in orders} + total_row_count = 0 + skipped_bert_zero = 0 + skipped_fallback = 0 + for m in metrics: + for o in orders: + for obj in rows[m].get(o, []): + total_row_count += 1 + s = _score_of(obj) + if not isinstance(s, (int, float)): + continue + impl = str(obj.get('metric') or '').lower() + is_bert_metric = (m.endswith('bertscore')) + if args.skip_bert_fallback and is_bert_metric and 'token_f1' in impl: + skipped_fallback += 1 + continue + if args.ignore_bert_zero and is_bert_metric and float(s) == 0.0: + skipped_bert_zero += 1 + continue + key = _key_of(obj) + per_order_scores[o].setdefault(key, {})[m] = float(s) + + # Aggregate per key per order + aggregated: Dict[str, Dict[str, dict]] = {o: {} for o in orders} + for o in orders: + key2metrics = per_order_scores[o] + for key, m2s in key2metrics.items(): + present = [m2s[m] for m in metrics if m in m2s] + if not present: + continue + sum_present = float(sum(present)) + cnt = len(present) + mean = sum_present / cnt + total_equiv = mean * len(metrics) + # capture a representative meta for reference (first row we can find) + # Search any contributing metric's row to collect meta + meta = None + # try to find a meta by scanning original rows for this key + if not meta: + for m in metrics: + for obj in rows[m].get(o, []): + if _key_of(obj) == key: + meta = obj.get('meta') or None + if meta: + break + if meta: + break + aggregated[o][key] = { + 'key': key, + 'order': o, + 'num_metrics': cnt, + 'sum_present': sum_present, + 'mean_present': mean, + 'total_equiv': total_equiv, + 'per_metric': {k: float(v) for k, v in m2s.items()}, + 'meta': meta, + } + + # Compute diffs between image-text and text-image by total_equiv + a, b = 'image-text', 'text-image' + diffs: List[dict] = [] + keys = set(aggregated.get(a, {}).keys()) & set(aggregated.get(b, {}).keys()) + for k in keys: + ea = aggregated[a][k]['total_equiv'] + eb = aggregated[b][k]['total_equiv'] + diffs.append({ + 'key': k, + 'order_a': a, + 'order_b': b, + 'total_a': float(ea), + 'total_b': float(eb), + 'abs_diff': abs(float(ea) - float(eb)), + 'meta': aggregated[a][k].get('meta') or aggregated[b][k].get('meta') or None, + }) + diffs.sort(key=lambda d: (-d['abs_diff'], d['key'])) + diff_topk_total = diffs[: max(0, args.k)] + + # Per-order top/bottom by total_equiv + order_topk_total: Dict[str, List[dict]] = {} + order_bottomk_total: Dict[str, List[dict]] = {} + for o in orders: + entries = list(aggregated[o].values()) + entries.sort(key=lambda e: (-e['total_equiv'], e['key'])) + order_topk_total[o] = entries[: max(0, args.k)] + entries.sort(key=lambda e: (e['total_equiv'], e['key'])) + order_bottomk_total[o] = entries[: max(0, args.k)] + + out = { + 'run_id': str(base), + 'k': int(args.k), + 'orders': orders, + 'metrics': metrics, + 'ignore_bert_zero': bool(args.ignore_bert_zero), + 'skip_bert_fallback': bool(args.skip_bert_fallback), + 'stats': { + 'total_rows_seen': total_row_count, + 'skipped_bert_zero': skipped_bert_zero, + 'skipped_fallback': skipped_fallback, + 'keys_per_order': {o: len(aggregated[o]) for o in orders}, + }, + 'diff_topk_total': diff_topk_total, + 'order_topk_total': order_topk_total, + 'order_bottomk_total': order_bottomk_total, + 'note': 'Per key we average across available metrics then scale to len(metrics) for a comparable total; BERTScore zeros ignored if requested.', + } + + out_path = Path(args.write) if args.write else (base / 'per_sample' / 'extremes_total.json') + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f"[select_total_extremes] wrote {out_path}") + + +if __name__ == '__main__': + main() + diff --git a/ICL/LV/code/core/eval/summarize_by_order.py b/ICL/LV/code/core/eval/summarize_by_order.py new file mode 100644 index 0000000000000000000000000000000000000000..f9403829a74f2f065f75f202205bf4703e3b3daa --- /dev/null +++ b/ICL/LV/code/core/eval/summarize_by_order.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +""" +Summarize scores by modal order. + +Outputs: +- /by_order/.json: per-order metric scores (0..100 scale) and per-order avg/total. +- /overall.json: a dict containing by_order summaries (replaces the old '统账' overall). + +Notes: +- reasoning_ras is scaled to percentage (×100) when values look like 0..1 (max<=1.05) or when --ras-auto-scale is disabled we always multiply. +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Dict, List, Optional + + +DEFAULT_METRICS = [ + "vqa_tokenf1", + "vqa_bertscore", + "captioning_bertscore", + "captioning_cider", + "classification_accuracy", + "classification_f1", + "reasoning_accuracy", + "reasoning_ras", +] + + +def _read_json(path: Path) -> Optional[dict]: + try: + with path.open("r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return None + + +def _mean(xs: List[Optional[float]]) -> Optional[float]: + vals = [float(x) for x in xs if isinstance(x, (int, float))] + return (sum(vals) / len(vals)) if vals else None + + +def _scale_ras(val: Optional[float], mul: float, auto: bool) -> Optional[float]: + if not isinstance(val, (int, float)): + return None + try: + v = float(val) + if not auto: + return v * float(mul) + # auto-scale only when value looks like 0..1 + return (v * float(mul)) if v <= 1.05 else v + except Exception: + return None + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--output-base", required=True) + ap.add_argument("--orders", default="image-text,text-image,text-image-text") + ap.add_argument("--metrics", nargs="*", default=DEFAULT_METRICS) + ap.add_argument("--ras-mul", type=float, default=100.0) + ap.add_argument("--ras-auto-scale", action="store_true") + ap.add_argument("--write-filename", default="overall.json") + # Optional meta to carry over + ap.add_argument("--adapter", default="") + ap.add_argument("--model-path", default="") + ap.add_argument("--k-shots", type=int, default=-1) + ap.add_argument("--split", default="") + args = ap.parse_args() + + base = Path(args.output_base) + orders = [o.strip() for o in args.orders.split(",") if o.strip()] + + # Read all metric summaries + metric_summaries: Dict[str, dict] = {} + for m in args.metrics: + summ = _read_json(base / m / "summary.json") or {} + metric_summaries[m] = summ + + # Build per-order reports + by_order: Dict[str, dict] = {} + for o in orders: + mvals: Dict[str, Optional[float]] = {} + for m in args.metrics: + v = metric_summaries.get(m, {}).get(o) + if m == "reasoning_ras": + v = _scale_ras(v, mul=float(args.ras_mul), auto=bool(args.ras_auto_scale)) + v = (None if v is None else float(v)) + mvals[m] = v + avg = _mean(list(mvals.values())) + total = None + xs = [float(x) for x in mvals.values() if isinstance(x, (int, float))] + if xs: + total = float(sum(xs)) + by_order[o] = { + "order": o, + "metrics": mvals, + "avg": (None if avg is None else float(avg)), + "total": total, + "note": "All metrics on 0..100 scale; reasoning_ras auto×100 when needed.", + } + + # Write per-order files + out_dir = base / "by_order" + out_dir.mkdir(parents=True, exist_ok=True) + for o, rec in by_order.items(): + (out_dir / f"{o}.json").write_text(json.dumps(rec, ensure_ascii=False, indent=2), encoding="utf-8") + + # Write overall.json (by-order view) + ov = { + "by_order": by_order, + "meta": { + "adapter": (args.adapter or None), + "model_path": (args.model_path or None), + "orders": ",".join(orders), + "k_shots": (None if args.k_shots < 0 else int(args.k_shots)), + "split": (args.split or None), + }, + } + (base / args.write_filename).write_text(json.dumps(ov, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"[summarize_by_order] wrote {base/args.write_filename} and {out_dir}/*.json") + + +if __name__ == "__main__": + main() + diff --git a/ICL/LV/code/core/eval/summarize_overall.py b/ICL/LV/code/core/eval/summarize_overall.py new file mode 100644 index 0000000000000000000000000000000000000000..088041d2ca8db75f7b747bd709f010be3ad9ba45 --- /dev/null +++ b/ICL/LV/code/core/eval/summarize_overall.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +""" +Aggregate an overall score for a modal‑order run and save to a single JSON. + +Inputs: expects per-metric summaries under //summary.json +where each summary.json maps order -> numeric score (0..100 scale). + +Output: writes / (default: overall.json) with + { + "metrics": {metric_name: mean_over_orders_or_null}, + "overall_avg": mean_over_metrics_or_null, + "meta": {optional run metadata} + } + +This lets downstream tooling easily pick the top/bottom runs by overall score. +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Dict, Optional + + +DEFAULT_METRICS = [ + "vqa_tokenf1", + "vqa_bertscore", + "captioning_bertscore", + "captioning_cider", + "classification_accuracy", + "classification_f1", + "reasoning_accuracy", + "reasoning_ras", +] + + +def _read_json(path: Path) -> Optional[dict]: + try: + with path.open("r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return None + + +def _mean(xs): + vals = [float(x) for x in xs if isinstance(x, (int, float))] + return (sum(vals) / len(vals)) if vals else None + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--output-base", required=True, help="Path containing per-metric subfolders") + ap.add_argument("--write-filename", default="overall.json", help="Filename to write under output-base") + ap.add_argument("--metrics", nargs="*", default=DEFAULT_METRICS, help="Metric folder names to include") + # Optional metadata (purely recorded into the JSON) + ap.add_argument("--adapter", default="") + ap.add_argument("--model-path", default="") + ap.add_argument("--orders", default="") + ap.add_argument("--k-shots", type=int, default=-1) + ap.add_argument("--split", default="") + args = ap.parse_args() + + base = Path(args.output_base) + metrics: Dict[str, Optional[float]] = {} + for m in args.metrics: + summ = _read_json(base / m / "summary.json") + if isinstance(summ, dict) and summ: + mv = _mean([v for v in summ.values()]) + else: + mv = None + metrics[m] = (None if mv is None else float(mv)) + + overall = _mean([v for v in metrics.values() if isinstance(v, (int, float))]) + + out = { + "metrics": metrics, + "overall_avg": (None if overall is None else float(overall)), + "meta": { + "adapter": (args.adapter or None), + "model_path": (args.model_path or None), + "orders": (args.orders or None), + "k_shots": (None if args.k_shots < 0 else int(args.k_shots)), + "split": (args.split or None), + }, + } + + out_path = base / args.write_filename + out_path.write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"[summarize_overall] wrote {out_path}") + + +if __name__ == "__main__": + main() + diff --git a/ICL/LV/code/core/eval/zero_shot_vqa.py b/ICL/LV/code/core/eval/zero_shot_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..82a389ebf499d03ef48946ed25c3d933fc8225ce --- /dev/null +++ b/ICL/LV/code/core/eval/zero_shot_vqa.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +import argparse +import json +import random +from pathlib import Path +from typing import Dict, List + +from core.datasets.m3it_reader import iter_m3it_samples, load_instructions +from core.metrics.metrics import token_f1, bertscore_f1 +from core.prompting.openai_segments import openai_to_list_format + + +VQA_SUBTASKS = [ + 'vqa/vqav2','vqa/docvqa','vqa/ocr-vqa','vqa/st-vqa','vqa/text-vqa','vqa/gqa','vqa/okvqa','vqa/a-okvqa', +] + +PAPER_VQA_INTRO = 'Examine the image and answer the question by closely following the structure shown in the example provided.' + + +def distribute_quota(total: int, n: int) -> List[int]: + base = total // n; rem = total % n + return [base + (1 if i < rem else 0) for i in range(n)] + + +def load_adapter(name: str, model_path: str): + name = (name or '').lower() + if name in ('idefics2','idefics','i2'): + from adapters import idefics2_adapter as A + elif name in ('qwen-vl','qwenvl','qwen'): + from adapters import qwen_vl_adapter as A + elif name in ('qwen3-vl','qwen3vl','qwen3'): + from adapters import qwen3vl_adapter as A + elif name in ('gemma3','gemma-3','gemma'): + from adapters import gemma3_adapter as A + else: + raise ValueError(f'Unknown adapter: {name}') + return A.create(model_path) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--adapter', required=True, help='idefics2 | qwen-vl | qwen3-vl | gemma3') + ap.add_argument('--model-path', required=True) + ap.add_argument('--dataset-root', required=True) + ap.add_argument('--output-dir', default='runs/unified_zero_shot_vqa') + ap.add_argument('--total-samples', type=int, default=500) + ap.add_argument('--temperature', type=float, default=0.2) + ap.add_argument('--top-p', type=float, default=1.0) + ap.add_argument('--max-new-tokens', type=int, default=32) + ap.add_argument('--use-paper-instruction', action='store_true') + ap.add_argument('--no-instruction', action='store_true') + ap.add_argument('--split', type=str, default='test') + ap.add_argument('--auto-detect', action='store_true') + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--sequential', action='store_true') + ap.add_argument('--bertscore-model', type=str, default='roberta-large') + ap.add_argument('--no-bertscore-baseline', action='store_true') + ap.add_argument('--bertscore-batch-size', type=int, default=32) + ap.add_argument('--bertscore-lang', type=str, default='', help="Language code for BERTScore baseline rescaling (e.g., 'en', 'zh')") + ap.add_argument('--instruction-image', type=str, default=None) + ap.add_argument('--dump-first', type=int, default=0) + args = ap.parse_args() + + out_dir = Path(args.output_dir); out_dir.mkdir(parents=True, exist_ok=True) + cache_dir = out_dir / '_image_cache' + + adapter = load_adapter(args.adapter, args.model_path) + + tasks = VQA_SUBTASKS + if args.auto_detect: + from core.datasets.m3it_reader import _candidate_split_files as _cand + root = Path(args.dataset_root) + present = [t for t in VQA_SUBTASKS if any(p.exists() for p in _cand(root, t, args.split))] + if present: tasks = present + + per_task = distribute_quota(args.total_samples, len(tasks)) + + all_preds: List[str] = []; all_refs: List[List[str]] = []; details: List[Dict] = [] + rng = random.Random(args.seed) + for subdir, quota in zip(tasks, per_task): + if args.no_instruction: + inst = '' + else: + insts = load_instructions(Path(args.dataset_root), subdir) + ds = '' + if isinstance(insts, list) and insts: + ds = '\n'.join([s for s in insts if isinstance(s, str) and s.strip()]) + base = 'Given the image, answer the question.' + inst = (PAPER_VQA_INTRO + (('\n' + ds) if ds else '')) if args.use_paper_instruction else (ds or base) + + pool = [s for s in iter_m3it_samples(args.dataset_root, subdir, split=args.split, cache_dir=str(cache_dir))] + if not pool: continue + take = min(quota, len(pool)) + idxs = list(range(len(pool))) if args.sequential else rng.sample(range(len(pool)), k=take) + idxs = idxs[:take] + for i in idxs: + smp = pool[i] + oa = [] + if inst: oa.append({'type':'text','text':inst}) + if args.instruction_image: oa.append({'type':'image_url','image_url':args.instruction_image}) + oa.append({'type':'image_url','image_url':smp.image_path}) + oa.append({'type':'text','text':f"[REQUEST]\n{(smp.text or '').strip()}\n[RESPONSE]"}) + segs = openai_to_list_format(oa, cache_dir=cache_dir / '_oa_cache') + + pred = adapter.generate_from_segments( + segs, + temperature=args.temperature, + top_p=args.top_p, + max_new_tokens=args.max_new_tokens, + ) + + if args.dump_first and len(details) < args.dump_first: + (out_dir / f'debug_{subdir.replace("/","_")}_{len(details):04d}_openai.json').write_text(json.dumps(oa, ensure_ascii=False, indent=2), encoding='utf-8') + (out_dir / f'debug_{subdir.replace("/","_")}_{len(details):04d}_list.json').write_text(json.dumps(segs, ensure_ascii=False, indent=2), encoding='utf-8') + + all_preds.append(pred) + all_refs.append(smp.answers or []) + details.append({'task': subdir, 'image_path': smp.image_path, 'text': smp.text, 'answers': smp.answers}) + print(f"{subdir}: {len(idxs)} samples") + + tf1 = token_f1(all_preds, all_refs) + try: + bsf1 = bertscore_f1(all_preds, all_refs, model_type=args.bertscore_model, + rescale_with_baseline=not args.no_bertscore_baseline, + batch_size=args.bertscore_batch_size, + lang=(args.bertscore_lang or None)) + except Exception: + bsf1 = None + out = {'adapter': args.adapter, 'model_path': args.model_path, 'total': len(all_preds), 'metrics': {'token_f1': tf1, 'bertscore_f1': bsf1}, + 'predictions': [{'pred': p, 'answers': r, 'meta': m} for p, r, m in zip(all_preds, all_refs, details)]} + (out_dir / 'vqa_zero_shot.json').write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding='utf-8') + print(f'Token-F1={tf1:.2f}', 'BERTScore-F1=' + (f'{bsf1:.2f}' if bsf1 is not None else 'NA')) + + +if __name__ == '__main__': + main() diff --git a/ICL/LV/code/core/metrics/__init__.py b/ICL/LV/code/core/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ICL/LV/code/core/metrics/roscoe_shim.py b/ICL/LV/code/core/metrics/roscoe_shim.py new file mode 100644 index 0000000000000000000000000000000000000000..68e16a1823c54d925fde5b924c98a62d16407b07 --- /dev/null +++ b/ICL/LV/code/core/metrics/roscoe_shim.py @@ -0,0 +1,278 @@ +"""Thin shim to call ROSCOE scoring code without installing it as a package. + +Usage (from eval_order_reasoning_ras.py via --roscoe-module core.metrics.roscoe_shim --roscoe-func evaluate): + evaluate(preds: List[str], refs: List[str], model_path: Optional[str]) -> float + +Behavior: + - Tries to import ROSCOE from a filesystem path given by env ROSCOE_PY_PATH + (and also ROSCOE_PY_PATH/src), so you can point at a git checkout like + /z_data/syxin/roscoe/roscoe or /z_data/syxin/roscoe without pip installing. + - Supports both legacy layout (roscoe/score.py) and newer layout (roscoe/roscoe.py). + - Builds minimal Chain objects (one step per text) and computes the average + 'reasoning_alignment' over all pairs using ROSCOE's Evaluator. + - model_path should be a SIMCSE-compatible model id or local directory, e.g., + '/z_data/pretrained/syxin/roscoe-512-roberta-base' or 'facebook/roscoe-512-roberta-base'. +""" + +from __future__ import annotations + +from typing import List, Optional + + +def _ensure_roscoe_on_path(): + import os, sys + rp = os.environ.get('ROSCOE_PY_PATH') or '' + # Try to infer repo root automatically from this file location: + # /code/core/metrics/roscoe_shim.py -> + here = os.path.abspath(os.path.dirname(__file__)) + code_dir = os.path.dirname(os.path.dirname(here)) # .../code + repo_root = os.path.dirname(code_dir) # repo root + # Optional external override (kept for backward compatibility) + code_root = os.environ.get('CODE_ROOT') or repo_root + + # Build a list of plausible roots for the roscoe code + cand: list[str] = [] + def _add(p: str): + try: + if p and os.path.isdir(p) and (p not in sys.path): + sys.path.insert(0, p) + cand.append(p) + except Exception: + pass + + roots: list[str] = [] + if rp: + roots.extend([ + rp, + os.path.join(rp, 'src'), + os.path.join(rp, 'roscoe'), + os.path.join(rp, 'src', 'roscoe'), + os.path.join(rp, 'projects'), + os.path.join(rp, 'projects', 'roscoe'), + ]) + + # Common defaults if env not set or wrong + defaults = [ + # Typical locations under a shared /z_data tree + '/z_data/syxin/roscoe/roscoe', + '/z_data/syxin/roscoe', + '/z_data/syxin/roscoe/projects/roscoe', + '/z_data/syxin/roscoe/projects', + # Auto-detected from this repo checkout + os.path.join(repo_root, 'roscoe'), + os.path.join(repo_root, 'roscoe', 'roscoe'), + # Compatibility with legacy CODE_ROOT env + os.path.join(code_root, 'roscoe', 'roscoe'), + os.path.join(code_root, 'roscoe'), + ] + + # Allow comma-separated override list via ROSCOE_ROOTS + extra = (os.environ.get('ROSCOE_ROOTS') or '').strip() + if extra: + for part in extra.split(','): + p = part.strip() + if p and p not in roots and p not in defaults: + defaults.append(p) + + for d in defaults: + if d not in roots: + roots.append(d) + for r in roots: + _add(r) + if os.environ.get('ROSCOE_SHIM_DEBUG'): + print('[roscoe_shim] candidate roots:', roots) + print('[roscoe_shim] sys.path head after insert:', sys.path[:5]) + # Return the full list of candidate roots (not only newly-added), + # so downstream file-based probing can still work even if paths were + # pre-inserted into sys.path by the caller. + return roots + + +def _get_attr(mod, names): + """Return the first present attribute from names on module mod, else None.""" + for n in names: + try: + if hasattr(mod, n): + return getattr(mod, n) + except Exception: + pass + return None + + +def _load_roscoe_symbols(): + """Load required symbols from roscoe.{score|roscoe}. + + Returns (Evaluator, Chain, CHAIN_ALIGNMENT, SIM_MODEL) or (None, ...) + """ + # First, try package imports directly + try: + import importlib + mod = importlib.import_module('roscoe.score') + Ev = _get_attr(mod, ['Evaluator']) + Ch = _get_attr(mod, ['Chain']) + CHAIN = _get_attr(mod, ['CHAIN_ALIGNMENT', 'REASONING_ALIGNMENT']) + SIM = _get_attr(mod, ['SIMSCE', 'SIMCSE']) + if Ev and Ch and CHAIN and SIM: + return Ev, Ch, CHAIN, SIM + except Exception: + pass + + try: + import importlib + mod = importlib.import_module('roscoe.roscoe') # newer layout + Ev = _get_attr(mod, ['Evaluator']) + Ch = _get_attr(mod, ['Chain']) + CHAIN = _get_attr(mod, ['CHAIN_ALIGNMENT', 'REASONING_ALIGNMENT']) + SIM = _get_attr(mod, ['SIMSCE', 'SIMCSE']) + if Ev and Ch and CHAIN and SIM: + return Ev, Ch, CHAIN, SIM + except Exception: + pass + + # If imports fail, try file-based loading under projects.roscoe.* + import importlib.util, sys, types, os + roots = _ensure_roscoe_on_path() + # Also consider sys.path entries that likely contain roscoe sources + try: + import sys, os as _os + extra_roots = [] + for p in list(sys.path): + try: + if not p or not _os.path.isdir(p): + continue + bn = _os.path.basename(p.rstrip(_os.sep)).lower() + if ('roscoe' in bn) or _os.path.isdir(_os.path.join(p, 'roscoe')): + if p not in roots: + extra_roots.append(p) + except Exception: + pass + roots = list(roots) + extra_roots + except Exception: + pass + + # Candidates: (filename, module_name) + candidates = [ + ('score.py', 'projects.roscoe.score'), + ('roscoe.py', 'projects.roscoe.roscoe'), # newer + ] + + sel_mod = None + # utils is expected alongside the chosen module + for r in roots: + base = r + # Allow both r == .../roscoe and r == repo root + for fname, mname in candidates: + if os.path.basename(base) == 'roscoe': + mod_path = os.path.join(base, fname) + utils_path = os.path.join(base, 'utils.py') + else: + mod_path = os.path.join(base, 'roscoe', fname) + utils_path = os.path.join(base, 'roscoe', 'utils.py') + if os.path.isfile(mod_path) and os.path.isfile(utils_path): + # Prepare projects.roscoe.* namespace + if 'projects' not in sys.modules: + sys.modules['projects'] = types.ModuleType('projects') + if 'projects.roscoe' not in sys.modules: + sys.modules['projects.roscoe'] = types.ModuleType('projects.roscoe') + # Load utils first + spec_u = importlib.util.spec_from_file_location('projects.roscoe.utils', utils_path) + if spec_u and spec_u.loader: + mod_u = importlib.util.module_from_spec(spec_u) # type: ignore + sys.modules['projects.roscoe.utils'] = mod_u + spec_u.loader.exec_module(mod_u) # type: ignore + # Load the module (score.py or roscoe.py) + spec_s = importlib.util.spec_from_file_location(mname, mod_path) + if spec_s and spec_s.loader: + sel_mod = importlib.util.module_from_spec(spec_s) # type: ignore + sys.modules[mname] = sel_mod + spec_s.loader.exec_module(sel_mod) # type: ignore + break + if sel_mod: + break + + if sel_mod is None: + if __import__('os').environ.get('ROSCOE_SHIM_DEBUG'): + try: + import glob + probe = [] + for r in (roots or []): + probe.extend(glob.glob(os.path.join(r, '**', 'score.py'), recursive=True)) + probe.extend(glob.glob(os.path.join(r, '**', 'roscoe.py'), recursive=True)) + print('[roscoe_shim] failed to locate roscoe/{score.py|roscoe.py} and utils.py under ROSCOE_PY_PATH; probe sample:', probe[:6]) + except Exception: + pass + else: + print('[roscoe_shim] failed to locate roscoe/{score.py|roscoe.py} and utils.py under ROSCOE_PY_PATH') + return None, None, None, None + + Ev = _get_attr(sel_mod, ['Evaluator']) + Ch = _get_attr(sel_mod, ['Chain']) + CHAIN = _get_attr(sel_mod, ['CHAIN_ALIGNMENT', 'REASONING_ALIGNMENT']) + SIM = _get_attr(sel_mod, ['SIMSCE', 'SIMCSE']) + return Ev, Ch, CHAIN, SIM + + +def evaluate_list(preds: List[str], refs: List[str], model_path: Optional[str] = None) -> Optional[List[float]]: + _ensure_roscoe_on_path() + # Provide a minimal nltk stub if missing to satisfy imports in roscoe/utils.py + try: + import nltk # type: ignore + except Exception: + import types, sys + if 'nltk' not in sys.modules: + nltk_stub = types.ModuleType('nltk') + tok_stub = types.ModuleType('nltk.tokenize') + def _sent_tokenize(s: str): + # Very simple fallback: split on newlines and periods; not used by our path + import re + parts = [] + for line in (s or '').splitlines(): + parts.extend([p for p in re.split(r"(?<=[.!?])\s+", line) if p]) + return parts if parts else [(s or '').strip()] + tok_stub.sent_tokenize = _sent_tokenize # type: ignore[attr-defined] + sys.modules['nltk'] = nltk_stub + sys.modules['nltk.tokenize'] = tok_stub + + # Load symbols from roscoe (robust to both legacy/new layout) + Evaluator, Chain, CHAIN_ALIGNMENT, SIM_MODEL = _load_roscoe_symbols() + if not (Evaluator and Chain and CHAIN_ALIGNMENT and SIM_MODEL): + return None + + # Minimal chains: one step per text keeps interface simple and avoids nltk dependency + hypos = [Chain([str(p or '').strip()]) for p in preds] + references = [Chain([str(r or '').strip()]) for r in refs] + # Evaluator expects a context list; we pass empty chains + context = [Chain([]) for _ in hypos] + + # Initialize evaluator for embedding-based scores only + # SIMCSE path is recommended for roscoe-512-roberta-base + transformer_model = model_path or 'facebook/roscoe-512-roberta-base' + ev = Evaluator( + hypos=hypos, + context=context, + references=references, + score_types=[CHAIN_ALIGNMENT], + model_type=SIM_MODEL, + transformer_model=transformer_model, + # keep small batch defaults to avoid OOM on long lists + discourse_batch=32, + coherence_batch=16, + ) + try: + # Compute just reasoning_alignment (CHAIN_ALIGNMENT) + scores = ev.evaluate(score_types=[CHAIN_ALIGNMENT]) + arr = scores.get(CHAIN_ALIGNMENT) or [] + vals = [float(x) for x in arr if isinstance(x, (int, float))] + if not vals: + return None + return vals + except Exception as e: + print('[roscoe_shim] evaluation failed:', e) + return None + + +def evaluate(preds: List[str], refs: List[str], model_path: Optional[str] = None) -> float: + vals = evaluate_list(preds, refs, model_path=model_path) + if not vals: + return None # type: ignore[return-value] + return sum(vals) / len(vals) diff --git a/ICL/LV/code/core/prompting/__init__.py b/ICL/LV/code/core/prompting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ICL/LV/code/core/prompting/openai_segments.py b/ICL/LV/code/core/prompting/openai_segments.py new file mode 100644 index 0000000000000000000000000000000000000000..67c1d4200f9fc9a5a1c6b0b3bd7d28c19bcb6fdd --- /dev/null +++ b/ICL/LV/code/core/prompting/openai_segments.py @@ -0,0 +1,128 @@ +import base64 +import io +import os +from dataclasses import dataclass +from pathlib import Path +from typing import List, Dict, Optional + +from PIL import Image + + +@dataclass +class OAItem: + """OpenAI-like content item. + + Only two variants are supported here: + - {"type": "text", "text": "..."} + - {"type": "image_url", "image_url": ""} + """ + type: str + text: Optional[str] = None + image_url: Optional[str] = None + + +def _win_to_wsl_path(p: str) -> str: + """Best-effort conversion of a Windows path like C:\\foo\\bar to /mnt/c/foo/bar. + + If the input already looks like a POSIX path, it is returned unchanged. + """ + if not isinstance(p, str) or not p: + return p + if p.startswith('/') or p.startswith('~'): + return p + if len(p) >= 2 and p[1] == ':' and (p[0].isalpha()): + drive = p[0].lower() + rest = p[2:].replace('\\', '/').lstrip('/') + return f"/mnt/{drive}/{rest}" + return p + + +def _is_data_url(s: str) -> bool: + return isinstance(s, str) and s.startswith('data:image') and ';base64,' in s + + +def _is_probably_b64(s: str) -> bool: + # Heuristic: long-ish string without path separators or scheme + return isinstance(s, str) and len(s) > 100 and '://' not in s and ('\\' not in s and '/' not in s) + + +def ensure_image_path(img: str, cache_dir: Path) -> str: + """Return a local filesystem path for an image. + + - If `img` is a Windows path, convert to WSL. + - If it's a data URL or bare base64, decode and write to cache_dir. + - Otherwise, return as-is. + """ + if not isinstance(img, str) or not img.strip(): + return '' + img = img.strip() + # Convert Windows path to WSL + p = _win_to_wsl_path(img) + # Decode data URL or raw base64 + if _is_data_url(p): + head, b64 = p.split(',', 1) + ext = 'jpg' + if 'image/png' in head: + ext = 'png' + cache_dir.mkdir(parents=True, exist_ok=True) + out = cache_dir / f"instr_{abs(hash(p))}.{ext}" + if not out.exists(): + img_bytes = base64.b64decode(b64) + Image.open(io.BytesIO(img_bytes)).convert('RGB').save(out, format='JPEG', quality=90) + return str(out) + if _is_probably_b64(p): + cache_dir.mkdir(parents=True, exist_ok=True) + out = cache_dir / f"instr_{abs(hash(p))}.jpg" + if not out.exists(): + Image.open(io.BytesIO(base64.b64decode(p))).convert('RGB').save(out, format='JPEG', quality=90) + return str(out) + return p + + +def build_openai_sequence( + instruction_text: Optional[str], + instruction_image: Optional[str], + demos: List[Dict], + query_image_path: str, + query_text: str, +) -> List[Dict[str, str]]: + """Assemble the OpenAI-like flat content list following the recommended order. + + - instruction_text: optional preamble text + - instruction_image: optional image path or base64 (will be resolved by adapter) + - demos: list of {image_path, text_in, text_out} + - final query with image + text and empty RESPONSE + """ + items: List[Dict[str, str]] = [] + if instruction_text and instruction_text.strip(): + items.append({'type': 'text', 'text': instruction_text.strip()}) + if instruction_image and str(instruction_image).strip(): + items.append({'type': 'image_url', 'image_url': str(instruction_image).strip()}) + for d in demos: + items.append({'type': 'image_url', 'image_url': d['image_path']}) + items.append({'type': 'text', 'text': f"[REQUEST]\n{(d.get('text_in') or '').strip()}\n[RESPONSE]\n{(d.get('text_out') or '').strip()}"}) + items.append({'type': 'image_url', 'image_url': query_image_path}) + items.append({'type': 'text', 'text': f"[REQUEST]\n{(query_text or '').strip()}\n[RESPONSE]"}) + return items + + +def openai_to_list_format(items: List[Dict[str, str]], cache_dir: Path) -> List[Dict[str, str]]: + """Convert OpenAI-like items into the minimal list-format used by our runner. + + - image_url -> {'image': } + - text -> {'text': '...'} + """ + out: List[Dict[str, str]] = [] + for it in (items or []): + t = it.get('type') + if t == 'text': + s = (it.get('text') or '').strip() + if s: + out.append({'text': s}) + elif t == 'image_url': + raw = it.get('image_url') or '' + lp = ensure_image_path(raw, cache_dir) + if lp: + out.append({'image': lp}) + return out + diff --git a/ICL/LV/code/quick_r_acc_k1.log b/ICL/LV/code/quick_r_acc_k1.log new file mode 100644 index 0000000000000000000000000000000000000000..bd42cf11d4a10ce3141fa2ea3882642cbb5db85d --- /dev/null +++ b/ICL/LV/code/quick_r_acc_k1.log @@ -0,0 +1,8 @@ +Traceback (most recent call last): + File "", line 198, in _run_module_as_main + File "", line 88, in _run_code + File "/mnt/e/xiaobin/code/core/eval/eval_order_reasoning_accuracy.py", line 11, in + from core.eval.order_eval_core import run_predictions + File "/mnt/e/xiaobin/code/core/eval/order_eval_core.py", line 17, in + import torch +ModuleNotFoundError: No module named 'torch' diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_qwen3vl.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_qwen3vl.sh new file mode 100644 index 0000000000000000000000000000000000000000..b0e61e655c115d1771ce4675fd7bd4d0e702fd06 --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_qwen3vl.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" + +ADAPTER="qwen3-vl" +MODEL_PATH="${MODEL_PATH:-/workspace/Qwen3-VL-8B-Instruct}" +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_qwen3-vl}" + +# Swap GPU0 and GPU2: GPU0->k2, GPU2->k0 (others identity) +export GPU_K_MAP="0:2,1:1,2:0,3:3,4:4,5:5,6:6,7:7" +exec bash "$ROOT_DIR/run_all_metrics_shot_sweep_0_7.sh" \ + --adapter "$ADAPTER" \ + --model-path "$MODEL_PATH" \ + --dataset-root "$DATASET_ROOT" \ + --retriever-model "$RETRIEVER_MODEL" \ + --output-base "$OUTPUT_BASE" \ + "$@" diff --git a/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7.sh b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7.sh new file mode 100644 index 0000000000000000000000000000000000000000..5dce30685a78af324a7191b1c45204382f756c90 --- /dev/null +++ b/ICL/LV/code/run_all_metrics_shot_sweep_0_7_two_waves_gpu4_7.sh @@ -0,0 +1,194 @@ +#!/usr/bin/env bash +# Run 8 metrics (modal-order evaluation) for k-shots = 0..7 using GPUs 4..7. +# - Wave 1: shots 0..3 on GPUs 4,5,6,7 respectively (in parallel) +# - Wave 2: shots 4..7 on GPUs 4,5,6,7 respectively (after wave 1 completes) +# Metrics: vqa_tokenf1, vqa_bertscore, captioning_bertscore, captioning_cider, +# classification_accuracy, classification_f1, reasoning_accuracy, reasoning_ras. +# - Strict ROSCOE backend (no fallback) for reasoning_ras. +# - total-samples defaults to 4000. +# - Prints a final summary per k-shot with all 8 scores converted to 0..100 and their average. + +set -euo pipefail +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$ROOT_DIR" +export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}" +# Always use the current Python interpreter (preserves active conda env) +PYTHON_BIN="$(command -v python)" + +# Defaults (override by CLI or env) +ADAPTER="${ADAPTER:-qwen3-vl}" +MODEL_PATH="${MODEL_PATH:-}" # must be provided or via --model-path +DATASET_ROOT="${DATASET_ROOT:-/workspace/M3IT}" +RETRIEVER_MODEL="${RETRIEVER_MODEL:-/z_data/pretrained/syxin/bridgetower-large-itm-mlm-itc}" +# Default to a single fixed modal order to align with eval_random_k_shot_vqa.py +# (image first, then text). Override via --orders to sweep multiple orders if needed. +ORDERS="${ORDERS:-image-text}" +TOTAL_SAMPLES="${TOTAL_SAMPLES:-4000}" +SPLIT="${SPLIT:-val}" +TEMP="${TEMP:-0.6}" +TOPP="${TOPP:-1.0}" +MAX_NEW="${MAX_NEW:-128}" +RAS_BACKEND="${RAS_BACKEND:-roscoe}" +BERTSCORE_MODEL="${BERTSCORE_MODEL:-/z_data/pretrained/syxin/roberta-large}" +BERTSCORE_LANG="${BERTSCORE_LANG:-en}" +OUTPUT_BASE="${OUTPUT_BASE:-runs/shot_sweep_allmetrics_${ADAPTER}}" +REUSE="${REUSE:---reuse-cache}" + +# ROSCOE config (auto-detect repo-local path; strict by default) +export ROSCOE_MODEL_PATH="${ROSCOE_MODEL_PATH:-/z_data/pretrained/syxin/roscoe-512-roberta-base}" +if [[ -z "${ROSCOE_PY_PATH:-}" ]]; then + if [[ -d "$ROOT_DIR/../roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/../roscoe"; fi + if [[ -z "${ROSCOE_PY_PATH:-}" && -d "$ROOT_DIR/roscoe" ]]; then export ROSCOE_PY_PATH="$ROOT_DIR/roscoe"; fi +fi +ROSCOE_PATH_ARG=""; if [[ -n "${ROSCOE_PY_PATH:-}" && -d "$ROSCOE_PY_PATH" ]]; then ROSCOE_PATH_ARG="--roscoe-path \"$ROSCOE_PY_PATH\""; fi +RAS_STRICT="${RAS_STRICT:-1}"; STRICT_FLAG=""; if [[ "$RAS_STRICT" == "1" ]]; then STRICT_FLAG="--ras-strict"; fi + +# CLI overrides +while [[ $# -gt 0 ]]; do + case "$1" in + --adapter) ADAPTER="$2"; shift 2;; + --model-path) MODEL_PATH="$2"; shift 2;; + --dataset-root) DATASET_ROOT="$2"; shift 2;; + --retriever-model) RETRIEVER_MODEL="$2"; shift 2;; + --orders) ORDERS="$2"; shift 2;; + --total-samples) TOTAL_SAMPLES="$2"; shift 2;; + --split) SPLIT="$2"; shift 2;; + --temp) TEMP="$2"; shift 2;; + --top-p) TOPP="$2"; shift 2;; + --max-new) MAX_NEW="$2"; shift 2;; + --bertscore-model) BERTSCORE_MODEL="$2"; shift 2;; + --bertscore-lang) BERTSCORE_LANG="$2"; shift 2;; + --output-base) OUTPUT_BASE="$2"; shift 2;; + --ras-backend) RAS_BACKEND="$2"; shift 2;; + *) echo "Unknown arg: $1" >&2; exit 2;; + esac +done + +if [[ -z "$MODEL_PATH" ]]; then echo "[ERR] MODEL_PATH empty. Pass --model-path" >&2; exit 2; fi +mkdir -p "$OUTPUT_BASE" runs/logs + +echo "[INFO] Adapter=$ADAPTER | Model=$MODEL_PATH | Dataset=$DATASET_ROOT | Orders=$ORDERS | total-samples=$TOTAL_SAMPLES | split=$SPLIT" + +# Build the per-shot command list (identical to the 0..7 base runner) +build_cmds_for_k() { + local kshot="$1"; local outdir="$2"; + echo "[build] k-shot=$kshot -> $outdir" >&2 + cat <&2 echo "[GPU$gpu] k-shot=$kshot -> $outdir" + { + # Each command runs sequentially on the assigned GPU + while IFS= read -r cmd; do + echo "[GPU$gpu][k=$kshot] $cmd" + CUDA_VISIBLE_DEVICES="$gpu" eval "$cmd" + done <<< "$(build_cmds_for_k "$kshot" "$outdir")" + } > "runs/logs/shot_allmetrics_g${gpu}_k${kshot}.log" 2>&1 +} + +# Wave launcher for a list of (gpu,k) pairs +launch_wave() { + local -a pairs=("$@") + local pids=() + local desc="" + for pair in "${pairs[@]}"; do + local gpu="${pair%%:*}" + local k="${pair##*:}" + desc+=" (GPU${gpu}->k${k})" + run_one_gpu "$gpu" "$k" "$OUTPUT_BASE/shot${k}" & + pids+=("$!") + sleep 1 + done + echo "[INFO] Launched wave:${desc}" + local fail=0; local idx=0 + for pid in "${pids[@]}"; do + if ! wait "$pid"; then echo "[WAVE] job $idx failed (pid=$pid)" >&2; fail=$((fail+1)); else echo "[WAVE] job $idx done (pid=$pid)"; fi + idx=$((idx+1)) + done + if [[ $fail -gt 0 ]]; then echo "[WARN] Some jobs failed in this wave: $fail" >&2; fi +} + +# Two waves on GPUs 4..7 +# Wave 1: place 0..3-shot on GPUs 0..3 (each GPU runs exactly one k at a time) +# Wave 1: shots 0..3 on GPUs 4..7 +launch_wave 4:0 5:1 6:2 7:3 +# Wave 2: shots 4..7 on GPUs 4..7 +launch_wave 4:4 5:5 6:6 7:7 + +# Summarize: read each metric's summary.json, scale to 0..100 if needed, then average per shot +echo "\n[SUMMARY] All 8 metrics per k-shot (0..7), 0..100 scale with AVG" +python - "$OUTPUT_BASE" <<'PY' +import json, os, sys + +base = sys.argv[1] + +metrics = [ + ("vqa_tokenf1", 1.0), + ("vqa_bertscore", 1.0), + ("captioning_bertscore", 1.0), + # Already scaled to 0..100 in eval_order_caption_cider.py + ("captioning_cider", 1.0), + ("classification_accuracy", 1.0), + ("classification_f1", 1.0), + ("reasoning_accuracy", 1.0), + # Already scaled to 0..100 in eval_order_reasoning_ras.py when needed + ("reasoning_ras", 1.0), +] + +def readj(p): + try: + with open(p,'r',encoding='utf-8') as f: + return json.load(f) + except Exception: + return None + +rows = [] +for k in range(8): + shot_dir = os.path.join(base, f'shot{k}') + row = {"k": k, "metrics": {}, "avg": None} + vals = [] + for m, scale in metrics: + summ = readj(os.path.join(shot_dir, m, 'summary.json')) + if isinstance(summ, dict) and summ: + # average over orders + scores = [v for v in summ.values() if isinstance(v, (int, float))] + mv = (sum(scores)/len(scores)) if scores else None + else: + mv = None + # Special-case fallback: for k=0, VQA Token-F1 may come from zero_shot_vqa output + if k == 0 and m == 'vqa_tokenf1': + z = readj(os.path.join(shot_dir, 'vqa_zero_shot.json')) + try: + mv = float(z.get('metrics',{}).get('token_f1')) if isinstance(z, dict) else None + except Exception: + mv = None + if mv is not None: + mv = float(mv) * float(scale) + vals.append(mv) + row["metrics"][m] = mv + row["avg"] = (sum(vals)/len(vals)) if vals else None + rows.append(row) + +for r in rows: + k = r["k"] + parts = [] + for m, _ in [(m,s) for m,s in metrics]: + v = r["metrics"].get(m) + parts.append(f"{m}={('NA' if v is None else f'{v:.2f}')}") + avg = 'NA' if r['avg'] is None else f"{r['avg']:.2f}" + print(f" k={k} " + ", ".join(parts) + f", AVG={avg}") +PY + +echo "\n[INFO] Outputs under $OUTPUT_BASE/shot{0..7}/{metric}/" diff --git a/ICL/LV/code/select_random_samples.py b/ICL/LV/code/select_random_samples.py new file mode 100644 index 0000000000000000000000000000000000000000..82a6829a0d6ac0eb7276615854fcdb1ccabcd51c --- /dev/null +++ b/ICL/LV/code/select_random_samples.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python3 +""" +Select random examples per category (captioning, reasoning, classification, vqa) +from cached eval outputs, copy images, and write per-run summary files. + +Supports both: +- shot-sweep runs (k=0..7) → 8 files per adapter +- modal-order runs (orders image-text, text-image, text-image-text) → 3 files per adapter + +Defaults assume runs under /z_data/syxin/code/runs, but you can override via CLI. + +Example: + python select_random_samples.py \ + --adapters gemma3 idefics2 qwen3vl \ + --runs-root /z_data/syxin/code/runs \ + --per-category 3 --seed 2024 + +This will write outputs into each run directory under a subfolder `selected_samples/`. +""" + +import argparse +import json +import os +import random +import shutil +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +CATEGORIES = ["captioning", "reasoning", "classification", "vqa"] +ORDERS = ["image-text", "text-image", "text-image-text"] + +# Prefer one metric dir per category; fall back to alternates when missing. +PREFER_METRIC_DIR: Dict[str, List[str]] = { + "vqa": ["vqa_tokenf1", "vqa_bertscore"], + "captioning": ["captioning_cider", "captioning_bertscore"], + "classification": ["classification_accuracy", "classification_f1"], + "reasoning": ["reasoning_accuracy", "reasoning_ras"], +} + +# Adapter → canonical directory name variants to try +ADAPTER_DIR_NAME: Dict[str, List[str]] = { + "gemma3": ["gemma3"], + "idefics2": ["idefics2"], + # qwen3-vl can appear with or without hyphen; try both + "qwen3vl": ["qwen3-vl", "qwen3vl"], +} + + +def _find_existing_path(base: Path, candidates: List[str]) -> Optional[Path]: + for name in candidates: + p = base / name + if p.exists(): + return p + return None + + +def _pick_metric_dir(base: Path, category: str) -> Optional[Path]: + for d in PREFER_METRIC_DIR.get(category, []): + p = base / d + if p.exists(): + return p + # fallback: pick any metric dir that has a matching cache file + for sub in sorted([x for x in base.iterdir() if x.is_dir()]): + cache = sub / "_cache" + if cache.exists(): + for jf in cache.glob(f"{category}__*.jsonl"): + return sub + return None + + +def _read_jsonl(path: Path) -> List[Dict]: + out: List[Dict] = [] + try: + with path.open("r", encoding="utf-8") as f: + for line in f: + s = line.strip() + if not s: + continue + try: + out.append(json.loads(s)) + except Exception: + continue + except Exception: + return [] + return out + + +def _choose_n(items: List[int], n: int, rng: random.Random) -> List[int]: + if not items: + return [] + k = min(n, len(items)) + return rng.sample(items, k) + + +def _safe_copy(src: Path, dst: Path) -> Optional[Path]: + try: + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(str(src), str(dst)) + return dst + except Exception: + return None + + +def _pick_from_cache(cache_file: Path, n: int, rng: random.Random) -> Tuple[List[Dict], List[int]]: + rows = _read_jsonl(cache_file) + idxs = _choose_n(list(range(len(rows))), n, rng) + return ([rows[i] for i in idxs], idxs) + + +def _normalize_order(order: str) -> str: + return (order or "").strip().lower() + + +def export_shot_sweep(adapter: str, shot_base: Path, out_root: Path, per_category: int, seed: int) -> None: + rng = random.Random(seed) + # k=0..7 → 8 files + for k in range(8): + shot_dir = shot_base / f"shot{k}" + if not shot_dir.exists(): + continue + result: Dict[str, List[Dict]] = {} + copy_dir = out_root / f"shot_k{k}" / "images" + for cat in CATEGORIES: + metric_base = _pick_metric_dir(shot_dir, cat) + if metric_base is None: + continue + cache_dir = metric_base / "_cache" + # Prefer image-text; else any available order + order = None + cfile = cache_dir / f"{cat}__image-text.jsonl" + if not cfile.exists(): + cand = list(cache_dir.glob(f"{cat}__*.jsonl")) + if cand: + cfile = cand[0] + if not cfile.exists(): + continue + order = _normalize_order(cfile.stem.split("__", 1)[-1]) + picked, idxs = _pick_from_cache(cfile, per_category, rng) + out_rows: List[Dict] = [] + for i, (row, idx) in enumerate(zip(picked, idxs)): + meta = row.get("meta") or {} + img_src = Path(str(meta.get("image_path") or "")) + # Copy image + img_name = f"{cat}_{i:02d}_{img_src.name or 'img'}.jpg" + img_dst = copy_dir / cat / img_name + copied = _safe_copy(img_src, img_dst) if img_src.exists() else None + out_rows.append({ + "k_shots": k, + "category": cat, + "order": order, + "task": meta.get("task"), + "image_src": str(img_src), + "image_file": (str(copied.relative_to(out_root)) if copied and copied.exists() else None), + "prompt_text": meta.get("text"), # final request text + "inputs": meta.get("inputs"), + "model_answer": row.get("pred"), + "references": row.get("answers"), + "ref_text": row.get("ref_text"), + "cache_index": idx, + "cache_file": str(cfile), + }) + result[cat] = out_rows + # Write one file per k + out_file = out_root / f"selected_shot_k{k}.json" + summary = { + "adapter": adapter, + "type": "shot_sweep", + "shot_dir": str(shot_dir), + "result": result, + } + out_file.parent.mkdir(parents=True, exist_ok=True) + out_file.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + + +def export_modal_orders(adapter: str, order_base: Path, out_root: Path, per_category: int, seed: int) -> None: + rng = random.Random(seed) + # 3 files: one per order + for ord_name in ORDERS: + ord_key = _normalize_order(ord_name) + result: Dict[str, List[Dict]] = {} + copy_dir = out_root / f"order_{ord_key}" / "images" + for cat in CATEGORIES: + metric_base = _pick_metric_dir(order_base, cat) + if metric_base is None: + continue + cache_dir = metric_base / "_cache" + cfile = cache_dir / f"{cat}__{ord_key}.jsonl" + if not cfile.exists(): + # Skip this category for this order if no cache + continue + picked, idxs = _pick_from_cache(cfile, per_category, rng) + out_rows: List[Dict] = [] + for i, (row, idx) in enumerate(zip(picked, idxs)): + meta = row.get("meta") or {} + img_src = Path(str(meta.get("image_path") or "")) + img_name = f"{cat}_{i:02d}_{img_src.name or 'img'}.jpg" + img_dst = copy_dir / cat / img_name + copied = _safe_copy(img_src, img_dst) if img_src.exists() else None + out_rows.append({ + "order": ord_key, + "category": cat, + "task": meta.get("task"), + "image_src": str(img_src), + "image_file": (str(copied.relative_to(out_root)) if copied and copied.exists() else None), + "prompt_text": meta.get("text"), + "inputs": meta.get("inputs"), + "model_answer": row.get("pred"), + "references": row.get("answers"), + "ref_text": row.get("ref_text"), + "cache_index": idx, + "cache_file": str(cfile), + }) + result[cat] = out_rows + out_file = out_root / f"selected_order_{ord_key}.json" + summary = { + "adapter": adapter, + "type": "modal_order", + "order_dir": str(order_base), + "result": result, + } + out_file.parent.mkdir(parents=True, exist_ok=True) + out_file.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") + + +def resolve_adapter_paths(runs_root: Path, adapter_key: str) -> Tuple[Optional[Path], Optional[Path]]: + """Return (shot_base, order_base) for the adapter if found. + The function tries known variants like qwen3-vl vs qwen3vl. + shot_base points to the top shot_sweep_allmetrics_* dir. + order_base points to order_* dir. + """ + shot_base = None + order_base = None + variants = ADAPTER_DIR_NAME.get(adapter_key, [adapter_key]) + # shot_sweep + for v in variants: + # prefer exact folder name + p = runs_root / f"shot_sweep_allmetrics_{v}" + if p.exists(): + shot_base = p + break + if shot_base is None: + # try scan by prefix + for sub in runs_root.glob("shot_sweep_allmetrics_*"): + for v in variants: + if v in sub.name: + shot_base = sub; break + if shot_base is not None: + break + # order + for v in variants: + p = runs_root / f"order_{v}" + if p.exists(): + order_base = p + break + if order_base is None: + for sub in runs_root.glob("order_*"): + for v in variants: + if v in sub.name: + order_base = sub; break + if order_base is not None: + break + return shot_base, order_base + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--adapters", nargs="*", default=["gemma3", "idefics2", "qwen3vl"], help="Adapters to process") + ap.add_argument("--runs-root", default="/z_data/syxin/code/runs", help="Root folder that contains run outputs") + ap.add_argument("--per-category", type=int, default=3) + ap.add_argument("--seed", type=int, default=0) + args = ap.parse_args() + + runs_root = Path(args.runs_root) + + for adapter in args.adapters: + shot_base, order_base = resolve_adapter_paths(runs_root, adapter) + # Output root defaults inside the shot_base dir if it exists; else in runs_root + out_root_base = shot_base if shot_base is not None else (order_base if order_base is not None else runs_root) + out_root = out_root_base / f"selected_samples_{adapter}" + + print(f"[Adapter] {adapter}") + if shot_base and shot_base.exists(): + print(f" shot-sweep: {shot_base}") + export_shot_sweep(adapter, shot_base, out_root, args.per_category, args.seed) + else: + print(" shot-sweep: NOT FOUND") + if order_base and order_base.exists(): + print(f" modal-order: {order_base}") + export_modal_orders(adapter, order_base, out_root, args.per_category, args.seed) + else: + print(" modal-order: NOT FOUND") + print(f" outputs → {out_root}") + + +if __name__ == "__main__": + main() diff --git a/ICL/SFT_new/.claude/settings.local.json b/ICL/SFT_new/.claude/settings.local.json new file mode 100644 index 0000000000000000000000000000000000000000..d945a95142194b67cabfd358f85f030d349810fb --- /dev/null +++ b/ICL/SFT_new/.claude/settings.local.json @@ -0,0 +1,20 @@ +{ + "permissions": { + "allow": [ + "Bash(python3 -m json.tool)", + "Bash(find /workspace/xiaobin/ICL -type f \\\\\\(-name eval*.py -o -name inference*.py -o -name test*.py -o -name predict*.py -o -name generate*.py \\\\\\) ! -name generate_captions*.py)", + "Bash(python3:*)", + "Bash(chmod:*)", + "WebSearch", + "Bash(find /workspace/miniconda3/lib/python3.13/site-packages -name *qwen*processor* -type f)", + "Bash(find /workspace/xiaobin/dataset -name *.jpg -o -name *.png)", + "Read(//workspace/xiaobin/ICL/sft_model/**)", + "Read(//workspace/xiaobin/ICL/sft_model/epoch3_step1406_fp32/**)", + "Read(//workspace/xiaobin/ICL/sft_model/epoch3_step1406/**)", + "Bash(conda run:*)", + "Bash(ls /workspace/xiaobin/ICL/sft_model/epoch3_step1406_fp32/*_optim_states.pt)", + "Read(//workspace/xiaobin/ICL/**)", + "Read(//workspace/xiaobin/**)" + ] + } +} diff --git a/ICL/SFT_new/__pycache__/build_sft.cpython-313.pyc b/ICL/SFT_new/__pycache__/build_sft.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcc6b0143057bb6299a80e5e72228c5b6d851741 Binary files /dev/null and b/ICL/SFT_new/__pycache__/build_sft.cpython-313.pyc differ diff --git a/ICL/SFT_new/__pycache__/generate_captions.cpython-313.pyc b/ICL/SFT_new/__pycache__/generate_captions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9d7ca67e9c7e5c20eaae5918f1021bef2e9fd88 Binary files /dev/null and b/ICL/SFT_new/__pycache__/generate_captions.cpython-313.pyc differ diff --git a/ICL/SFT_new/__pycache__/generate_captions_all.cpython-313.pyc b/ICL/SFT_new/__pycache__/generate_captions_all.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3f367b92e94f0058ce9aef39e34703a8574669e Binary files /dev/null and b/ICL/SFT_new/__pycache__/generate_captions_all.cpython-313.pyc differ diff --git a/ICL/SFT_new/__pycache__/train.cpython-311.pyc b/ICL/SFT_new/__pycache__/train.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c31ba1bae27d61011d2a95e85e1702e731fd700 Binary files /dev/null and b/ICL/SFT_new/__pycache__/train.cpython-311.pyc differ