| import re |
| import time |
| from typing import Union, List, Dict |
|
|
| from werkzeug.datastructures import FileStorage |
|
|
| from .. import BaseAgent |
| from ...exceptions.exceptions import InternalErrorException, LLMException, SandboxException |
| from ...schemas import ( |
| AgentType, AgentRequest, AgentFinish, AgentAction, AgentResponse, |
| BaseAgentResponse, AgentObservation, RunCodeOutput, MediaFile |
| ) |
| from ...tools import PythonSandBoxToolResponse, AsyncPythonSandBoxTool |
| from ...utils import get_logger, replace_latex_format, extract_and_replace_url, \ |
| OBSERVATION_PREFIX_CN, OBSERVATION_PREFIX_EN, AGENT_FAILED_CN, AGENT_FAILED_EN, \ |
| TOOL_INPUT_PREFIX_CN, TOOL_INPUT_PREFIX_EN |
|
|
| SAND_BOX_PLUGIN_NAME = 'python_code_sandbox' |
| FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"] |
| CODE_BLOCK_START_TAG = '```python' |
| CODE_BLOCK_TAG = '```' |
|
|
| logger = get_logger() |
|
|
| SAND_BOX_PLUGIN_NAME = 'python_code_sandbox' |
| FINAL_ANSWER_INDICATORS = ["Final Answer:", "[END]", "The final Answer", "final answer"] |
| CODE_BLOCK_START_TAG = '```python' |
| CODE_BLOCK_TAG = '```' |
| STOP_WORD = ['Observation:'] |
|
|
| logger = get_logger() |
|
|
|
|
| class AsyncReactAgent(BaseAgent): |
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| self._name = self._name or "AsyncReactAgent" |
| self._type = AgentType.react |
| self.__intermediate_steps: List[BaseAgentResponse] = [] |
|
|
| @property |
| def intermediate_steps(self): |
| return self.__intermediate_steps |
|
|
| def run(self, *args, **kwargs): |
| pass |
|
|
| async def sync_to_sandbox(self, file: Union[str, Dict, FileStorage]): |
| sandbox_plugin = self.plugins_map.get(SAND_BOX_PLUGIN_NAME) |
| if not isinstance(sandbox_plugin, (AsyncPythonSandBoxTool, AsyncPythonSandBoxTool)): |
| raise InternalErrorException("SandBox client is not ready for agent, please check init logic.") |
| return await sandbox_plugin.sync_to_sandbox(file) |
|
|
| async def async_run(self, agent_req: AgentRequest): |
| instruction = '\n'.join(message.content for message in agent_req.messages) |
| async for response in self._chat(instruction, is_cn=agent_req.is_cn): |
| yield response |
|
|
| async def _chat(self, instruction: str, is_cn=False, max_iterations=10, |
| max_single_step_iterations=3): |
| current_iteration = 0 |
|
|
| for _ in range(max_iterations): |
| current_iteration += 1 |
| llm_response = await self._single_round_thought(instruction, |
| max_llm_iteration=max_single_step_iterations, |
| is_cn=is_cn) |
| logger.info("Round {} of {}, [LLM raw output]:\n{}\n\n[Formatted output]:\n{}\n" |
| .format(current_iteration, max_iterations, llm_response.raw_output, |
| llm_response.formatted_output)) |
| yield self.create_agent_response(llm_response.formatted_output, [], llm_response.raw_output) |
|
|
| if isinstance(llm_response, AgentFinish): |
| logger.info("Find final answer, stop iteration.") |
| break |
|
|
| self.intermediate_steps.append(llm_response) |
| action_response, cur_output_files = await self._process_agent_action(llm_response, current_iteration, |
| max_iterations, is_cn) |
| logger.info("Round {} of {}, [Plugin raw output]:\n{}\n[Formatted output]:\n{}\n" |
| .format(current_iteration, max_iterations, action_response.raw_output, |
| action_response.formatted_output)) |
| self.intermediate_steps.append(action_response) |
|
|
| yield self.create_agent_response(action_response.formatted_output, |
| cur_output_files, |
| action_response.raw_output) |
|
|
| logger.info(f"Finished iteration in {current_iteration}.") |
|
|
| |
| async def _process_agent_action(self, response, current_iteration, max_iterations, is_cn: bool = False): |
| try: |
| response.tool = 'python_code_sandbox' |
| action_response = await self.get_plugin_tool_async_function()[response.tool](response.tool_input) |
| logger.info( |
| f"Step {current_iteration} of {max_iterations}. Got agent observation raw output:\n" |
| f"{action_response.output_text}") |
|
|
| if "STDERR" in action_response.output_text: |
| formatted_output = self._process_sandbox_output(action_response.output_text) |
| else: |
| formatted_output = action_response.output_text |
|
|
| formatted_output = replace_latex_format(formatted_output) |
| observation_prefix = OBSERVATION_PREFIX_CN if is_cn else OBSERVATION_PREFIX_EN |
| formatted_output = f"{observation_prefix}\n{formatted_output}\n" |
|
|
| action_observation = AgentObservation(tool=response.tool, |
| formatted_output=formatted_output, |
| raw_output=action_response.output_text) |
| cur_output_files = self._get_output_files(action_response) |
| return action_observation, cur_output_files |
|
|
| except Exception as e: |
| logger.error(f"Error occurred while executing tool {response.tool} with input {response.tool_input}. " |
| f"Error: {str(e)}", exc_info=True) |
| |
| raise SandboxException("Error occurred while running the tool") from e |
|
|
| def _compose_prompt(self, instruction) -> str: |
| """ |
| Compose the prompt from template, worker description, examples and instruction. |
| """ |
| agent_scratchpad = self.prompt_template.construct_scratchpad(self.__intermediate_steps) |
| tool_description = self._get_plugin_description() |
| tool_names = ", ".join(list(self.plugins_map.keys())) |
| if self.prompt_template is None: |
| raise InternalErrorException("Agent prompt is none, please check init process") |
|
|
| return self.prompt_template.format( |
| instruction=instruction, |
| agent_scratchpad=agent_scratchpad, |
| tool_description=tool_description, |
| tool_names=tool_names |
| ) |
|
|
| async def _single_round_thought(self, instruction: str, max_llm_iteration=3, is_cn: bool = False) -> \ |
| Union[AgentAction, AgentFinish]: |
|
|
| llm_iteration_count = 0 |
|
|
| llm_response = None |
| while llm_iteration_count <= max_llm_iteration: |
| llm_iteration_count += 1 |
| try: |
| llm_response = await self._get_llm_response(instruction) |
| action_response = self._parse_output(llm_response.content, is_cn) |
|
|
| return action_response |
| except Exception as e: |
| logger.error("LLM iteration {} out of {} failed. Error: {}". |
| format(llm_iteration_count, max_llm_iteration, str(e)), exc_info=True) |
|
|
| if llm_iteration_count > max_llm_iteration: |
| logger.error("LLM iteration {} exceed max retry {}. Aborting". |
| format(llm_iteration_count, max_llm_iteration)) |
| return AgentFinish(formatted_output=AGENT_FAILED_CN if is_cn else AGENT_FAILED_EN, |
| raw_output=str(llm_response)) |
|
|
| async def _get_llm_response(self, instruction: str): |
| prompt = self._compose_prompt(instruction) |
| logger.info("Send prompt to LLM:\n{}".format(prompt)) |
| response = await self.llm.async_completion(prompt) |
| if response.state == "error": |
| raise LLMException("Failed to retrieve response from LLM, error: {}".format(str(response.content))) |
|
|
| logger.info("Got response from llm, raw response content: \n{}".format(response.content)) |
| return response |
|
|
| def _parse_output(self, llm_output: str, is_cn: bool = False) -> Union[AgentAction, AgentFinish]: |
|
|
| for stop_word in STOP_WORD: |
| if stop_word in llm_output: |
| llm_output = llm_output.split(stop_word)[0].rstrip() |
| break |
|
|
| |
| for indicator in FINAL_ANSWER_INDICATORS: |
| if indicator in llm_output: |
| |
| parts = llm_output.split(indicator) |
| |
| formatted_output = ''.join(parts).strip() |
| formatted_output = replace_latex_format(formatted_output) |
| return AgentFinish(raw_output=llm_output, formatted_output=formatted_output) |
|
|
| |
| ACTION_REGEX_1 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```python\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$" |
| ACTION_REGEX_2 = r"(.*?)\n?Action:\s*(.*?)\n?Action\s*Input:\s*```py\n(.*?)```(.*?)$|(.*?)\n?'''(\w+)\n?(.*?)\n?'''(.*?)$" |
|
|
| action_match = re.search(ACTION_REGEX_1, llm_output, re.DOTALL) or re.search(ACTION_REGEX_2, llm_output, re.DOTALL) |
|
|
| |
| if action_match: |
| context = action_match.group(1).strip() |
| action_tool_description = action_match.group(2).strip() |
| action_input = action_match.group(3).strip() |
|
|
| |
| |
| format_code_block = self._format_code_block(action_input) |
|
|
| prefix = TOOL_INPUT_PREFIX_CN if is_cn else TOOL_INPUT_PREFIX_EN |
| formatted_output = "{}\n{}\n{}\n".format(context, prefix, format_code_block) |
| formatted_output = replace_latex_format(formatted_output) |
|
|
| return AgentAction(tool=action_tool_description, |
| tool_input=format_code_block, |
| formatted_output=formatted_output, |
| raw_output=llm_output) |
|
|
| |
| if not re.search(r"Action\s*:", llm_output, re.DOTALL): |
| raise LLMException(f"Missing 'Action' in LLM output: `{llm_output}`") |
| elif not re.search(r"Action\s*Input\s*:", llm_output, re.DOTALL): |
| raise LLMException(f"Missing 'Action Input' in LLM output: `{llm_output}`") |
| else: |
| raise LLMException(f"Unrecognized LLM output format: `{llm_output}`") |
|
|
| def _format_code_block(self, tool_input): |
| stripped_tool_input = tool_input.strip() |
|
|
| if stripped_tool_input.startswith(CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG): |
| if not stripped_tool_input.startswith(CODE_BLOCK_START_TAG + '\n'): |
| stripped_tool_input = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_START_TAG):] + \ |
| '\n' |
| formatted_code = stripped_tool_input |
| elif stripped_tool_input.startswith(CODE_BLOCK_TAG) and not stripped_tool_input.startswith( |
| CODE_BLOCK_START_TAG) and stripped_tool_input.endswith(CODE_BLOCK_TAG): |
| formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input[len(CODE_BLOCK_TAG):] + '\n' |
| else: |
| formatted_code = CODE_BLOCK_START_TAG + '\n' + stripped_tool_input + '\n' + CODE_BLOCK_TAG + '\n' |
|
|
| return formatted_code.encode("utf-8").decode("utf-8") |
|
|
| def _process_sandbox_output(self, output: str): |
| """Function to process the result containing STDERR.""" |
| if len(output) <= 1000: |
| return output |
|
|
| logger.info("Output contains error, original message is over 1000, trim it for response. ori output: \n{}". |
| format(output)) |
| rows = output.split("\n") |
| |
| top_segment = [] |
| length = 0 |
| for sub_p in rows: |
| if length + len(sub_p) > 500: |
| break |
| top_segment.append(sub_p) |
| length += len(sub_p) |
|
|
| |
| bottom_segment = [] |
| length = 0 |
| for sub_p in reversed(rows): |
| if length + len(sub_p) > 500: |
| break |
| bottom_segment.insert(0, sub_p) |
| length += len(sub_p) |
|
|
| |
| timed_output = "\n".join(top_segment + ["......"] + bottom_segment) |
|
|
| return timed_output |
|
|
| def _get_output_files(self, tool_response) -> list[MediaFile]: |
| output_files = [] |
|
|
| if isinstance(tool_response, PythonSandBoxToolResponse) and isinstance(tool_response.raw_output, RunCodeOutput): |
| raw_output = tool_response.raw_output |
|
|
| if raw_output.code == 0 and not raw_output.data.is_partial: |
| result_data = raw_output.data.result |
|
|
| |
| if len(result_data.new_generated_files) > 0: |
| output_files.extend([MediaFile(tos_path=file.download_link) for file in |
| result_data.new_generated_files]) |
|
|
| if len(result_data.code_output_result) > 0: |
| output_files.extend( |
| [MediaFile(tos_path=image.content) for image in result_data.code_output_result |
| if image.type == 'image']) |
|
|
| return output_files |
|
|
| def _replace_csv_path(self, input_string): |
| |
| pattern = r'pd\.read_csv\(["\'](.*\.csv)["\']\)' |
| replacement = "pd.read_csv('/path/to/your/dataset')" |
| updated_string = re.sub(pattern, replacement, input_string) |
| return updated_string |
|
|
| @staticmethod |
| def create_agent_response(formatted_output, output_files, raw_output): |
| return AgentResponse(output_text=formatted_output, output_files=output_files, raw_output_text=raw_output) |
|
|
|
|