import ast
import json
import regex as re
from collections.abc import Sequence
from typing import List, Any
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
def _is_string_type(
tool_name: str, arg_name: str, tools: List[ChatCompletionToolsParam] | None
):
if tools is None:
return False
for tool in tools:
if tool.function.name == tool_name:
if tool.function.parameters is None:
return False
arg_type = (
tool.function.parameters.get("properties", {})
.get(arg_name, {})
.get("type", None)
)
return arg_type == "string"
logger.debug("No tool named '%s'.", tool_name)
return False
def _deserialize(value: str) -> Any:
try:
return json.loads(value)
except Exception:
pass
try:
return ast.literal_eval(value)
except Exception:
pass
return value
@ToolParserManager.register_module("telechat3")
class TeleChat3ModelToolParser(ToolParser):
"""
Tool call parser for TeleChat3-36B models.
Used when --enable-auto-tool-choice --tool-call-parser telechat3
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# initialize properties used for state when parsing tool calls in
# streaming mode
self.current_tool_id: int = -1
self.tool_start_token = ""
self.tool_end_token = ""
self.func_detail_regex = re.compile(
r"(.*?)(.*?)?", re.DOTALL
)
self.func_arg_regex = re.compile(
r"(.*?)(?:\\n|\s)*(.*?)",
re.DOTALL,
)
self._buffer = ""
def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest):
matched_tool_calls = self.func_detail_regex.findall(model_output)
logger.debug("model_output: %s", model_output)
tool_calls = []
try:
for match in matched_tool_calls:
tc_name = match[0].strip()
arg_dict = {}
if len(match) > 1:
for key, value in self.func_arg_regex.findall(match[1]):
arg_key = key.strip()
arg_val = value.strip()
if not _is_string_type(tc_name, key, request.tools):
arg_val = _deserialize(arg_val)
logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val)
arg_dict[arg_key] = arg_val
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=tc_name,
arguments=json.dumps(arg_dict, ensure_ascii=False),
),
)
)
except Exception:
logger.exception("Failed to extract tool call spec")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
else:
if len(tool_calls) > 0:
content = model_output[: model_output.find(self.tool_start_token)]
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=content
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
self._buffer += delta_text
cur_text = self._buffer
start_idx = cur_text.find(self.tool_start_token)
if start_idx == -1:
self._buffer = ""
return DeltaMessage(content=cur_text)
logger.debug("cur_text = %s", cur_text)
end_idx = cur_text.find(self.tool_end_token)
if end_idx != -1:
extracted_tool_calls = self.extract_tool_calls(
cur_text[: end_idx + len(self.tool_end_token)], request
)
if len(extracted_tool_calls.tool_calls) == 0:
logger.warning("Failed to extract any tool calls.")
return None
self.current_tool_id += 1
tool_call = extracted_tool_calls.tool_calls[0]
delta = DeltaMessage(
content=extracted_tool_calls.content,
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
id=tool_call.id,
type=tool_call.type,
function=DeltaFunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
)
],
)
self._buffer = cur_text[end_idx + len(self.tool_end_token) :]
return delta
self._buffer = cur_text[start_idx:]
return DeltaMessage(content=cur_text[:start_idx])
def register_tool_parser(): ...