| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import asyncio |
| | from threading import Thread |
| | from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence |
| |
|
| | from ..extras.misc import torch_gc |
| | from ..hparams import get_infer_args |
| | from .hf_engine import HuggingfaceEngine |
| | from .vllm_engine import VllmEngine |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from numpy.typing import NDArray |
| |
|
| | from .base_engine import BaseEngine, Response |
| |
|
| |
|
| | def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None: |
| | asyncio.set_event_loop(loop) |
| | loop.run_forever() |
| |
|
| |
|
| | class ChatModel: |
| | def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: |
| | model_args, data_args, finetuning_args, generating_args = get_infer_args(args) |
| | if model_args.infer_backend == "huggingface": |
| | self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) |
| | elif model_args.infer_backend == "vllm": |
| | self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) |
| | else: |
| | raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) |
| |
|
| | self._loop = asyncio.new_event_loop() |
| | self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) |
| | self._thread.start() |
| |
|
| | def chat( |
| | self, |
| | messages: Sequence[Dict[str, str]], |
| | system: Optional[str] = None, |
| | tools: Optional[str] = None, |
| | image: Optional["NDArray"] = None, |
| | **input_kwargs, |
| | ) -> List["Response"]: |
| | task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop) |
| | return task.result() |
| |
|
| | async def achat( |
| | self, |
| | messages: Sequence[Dict[str, str]], |
| | system: Optional[str] = None, |
| | tools: Optional[str] = None, |
| | image: Optional["NDArray"] = None, |
| | **input_kwargs, |
| | ) -> List["Response"]: |
| | return await self.engine.chat(messages, system, tools, image, **input_kwargs) |
| |
|
| | def stream_chat( |
| | self, |
| | messages: Sequence[Dict[str, str]], |
| | system: Optional[str] = None, |
| | tools: Optional[str] = None, |
| | image: Optional["NDArray"] = None, |
| | **input_kwargs, |
| | ) -> Generator[str, None, None]: |
| | generator = self.astream_chat(messages, system, tools, image, **input_kwargs) |
| | while True: |
| | try: |
| | task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) |
| | yield task.result() |
| | except StopAsyncIteration: |
| | break |
| |
|
| | async def astream_chat( |
| | self, |
| | messages: Sequence[Dict[str, str]], |
| | system: Optional[str] = None, |
| | tools: Optional[str] = None, |
| | image: Optional["NDArray"] = None, |
| | **input_kwargs, |
| | ) -> AsyncGenerator[str, None]: |
| | async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs): |
| | yield new_token |
| |
|
| | def get_scores( |
| | self, |
| | batch_input: List[str], |
| | **input_kwargs, |
| | ) -> List[float]: |
| | task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) |
| | return task.result() |
| |
|
| | async def aget_scores( |
| | self, |
| | batch_input: List[str], |
| | **input_kwargs, |
| | ) -> List[float]: |
| | return await self.engine.get_scores(batch_input, **input_kwargs) |
| |
|
| |
|
| | def run_chat() -> None: |
| | try: |
| | import platform |
| |
|
| | if platform.system() != "Windows": |
| | import readline |
| | except ImportError: |
| | print("Install `readline` for a better experience.") |
| |
|
| | chat_model = ChatModel() |
| | messages = [] |
| | print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") |
| |
|
| | while True: |
| | try: |
| | query = input("\nUser: ") |
| | except UnicodeDecodeError: |
| | print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") |
| | continue |
| | except Exception: |
| | raise |
| |
|
| | if query.strip() == "exit": |
| | break |
| |
|
| | if query.strip() == "clear": |
| | messages = [] |
| | torch_gc() |
| | print("History has been removed.") |
| | continue |
| |
|
| | messages.append({"role": "user", "content": query}) |
| | print("Assistant: ", end="", flush=True) |
| |
|
| | response = "" |
| | for new_text in chat_model.stream_chat(messages): |
| | print(new_text, end="", flush=True) |
| | response += new_text |
| | print() |
| | messages.append({"role": "assistant", "content": response}) |
| |
|