| | """Chain pipeline where the outputs of one step feed directly into next.""" |
| |
|
| | from typing import Any, Dict, List, Optional |
| |
|
| | from langchain_core.callbacks import ( |
| | AsyncCallbackManagerForChainRun, |
| | CallbackManagerForChainRun, |
| | ) |
| | from langchain_core.utils.input import get_color_mapping |
| | from pydantic import ConfigDict, model_validator |
| | from typing_extensions import Self |
| |
|
| | from langchain.chains.base import Chain |
| |
|
| |
|
| | class SequentialChain(Chain): |
| | """Chain where the outputs of one chain feed directly into next.""" |
| |
|
| | chains: List[Chain] |
| | input_variables: List[str] |
| | output_variables: List[str] |
| | return_all: bool = False |
| |
|
| | model_config = ConfigDict( |
| | arbitrary_types_allowed=True, |
| | extra="forbid", |
| | ) |
| |
|
| | @property |
| | def input_keys(self) -> List[str]: |
| | """Return expected input keys to the chain. |
| | |
| | :meta private: |
| | """ |
| | return self.input_variables |
| |
|
| | @property |
| | def output_keys(self) -> List[str]: |
| | """Return output key. |
| | |
| | :meta private: |
| | """ |
| | return self.output_variables |
| |
|
| | @model_validator(mode="before") |
| | @classmethod |
| | def validate_chains(cls, values: Dict) -> Any: |
| | """Validate that the correct inputs exist for all chains.""" |
| | chains = values["chains"] |
| | input_variables = values["input_variables"] |
| | memory_keys = list() |
| | if "memory" in values and values["memory"] is not None: |
| | """Validate that prompt input variables are consistent.""" |
| | memory_keys = values["memory"].memory_variables |
| | if set(input_variables).intersection(set(memory_keys)): |
| | overlapping_keys = set(input_variables) & set(memory_keys) |
| | raise ValueError( |
| | f"The input key(s) {''.join(overlapping_keys)} are found " |
| | f"in the Memory keys ({memory_keys}) - please use input and " |
| | f"memory keys that don't overlap." |
| | ) |
| |
|
| | known_variables = set(input_variables + memory_keys) |
| |
|
| | for chain in chains: |
| | missing_vars = set(chain.input_keys).difference(known_variables) |
| | if chain.memory: |
| | missing_vars = missing_vars.difference(chain.memory.memory_variables) |
| |
|
| | if missing_vars: |
| | raise ValueError( |
| | f"Missing required input keys: {missing_vars}, " |
| | f"only had {known_variables}" |
| | ) |
| | overlapping_keys = known_variables.intersection(chain.output_keys) |
| | if overlapping_keys: |
| | raise ValueError( |
| | f"Chain returned keys that already exist: {overlapping_keys}" |
| | ) |
| |
|
| | known_variables |= set(chain.output_keys) |
| |
|
| | if "output_variables" not in values: |
| | if values.get("return_all", False): |
| | output_keys = known_variables.difference(input_variables) |
| | else: |
| | output_keys = chains[-1].output_keys |
| | values["output_variables"] = output_keys |
| | else: |
| | missing_vars = set(values["output_variables"]).difference(known_variables) |
| | if missing_vars: |
| | raise ValueError( |
| | f"Expected output variables that were not found: {missing_vars}." |
| | ) |
| |
|
| | return values |
| |
|
| | def _call( |
| | self, |
| | inputs: Dict[str, str], |
| | run_manager: Optional[CallbackManagerForChainRun] = None, |
| | ) -> Dict[str, str]: |
| | known_values = inputs.copy() |
| | _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() |
| | for i, chain in enumerate(self.chains): |
| | callbacks = _run_manager.get_child() |
| | outputs = chain(known_values, return_only_outputs=True, callbacks=callbacks) |
| | known_values.update(outputs) |
| | return {k: known_values[k] for k in self.output_variables} |
| |
|
| | async def _acall( |
| | self, |
| | inputs: Dict[str, Any], |
| | run_manager: Optional[AsyncCallbackManagerForChainRun] = None, |
| | ) -> Dict[str, Any]: |
| | known_values = inputs.copy() |
| | _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() |
| | callbacks = _run_manager.get_child() |
| | for i, chain in enumerate(self.chains): |
| | outputs = await chain.acall( |
| | known_values, return_only_outputs=True, callbacks=callbacks |
| | ) |
| | known_values.update(outputs) |
| | return {k: known_values[k] for k in self.output_variables} |
| |
|
| |
|
| | class SimpleSequentialChain(Chain): |
| | """Simple chain where the outputs of one step feed directly into next.""" |
| |
|
| | chains: List[Chain] |
| | strip_outputs: bool = False |
| | input_key: str = "input" |
| | output_key: str = "output" |
| |
|
| | model_config = ConfigDict( |
| | arbitrary_types_allowed=True, |
| | extra="forbid", |
| | ) |
| |
|
| | @property |
| | def input_keys(self) -> List[str]: |
| | """Expect input key. |
| | |
| | :meta private: |
| | """ |
| | return [self.input_key] |
| |
|
| | @property |
| | def output_keys(self) -> List[str]: |
| | """Return output key. |
| | |
| | :meta private: |
| | """ |
| | return [self.output_key] |
| |
|
| | @model_validator(mode="after") |
| | def validate_chains(self) -> Self: |
| | """Validate that chains are all single input/output.""" |
| | for chain in self.chains: |
| | if len(chain.input_keys) != 1: |
| | raise ValueError( |
| | "Chains used in SimplePipeline should all have one input, got " |
| | f"{chain} with {len(chain.input_keys)} inputs." |
| | ) |
| | if len(chain.output_keys) != 1: |
| | raise ValueError( |
| | "Chains used in SimplePipeline should all have one output, got " |
| | f"{chain} with {len(chain.output_keys)} outputs." |
| | ) |
| | return self |
| |
|
| | def _call( |
| | self, |
| | inputs: Dict[str, str], |
| | run_manager: Optional[CallbackManagerForChainRun] = None, |
| | ) -> Dict[str, str]: |
| | _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() |
| | _input = inputs[self.input_key] |
| | color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) |
| | for i, chain in enumerate(self.chains): |
| | _input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}")) |
| | if self.strip_outputs: |
| | _input = _input.strip() |
| | _run_manager.on_text( |
| | _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose |
| | ) |
| | return {self.output_key: _input} |
| |
|
| | async def _acall( |
| | self, |
| | inputs: Dict[str, Any], |
| | run_manager: Optional[AsyncCallbackManagerForChainRun] = None, |
| | ) -> Dict[str, Any]: |
| | _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() |
| | _input = inputs[self.input_key] |
| | color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) |
| | for i, chain in enumerate(self.chains): |
| | _input = await chain.arun( |
| | _input, callbacks=_run_manager.get_child(f"step_{i+1}") |
| | ) |
| | if self.strip_outputs: |
| | _input = _input.strip() |
| | await _run_manager.on_text( |
| | _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose |
| | ) |
| | return {self.output_key: _input} |
| |
|