| import copy |
| import time |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import hydra |
| from pydantic import root_validator |
|
|
| from langchain import LLMChain, PromptTemplate |
| from langchain.agents import AgentExecutor, BaseMultiActionAgent, ZeroShotAgent |
| from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX |
| from langchain.chat_models import ChatOpenAI |
| from langchain.schema import ( |
| AgentAction, |
| AgentFinish, |
| OutputParserException, |
| ) |
|
|
| from flows.base_flows import Flow, CompositeFlow, GenericLCTool |
| from flows.messages import OutputMessage, UpdateMessage_Generic |
| from flows.utils.caching_utils import flow_run_cache |
|
|
|
|
| class GenericZeroShotAgent(ZeroShotAgent): |
| @classmethod |
| def create_prompt( |
| cls, |
| tools: Dict[str, Flow], |
| prefix: str = PREFIX, |
| suffix: str = SUFFIX, |
| format_instructions: str = FORMAT_INSTRUCTIONS, |
| input_variables: Optional[List[str]] = None, |
| ) -> PromptTemplate: |
| """Create prompt in the style of the zero shot agent. |
| |
| Args: |
| tools: List of tools the agent will have access to, used to format the |
| prompt. |
| prefix: String to put before the list of tools. |
| suffix: String to put after the list of tools. |
| input_variables: List of input variables the final prompt will expect. |
| |
| Returns: |
| A PromptTemplate with the template assembled from the pieces here. |
| """ |
| |
| |
| tool_strings = "\n".join([f"{tool_name}: {tool.flow_config['description']}" for tool_name, tool in tools.items()]) |
| tool_names = ", ".join(tools.keys()) |
| format_instructions = format_instructions.format(tool_names=tool_names) |
| template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) |
| if input_variables is None: |
| input_variables = ["input", "agent_scratchpad"] |
| return PromptTemplate(template=template, input_variables=input_variables) |
|
|
|
|
| class GenericAgentExecutor(AgentExecutor): |
| tools: Dict[str, Flow] |
|
|
| @root_validator() |
| def validate_tools(cls, values: Dict) -> Dict: |
| """Validate that tools are compatible with agent.""" |
| agent = values["agent"] |
| tools = values["tools"] |
| allowed_tools = agent.get_allowed_tools() |
| if allowed_tools is not None: |
| if set(allowed_tools) != set(tools.keys()): |
| raise ValueError( |
| f"Allowed tools ({allowed_tools}) different than " |
| f"provided tools ({tools.keys()})" |
| ) |
| return values |
|
|
| @root_validator() |
| def validate_return_direct_tool(cls, values: Dict) -> Dict: |
| """Validate that tools are compatible with agent.""" |
| agent = values["agent"] |
| tools = values["tools"] |
| if isinstance(agent, BaseMultiActionAgent): |
| for tool in tools: |
| if tool.flow_config["return_direct"]: |
| raise ValueError( |
| "Tools that have `return_direct=True` are not allowed " |
| "in multi-action agents" |
| ) |
| return values |
|
|
| def _get_tool_return( |
| self, next_step_output: Tuple[AgentAction, str] |
| ) -> Optional[AgentFinish]: |
| """Check if the tool is a returning tool.""" |
| agent_action, observation = next_step_output |
| |
| |
| if agent_action.tool in self.tools: |
| if self.tools[agent_action.tool].flow_config["return_direct"]: |
| return AgentFinish( |
| {self.agent.return_values[0]: observation}, |
| "", |
| ) |
| return None |
|
|
|
|
| class ReActFlow(CompositeFlow): |
| EXCEPTION_FLOW_CONFIG = { |
| "_target_": "flows.base_flows.GenericLCTool.instantiate_from_config", |
| "config": { |
| "name": "_Exception", |
| "description": "Exception tool", |
|
|
| "tool_type": "exception", |
| "input_keys": ["query"], |
| "output_keys": ["raw_response"], |
|
|
| "verbose": False, |
| "clear_flow_namespace_on_run_end": False, |
|
|
| "input_data_transformations": [], |
| "output_data_transformations": [], |
| "keep_raw_response": True |
| } |
| } |
|
|
| INVALID_FLOW_CONFIG = { |
| "_target_": "flows.base_flows.GenericLCTool.instantiate_from_config", |
| "config": { |
| "name": "invalid_tool", |
| "description": "Called when tool name is invalid.", |
|
|
| "tool_type": "invalid", |
| "input_keys": ["tool_name"], |
| "output_keys": ["raw_response"], |
|
|
| "verbose": False, |
| "clear_flow_namespace_on_run_end": False, |
|
|
| "input_data_transformations": [], |
| "output_data_transformations": [], |
| "keep_raw_response": True |
| } |
| } |
|
|
| SUPPORTS_CACHING: bool = True |
|
|
| api_keys: Dict[str, str] |
|
|
| backend: GenericAgentExecutor |
| react_prompt_template: PromptTemplate |
|
|
| exception_flow: GenericLCTool |
| invalid_flow: GenericLCTool |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
| self.api_keys = None |
| self.backend = None |
| self.react_prompt_template = GenericZeroShotAgent.create_prompt( |
| tools=self.subflows, |
| **self.flow_config.get("prompt_config", {}) |
| ) |
|
|
| self._set_up_necessary_subflows() |
|
|
| def set_up_flow_state(self): |
| super().set_up_flow_state() |
| self.flow_state["intermediate_steps"]: List[Tuple[AgentAction, str]] = [] |
|
|
| def _set_up_necessary_subflows(self): |
| self.exception_flow = hydra.utils.instantiate( |
| self.EXCEPTION_FLOW_CONFIG, _convert_="partial", _recursive_=False |
| ) |
| self.invalid_flow = hydra.utils.instantiate( |
| self.INVALID_FLOW_CONFIG, _convert_="partial", _recursive_=False |
| ) |
|
|
| def _get_prompt_message(self, input_data: Dict[str, Any]) -> str: |
| data = copy.deepcopy(input_data) |
| data["agent_scratchpad"] = "{agent_scratchpad}" |
|
|
| return self.react_prompt_template.format(**data) |
|
|
| @staticmethod |
| def get_raw_response(output: OutputMessage) -> str: |
| key = output.data["output_keys"][0] |
| return output.data["output_data"]["raw_response"][key] |
|
|
| def _take_next_step( |
| self, |
| |
| |
| inputs: Dict[str, str], |
| intermediate_steps: List[Tuple[AgentAction, str]], |
| |
| |
| private_keys: Optional[List[str]] = [], |
| keys_to_ignore_for_hash: Optional[List[str]] = [] |
| ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: |
| """Take a single step in the thought-action-observation loop. |
| |
| Override this to take control of how the agent makes and acts on choices. |
| """ |
| try: |
| |
| output = self.backend.agent.plan( |
| intermediate_steps, |
| |
| **inputs, |
| ) |
| except OutputParserException as e: |
| if isinstance(self.backend.handle_parsing_errors, bool): |
| raise_error = not self.backend.handle_parsing_errors |
| else: |
| raise_error = False |
| if raise_error: |
| raise e |
| text = str(e) |
|
|
| if isinstance(self.backend.handle_parsing_errors, bool): |
| if e.send_to_llm: |
| observation = str(e.observation) |
| text = str(e.llm_output) |
| else: |
| observation = "Invalid or incomplete response" |
| elif isinstance(self.backend.handle_parsing_errors, str): |
| observation = self.backend.handle_parsing_errors |
| elif callable(self.backend.handle_parsing_errors): |
| observation = self.backend.handle_parsing_errors(e) |
| else: |
| raise ValueError("Got unexpected type of `handle_parsing_errors`") |
| |
| output = AgentAction("_Exception", observation, text) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self._state_update_dict({"query": output.tool_input}) |
| tool_output = self._call_flow_from_state( |
| self.exception_flow, |
| private_keys=private_keys, |
| keys_to_ignore_for_hash=keys_to_ignore_for_hash, |
| search_class_namespace_for_inputs=False |
| ) |
| observation = self.get_raw_response(tool_output) |
| return [(output, observation)] |
| |
| |
| if isinstance(output, AgentFinish): |
| return output |
| |
| actions: List[AgentAction] |
| if isinstance(output, AgentAction): |
| actions = [output] |
| else: |
| actions = output |
| result = [] |
| for agent_action in actions: |
| |
| |
| |
| if agent_action.tool in self.subflows: |
| tool = self.subflows[agent_action.tool] |
| |
| if isinstance(agent_action.tool_input, dict): |
| self._state_update_dict(agent_action.tool_input) |
| else: |
| self._state_update_dict({tool.flow_config["input_keys"][0]:agent_action.tool_input}) |
|
|
| tool_output = self._call_flow_from_state( |
| tool, |
| private_keys=private_keys, |
| keys_to_ignore_for_hash=keys_to_ignore_for_hash, |
| search_class_namespace_for_inputs=False |
| ) |
| observation = self.get_raw_response(tool_output) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| else: |
| |
| |
| |
| |
| |
| |
| |
| |
| self._state_update_dict({"tool_name": agent_action.tool}) |
| tool_output = self._call_flow_from_state( |
| self.invalid_flow, |
| private_keys=private_keys, |
| keys_to_ignore_for_hash=keys_to_ignore_for_hash, |
| search_class_namespace_for_inputs=False |
| ) |
| observation = self.get_raw_response(tool_output) |
| result.append((agent_action, observation)) |
| return result |
|
|
| def _run( |
| self, |
| input_data: Dict[str, Any], |
| private_keys: Optional[List[str]] = [], |
| keys_to_ignore_for_hash: Optional[List[str]] = [] |
| ) -> str: |
| """Run text through and get agent response.""" |
| |
| |
| |
| |
| |
| |
| self.flow_state["intermediate_steps"] = [] |
| intermediate_steps = self.flow_state["intermediate_steps"] |
| |
| iterations = 0 |
| time_elapsed = 0.0 |
| start_time = time.time() |
| |
| while self.backend._should_continue(iterations, time_elapsed): |
| |
| |
| |
| |
| |
| |
| |
| next_step_output = self._take_next_step( |
| input_data, |
| intermediate_steps, |
| private_keys, |
| keys_to_ignore_for_hash |
| ) |
| if isinstance(next_step_output, AgentFinish): |
| |
| return next_step_output.return_values["output"] |
|
|
| intermediate_steps.extend(next_step_output) |
| for act, obs in next_step_output: |
| pass |
| |
| |
|
|
| if len(next_step_output) == 1: |
| next_step_action = next_step_output[0] |
| |
| tool_return = self.backend._get_tool_return(next_step_action) |
| if tool_return is not None: |
| |
| return tool_return.return_values["output"] |
| |
| iterations += 1 |
| time_elapsed = time.time() - start_time |
|
|
| output = self.backend.agent.return_stopped_response( |
| self.backend.early_stopping_method, intermediate_steps, **input_data |
| ) |
| return output.return_values["output"] |
|
|
| @flow_run_cache() |
| def run( |
| self, |
| input_data: Dict[str, Any], |
| private_keys: Optional[List[str]] = [], |
| keys_to_ignore_for_hash: Optional[List[str]] = [] |
| ) -> Dict[str, Any]: |
| self.api_keys = input_data["api_keys"] |
| del input_data["api_keys"] |
|
|
| llm = ChatOpenAI( |
| model_name=self.flow_config["model_name"], |
| openai_api_key=self.api_keys["openai"], |
| **self.flow_config["generation_parameters"], |
| ) |
| llm_chain = LLMChain(llm=llm, prompt=self.react_prompt_template) |
| agent = GenericZeroShotAgent(llm_chain=llm_chain, allowed_tools=list(self.subflows.keys())) |
|
|
| self.backend = GenericAgentExecutor.from_agent_and_tools( |
| agent=agent, |
| tools=self.subflows, |
| max_iterations=self.flow_config.get("max_iterations", 15), |
| max_execution_time=self.flow_config.get("max_execution_time") |
| ) |
|
|
| data = {k: input_data[k] for k in self.get_input_keys(input_data)} |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| output = self._run(data, private_keys, keys_to_ignore_for_hash) |
|
|
| return {input_data["output_keys"][0]: output} |
|
|