| from __future__ import annotations |
|
|
| import json |
| from typing import Any, Dict, List, Optional |
|
|
| from langchain_core._api import deprecated |
| from langchain_core.callbacks import CallbackManagerForChainRun |
| from langchain_core.language_models import BaseLanguageModel |
| from langchain_core.prompts import BasePromptTemplate |
| from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter |
| from pydantic import Field |
|
|
| from langchain.chains.base import Chain |
| from langchain.chains.llm import LLMChain |
| from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR |
|
|
|
|
| @deprecated( |
| since="0.2.7", |
| alternative=( |
| "example in API reference with more detail: " |
| "https://api.python.langchain.com/en/latest/chains/langchain.chains.qa_generation.base.QAGenerationChain.html" |
| ), |
| removal="1.0", |
| ) |
| class QAGenerationChain(Chain): |
| """Base class for question-answer generation chains. |
| |
| This class is deprecated. See below for an alternative implementation. |
| |
| Advantages of this implementation include: |
| |
| - Supports async and streaming; |
| - Surfaces prompt and text splitter for easier customization; |
| - Use of JsonOutputParser supports JSONPatch operations in streaming mode, |
| as well as robustness to markdown. |
| |
| .. code-block:: python |
| |
| from langchain.chains.qa_generation.prompt import CHAT_PROMPT as prompt |
| # Note: import PROMPT if using a legacy non-chat model. |
| from langchain_core.output_parsers import JsonOutputParser |
| from langchain_core.runnables import ( |
| RunnableLambda, |
| RunnableParallel, |
| RunnablePassthrough, |
| ) |
| from langchain_core.runnables.base import RunnableEach |
| from langchain_openai import ChatOpenAI |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| |
| llm = ChatOpenAI() |
| text_splitter = RecursiveCharacterTextSplitter(chunk_overlap=500) |
| split_text = RunnableLambda( |
| lambda x: text_splitter.create_documents([x]) |
| ) |
| |
| chain = RunnableParallel( |
| text=RunnablePassthrough(), |
| questions=( |
| split_text | RunnableEach(bound=prompt | llm | JsonOutputParser()) |
| ) |
| ) |
| """ |
|
|
| llm_chain: LLMChain |
| """LLM Chain that generates responses from user input and context.""" |
| text_splitter: TextSplitter = Field( |
| default=RecursiveCharacterTextSplitter(chunk_overlap=500) |
| ) |
| """Text splitter that splits the input into chunks.""" |
| input_key: str = "text" |
| """Key of the input to the chain.""" |
| output_key: str = "questions" |
| """Key of the output of the chain.""" |
| k: Optional[int] = None |
| """Number of questions to generate.""" |
|
|
| @classmethod |
| def from_llm( |
| cls, |
| llm: BaseLanguageModel, |
| prompt: Optional[BasePromptTemplate] = None, |
| **kwargs: Any, |
| ) -> QAGenerationChain: |
| """ |
| Create a QAGenerationChain from a language model. |
| |
| Args: |
| llm: a language model |
| prompt: a prompt template |
| **kwargs: additional arguments |
| |
| Returns: |
| a QAGenerationChain class |
| """ |
| _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) |
| chain = LLMChain(llm=llm, prompt=_prompt) |
| return cls(llm_chain=chain, **kwargs) |
|
|
| @property |
| def _chain_type(self) -> str: |
| raise NotImplementedError |
|
|
| @property |
| def input_keys(self) -> List[str]: |
| return [self.input_key] |
|
|
| @property |
| def output_keys(self) -> List[str]: |
| return [self.output_key] |
|
|
| def _call( |
| self, |
| inputs: Dict[str, Any], |
| run_manager: Optional[CallbackManagerForChainRun] = None, |
| ) -> Dict[str, List]: |
| docs = self.text_splitter.create_documents([inputs[self.input_key]]) |
| results = self.llm_chain.generate( |
| [{"text": d.page_content} for d in docs], run_manager=run_manager |
| ) |
| qa = [json.loads(res[0].text) for res in results.generations] |
| return {self.output_key: qa} |
|
|