Spaces:
Sleeping
Sleeping
| from langchain_core.prompts import PromptTemplate | |
| from langchain_chroma import Chroma | |
| from langchain_community.llms import Ollama | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| import sys | |
| def load_vectorstore(): | |
| embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| vectorstore = Chroma( | |
| persist_directory="chroma_db", | |
| embedding_function=embeddings, | |
| collection_name="rag_code_assistant" | |
| ) | |
| return vectorstore | |
| def load_prompt(): | |
| template = ( | |
| "You are an expert Python coding assistant. Use the following documentation " | |
| "excerpts to answer the user's question accurately. \n\n" | |
| "IMPORTANT INSTRUCTIONS:\n" | |
| "- Prioritize standard Python concepts over C-extensions or advanced typing unless specifically asked.\n" | |
| "- If multiple types of answers are in the context, synthesize them into a complete answer.\n" | |
| "- If the answer is not in the context, say 'I don't know'.\n\n" | |
| "Context:\n" | |
| "{context}\n\n" | |
| "Question: {question}\n\n" | |
| "Answer:" | |
| ) | |
| prompt = PromptTemplate( | |
| input_variables=["context", "question"], | |
| template=template | |
| ) | |
| return prompt | |
| def load_llm(): | |
| llm = Ollama(model="llama3", temperature=0.1) | |
| return llm | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| def load_retriever(vectorstore): | |
| retriever = vectorstore.as_retriever( | |
| search_type="mmr", # Use MMR instead of standard similarity | |
| search_kwargs={ | |
| "k": 4, # Return 4 diverse chunks to the LLM | |
| "fetch_k": 20 # Search top 20, then pick the 4 most diverse | |
| } | |
| ) | |
| return retriever | |
| def load_rag_chain(retriever, prompt, llm): | |
| rag_chain = ( | |
| {"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| return rag_chain | |
| def load_answer(rag_chain, retriever, question): | |
| print("=== ASSISTANT'S ANSWER ===") | |
| # We use .stream() instead of .invoke() for the RAG chain | |
| # This yields words one by one as they are generated | |
| full_answer = "" | |
| for chunk in rag_chain.stream(question): | |
| print(chunk, end="", flush=True) # Print word immediately, without newlines | |
| full_answer += chunk | |
| print("\n") # Add a final newline when done | |
| # We still use .invoke() for the retriever to get the source documents | |
| # (Database retrieval happens instantly, no need to stream it) | |
| source_docs = retriever.invoke(question) | |
| return full_answer, source_docs | |
| vectorstore = load_vectorstore() | |
| prompt = load_prompt() | |
| llm = load_llm() | |
| retriever = load_retriever(vectorstore) | |
| rag_chain = load_rag_chain(retriever, prompt, llm) | |
| if __name__ == "__main__": | |
| vectorstore = load_vectorstore() | |
| prompt = load_prompt() | |
| llm = load_llm() | |
| retriever = load_retriever(vectorstore) | |
| rag_chain = load_rag_chain(retriever, prompt, llm) | |
| question = sys.argv[1] if len(sys.argv) > 1 else "What is Python?" | |
| print(f"\nUser Question: {question}\nThinking...\n") | |
| # load_answer now handles the streaming print | |
| answer, source_docs = load_answer(rag_chain, retriever, question) | |
| print("\n=== SOURCES USED ===") | |
| for idx, doc in enumerate(source_docs): | |
| source_file = doc.metadata.get("source", "Unknown file") | |
| print(f"{idx + 1}. {source_file}") | |