| import json |
| import logging |
| from collections.abc import Generator |
| from copy import deepcopy |
| from typing import Any, Optional, Union |
|
|
| from core.agent.base_agent_runner import BaseAgentRunner |
| from core.app.apps.base_app_queue_manager import PublishFrom |
| from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent |
| from core.file import file_manager |
| from core.model_runtime.entities import ( |
| AssistantPromptMessage, |
| LLMResult, |
| LLMResultChunk, |
| LLMResultChunkDelta, |
| LLMUsage, |
| PromptMessage, |
| PromptMessageContent, |
| PromptMessageContentType, |
| SystemPromptMessage, |
| TextPromptMessageContent, |
| ToolPromptMessage, |
| UserPromptMessage, |
| ) |
| from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform |
| from core.tools.entities.tool_entities import ToolInvokeMeta |
| from core.tools.tool_engine import ToolEngine |
| from models.model import Message |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class FunctionCallAgentRunner(BaseAgentRunner): |
| def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]: |
| """ |
| Run FunctionCall agent application |
| """ |
| self.query = query |
| app_generate_entity = self.application_generate_entity |
|
|
| app_config = self.app_config |
|
|
| |
| tool_instances, prompt_messages_tools = self._init_prompt_tools() |
|
|
| iteration_step = 1 |
| max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 |
|
|
| |
| function_call_state = True |
| llm_usage = {"usage": None} |
| final_answer = "" |
|
|
| |
| trace_manager = app_generate_entity.trace_manager |
|
|
| def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): |
| if not final_llm_usage_dict["usage"]: |
| final_llm_usage_dict["usage"] = usage |
| else: |
| llm_usage = final_llm_usage_dict["usage"] |
| llm_usage.prompt_tokens += usage.prompt_tokens |
| llm_usage.completion_tokens += usage.completion_tokens |
| llm_usage.prompt_price += usage.prompt_price |
| llm_usage.completion_price += usage.completion_price |
| llm_usage.total_price += usage.total_price |
|
|
| model_instance = self.model_instance |
|
|
| while function_call_state and iteration_step <= max_iteration_steps: |
| function_call_state = False |
|
|
| if iteration_step == max_iteration_steps: |
| |
| prompt_messages_tools = [] |
|
|
| message_file_ids = [] |
| agent_thought = self.create_agent_thought( |
| message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids |
| ) |
|
|
| |
| prompt_messages = self._organize_prompt_messages() |
| self.recalc_llm_max_tokens(self.model_config, prompt_messages) |
| |
| chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( |
| prompt_messages=prompt_messages, |
| model_parameters=app_generate_entity.model_conf.parameters, |
| tools=prompt_messages_tools, |
| stop=app_generate_entity.model_conf.stop, |
| stream=self.stream_tool_call, |
| user=self.user_id, |
| callbacks=[], |
| ) |
|
|
| tool_calls: list[tuple[str, str, dict[str, Any]]] = [] |
|
|
| |
| response = "" |
|
|
| |
| tool_call_names = "" |
| tool_call_inputs = "" |
|
|
| current_llm_usage = None |
|
|
| if self.stream_tool_call: |
| is_first_chunk = True |
| for chunk in chunks: |
| if is_first_chunk: |
| self.queue_manager.publish( |
| QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER |
| ) |
| is_first_chunk = False |
| |
| if self.check_tool_calls(chunk): |
| function_call_state = True |
| tool_calls.extend(self.extract_tool_calls(chunk)) |
| tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) |
| try: |
| tool_call_inputs = json.dumps( |
| {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False |
| ) |
| except json.JSONDecodeError as e: |
| |
| tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) |
|
|
| if chunk.delta.message and chunk.delta.message.content: |
| if isinstance(chunk.delta.message.content, list): |
| for content in chunk.delta.message.content: |
| response += content.data |
| else: |
| response += chunk.delta.message.content |
|
|
| if chunk.delta.usage: |
| increase_usage(llm_usage, chunk.delta.usage) |
| current_llm_usage = chunk.delta.usage |
|
|
| yield chunk |
| else: |
| result: LLMResult = chunks |
| |
| if self.check_blocking_tool_calls(result): |
| function_call_state = True |
| tool_calls.extend(self.extract_blocking_tool_calls(result)) |
| tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) |
| try: |
| tool_call_inputs = json.dumps( |
| {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False |
| ) |
| except json.JSONDecodeError as e: |
| |
| tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) |
|
|
| if result.usage: |
| increase_usage(llm_usage, result.usage) |
| current_llm_usage = result.usage |
|
|
| if result.message and result.message.content: |
| if isinstance(result.message.content, list): |
| for content in result.message.content: |
| response += content.data |
| else: |
| response += result.message.content |
|
|
| if not result.message.content: |
| result.message.content = "" |
|
|
| self.queue_manager.publish( |
| QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER |
| ) |
|
|
| yield LLMResultChunk( |
| model=model_instance.model, |
| prompt_messages=result.prompt_messages, |
| system_fingerprint=result.system_fingerprint, |
| delta=LLMResultChunkDelta( |
| index=0, |
| message=result.message, |
| usage=result.usage, |
| ), |
| ) |
|
|
| assistant_message = AssistantPromptMessage(content="", tool_calls=[]) |
| if tool_calls: |
| assistant_message.tool_calls = [ |
| AssistantPromptMessage.ToolCall( |
| id=tool_call[0], |
| type="function", |
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( |
| name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False) |
| ), |
| ) |
| for tool_call in tool_calls |
| ] |
| else: |
| assistant_message.content = response |
|
|
| self._current_thoughts.append(assistant_message) |
|
|
| |
| self.save_agent_thought( |
| agent_thought=agent_thought, |
| tool_name=tool_call_names, |
| tool_input=tool_call_inputs, |
| thought=response, |
| tool_invoke_meta=None, |
| observation=None, |
| answer=response, |
| messages_ids=[], |
| llm_usage=current_llm_usage, |
| ) |
| self.queue_manager.publish( |
| QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER |
| ) |
|
|
| final_answer += response + "\n" |
|
|
| |
| tool_responses = [] |
| for tool_call_id, tool_call_name, tool_call_args in tool_calls: |
| tool_instance = tool_instances.get(tool_call_name) |
| if not tool_instance: |
| tool_response = { |
| "tool_call_id": tool_call_id, |
| "tool_call_name": tool_call_name, |
| "tool_response": f"there is not a tool named {tool_call_name}", |
| "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), |
| } |
| else: |
| |
| tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( |
| tool=tool_instance, |
| tool_parameters=tool_call_args, |
| user_id=self.user_id, |
| tenant_id=self.tenant_id, |
| message=self.message, |
| invoke_from=self.application_generate_entity.invoke_from, |
| agent_tool_callback=self.agent_callback, |
| trace_manager=trace_manager, |
| ) |
| |
| for message_file_id, save_as in message_files: |
| if save_as: |
| self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) |
|
|
| |
| self.queue_manager.publish( |
| QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER |
| ) |
| |
| message_file_ids.append(message_file_id) |
|
|
| tool_response = { |
| "tool_call_id": tool_call_id, |
| "tool_call_name": tool_call_name, |
| "tool_response": tool_invoke_response, |
| "meta": tool_invoke_meta.to_dict(), |
| } |
|
|
| tool_responses.append(tool_response) |
| if tool_response["tool_response"] is not None: |
| self._current_thoughts.append( |
| ToolPromptMessage( |
| content=tool_response["tool_response"], |
| tool_call_id=tool_call_id, |
| name=tool_call_name, |
| ) |
| ) |
|
|
| if len(tool_responses) > 0: |
| |
| self.save_agent_thought( |
| agent_thought=agent_thought, |
| tool_name=None, |
| tool_input=None, |
| thought=None, |
| tool_invoke_meta={ |
| tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses |
| }, |
| observation={ |
| tool_response["tool_call_name"]: tool_response["tool_response"] |
| for tool_response in tool_responses |
| }, |
| answer=None, |
| messages_ids=message_file_ids, |
| ) |
| self.queue_manager.publish( |
| QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER |
| ) |
|
|
| |
| for prompt_tool in prompt_messages_tools: |
| self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) |
|
|
| iteration_step += 1 |
|
|
| self.update_db_variables(self.variables_pool, self.db_variables_pool) |
| |
| self.queue_manager.publish( |
| QueueMessageEndEvent( |
| llm_result=LLMResult( |
| model=model_instance.model, |
| prompt_messages=prompt_messages, |
| message=AssistantPromptMessage(content=final_answer), |
| usage=llm_usage["usage"] or LLMUsage.empty_usage(), |
| system_fingerprint="", |
| ) |
| ), |
| PublishFrom.APPLICATION_MANAGER, |
| ) |
|
|
| def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: |
| """ |
| Check if there is any tool call in llm result chunk |
| """ |
| if llm_result_chunk.delta.message.tool_calls: |
| return True |
| return False |
|
|
| def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: |
| """ |
| Check if there is any blocking tool call in llm result |
| """ |
| if llm_result.message.tool_calls: |
| return True |
| return False |
|
|
| def extract_tool_calls( |
| self, llm_result_chunk: LLMResultChunk |
| ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: |
| """ |
| Extract tool calls from llm result chunk |
| |
| Returns: |
| List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] |
| """ |
| tool_calls = [] |
| for prompt_message in llm_result_chunk.delta.message.tool_calls: |
| args = {} |
| if prompt_message.function.arguments != "": |
| args = json.loads(prompt_message.function.arguments) |
|
|
| tool_calls.append( |
| ( |
| prompt_message.id, |
| prompt_message.function.name, |
| args, |
| ) |
| ) |
|
|
| return tool_calls |
|
|
| def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: |
| """ |
| Extract blocking tool calls from llm result |
| |
| Returns: |
| List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] |
| """ |
| tool_calls = [] |
| for prompt_message in llm_result.message.tool_calls: |
| args = {} |
| if prompt_message.function.arguments != "": |
| args = json.loads(prompt_message.function.arguments) |
|
|
| tool_calls.append( |
| ( |
| prompt_message.id, |
| prompt_message.function.name, |
| args, |
| ) |
| ) |
|
|
| return tool_calls |
|
|
| def _init_system_message( |
| self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None |
| ) -> list[PromptMessage]: |
| """ |
| Initialize system message |
| """ |
| if not prompt_messages and prompt_template: |
| return [ |
| SystemPromptMessage(content=prompt_template), |
| ] |
|
|
| if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: |
| prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) |
|
|
| return prompt_messages |
|
|
| def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: |
| """ |
| Organize user query |
| """ |
| if self.files: |
| prompt_message_contents: list[PromptMessageContent] = [] |
| prompt_message_contents.append(TextPromptMessageContent(data=query)) |
| for file_obj in self.files: |
| prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) |
|
|
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) |
| else: |
| prompt_messages.append(UserPromptMessage(content=query)) |
|
|
| return prompt_messages |
|
|
| def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: |
| """ |
| As for now, gpt supports both fc and vision at the first iteration. |
| We need to remove the image messages from the prompt messages at the first iteration. |
| """ |
| prompt_messages = deepcopy(prompt_messages) |
|
|
| for prompt_message in prompt_messages: |
| if isinstance(prompt_message, UserPromptMessage): |
| if isinstance(prompt_message.content, list): |
| prompt_message.content = "\n".join( |
| [ |
| content.data |
| if content.type == PromptMessageContentType.TEXT |
| else "[image]" |
| if content.type == PromptMessageContentType.IMAGE |
| else "[file]" |
| for content in prompt_message.content |
| ] |
| ) |
|
|
| return prompt_messages |
|
|
| def _organize_prompt_messages(self): |
| prompt_template = self.app_config.prompt_template.simple_prompt_template or "" |
| self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) |
| query_prompt_messages = self._organize_user_query(self.query, []) |
|
|
| self.history_prompt_messages = AgentHistoryPromptTransform( |
| model_config=self.model_config, |
| prompt_messages=[*query_prompt_messages, *self._current_thoughts], |
| history_messages=self.history_prompt_messages, |
| memory=self.memory, |
| ).get_prompt() |
|
|
| prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] |
| if len(self._current_thoughts) != 0: |
| |
| prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) |
| return prompt_messages |
|
|