| | import re |
| | import time |
| | from typing import Union, List, Dict |
| |
|
| | from werkzeug.datastructures import FileStorage |
| |
|
| | from .. import BaseAgent |
| | from ...exceptions.exceptions import InternalErrorException, LLMException, SandboxException |
| | from ...schemas import ( |
| | AgentType, AgentRequest, AgentFinish, AgentAction, AgentResponse, |
| | BaseAgentResponse, AgentObservation, RunCodeOutput, MediaFile |
| | ) |
| | from ...tools import PythonSandBoxToolResponse, AsyncPythonSandBoxTool |
| | from ...utils import get_logger, replace_latex_format, extract_and_replace_url, \ |
| | OBSERVATION_PREFIX_CN, OBSERVATION_PREFIX_EN, AGENT_FAILED_CN, AGENT_FAILED_EN, \ |
| | TOOL_INPUT_PREFIX_CN, TOOL_INPUT_PREFIX_EN |
| |
|
| | SAND_BOX_PLUGIN_NAME = 'python_code_sandbox' |
| | FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"] |
| | CODE_BLOCK_START_TAG = '```python' |
| | CODE_BLOCK_TAG = '```' |
| |
|
| | logger = get_logger() |
| |
|
| | SAND_BOX_PLUGIN_NAME = 'python_code_sandbox' |
| | FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"] |
| | CODE_BLOCK_START_TAG = '```python' |
| | CODE_BLOCK_TAG = '```' |
| | STOP_WORD = ['Observation:'] |
| |
|
| | logger = get_logger() |
| |
|
| |
|
| | class AsyncReactAgent(BaseAgent): |
| | def __init__(self, **kwargs): |
| | super().__init__(**kwargs) |
| | self._name = self._name or "AsyncReactAgent" |
| | self._type = AgentType.react |
| | self.__intermediate_steps: List[BaseAgentResponse] = [] |
| |
|
| | @property |
| | def intermediate_steps(self): |
| | return self.__intermediate_steps |
| |
|
| | def run(self, *args, **kwargs): |
| | pass |
| |
|
| | async def sync_to_sandbox(self, file: Union[str, Dict, FileStorage]): |
| | sandbox_plugin = self.plugins_map.get(SAND_BOX_PLUGIN_NAME) |
| | if not isinstance(sandbox_plugin, (AsyncPythonSandBoxTool, AsyncPythonSandBoxTool)): |
| | raise InternalErrorException("SandBox client is not ready for agent, please check init logic.") |
| | return await sandbox_plugin.sync_to_sandbox(file) |
| |
|
| | async def async_run(self, agent_req: AgentRequest): |
| | instruction = '\n'.join(message.content for message in agent_req.messages) |
| | async for response in self._chat(instruction, is_cn=agent_req.is_cn): |
| | yield response |
| |
|
| | async def _chat(self, instruction: str, is_cn=False, max_iterations=10, |
| | max_single_step_iterations=3): |
| | current_iteration = 0 |
| |
|
| | for _ in range(max_iterations): |
| | current_iteration += 1 |
| | llm_response = await self._single_round_thought(instruction, |
| | max_llm_iteration=max_single_step_iterations, |
| | is_cn=is_cn) |
| | logger.info("Round {} of {}, [LLM raw output]:\n{}\n\n[Formatted output]:\n{}\n" |
| | .format(current_iteration, max_iterations, llm_response.raw_output, |
| | llm_response.formatted_output)) |
| | yield self.create_agent_response(llm_response.formatted_output, [], llm_response.raw_output) |
| |
|
| | if isinstance(llm_response, AgentFinish): |
| | logger.info("Find final answer, stop iteration.") |
| | break |
| |
|
| | self.intermediate_steps.append(llm_response) |
| | action_response, cur_output_files = await self._process_agent_action(llm_response, current_iteration, |
| | max_iterations, is_cn) |
| | logger.info("Round {} of {}, [Plugin raw output]:\n{}\n[Formatted output]:\n{}\n" |
| | .format(current_iteration, max_iterations, action_response.raw_output, |
| | action_response.formatted_output)) |
| | self.intermediate_steps.append(action_response) |
| |
|
| | yield self.create_agent_response(action_response.formatted_output, |
| | cur_output_files, |
| | action_response.raw_output) |
| |
|
| | logger.info(f"Finished iteration in {current_iteration}.") |
| |
|
| | |
| | async def _process_agent_action(self, response, current_iteration, max_iterations, is_cn: bool = False): |
| | try: |
| | response.tool = 'python_code_sandbox' |
| | action_response = await self.get_plugin_tool_async_function()[response.tool](response.tool_input) |
| | logger.info( |
| | f"Step {current_iteration} of {max_iterations}. Got agent observation raw output:\n" |
| | f"{action_response.output_text}") |
| |
|
| | if "STDERR" in action_response.output_text: |
| | formatted_output = self._process_sandbox_output(action_response.output_text) |
| | else: |
| | formatted_output = action_response.output_text |
| |
|
| | formatted_output = replace_latex_format(formatted_output) |
| | observation_prefix = OBSERVATION_PREFIX_CN if is_cn else OBSERVATION_PREFIX_EN |
| | formatted_output = f"{observation_prefix}\n{formatted_output}\n" |
| |
|
| | action_observation = AgentObservation(tool=response.tool, |
| | formatted_output=formatted_output, |
| | raw_output=action_response.output_text) |
| | cur_output_files = self._get_output_files(action_response) |
| | return action_observation, cur_output_files |
| |
|
| | except Exception as e: |
| | logger.error(f"Error occurred while executing tool {response.tool} with input {response.tool_input}. " |
| | f"Error: {str(e)}", exc_info=True) |
| | |
| | raise SandboxException("Error occurred while running the tool") from e |
| |
|
| | def _compose_prompt(self, instruction) -> str: |
| | """ |
| | Compose the prompt from template, worker description, examples and instruction. |
| | """ |
| | agent_scratchpad = self.prompt_template.construct_scratchpad(self.__intermediate_steps) |
| | tool_description = self._get_plugin_description() |
| | tool_names = ", ".join(list(self.plugins_map.keys())) |
| | if self.prompt_template is None: |
| | raise InternalErrorException("Agent prompt is none, please check init process") |
| |
|
| | return self.prompt_template.format( |
| | instruction=instruction, |
| | agent_scratchpad=agent_scratchpad, |
| | tool_description=tool_description, |
| | tool_names=tool_names |
| | ) |
| |
|
| | async def _single_round_thought(self, instruction: str, max_llm_iteration=3, is_cn: bool = False) -> \ |
| | Union[AgentAction, AgentFinish]: |
| |
|
| | llm_iteration_count = 0 |
| |
|
| | llm_response = None |
| | while llm_iteration_count <= max_llm_iteration: |
| | llm_iteration_count += 1 |
| | try: |
| | llm_response = await self._get_llm_response(instruction) |
| | action_response = self._parse_output(llm_response.content, is_cn) |
| |
|
| | return action_response |
| | except Exception as e: |
| | logger.error("LLM iteration {} out of {} failed. Error: {}". |
| | format(llm_iteration_count, max_llm_iteration, str(e)), exc_info=True) |
| |
|
| | if llm_iteration_count > max_llm_iteration: |
| | logger.error("LLM iteration {} exceed max retry {}. Aborting". |
| | format(llm_iteration_count, max_llm_iteration)) |
| | return AgentFinish(formatted_output=AGENT_FAILED_CN if is_cn else AGENT_FAILED_EN, |
| | raw_output=str(llm_response)) |
| |
|
| | async def _get_llm_response(self, instruction: str): |
| | prompt = self._compose_prompt(instruction) |
| | logger.info("Send prompt to LLM:\n{}".format(prompt)) |
| | response = await self.llm.async_completion(prompt) |
| | if response.state == "error": |
| | raise LLMException("Failed to retrieve response from LLM, error: {}".format(str(response.content))) |
| |
|
| | logger.info("Got response from llm, raw response content: \n{}".format(response.content)) |
| | return response |
| |
|
| | def _parse_output(self, llm_output: str, is_cn: bool = False) -> Union[AgentAction, AgentFinish]: |
| |
|
| | for stop_word in STOP_WORD: |
| | if stop_word in llm_output: |
| | llm_output = llm_output.split(stop_word)[0].rstrip() |
| | break |
| |
|
| | |
| | for indicator in FINAL_ANSWER_INDICATORS: |
| | if indicator in llm_output: |
| | |
| | parts = llm_output.split(indicator) |
| | |
| | formatted_output = ''.join(parts).strip() |
| | formatted_output = replace_latex_format(formatted_output) |
| | return AgentFinish(raw_output=llm_output, formatted_output=formatted_output) |
| |
|
| | |
| | ACTION_REGEX_1 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```python\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$" |
| | ACTION_REGEX_2 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```py\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$" |
| |
|
| | action_match = re.search(ACTION_REGEX_1, llm_output, re.DOTALL) or re.search(ACTION_REGEX_2, llm_output, re.DOTALL) |
| |
|
| | |
| | if action_match: |
| | context = action_match.group(1).strip() |
| | action_tool_description = action_match.group(2).strip() |
| | action_input = action_match.group(3).strip() |
| |
|
| | |
| | |
| | format_code_block = self._format_code_block(action_input) |
| |
|
| | prefix = TOOL_INPUT_PREFIX_CN if is_cn else TOOL_INPUT_PREFIX_EN |
| | formatted_output = "{}\n{}\n{}\n".format(context, prefix, format_code_block) |
| | formatted_output = replace_latex_format(formatted_output) |
| |
|
| | return AgentAction(tool=action_tool_description, |
| | tool_input=format_code_block, |
| | formatted_output=formatted_output, |
| | raw_output=llm_output) |
| |
|
| | |
| | if not re.search(r"Action\s*:", llm_output, re.DOTALL): |
| | raise LLMException(f"Missing 'Action' in LLM output: `{llm_output}`") |
| | elif not re.search(r"Action\s*Input\s*:", llm_output, re.DOTALL): |
| | raise LLMException(f"Missing 'Action Input' in LLM output: `{llm_output}`") |
| | else: |
| | raise LLMException(f"Unrecognized LLM output format: `{llm_output}`") |
| |
|
| | def _format_code_block(self, tool_input): |
| | stripped_tool_input = tool_input.strip() |
| |
|
| | if stripped_tool_input.startswith(CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG): |
| | if not stripped_tool_input.startswith(CODE_BLOCK_START_TAG + '\n'): |
| | stripped_tool_input = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_START_TAG):] + \ |
| | '\n' |
| | formatted_code = stripped_tool_input |
| | elif stripped_tool_input.startswith(CODE_BLOCK_TAG) and not stripped_tool_input.startswith( |
| | CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG): |
| | formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_TAG):] + '\n' |
| | else: |
| | formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input + '\n' + CODE_BLOCK_TAG + '\n' |
| |
|
| | return formatted_code.encode("utf-8").decode("utf-8") |
| |
|
| | def _process_sandbox_output(self, output: str): |
| | """Function to process the result containing STDERR.""" |
| | if len(output) <= 1000: |
| | return output |
| |
|
| | logger.info("Output contains error, original message is over 1000, trim it for response. ori output: \n{}". |
| | format(output)) |
| | rows = output.split("\n") |
| | |
| | top_segment = [] |
| | length = 0 |
| | for sub_p in rows: |
| | if length + len(sub_p) > 500: |
| | break |
| | top_segment.append(sub_p) |
| | length += len(sub_p) |
| |
|
| | |
| | bottom_segment = [] |
| | length = 0 |
| | for sub_p in reversed(rows): |
| | if length + len(sub_p) > 500: |
| | break |
| | bottom_segment.insert(0, sub_p) |
| | length += len(sub_p) |
| |
|
| | |
| | timed_output = "\n".join(top_segment + ["......"] + bottom_segment) |
| |
|
| | return timed_output |
| |
|
| | def _get_output_files(self, tool_response) -> list[MediaFile]: |
| | output_files = [] |
| |
|
| | if isinstance(tool_response, PythonSandBoxToolResponse) and isinstance(tool_response.raw_output, RunCodeOutput): |
| | raw_output = tool_response.raw_output |
| |
|
| | if raw_output.code == 0 and not raw_output.data.is_partial: |
| | result_data = raw_output.data.result |
| |
|
| | |
| | if len(result_data.new_generated_files) > 0: |
| | output_files.extend([MediaFile(tos_path=file.download_link) for file in |
| | result_data.new_generated_files]) |
| |
|
| | if len(result_data.code_output_result) > 0: |
| | output_files.extend( |
| | [MediaFile(tos_path=image.content) for image in result_data.code_output_result |
| | if image.type == 'image']) |
| |
|
| | return output_files |
| |
|
| | def _replace_csv_path(self, input_string): |
| | |
| | pattern = r'pd\.read_csv\(["\'](.*\.csv)["\']\)' |
| | replacement = "pd.read_csv('/path/to/your/dataset')" |
| | updated_string = re.sub(pattern, replacement, input_string) |
| | return updated_string |
| |
|
| | @staticmethod |
| | def create_agent_response(formatted_output, output_files, raw_output): |
| | return AgentResponse(output_text=formatted_output, output_files=output_files, raw_output_text=raw_output) |
| |
|
| |
|