| import copy |
| import os |
| import types |
| import uuid |
| from typing import Any, Dict, List, Union, Optional, Tuple, Mapping |
| import time |
| import queue |
| import pathlib |
| from datetime import datetime |
|
|
| from langchain.schema import BasePromptTemplate |
| from langchain.chains import LLMChain |
| from langchain.chains import MapReduceDocumentsChain, StuffDocumentsChain, ReduceDocumentsChain |
| from langchain.chains.combine_documents.base import BaseCombineDocumentsChain |
| from langchain.chains.summarize import map_reduce_prompt, LoadingCallable, _load_stuff_chain, _load_map_reduce_chain, \ |
| _load_refine_chain |
| from langchain.schema.language_model import BaseLanguageModel |
|
|
| from src.utils import hash_file, get_sha |
|
|
| from langchain.callbacks.base import BaseCallbackHandler, Callbacks |
| from langchain.schema import LLMResult |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain.docstore.document import Document |
|
|
|
|
| class StreamingGradioCallbackHandler(BaseCallbackHandler): |
| """ |
| Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend |
| """ |
|
|
| def __init__(self, timeout: Optional[float] = None, block=True, max_time=None, verbose=False): |
| super().__init__() |
| self.text_queue = queue.SimpleQueue() |
| self.stop_signal = None |
| self.do_stop = False |
| self.timeout = timeout |
| self.block = block |
| self.max_time = max_time |
| self.tgen0 = None |
| self.verbose = verbose |
|
|
| def on_llm_start( |
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any |
| ) -> None: |
| self.tgen0 = time.time() |
| """Run when LLM starts running. Clean the queue.""" |
| while not self.text_queue.empty(): |
| try: |
| self.text_queue.get(block=False) |
| except queue.Empty: |
| continue |
|
|
| def on_llm_new_token(self, token: str, **kwargs: Any) -> None: |
| """Run on new LLM token. Only available when streaming is enabled.""" |
| if self.tgen0 is not None and self.max_time is not None and (time.time() - self.tgen0) > self.max_time: |
| if self.verbose: |
| print("Took too long in StreamingGradioCallbackHandler: %s" % (time.time() - self.tgen0), flush=True) |
| self.text_queue.put(self.stop_signal) |
| else: |
| self.text_queue.put(token) |
|
|
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: |
| """Run when LLM ends running.""" |
| self.text_queue.put(self.stop_signal) |
|
|
| def on_llm_error( |
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any |
| ) -> None: |
| """Run when LLM errors.""" |
| self.text_queue.put(self.stop_signal) |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self): |
| while True: |
| try: |
| value = self.stop_signal |
| if self.do_stop: |
| print("hit stop", flush=True) |
| |
| raise StopIteration() |
| |
| value = self.text_queue.get(block=self.block, timeout=self.timeout) |
| break |
| except queue.Empty: |
| time.sleep(0.01) |
| if value == self.stop_signal: |
| raise StopIteration() |
| else: |
| return value |
|
|
|
|
| def _chunk_sources(sources, chunk=True, chunk_size=512, language=None, db_type=None): |
| assert db_type is not None |
|
|
| if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources): |
| |
| sources = [sources] |
| if not chunk: |
| [x.metadata.update(dict(chunk_id=0)) for chunk_id, x in enumerate(sources)] |
| if db_type in ['chroma', 'chroma_old']: |
| |
| source_chunks = [Document(page_content=x.page_content, |
| metadata=copy.deepcopy(x.metadata) or {}) |
| for x in sources] |
| else: |
| source_chunks = sources |
| else: |
| if language and False: |
| |
| |
| |
| keep_separator = True |
| separators = RecursiveCharacterTextSplitter.get_separators_for_language(language) |
| else: |
| separators = ["\n\n", "\n", " ", ""] |
| keep_separator = False |
| splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator, |
| separators=separators) |
| source_chunks = splitter.split_documents(sources) |
|
|
| |
| [x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)] |
|
|
| if db_type in ['chroma', 'chroma_old']: |
| |
|
|
| |
| |
| [x.metadata.update(dict(chunk_id=-1)) for chunk_id, x in enumerate(sources)] |
|
|
| |
| return list(sources) + source_chunks |
| else: |
| return source_chunks |
|
|
|
|
| def add_parser(docs1, parser): |
| [x.metadata.update(dict(parser=x.metadata.get('parser', parser))) for x in docs1] |
|
|
|
|
| def _add_meta(docs1, file, headsize=50, filei=0, parser='NotSet'): |
| if os.path.isfile(file): |
| file_extension = pathlib.Path(file).suffix |
| hashid = hash_file(file) |
| else: |
| file_extension = str(file) |
| hashid = get_sha(file) |
| doc_hash = str(uuid.uuid4())[:10] |
| if not isinstance(docs1, (list, tuple, types.GeneratorType)): |
| docs1 = [docs1] |
| [x.metadata.update(dict(input_type=file_extension, |
| parser=x.metadata.get('parser', parser), |
| date=str(datetime.now()), |
| time=time.time(), |
| order_id=order_id, |
| hashid=hashid, |
| doc_hash=doc_hash, |
| file_id=filei, |
| head=x.page_content[:headsize].strip())) for order_id, x in enumerate(docs1)] |
|
|
|
|
| def fix_json_meta(docs1): |
| if not isinstance(docs1, (list, tuple, types.GeneratorType)): |
| docs1 = [docs1] |
| |
| [x.metadata.update(dict(sender_name=x.metadata.get('sender_name') or '')) for x in docs1] |
| [x.metadata.update(dict(timestamp_ms=x.metadata.get('timestamp_ms') or '')) for x in docs1] |
|
|
|
|
| class H2OMapReduceDocumentsChain(MapReduceDocumentsChain): |
| def combine_docs( |
| self, |
| docs: List[Document], |
| token_max: Optional[int] = None, |
| callbacks: Callbacks = None, |
| **kwargs: Any, |
| ) -> Tuple[List, dict]: |
| """Combine documents in a map reduce manner. |
| |
| Combine by mapping first chain over all documents, then reducing the results. |
| This reducing can be done recursively if needed (if there are many documents). |
| """ |
| map_results = self.llm_chain.apply( |
| |
| [{self.document_variable_name: d.page_content, **kwargs} for d in docs], |
| callbacks=callbacks, |
| ) |
| question_result_key = self.llm_chain.output_key |
| result_docs = [ |
| Document(page_content=r[question_result_key], metadata=docs[i].metadata) |
| |
| for i, r in enumerate(map_results) |
| ] |
| extra_return_dict = {} |
| if self.return_intermediate_steps: |
| intermediate_steps = [r[question_result_key] for r in map_results] |
| extra_return_dict["intermediate_steps"] = intermediate_steps |
| result_docs_content = [x.page_content for x in result_docs] |
| return result_docs_content, extra_return_dict |
|
|
| async def acombine_docs( |
| self, |
| docs: List[Document], |
| token_max: Optional[int] = None, |
| callbacks: Callbacks = None, |
| **kwargs: Any, |
| ) -> Tuple[List, dict]: |
| """Combine documents in a map reduce manner. |
| |
| Combine by mapping first chain over all documents, then reducing the results. |
| This reducing can be done recursively if needed (if there are many documents). |
| """ |
| map_results = await self.llm_chain.aapply( |
| |
| [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], |
| callbacks=callbacks, |
| ) |
| question_result_key = self.llm_chain.output_key |
| result_docs = [ |
| Document(page_content=r[question_result_key], metadata=docs[i].metadata) |
| |
| for i, r in enumerate(map_results) |
| ] |
| extra_return_dict = {} |
| if self.return_intermediate_steps: |
| intermediate_steps = [r[question_result_key] for r in map_results] |
| extra_return_dict["intermediate_steps"] = intermediate_steps |
| result_docs_content = [x.page_content for x in result_docs] |
| return result_docs_content, extra_return_dict |
|
|
| @property |
| def _chain_type(self) -> str: |
| return "map_documents_chain" |
|
|
|
|
| def _load_map_chain( |
| llm: BaseLanguageModel, |
| map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, |
| combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT, |
| combine_document_variable_name: str = "text", |
| map_reduce_document_variable_name: str = "text", |
| collapse_prompt: Optional[BasePromptTemplate] = None, |
| reduce_llm: Optional[BaseLanguageModel] = None, |
| collapse_llm: Optional[BaseLanguageModel] = None, |
| verbose: Optional[bool] = None, |
| token_max: int = 3000, |
| callbacks: Callbacks = None, |
| **kwargs: Any, |
| ) -> H2OMapReduceDocumentsChain: |
| map_chain = LLMChain( |
| llm=llm, prompt=map_prompt, verbose=verbose, callbacks=callbacks |
| ) |
| _reduce_llm = reduce_llm or llm |
| reduce_chain = LLMChain( |
| llm=_reduce_llm, prompt=combine_prompt, verbose=verbose, callbacks=callbacks |
| ) |
| |
| combine_documents_chain = StuffDocumentsChain( |
| llm_chain=reduce_chain, |
| document_variable_name=combine_document_variable_name, |
| verbose=verbose, |
| callbacks=callbacks, |
| ) |
| if collapse_prompt is None: |
| collapse_chain = None |
| if collapse_llm is not None: |
| raise ValueError( |
| "collapse_llm provided, but collapse_prompt was not: please " |
| "provide one or stop providing collapse_llm." |
| ) |
| else: |
| _collapse_llm = collapse_llm or llm |
| collapse_chain = StuffDocumentsChain( |
| llm_chain=LLMChain( |
| llm=_collapse_llm, |
| prompt=collapse_prompt, |
| verbose=verbose, |
| callbacks=callbacks, |
| ), |
| document_variable_name=combine_document_variable_name, |
| ) |
| reduce_documents_chain = ReduceDocumentsChain( |
| combine_documents_chain=combine_documents_chain, |
| collapse_documents_chain=collapse_chain, |
| token_max=token_max, |
| verbose=verbose, |
| callbacks=callbacks, |
| ) |
| return H2OMapReduceDocumentsChain( |
| llm_chain=map_chain, |
| reduce_documents_chain=reduce_documents_chain, |
| document_variable_name=map_reduce_document_variable_name, |
| verbose=verbose, |
| callbacks=callbacks, |
| **kwargs, |
| ) |
|
|
|
|
| def load_general_summarization_chain( |
| llm: BaseLanguageModel, |
| chain_type: str = "stuff", |
| verbose: Optional[bool] = None, |
| **kwargs: Any, |
| ) -> BaseCombineDocumentsChain: |
| """Load summarizing chain. |
| |
| Args: |
| llm: Language Model to use in the chain. |
| chain_type: Type of document combining chain to use. Should be one of "stuff", |
| "map_reduce", and "refine". |
| verbose: Whether chains should be run in verbose mode or not. Note that this |
| applies to all chains that make up the final chain. |
| |
| Returns: |
| A chain to use for summarizing. |
| """ |
| loader_mapping: Mapping[str, LoadingCallable] = { |
| "stuff": _load_stuff_chain, |
| "map_reduce": _load_map_reduce_chain, |
| "refine": _load_refine_chain, |
| "map": _load_map_chain, |
| } |
| if chain_type not in loader_mapping: |
| raise ValueError( |
| f"Got unsupported chain type: {chain_type}. " |
| f"Should be one of {loader_mapping.keys()}" |
| ) |
| return loader_mapping[chain_type](llm, verbose=verbose, **kwargs) |
|
|
|
|
| """Utils for interacting with the Semantic Scholar API.""" |
| import logging |
| from typing import Any, Dict, Optional |
|
|
| from langchain_core.pydantic_v1 import BaseModel, root_validator |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class H2OSemanticScholarAPIWrapper(BaseModel): |
| """Wrapper around semanticscholar.org API. |
| https://github.com/danielnsilva/semanticscholar |
| |
| You should have this library installed. |
| |
| `pip install semanticscholar` |
| |
| Semantic Scholar API can conduct searches and fetch document metadata |
| like title, abstract, authors, etc. |
| |
| Attributes: |
| top_k_results: number of the top-scored document used for the Semantic Scholar tool |
| load_max_docs: a limit to the number of loaded documents |
| |
| Example: |
| .. code-block:: python |
| |
| from langchain_community.utilities.semanticscholar import SemanticScholarAPIWrapper |
| ss = SemanticScholarAPIWrapper( |
| top_k_results = 3, |
| load_max_docs = 3 |
| ) |
| ss.run("biases in large language models") |
| """ |
|
|
| semanticscholar_search: Any |
| top_k_results: int = 5 |
| S2_MAX_QUERY_LENGTH: int = 300 |
| load_max_docs: int = 100 |
| doc_content_chars_max: Optional[int] = 4000 |
| returned_fields = [ |
| "title", |
| "abstract", |
| "venue", |
| "year", |
| "paperId", |
| "citationCount", |
| "openAccessPdf", |
| "authors", |
| "externalIds", |
| ] |
|
|
| @root_validator() |
| def validate_environment(cls, values: Dict) -> Dict: |
| """Validate that the python package exists in environment.""" |
| try: |
| from semanticscholar import SemanticScholar |
|
|
| sch = SemanticScholar(api_key=os.getenv('S2_API_KEY')) |
| values["semanticscholar_search"] = sch.search_paper |
| except ImportError: |
| raise ImportError( |
| "Could not import Semanticscholar python package. " |
| "Please install it with `pip install semanticscholar`." |
| ) |
| return values |
|
|
| def run(self, query: str) -> str: |
| """Run the Semantic Scholar API.""" |
| results = self.semanticscholar_search( |
| query, limit=self.load_max_docs, fields=self.returned_fields |
| ) |
| documents = [] |
| for item in results[: self.top_k_results]: |
| authors = ", ".join( |
| author["name"] for author in getattr(item, "authors", []) |
| ) |
| documents.append( |
| f"Published year: {getattr(item, 'year', None)}\n" |
| f"Title: {getattr(item, 'title', None)}\n" |
| f"Authors: {authors}\n" |
| f"Astract: {getattr(item, 'abstract', None)}\n" |
| ) |
|
|
| if documents: |
| return "\n\n".join(documents)[: self.doc_content_chars_max] |
| else: |
| return "No results found." |
|
|