| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | import asyncio
|
| | import heapq
|
| | import importlib
|
| | import logging
|
| | import os
|
| | import socket
|
| | import threading
|
| | from abc import ABC, abstractmethod
|
| | from contextlib import asynccontextmanager
|
| | from typing import Any, Callable, Dict, List, Tuple, Type
|
| | from uuid import uuid4
|
| |
|
| | import aiohttp
|
| | import fastapi
|
| | import ray
|
| | import uvicorn
|
| | from cachetools import LRUCache
|
| | from omegaconf import DictConfig
|
| | from openai import AsyncOpenAI
|
| | from openai.types.chat.chat_completion import ChatCompletion
|
| | from starlette.requests import Request
|
| |
|
| | from verl.protocol import DataProto
|
| | from verl.single_controller.ray.base import RayWorkerGroup
|
| | from verl.utils import hf_tokenizer
|
| | from verl.utils.fs import copy_to_local
|
| |
|
| | logger = logging.getLogger(__file__)
|
| |
|
| |
|
| | def _get_free_port():
|
| | with socket.socket() as sock:
|
| | sock.bind(("", 0))
|
| | return sock.getsockname()[1]
|
| |
|
| |
|
| | class AsyncServerBase(ABC):
|
| | """Base class for AsyncServer."""
|
| |
|
| | def __init__(self):
|
| | self.address = ray._private.services.get_node_ip_address()
|
| | self.port = None
|
| | self.server_ready = asyncio.Event()
|
| | asyncio.create_task(self._start_fastapi_server())
|
| |
|
| | async def _start_fastapi_server(self):
|
| | @asynccontextmanager
|
| | async def lifespan(app: fastapi.FastAPI):
|
| | print("FastAPI startup")
|
| | self.server_ready.set()
|
| | yield
|
| |
|
| |
|
| |
|
| | print("FastAPI shutdown, maybe address already in use, exit process immediately.")
|
| | os._exit(-1)
|
| |
|
| | app = fastapi.FastAPI(lifespan=lifespan)
|
| | app.router.add_api_route("/v1/chat/completions", self.chat_completion, methods=["POST"])
|
| |
|
| | self.port = _get_free_port()
|
| | config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning")
|
| | server = uvicorn.Server(config)
|
| | await server.serve()
|
| |
|
| | async def get_server_address(self) -> Tuple[str, int]:
|
| | """Get FastAPI server address."""
|
| | await self.server_ready.wait()
|
| | return f"{self.address}:{self.port}"
|
| |
|
| | @abstractmethod
|
| | async def chat_completion(self, raw_request: Request):
|
| | """OpenAI chat completion API.
|
| |
|
| | API reference: https://platform.openai.com/docs/api-reference/chat/create
|
| | """
|
| | raise NotImplementedError
|
| |
|
| | @abstractmethod
|
| | async def init_engine(self):
|
| | """Init async LLM engine."""
|
| | raise NotImplementedError
|
| |
|
| | @abstractmethod
|
| | async def wake_up(self):
|
| | """Wake up engine to load model weights and build kv cache."""
|
| | raise NotImplementedError
|
| |
|
| | @abstractmethod
|
| | async def sleep(self):
|
| | """Sleep engine to offload model weights and discard kv cache."""
|
| | raise NotImplementedError
|
| |
|
| |
|
| | class ChatCompletionScheduler:
|
| | def __init__(
|
| | self,
|
| | config: DictConfig,
|
| | model_path: str,
|
| | server_addresses: List[str],
|
| | max_cache_size: int = 10000,
|
| | ):
|
| | """
|
| | Args:
|
| | config: DictConfig, rollout config.
|
| | model_path: str, model path.
|
| | server_addresses: List[str], server addresses.
|
| | max_cache_size: int, max cache size of request_id to address mapping.
|
| | """
|
| | self.config = config
|
| | self.model_name = "/".join(model_path.split("/")[-2:])
|
| | local_path = copy_to_local(model_path)
|
| | self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True)
|
| |
|
| |
|
| | self.weighted_addresses = [[0, address] for address in server_addresses]
|
| | heapq.heapify(self.weighted_addresses)
|
| |
|
| |
|
| | self.request_id_to_address = LRUCache(maxsize=max_cache_size)
|
| |
|
| | async def submit_chat_completions(
|
| | self,
|
| | callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None],
|
| | callback_additional_info: Dict[str, Any],
|
| | **chat_complete_request,
|
| | ):
|
| | """
|
| | Submit a chat completion request to the server with the least number of requests.
|
| |
|
| | Args:
|
| | callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], async callback function
|
| | to handle the response. The callback function should have the following signature:
|
| |
|
| | ```python
|
| | async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception):
|
| | ...
|
| | ```
|
| | - completions: chat completion response from server.
|
| | - info: user provided `callback_additional_info`.
|
| | - exception: exception raise from OpenAI client if request failed, otherwise None.
|
| |
|
| | **CAUTION**: the callback function must be async and non-blocking, if you have any blocking operation,
|
| | please move to seperate thread or process pool to avoid blocking the event loop.
|
| |
|
| | callback_additional_info: Dict[str, Any], additional info to pass to the callback function.
|
| |
|
| | **chat_complete_request: dict, request parameters same as OpenAI AsyncCompletions.create.
|
| | OpenAI API reference: https://platform.openai.com/docs/api-reference/chat/create
|
| | """
|
| | if "extra_headers" not in chat_complete_request:
|
| | chat_complete_request["extra_headers"] = {}
|
| |
|
| | extra_headers = chat_complete_request["extra_headers"]
|
| | request_id = extra_headers.get("x-request-id", None)
|
| | if request_id:
|
| | if request_id.startswith("chatcmpl-"):
|
| | request_id = request_id[len("chatcmpl-") :]
|
| | extra_headers["x-request-id"] = request_id
|
| |
|
| | address = self.request_id_to_address.pop(request_id)
|
| | else:
|
| | address = self.weighted_addresses[0][1]
|
| | self.weighted_addresses[0][0] += 1
|
| | heapq.heapreplace(self.weighted_addresses, self.weighted_addresses[0])
|
| |
|
| |
|
| | request_id = uuid4().hex
|
| | self.request_id_to_address[request_id] = address
|
| | chat_complete_request["extra_headers"]["x-request-id"] = request_id
|
| |
|
| | completions, exception = None, None
|
| | try:
|
| |
|
| | completions = await self._chat_completions_openai(address, **chat_complete_request)
|
| | except Exception as e:
|
| |
|
| | exception = e
|
| |
|
| | await callback(completions, callback_additional_info, exception)
|
| |
|
| | async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion:
|
| | client = AsyncOpenAI(
|
| | base_url=f"http://{address}/v1",
|
| | api_key="token-abc123",
|
| | timeout=None,
|
| | max_retries=0
|
| | )
|
| | return await client.chat.completions.create(**chat_complete_request)
|
| |
|
| | async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion:
|
| | try:
|
| | session = aiohttp.ClientSession()
|
| | async with session.post(
|
| | url=f"http://{address}/v1/chat/completions",
|
| | headers={"Authorization": "Bearer token-abc123"},
|
| | json=chat_complete_request,
|
| | ) as resp:
|
| | data = await resp.json()
|
| | return ChatCompletion(**data)
|
| | finally:
|
| | await session.close()
|
| |
|
| | async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto:
|
| | raise NotImplementedError
|
| |
|
| |
|
| | class AsyncLLMServerManager:
|
| | """AsyncLLMServerManager manage a group of vllm instances, i.e AsyncvLLMServer."""
|
| |
|
| | def __init__(self, config: DictConfig, worker_group: RayWorkerGroup, *, scheduler_kwargs: Dict[str, Any] = None):
|
| | """Initialize AsyncLLMServerManager.
|
| |
|
| | Args:
|
| | config: DictConfig, actor_rollout_ref config.
|
| | worker_group: RayWorkerGroup, worker group of AsyncActorRolloutRefWorker.
|
| | scheduler_kwargs: Dict[str, Any], kwargs for chat scheduler.
|
| | """
|
| | self.config = config
|
| | self.worker_group = worker_group
|
| | self.scheduler_kwargs = scheduler_kwargs if scheduler_kwargs else {}
|
| |
|
| | self.rollout_tp_size = self.config.rollout.tensor_model_parallel_size
|
| | self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size
|
| |
|
| | register_center = ray.get_actor(f"{self.worker_group.name_prefix}_register_center")
|
| | workers_info = ray.get(register_center.get_worker_info.remote())
|
| | assert len(workers_info) == self.worker_group.world_size
|
| |
|
| | self.async_llm_servers = [None] * self.rollout_dp_size
|
| | self.server_addresses = [None] * self.rollout_dp_size
|
| |
|
| | server_class = async_server_class(
|
| | rollout_backend=self.config.rollout.name,
|
| | )
|
| |
|
| |
|
| | unready_dp_ranks = set(range(self.rollout_dp_size))
|
| | while len(unready_dp_ranks) > 0:
|
| | servers = {
|
| | rollout_dp_rank: server_class.options(
|
| |
|
| | scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
| | node_id=workers_info[rollout_dp_rank * self.rollout_tp_size],
|
| | soft=False,
|
| | ),
|
| | name=f"async_llm_server_{rollout_dp_rank}",
|
| | ).remote(config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix)
|
| | for rollout_dp_rank in unready_dp_ranks
|
| | }
|
| |
|
| | for rollout_dp_rank, server in servers.items():
|
| | try:
|
| | address = ray.get(server.get_server_address.remote())
|
| | self.server_addresses[rollout_dp_rank] = address
|
| | self.async_llm_servers[rollout_dp_rank] = server
|
| | unready_dp_ranks.remove(rollout_dp_rank)
|
| | except Exception:
|
| | ray.kill(server)
|
| | print(f"rollout server {rollout_dp_rank} failed, maybe address already in use, restarting...")
|
| |
|
| |
|
| | ray.get([server.init_engine.remote() for server in self.async_llm_servers])
|
| |
|
| |
|
| | self.chat_scheduler: ChatCompletionScheduler = None
|
| | self.chat_scheduler_loop = None
|
| | self.chat_scheduler_ready = threading.Event()
|
| | self.chat_scheduler_thread = threading.Thread(target=self._init_chat_scheduler, daemon=True)
|
| | self.chat_scheduler_thread.start()
|
| | self.chat_scheduler_ready.wait()
|
| |
|
| | def _init_chat_scheduler(self):
|
| | self.chat_scheduler_loop = asyncio.new_event_loop()
|
| | asyncio.set_event_loop(self.chat_scheduler_loop)
|
| |
|
| | module_path, class_name = self.config.rollout.chat_scheduler.rsplit(".", 1)
|
| | module = importlib.import_module(module_path)
|
| | scheduler_cls = getattr(module, class_name)
|
| | self.chat_scheduler = scheduler_cls(
|
| | config=self.config.rollout,
|
| | model_path=self.config.model.path,
|
| | server_addresses=self.server_addresses,
|
| | **self.scheduler_kwargs,
|
| | )
|
| |
|
| | self.chat_scheduler_ready.set()
|
| | self.chat_scheduler_loop.run_forever()
|
| |
|
| | def wake_up(self):
|
| | """Wake up all vllm instances."""
|
| | ray.get([server.wake_up.remote() for server in self.async_llm_servers])
|
| |
|
| | def sleep(self):
|
| | """Sleep all vllm instances."""
|
| | ray.get([server.sleep.remote() for server in self.async_llm_servers])
|
| |
|
| | def submit_chat_completions(
|
| | self,
|
| | callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None],
|
| | callback_additional_info: Dict[str, Any],
|
| | **chat_complete_request,
|
| | ):
|
| | """Submit a chat completion request to chat scheduler and wait until it is done.
|
| | To submit multiple requests in parallel, please use `generate_sequences` instead.
|
| |
|
| | Args: same as ChatCompletionScheduler.submit_chat_completions.
|
| | """
|
| | assert self.chat_scheduler is not None, "chat scheduler is not initialized."
|
| | future = asyncio.run_coroutine_threadsafe(
|
| | self.chat_scheduler.submit_chat_completions(
|
| | callback=callback,
|
| | callback_additional_info=callback_additional_info,
|
| | **chat_complete_request,
|
| | ),
|
| | self.chat_scheduler_loop,
|
| | )
|
| | future.result()
|
| |
|
| | def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto:
|
| | """Generate multiple sequences in parallel via chat scheduler."""
|
| | assert self.chat_scheduler is not None, "chat scheduler is not initialized."
|
| |
|
| | future = asyncio.run_coroutine_threadsafe(self.chat_scheduler.generate_sequences(prompts, **sampling_params), self.chat_scheduler_loop)
|
| | return future.result()
|
| |
|
| |
|
| | def async_server_class(rollout_backend: str) -> Type[AsyncServerBase]:
|
| | """Get async server class.
|
| |
|
| | Args:
|
| | rollout_backend: str, rollout backend, should be "vllm" or "sglang".
|
| |
|
| | Returns:
|
| | Type[AsyncServerBase]: async server class.
|
| | """
|
| | if rollout_backend == "vllm":
|
| | from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer
|
| |
|
| | return AsyncvLLMServer
|
| | elif rollout_backend == "sglang":
|
| | raise NotImplementedError
|
| | else:
|
| | raise NotImplementedError
|
| |
|