# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import verl.single_controller.base.decorator as decorator_module from verl.single_controller.base.decorator import ( DISPATCH_MODE_FN_REGISTRY, Dispatch, _check_dispatch_mode, get_predefined_dispatch_fn, register_dispatch_mode, update_dispatch_mode, ) @pytest.fixture def reset_dispatch_registry(): # Store original state original_registry = DISPATCH_MODE_FN_REGISTRY.copy() yield # Reset registry after test decorator_module.DISPATCH_MODE_FN_REGISTRY.clear() decorator_module.DISPATCH_MODE_FN_REGISTRY.update(original_registry) def test_register_new_dispatch_mode(reset_dispatch_registry): # Test registration def dummy_dispatch(worker_group, *args, **kwargs): return args, kwargs def dummy_collect(worker_group, output): return output register_dispatch_mode("TEST_MODE", dummy_dispatch, dummy_collect) # Verify enum extension _check_dispatch_mode(Dispatch.TEST_MODE) # Verify registry update assert get_predefined_dispatch_fn(Dispatch.TEST_MODE) == { "dispatch_fn": dummy_dispatch, "collect_fn": dummy_collect, } # Clean up Dispatch.remove("TEST_MODE") def test_update_existing_dispatch_mode(reset_dispatch_registry): # Store original implementation original_mode = Dispatch.ONE_TO_ALL # New implementations def new_dispatch(worker_group, *args, **kwargs): return args, kwargs def new_collect(worker_group, output): return output # Test update= update_dispatch_mode(original_mode, new_dispatch, new_collect) # Verify update assert get_predefined_dispatch_fn(original_mode)["dispatch_fn"] == new_dispatch assert get_predefined_dispatch_fn(original_mode)["collect_fn"] == new_collect