| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import asyncio |
| from typing import Generator |
|
|
| import ray |
| import torch |
| from transformers import AutoModelForCausalLM |
|
|
| from verl.checkpoint_engine import CheckpointEngineRegistry, CheckpointEngineWorker |
| from verl.single_controller.base.decorator import Dispatch, register |
| from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup |
| from verl.utils.device import get_device_name |
| from verl.utils.fs import copy_to_local |
| from verl.workers.config import CheckpointEngineConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig |
| from verl.workers.engine_workers import TrainingWorker, TrainingWorkerConfig |
| from verl.workers.rollout import BaseRollout, RolloutReplica |
|
|
|
|
| class TrainingWorkerTest(TrainingWorker): |
| def __init__(self, config: TrainingWorkerConfig, checkpoint_engine_config: CheckpointEngineConfig) -> None: |
| super().__init__(config) |
|
|
| backend = checkpoint_engine_config.backend |
| bucket_size = checkpoint_engine_config.update_weights_bucket_megabytes << 20 |
| engine_kwargs = checkpoint_engine_config.engine_kwargs.get(backend, {}) |
| if torch.distributed.get_rank() == 0: |
| engine_kwargs["is_master"] = True |
| self.checkpoint_engine = CheckpointEngineRegistry.new(backend, bucket_size=bucket_size, **engine_kwargs) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) |
| async def update_weights(self, global_steps: int = None): |
| per_tensor_param, _ = self.engine.get_per_tensor_param() |
| await self.checkpoint_engine.send_weights(per_tensor_param) |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False) |
| def execute_checkpoint_engine(self, method: str, *args, **kwargs): |
| return getattr(self.checkpoint_engine, method)(*args, **kwargs) |
|
|
|
|
| class MockServerAdapter(BaseRollout): |
| def __init__(self, config: RolloutConfig, model_config: HFModelConfig, check_allclose: bool = True): |
| super().__init__(config, model_config, device_mesh=None) |
| self.check_allclose = check_allclose |
| self.model = None |
| self.received_weights: dict[str, torch.Tensor] = {} |
|
|
| async def resume(self, tags: list[str]): |
| raise NotImplementedError() |
|
|
| async def release(self): |
| raise NotImplementedError() |
|
|
| async def update_weights( |
| self, |
| weights: Generator[tuple[str, torch.Tensor], None, None], |
| **kwargs, |
| ): |
| async for name, weight in weights: |
| weight = weight.clone() |
| if self.check_allclose: |
| self.received_weights[name] = weight.clone() |
|
|
| def check_weights(self): |
| if not self.check_allclose: |
| return |
|
|
| if self.model is None: |
| local_path = copy_to_local(self.model_config.path) |
| self.model = AutoModelForCausalLM.from_pretrained(local_path, torch_dtype=torch.bfloat16, device_map="cpu") |
|
|
| for name, weight in self.model.state_dict().items(): |
| assert name in self.received_weights, f"weight {name} not received" |
| received = self.received_weights[name] |
| assert torch.allclose(weight.to(received.device), received), f"weight {name} not equal" |
| self.received_weights.clear() |
|
|
|
|
| class MockReplica(RolloutReplica): |
| async def init_hybrid(self, worker_group: RayWorkerGroup): |
| """Init hybrid rollout server, rollout engine and training engine(fsdp/megatron) fused in same process. |
| |
| Args: |
| worker_group: RayWorkerGroup, fused workers where training engine(fsdp/megatron) have been initialized. |
| """ |
| self.workers = worker_group.workers[ |
| self.world_size * self.replica_rank : self.world_size * (self.replica_rank + 1) |
| ] |
|
|
| def get_ray_class_with_init_args(self) -> RayClassWithInitArgs: |
| """Get rollout worker actor class for colocated and standalone mode.""" |
| raise NotImplementedError |
|
|
| async def launch_servers(self): |
| """Launch http server in each node.""" |
| raise NotImplementedError |
|
|
|
|
| class CheckpointEngineWorkerTest(CheckpointEngineWorker): |
| def __init__( |
| self, rollout_config: RolloutConfig, model_config: HFModelConfig, check_allclose: bool = True, *args, **kwargs |
| ) -> None: |
| server_adapter = MockServerAdapter(rollout_config, model_config, check_allclose) |
| super().__init__(rollout_config, model_config, server_adapter, *args, **kwargs) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def check_weights(self): |
| self.server_adapter.check_weights() |
|
|
|
|
| def create_trainer_worker_group( |
| resource_pool: RayResourcePool, model_config: HFModelConfig, checkpoint_engine_config: CheckpointEngineConfig |
| ) -> RayWorkerGroup: |
| engine_config = FSDPEngineConfig(forward_only=True, fsdp_size=resource_pool.world_size, strategy="fsdp") |
| trainer_config = TrainingWorkerConfig( |
| model_type="language_model", |
| model_config=model_config, |
| engine_config=engine_config, |
| ) |
|
|
| ray_cls_with_init = RayClassWithInitArgs( |
| cls=ray.remote(TrainingWorkerTest), |
| config=trainer_config, |
| checkpoint_engine_config=checkpoint_engine_config, |
| ) |
| ray_cls_with_init.update_options( |
| { |
| "runtime_env": { |
| "env_vars": { |
| "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", |
| } |
| } |
| } |
| ) |
| wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name=get_device_name()) |
| return wg |
|
|
|
|
| async def create_rollout_worker_group( |
| resource_pool: RayResourcePool, |
| model_config: HFModelConfig, |
| rollout_config: RolloutConfig, |
| check_allclose: bool = True, |
| ) -> tuple[RayWorkerGroup, list[MockReplica]]: |
| |
| ray_cls_with_init = RayClassWithInitArgs( |
| cls=ray.remote(CheckpointEngineWorkerTest), |
| model_config=model_config, |
| rollout_config=rollout_config, |
| check_allclose=check_allclose, |
| ) |
| wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name=get_device_name()) |
|
|
| |
| rollout_world_size = ( |
| rollout_config.tensor_model_parallel_size |
| * rollout_config.data_parallel_size |
| * rollout_config.pipeline_model_parallel_size |
| ) |
| num_replicas = wg.world_size // rollout_world_size |
| replicas = [] |
| for replica_rank in range(num_replicas): |
| replica = MockReplica( |
| replica_rank=replica_rank, |
| config=rollout_config, |
| model_config=model_config, |
| ) |
| replicas.append(replica) |
| await asyncio.gather(*[replica.init_hybrid(wg) for replica in replicas]) |
|
|
| return wg, replicas |
|
|