| import re |
|
|
| import chainlit as cl |
| import tiktoken |
| from langchain.callbacks.base import BaseCallbackHandler |
|
|
|
|
| def format_docs(documents, max_context_size=100000, separator="\n\n"): |
| context = "" |
| encoder = tiktoken.get_encoding("cl100k_base") |
| i = 0 |
| for doc in documents: |
| i += 1 |
| if len(encoder.encode(context)) < max_context_size: |
| source = doc.metadata["link"] |
| title = doc.metadata["title"] |
| context += ( |
| f"Article: {title}\n" |
| + doc.page_content |
| + f"\nSource: {source}" |
| + separator |
| ) |
| return context |
|
|
|
|
| class PostMessageHandler(BaseCallbackHandler): |
| """ |
| Callback handler for handling the retriever and LLM processes. |
| Used to post the sources of the retrieved documents as a Chainlit element. |
| """ |
|
|
| def __init__(self, msg: cl.Message): |
| BaseCallbackHandler.__init__(self) |
| self.msg = msg |
| self.sources = [] |
|
|
| def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): |
| for d in documents: |
| source_doc = d.page_content + "\nSource: " + d.metadata["link"] |
| self.sources.append(source_doc) |
|
|
| def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): |
| if len(self.sources): |
| |
| sources_element = [ |
| cl.Text(name=f"source_{idx+1}", content=content) |
| for idx, content in enumerate(self.sources) |
| ] |
| source_names = [el.name for el in sources_element] |
| self.msg.elements += sources_element |
| self.msg.content += f"\nSources: {', '.join(source_names)}" |
|
|
| def clean_text(text): |
| text = re.sub("[Tt]weet", "", text) |
| text = re.sub(r"\ +", " ", text) |
| text = re.sub(r"\n+", "\n", text) |
| return text.strip() |
|
|