| |
| |
| import os |
| from typing import List |
| from langchain_groq import ChatGroq |
| from langchain.prompts import PromptTemplate |
| from langchain_community.vectorstores import Qdrant |
| from langchain_community.embeddings.fastembed import FastEmbedEmbeddings |
| from qdrant_client import QdrantClient |
| from langchain_community.chat_models import ChatOllama |
|
|
|
|
| import chainlit as cl |
| from langchain.chains import RetrievalQA |
|
|
| |
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| groq_api_key = os.getenv("GROQ_API_KEY") |
| qdrant_url = os.getenv("QDRANT_URL") |
| qdrant_api_key = os.getenv("QDRANT_API_KEY") |
|
|
| custom_prompt_template = """Use the following pieces of information to answer the user's question. |
| If you don't know the answer, just say that you don't know, don't try to make up an answer. |
| |
| Context: {context} |
| Question: {question} |
| |
| Only return the helpful answer below and nothing else. |
| Helpful answer: |
| """ |
|
|
| def set_custom_prompt(): |
| """ |
| Prompt template for QA retrieval for each vectorstore |
| """ |
| prompt = PromptTemplate(template=custom_prompt_template, |
| input_variables=['context', 'question']) |
| return prompt |
|
|
|
|
| chat_model = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768") |
| |
| |
|
|
| client = QdrantClient(api_key=qdrant_api_key, url=qdrant_url,) |
|
|
|
|
| def retrieval_qa_chain(llm, prompt, vectorstore): |
| qa_chain = RetrievalQA.from_chain_type( |
| llm=llm, |
| chain_type="stuff", |
| retriever=vectorstore.as_retriever(search_kwargs={'k': 2}), |
| return_source_documents=True, |
| chain_type_kwargs={'prompt': prompt} |
| ) |
| return qa_chain |
|
|
|
|
| def qa_bot(): |
| embeddings = FastEmbedEmbeddings() |
| vectorstore = Qdrant(client=client, embeddings=embeddings, collection_name="rag") |
| llm = chat_model |
| qa_prompt=set_custom_prompt() |
| qa = retrieval_qa_chain(llm, qa_prompt, vectorstore) |
| return qa |
|
|
|
|
| @cl.set_chat_profiles |
| async def chat_profile(): |
| return [ |
| cl.ChatProfile( |
| name="Virtual Tutor", |
| markdown_description="The underlying LLM model is **Mixtral**.", |
| icon="https://www.google.com/url?sa=i&url=https%3A%2F%2Fwww.gptshunter.com%2Fgpt-store%2FNTQyNjE4MGMyMzU1MTcyNjU4&psig=AOvVaw3dz6CEyBeDM9iyj8gcEwNI&ust=1711780055423000&source=images&cd=vfe&opi=89978449&ved=0CBAQjRxqFwoTCNjprerrmIUDFQAAAAAdAAAAABAK", |
| ), |
| ] |
|
|
| @cl.on_chat_start |
| async def start(): |
| """ |
| Initializes the bot when a new chat starts. |
| |
| This asynchronous function creates a new instance of the retrieval QA bot, |
| sends a welcome message, and stores the bot instance in the user's session. |
| """ |
| await cl.Avatar( |
| name="Tool 1", |
| url="https://www.google.com/url?sa=i&url=https%3A%2F%2Fwww.gptshunter.com%2Fgpt-store%2FNTQyNjE4MGMyMzU1MTcyNjU4&psig=AOvVaw3dz6CEyBeDM9iyj8gcEwNI&ust=1711780055423000&source=images&cd=vfe&opi=89978449&ved=0CBAQjRxqFwoTCNjprerrmIUDFQAAAAAdAAAAABAK",).send() |
| |
| chain = qa_bot() |
| welcome_message = cl.Message(content="Starting the bot...") |
| await welcome_message.send() |
| welcome_message.content = ( |
| "Welcome to Virtual Tutor." |
| ) |
| await welcome_message.update() |
| cl.user_session.set("chain", chain) |
|
|
|
|
| @cl.on_message |
| async def main(message): |
| """ |
| Processes incoming chat messages. |
| |
| This asynchronous function retrieves the QA bot instance from the user's session, |
| sets up a callback handler for the bot's response, and executes the bot's |
| call method with the given message and callback. The bot's answer and source |
| documents are then extracted from the response. |
| """ |
| chain = cl.user_session.get("chain") |
| cb = cl.AsyncLangchainCallbackHandler() |
| cb.answer_reached = True |
| |
| res = await chain.acall(message.content, callbacks=[cb]) |
| |
| answer = res["result"] |
| |
| source_documents = res["source_documents"] |
|
|
| text_elements = [] |
|
|
| if source_documents: |
| for source_idx, source_doc in enumerate(source_documents): |
| source_name = f"source_{source_idx}" |
| |
| text_elements.append( |
| cl.Text(content=source_doc.page_content, name=source_name) |
| ) |
| source_names = [text_el.name for text_el in text_elements] |
|
|
| if source_names: |
| answer += f"\nSources: {', '.join(source_names)}" |
| else: |
| answer += "\nNo sources found" |
|
|
| await cl.Message(content=answer, elements=text_elements,author="Tool 1").send() |