| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import os |
| import unittest |
| from unittest.mock import AsyncMock, MagicMock, patch |
|
|
| from verl.utils.profiler.config import ( |
| ProfilerConfig, |
| TorchProfilerToolConfig, |
| build_sglang_profiler_args, |
| build_vllm_profiler_args, |
| ) |
|
|
|
|
| class TestServerProfilerArgs(unittest.TestCase): |
| def test_build_vllm_profiler_args(self): |
| |
| tool_config = TorchProfilerToolConfig(contents=["stack", "shapes", "memory"]) |
| config = ProfilerConfig(save_path="/tmp/test", tool_config=tool_config) |
|
|
| |
| with patch.dict(os.environ, {}, clear=True): |
| args = build_vllm_profiler_args(config, tool_config, rank=0) |
|
|
| |
| self.assertEqual(os.environ.get("VLLM_TORCH_PROFILER_DIR"), "/tmp/test/agent_loop_rollout_replica_0") |
| self.assertEqual(os.environ.get("VLLM_TORCH_PROFILER_WITH_STACK"), "1") |
| self.assertEqual(os.environ.get("VLLM_TORCH_PROFILER_RECORD_SHAPES"), "1") |
| self.assertEqual(os.environ.get("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY"), "1") |
|
|
| |
| self.assertIn("profiler_config", args) |
| profiler_config_dict = json.loads(args["profiler_config"]) |
| self.assertEqual(profiler_config_dict["torch_profiler_dir"], "/tmp/test/agent_loop_rollout_replica_0") |
| self.assertTrue(profiler_config_dict["torch_profiler_with_stack"]) |
| self.assertTrue(profiler_config_dict["torch_profiler_record_shapes"]) |
| self.assertTrue(profiler_config_dict["torch_profiler_with_memory"]) |
|
|
| def test_build_sglang_profiler_args(self): |
| |
| tool_config = TorchProfilerToolConfig(contents=["stack", "shapes", "memory"]) |
| config = ProfilerConfig(save_path="/tmp/test", tool_config=tool_config) |
| with self.assertWarns(UserWarning): |
| args = build_sglang_profiler_args(config, tool_config, rank=0) |
| self.assertEqual(args["output_dir"], "/tmp/test/agent_loop_rollout_replica_0") |
| self.assertTrue(args["with_stack"]) |
| self.assertTrue(args["record_shapes"]) |
|
|
|
|
| class TestServerProfilerFunctionality(unittest.IsolatedAsyncioTestCase): |
| async def test_vllm_start_stop_profile(self): |
| try: |
| |
| from verl.workers.rollout.vllm_rollout.vllm_async_server import vLLMHttpServer |
| except ImportError: |
| self.skipTest("vllm or dependencies not installed") |
| return |
|
|
| |
| mock_profiler = MagicMock() |
| mock_profiler.check_enable.return_value = True |
| mock_profiler.check_this_rank.return_value = True |
| mock_profiler.is_discrete_mode.return_value = True |
|
|
| mock_engine = AsyncMock() |
|
|
| |
| mock_self = MagicMock() |
| mock_self.profiler_controller = mock_profiler |
| mock_self.engine = mock_engine |
|
|
| |
| await vLLMHttpServer.start_profile(mock_self) |
| mock_engine.start_profile.assert_called_once() |
|
|
| |
| await vLLMHttpServer.stop_profile(mock_self) |
| mock_engine.stop_profile.assert_called_once() |
|
|
| async def test_sglang_start_stop_profile(self): |
| try: |
| |
| from verl.workers.rollout.sglang_rollout.async_sglang_server import SGLangHttpServer |
| except ImportError: |
| self.skipTest("sglang or dependencies not installed") |
| return |
|
|
| |
| mock_profiler = MagicMock() |
| mock_profiler.check_enable.return_value = True |
| mock_profiler.check_this_rank.return_value = True |
| mock_profiler.is_discrete_mode.return_value = True |
| mock_profiler.config = MagicMock() |
| mock_profiler.tool_config = MagicMock() |
|
|
| mock_tokenizer_manager = AsyncMock() |
|
|
| mock_self = MagicMock() |
| mock_self.profiler_controller = mock_profiler |
| mock_self.tokenizer_manager = mock_tokenizer_manager |
| mock_self.replica_rank = 0 |
|
|
| |
| with patch("verl.workers.rollout.sglang_rollout.async_sglang_server.build_sglang_profiler_args") as mock_build: |
| mock_args = {"arg1": "val1"} |
| mock_build.return_value = mock_args |
|
|
| |
| await SGLangHttpServer.start_profile(mock_self) |
|
|
| mock_build.assert_called_once_with(mock_profiler.config, mock_profiler.tool_config, mock_self.replica_rank) |
| mock_tokenizer_manager.start_profile.assert_called_once_with(**mock_args) |
|
|
| |
| await SGLangHttpServer.stop_profile(mock_self) |
| mock_tokenizer_manager.stop_profile.assert_called_once() |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|