| import os |
| import chromadb |
| from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_core.runnables import RunnablePassthrough |
| from langchain_huggingface.embeddings import HuggingFaceEmbeddings |
| from langchain_chroma import Chroma |
|
|
| |
| REPO_ID = "Qwen/Qwen2.5-7B-Instruct" |
| COLLECTION_NAME = 'video_analysis_data' |
| DB_PATH = "./chroma_db" |
|
|
| def run_query(user_query): |
| if not os.getenv("HUGGINGFACEHUB_API_TOKEN"): |
| return "Error: Please set HUGGINGFACEHUB_API_TOKEN in your environment variables." |
|
|
| client = chromadb.PersistentClient(path=DB_PATH) |
| embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") |
| |
| vectorstore = Chroma( |
| client=client, |
| collection_name=COLLECTION_NAME, |
| embedding_function=embedding_function |
| ) |
| |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) |
|
|
| |
| llm_endpoint = HuggingFaceEndpoint( |
| repo_id=REPO_ID, |
| task="text-generation", |
| max_new_tokens=512, |
| repetition_penalty=1.1, |
| huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN") |
| ) |
| |
| |
| llm = ChatHuggingFace(llm=llm_endpoint) |
|
|
| template = """ |
| You are an expert Video Content Analyst. Use the Context to answer the Question. |
| If you don't know the answer, say you don't know. |
| Infer activity based on detected objects (e.g., people + skateboards = skateboarding). |
| |
| Context: |
| {context} |
| |
| Question: {question} |
| """ |
| prompt = ChatPromptTemplate.from_template(template) |
|
|
| rag_chain = ( |
| {"context": retriever, "question": RunnablePassthrough()} |
| | prompt |
| | llm |
| ) |
| |
| response = rag_chain.invoke(user_query) |
| |
| return response.content |
|
|
| |
| if __name__ == '__main__': |
| |
| |
| query1 = "What kind of objects were frequently detected in the video?" |
| answer1 = run_query(query1) |
| print("\n--- QUERY 1 ---") |
| print(f"Question: {query1}") |
| print(f"Answer:\n{answer1}") |
|
|
| print("\n" + "="*50 + "\n") |
|
|
| query2 = "What activity was detected around the 15-second mark in the video?" |
| answer2 = run_query(query2) |
| print("\n--- QUERY 2 ---") |
| print(f"Question: {query2}") |
| print(f"Answer:\n{answer2}") |
|
|