| | """Chain that runs an arbitrary python function.""" |
| |
|
| | import functools |
| | import logging |
| | from typing import Any, Awaitable, Callable, Dict, List, Optional |
| |
|
| | from langchain_core.callbacks import ( |
| | AsyncCallbackManagerForChainRun, |
| | CallbackManagerForChainRun, |
| | ) |
| | from pydantic import Field |
| |
|
| | from langchain.chains.base import Chain |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class TransformChain(Chain): |
| | """Chain that transforms the chain output. |
| | |
| | Example: |
| | .. code-block:: python |
| | |
| | from langchain.chains import TransformChain |
| | transform_chain = TransformChain(input_variables=["text"], |
| | output_variables["entities"], transform=func()) |
| | """ |
| |
|
| | input_variables: List[str] |
| | """The keys expected by the transform's input dictionary.""" |
| | output_variables: List[str] |
| | """The keys returned by the transform's output dictionary.""" |
| | transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform") |
| | """The transform function.""" |
| | atransform_cb: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = ( |
| | Field(None, alias="atransform") |
| | ) |
| | """The async coroutine transform function.""" |
| |
|
| | @staticmethod |
| | @functools.lru_cache |
| | def _log_once(msg: str) -> None: |
| | """Log a message once. |
| | |
| | :meta private: |
| | """ |
| | logger.warning(msg) |
| |
|
| | @property |
| | def input_keys(self) -> List[str]: |
| | """Expect input keys. |
| | |
| | :meta private: |
| | """ |
| | return self.input_variables |
| |
|
| | @property |
| | def output_keys(self) -> List[str]: |
| | """Return output keys. |
| | |
| | :meta private: |
| | """ |
| | return self.output_variables |
| |
|
| | def _call( |
| | self, |
| | inputs: Dict[str, str], |
| | run_manager: Optional[CallbackManagerForChainRun] = None, |
| | ) -> Dict[str, str]: |
| | return self.transform_cb(inputs) |
| |
|
| | async def _acall( |
| | self, |
| | inputs: Dict[str, Any], |
| | run_manager: Optional[AsyncCallbackManagerForChainRun] = None, |
| | ) -> Dict[str, Any]: |
| | if self.atransform_cb is not None: |
| | return await self.atransform_cb(inputs) |
| | else: |
| | self._log_once( |
| | "TransformChain's atransform is not provided, falling" |
| | " back to synchronous transform" |
| | ) |
| | return self.transform_cb(inputs) |
| |
|