| | """Chain that makes API calls and summarizes the responses to answer a question.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | from typing import Any, Dict, List, Optional, Sequence, Tuple |
| | from urllib.parse import urlparse |
| |
|
| | from langchain_core._api import deprecated |
| | from langchain_core.callbacks import ( |
| | AsyncCallbackManagerForChainRun, |
| | CallbackManagerForChainRun, |
| | ) |
| | from langchain_core.language_models import BaseLanguageModel |
| | from langchain_core.prompts import BasePromptTemplate |
| | from pydantic import Field, model_validator |
| | from typing_extensions import Self |
| |
|
| | from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT |
| | from langchain.chains.base import Chain |
| | from langchain.chains.llm import LLMChain |
| |
|
| |
|
| | def _extract_scheme_and_domain(url: str) -> Tuple[str, str]: |
| | """Extract the scheme + domain from a given URL. |
| | |
| | Args: |
| | url (str): The input URL. |
| | |
| | Returns: |
| | return a 2-tuple of scheme and domain |
| | """ |
| | parsed_uri = urlparse(url) |
| | return parsed_uri.scheme, parsed_uri.netloc |
| |
|
| |
|
| | def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool: |
| | """Check if a URL is in the allowed domains. |
| | |
| | Args: |
| | url (str): The input URL. |
| | limit_to_domains (Sequence[str]): The allowed domains. |
| | |
| | Returns: |
| | bool: True if the URL is in the allowed domains, False otherwise. |
| | """ |
| | scheme, domain = _extract_scheme_and_domain(url) |
| |
|
| | for allowed_domain in limit_to_domains: |
| | allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain) |
| | if scheme == allowed_scheme and domain == allowed_domain: |
| | return True |
| | return False |
| |
|
| |
|
| | try: |
| | from langchain_community.utilities.requests import TextRequestsWrapper |
| |
|
| | @deprecated( |
| | since="0.2.13", |
| | message=( |
| | "This class is deprecated and will be removed in langchain 1.0. " |
| | "See API reference for replacement: " |
| | "https://api.python.langchain.com/en/latest/chains/langchain.chains.api.base.APIChain.html" |
| | ), |
| | removal="1.0", |
| | ) |
| | class APIChain(Chain): |
| | """Chain that makes API calls and summarizes the responses to answer a question. |
| | |
| | *Security Note*: This API chain uses the requests toolkit |
| | to make GET, POST, PATCH, PUT, and DELETE requests to an API. |
| | |
| | Exercise care in who is allowed to use this chain. If exposing |
| | to end users, consider that users will be able to make arbitrary |
| | requests on behalf of the server hosting the code. For example, |
| | users could ask the server to make a request to a private API |
| | that is only accessible from the server. |
| | |
| | Control access to who can submit issue requests using this toolkit and |
| | what network access it has. |
| | |
| | See https://python.langchain.com/docs/security for more information. |
| | |
| | Note: this class is deprecated. See below for a replacement implementation |
| | using LangGraph. The benefits of this implementation are: |
| | |
| | - Uses LLM tool calling features to encourage properly-formatted API requests; |
| | - Support for both token-by-token and step-by-step streaming; |
| | - Support for checkpointing and memory of chat history; |
| | - Easier to modify or extend (e.g., with additional tools, structured responses, etc.) |
| | |
| | Install LangGraph with: |
| | |
| | .. code-block:: bash |
| | |
| | pip install -U langgraph |
| | |
| | .. code-block:: python |
| | |
| | from typing import Annotated, Sequence |
| | from typing_extensions import TypedDict |
| | |
| | from langchain.chains.api.prompt import API_URL_PROMPT |
| | from langchain_community.agent_toolkits.openapi.toolkit import RequestsToolkit |
| | from langchain_community.utilities.requests import TextRequestsWrapper |
| | from langchain_core.messages import BaseMessage |
| | from langchain_core.prompts import ChatPromptTemplate |
| | from langchain_openai import ChatOpenAI |
| | from langchain_core.runnables import RunnableConfig |
| | from langgraph.graph import END, StateGraph |
| | from langgraph.graph.message import add_messages |
| | from langgraph.prebuilt.tool_node import ToolNode |
| | |
| | # NOTE: There are inherent risks in giving models discretion |
| | # to execute real-world actions. We must "opt-in" to these |
| | # risks by setting allow_dangerous_request=True to use these tools. |
| | # This can be dangerous for calling unwanted requests. Please make |
| | # sure your custom OpenAPI spec (yaml) is safe and that permissions |
| | # associated with the tools are narrowly-scoped. |
| | ALLOW_DANGEROUS_REQUESTS = True |
| | |
| | # Subset of spec for https://jsonplaceholder.typicode.com |
| | api_spec = \"\"\" |
| | openapi: 3.0.0 |
| | info: |
| | title: JSONPlaceholder API |
| | version: 1.0.0 |
| | servers: |
| | - url: https://jsonplaceholder.typicode.com |
| | paths: |
| | /posts: |
| | get: |
| | summary: Get posts |
| | parameters: &id001 |
| | - name: _limit |
| | in: query |
| | required: false |
| | schema: |
| | type: integer |
| | example: 2 |
| | description: Limit the number of results |
| | \"\"\" |
| | |
| | llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) |
| | toolkit = RequestsToolkit( |
| | requests_wrapper=TextRequestsWrapper(headers={}), # no auth required |
| | allow_dangerous_requests=ALLOW_DANGEROUS_REQUESTS, |
| | ) |
| | tools = toolkit.get_tools() |
| | |
| | api_request_chain = ( |
| | API_URL_PROMPT.partial(api_docs=api_spec) |
| | | llm.bind_tools(tools, tool_choice="any") |
| | ) |
| | |
| | class ChainState(TypedDict): |
| | \"\"\"LangGraph state.\"\"\" |
| | |
| | messages: Annotated[Sequence[BaseMessage], add_messages] |
| | |
| | |
| | async def acall_request_chain(state: ChainState, config: RunnableConfig): |
| | last_message = state["messages"][-1] |
| | response = await api_request_chain.ainvoke( |
| | {"question": last_message.content}, config |
| | ) |
| | return {"messages": [response]} |
| | |
| | async def acall_model(state: ChainState, config: RunnableConfig): |
| | response = await llm.ainvoke(state["messages"], config) |
| | return {"messages": [response]} |
| | |
| | graph_builder = StateGraph(ChainState) |
| | graph_builder.add_node("call_tool", acall_request_chain) |
| | graph_builder.add_node("execute_tool", ToolNode(tools)) |
| | graph_builder.add_node("call_model", acall_model) |
| | graph_builder.set_entry_point("call_tool") |
| | graph_builder.add_edge("call_tool", "execute_tool") |
| | graph_builder.add_edge("execute_tool", "call_model") |
| | graph_builder.add_edge("call_model", END) |
| | chain = graph_builder.compile() |
| | |
| | .. code-block:: python |
| | |
| | example_query = "Fetch the top two posts. What are their titles?" |
| | |
| | events = chain.astream( |
| | {"messages": [("user", example_query)]}, |
| | stream_mode="values", |
| | ) |
| | async for event in events: |
| | event["messages"][-1].pretty_print() |
| | """ |
| |
|
| | api_request_chain: LLMChain |
| | api_answer_chain: LLMChain |
| | requests_wrapper: TextRequestsWrapper = Field(exclude=True) |
| | api_docs: str |
| | question_key: str = "question" |
| | output_key: str = "output" |
| | limit_to_domains: Optional[Sequence[str]] = Field( |
| | default_factory=list |
| | ) |
| | """Use to limit the domains that can be accessed by the API chain. |
| | |
| | * For example, to limit to just the domain `https://www.example.com`, set |
| | `limit_to_domains=["https://www.example.com"]`. |
| | |
| | * The default value is an empty tuple, which means that no domains are |
| | allowed by default. By design this will raise an error on instantiation. |
| | * Use a None if you want to allow all domains by default -- this is not |
| | recommended for security reasons, as it would allow malicious users to |
| | make requests to arbitrary URLS including internal APIs accessible from |
| | the server. |
| | """ |
| |
|
| | @property |
| | def input_keys(self) -> List[str]: |
| | """Expect input key. |
| | |
| | :meta private: |
| | """ |
| | return [self.question_key] |
| |
|
| | @property |
| | def output_keys(self) -> List[str]: |
| | """Expect output key. |
| | |
| | :meta private: |
| | """ |
| | return [self.output_key] |
| |
|
| | @model_validator(mode="after") |
| | def validate_api_request_prompt(self) -> Self: |
| | """Check that api request prompt expects the right variables.""" |
| | input_vars = self.api_request_chain.prompt.input_variables |
| | expected_vars = {"question", "api_docs"} |
| | if set(input_vars) != expected_vars: |
| | raise ValueError( |
| | f"Input variables should be {expected_vars}, got {input_vars}" |
| | ) |
| | return self |
| |
|
| | @model_validator(mode="before") |
| | @classmethod |
| | def validate_limit_to_domains(cls, values: Dict) -> Any: |
| | """Check that allowed domains are valid.""" |
| | |
| | |
| | if "limit_to_domains" not in values: |
| | raise ValueError( |
| | "You must specify a list of domains to limit access using " |
| | "`limit_to_domains`" |
| | ) |
| | if ( |
| | not values["limit_to_domains"] |
| | and values["limit_to_domains"] is not None |
| | ): |
| | raise ValueError( |
| | "Please provide a list of domains to limit access using " |
| | "`limit_to_domains`." |
| | ) |
| | return values |
| |
|
| | @model_validator(mode="after") |
| | def validate_api_answer_prompt(self) -> Self: |
| | """Check that api answer prompt expects the right variables.""" |
| | input_vars = self.api_answer_chain.prompt.input_variables |
| | expected_vars = {"question", "api_docs", "api_url", "api_response"} |
| | if set(input_vars) != expected_vars: |
| | raise ValueError( |
| | f"Input variables should be {expected_vars}, got {input_vars}" |
| | ) |
| | return self |
| |
|
| | def _call( |
| | self, |
| | inputs: Dict[str, Any], |
| | run_manager: Optional[CallbackManagerForChainRun] = None, |
| | ) -> Dict[str, str]: |
| | _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() |
| | question = inputs[self.question_key] |
| | api_url = self.api_request_chain.predict( |
| | question=question, |
| | api_docs=self.api_docs, |
| | callbacks=_run_manager.get_child(), |
| | ) |
| | _run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose) |
| | api_url = api_url.strip() |
| | if self.limit_to_domains and not _check_in_allowed_domain( |
| | api_url, self.limit_to_domains |
| | ): |
| | raise ValueError( |
| | f"{api_url} is not in the allowed domains: {self.limit_to_domains}" |
| | ) |
| | api_response = self.requests_wrapper.get(api_url) |
| | _run_manager.on_text( |
| | str(api_response), color="yellow", end="\n", verbose=self.verbose |
| | ) |
| | answer = self.api_answer_chain.predict( |
| | question=question, |
| | api_docs=self.api_docs, |
| | api_url=api_url, |
| | api_response=api_response, |
| | callbacks=_run_manager.get_child(), |
| | ) |
| | return {self.output_key: answer} |
| |
|
| | async def _acall( |
| | self, |
| | inputs: Dict[str, Any], |
| | run_manager: Optional[AsyncCallbackManagerForChainRun] = None, |
| | ) -> Dict[str, str]: |
| | _run_manager = ( |
| | run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() |
| | ) |
| | question = inputs[self.question_key] |
| | api_url = await self.api_request_chain.apredict( |
| | question=question, |
| | api_docs=self.api_docs, |
| | callbacks=_run_manager.get_child(), |
| | ) |
| | await _run_manager.on_text( |
| | api_url, color="green", end="\n", verbose=self.verbose |
| | ) |
| | api_url = api_url.strip() |
| | if self.limit_to_domains and not _check_in_allowed_domain( |
| | api_url, self.limit_to_domains |
| | ): |
| | raise ValueError( |
| | f"{api_url} is not in the allowed domains: {self.limit_to_domains}" |
| | ) |
| | api_response = await self.requests_wrapper.aget(api_url) |
| | await _run_manager.on_text( |
| | str(api_response), color="yellow", end="\n", verbose=self.verbose |
| | ) |
| | answer = await self.api_answer_chain.apredict( |
| | question=question, |
| | api_docs=self.api_docs, |
| | api_url=api_url, |
| | api_response=api_response, |
| | callbacks=_run_manager.get_child(), |
| | ) |
| | return {self.output_key: answer} |
| |
|
| | @classmethod |
| | def from_llm_and_api_docs( |
| | cls, |
| | llm: BaseLanguageModel, |
| | api_docs: str, |
| | headers: Optional[dict] = None, |
| | api_url_prompt: BasePromptTemplate = API_URL_PROMPT, |
| | api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT, |
| | limit_to_domains: Optional[Sequence[str]] = tuple(), |
| | **kwargs: Any, |
| | ) -> APIChain: |
| | """Load chain from just an LLM and the api docs.""" |
| | get_request_chain = LLMChain(llm=llm, prompt=api_url_prompt) |
| | requests_wrapper = TextRequestsWrapper(headers=headers) |
| | get_answer_chain = LLMChain(llm=llm, prompt=api_response_prompt) |
| | return cls( |
| | api_request_chain=get_request_chain, |
| | api_answer_chain=get_answer_chain, |
| | requests_wrapper=requests_wrapper, |
| | api_docs=api_docs, |
| | limit_to_domains=limit_to_domains, |
| | **kwargs, |
| | ) |
| |
|
| | @property |
| | def _chain_type(self) -> str: |
| | return "api_chain" |
| | except ImportError: |
| |
|
| | class APIChain: |
| | def __init__(self, *args: Any, **kwargs: Any) -> None: |
| | raise ImportError( |
| | "To use the APIChain, you must install the langchain_community package." |
| | "pip install langchain_community" |
| | ) |
| |
|