| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from abc import ABC, abstractmethod |
| from dataclasses import dataclass |
| from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union |
|
|
|
|
| if TYPE_CHECKING: |
| from transformers import PreTrainedModel, PreTrainedTokenizer |
| from vllm import AsyncLLMEngine |
|
|
| from ..data import Template |
| from ..data.mm_plugin import ImageInput, VideoInput |
| from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments |
|
|
|
|
| @dataclass |
| class Response: |
| response_text: str |
| response_length: int |
| prompt_length: int |
| finish_reason: Literal["stop", "length"] |
|
|
|
|
| class BaseEngine(ABC): |
| r""" |
| Base class for inference engine of chat models. |
| |
| Must implements async methods: chat(), stream_chat() and get_scores(). |
| """ |
|
|
| model: Union["PreTrainedModel", "AsyncLLMEngine"] |
| tokenizer: "PreTrainedTokenizer" |
| can_generate: bool |
| template: "Template" |
| generating_args: Dict[str, Any] |
|
|
| @abstractmethod |
| def __init__( |
| self, |
| model_args: "ModelArguments", |
| data_args: "DataArguments", |
| finetuning_args: "FinetuningArguments", |
| generating_args: "GeneratingArguments", |
| ) -> None: |
| r""" |
| Initializes an inference engine. |
| """ |
| ... |
|
|
| @abstractmethod |
| async def chat( |
| self, |
| messages: Sequence[Dict[str, str]], |
| system: Optional[str] = None, |
| tools: Optional[str] = None, |
| image: Optional["ImageInput"] = None, |
| video: Optional["VideoInput"] = None, |
| **input_kwargs, |
| ) -> List["Response"]: |
| r""" |
| Gets a list of responses of the chat model. |
| """ |
| ... |
|
|
| @abstractmethod |
| async def stream_chat( |
| self, |
| messages: Sequence[Dict[str, str]], |
| system: Optional[str] = None, |
| tools: Optional[str] = None, |
| image: Optional["ImageInput"] = None, |
| video: Optional["VideoInput"] = None, |
| **input_kwargs, |
| ) -> AsyncGenerator[str, None]: |
| r""" |
| Gets the response token-by-token of the chat model. |
| """ |
| ... |
|
|
| @abstractmethod |
| async def get_scores( |
| self, |
| batch_input: List[str], |
| **input_kwargs, |
| ) -> List[float]: |
| r""" |
| Gets a list of scores of the reward model. |
| """ |
| ... |
|
|