| | from fastapi import FastAPI, HTTPException, Depends, File, UploadFile, Form, Response, BackgroundTasks |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
| | from fastapi.responses import StreamingResponse |
| | from pydantic import BaseModel, Field, EmailStr |
| | from typing import List, Optional, Dict, Any, Union |
| | import uuid |
| | import os |
| | import io |
| | from urllib.parse import quote_plus |
| |
|
| | import shutil |
| | from datetime import datetime, timedelta |
| | from dotenv import load_dotenv |
| | import hashlib |
| | import jwt |
| | from passlib.context import CryptContext |
| | from pymongo import MongoClient |
| | from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory |
| |
|
| | |
| | load_dotenv() |
| |
|
| | |
| | from langchain_huggingface import HuggingFaceEmbeddings |
| | from langchain_community.vectorstores import FAISS |
| | from langchain.chains import ConversationalRetrievalChain |
| | from langchain_core.prompts import PromptTemplate, ChatPromptTemplate |
| | from langchain.text_splitter import RecursiveCharacterTextSplitter |
| | from langchain_core.documents import Document |
| | from langchain_groq import ChatGroq |
| | from google import genai |
| | from google.genai import types |
| |
|
| | |
| | MONGO_PASSWORD = quote_plus(os.getenv("MONGO_PASSWORD")) |
| | MONGO_DATABASE_NAME = os.getenv("DATABASE_NAME") |
| | MONGO_COLLECTION_NAME = os.getenv("COLLECTION_NAME") |
| | connection_string_template = os.getenv("CONNECTION_STRING") |
| | MONGO_CLUSTER_URL = connection_string_template.replace("${PASSWORD}", MONGO_PASSWORD) |
| | CHAT_COLLECTION = MONGO_COLLECTION_NAME or "chat_history" |
| | USER_COLLECTION = "users" |
| | VIDEO_COLLECTION = "videos" |
| |
|
| | |
| | SECRET_KEY = os.getenv("SECRET_KEY") |
| | ALGORITHM = "HS256" |
| | ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
| |
|
| | |
| | pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
| |
|
| | |
| | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
| |
|
| | |
| | app = FastAPI(title="RAG System API", description="An API for question answering based on video content with user authentication") |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | class TranscriptionRequest(BaseModel): |
| | youtube_url: str |
| |
|
| | class QueryRequest(BaseModel): |
| | query: str |
| | session_id: str |
| |
|
| | class QueryResponse(BaseModel): |
| | answer: str |
| | session_id: str |
| | source_documents: Optional[List[str]] = None |
| |
|
| | class User(BaseModel): |
| | username: str |
| | email: EmailStr |
| | full_name: Optional[str] = None |
| |
|
| | class UserInDB(User): |
| | hashed_password: str |
| |
|
| | class UserCreate(User): |
| | password: str |
| |
|
| | class Token(BaseModel): |
| | access_token: str |
| | token_type: str |
| |
|
| | class TokenData(BaseModel): |
| | username: Optional[str] = None |
| |
|
| | class VideoData(BaseModel): |
| | video_id: str |
| | user_id: str |
| | title: str |
| | source_type: str |
| | source_url: Optional[str] = None |
| | created_at: datetime = Field(default_factory=datetime.utcnow) |
| | transcription: str |
| | size: Optional[int] = None |
| |
|
| | |
| | class MongoDB: |
| | def __init__(self): |
| | self.client = MongoClient(MONGO_CLUSTER_URL) |
| | self.db = self.client[MONGO_DATABASE_NAME] |
| | self.users = self.db[USER_COLLECTION] |
| | self.videos = self.db[VIDEO_COLLECTION] |
| | |
| | |
| | self.users.create_index("username", unique=True) |
| | self.users.create_index("email", unique=True) |
| | self.videos.create_index("video_id", unique=True) |
| | self.videos.create_index("user_id") |
| |
|
| | def close(self): |
| | self.client.close() |
| |
|
| | |
| | class ChatManagement: |
| | def __init__(self, cluster_url, database_name, collection_name): |
| | self.connection_string = cluster_url |
| | self.database_name = database_name |
| | self.collection_name = collection_name |
| | self.chat_sessions = {} |
| | |
| | def create_new_chat(self): |
| | |
| | chat_id = str(uuid.uuid4()) |
| | |
| | chat_message_history = MongoDBChatMessageHistory( |
| | session_id=chat_id, |
| | connection_string=self.connection_string, |
| | database_name=self.database_name, |
| | collection_name=self.collection_name |
| | ) |
| | |
| | self.chat_sessions[chat_id] = chat_message_history |
| | return chat_id |
| | |
| | def get_chat_history(self, chat_id): |
| | |
| | if chat_id in self.chat_sessions: |
| | return self.chat_sessions[chat_id] |
| | |
| | chat_message_history = MongoDBChatMessageHistory( |
| | session_id=chat_id, |
| | connection_string=self.connection_string, |
| | database_name=self.database_name, |
| | collection_name=self.collection_name |
| | ) |
| | if chat_message_history.messages: |
| | self.chat_sessions[chat_id] = chat_message_history |
| | return chat_message_history |
| | return None |
| | |
| | def initialize_chat_history(self, chat_id): |
| | |
| | if chat_id in self.chat_sessions: |
| | return self.chat_sessions[chat_id] |
| | |
| | chat_message_history = MongoDBChatMessageHistory( |
| | session_id=chat_id, |
| | connection_string=self.connection_string, |
| | database_name=self.database_name, |
| | collection_name=self.collection_name |
| | ) |
| | |
| | self.chat_sessions[chat_id] = chat_message_history |
| | return chat_message_history |
| |
|
| | |
| | mongodb = MongoDB() |
| | chat_manager = ChatManagement(MONGO_CLUSTER_URL, MONGO_DATABASE_NAME, CHAT_COLLECTION) |
| | sessions = {} |
| |
|
| | |
| | VIDEOS_DIR = "temp_videos" |
| | os.makedirs(VIDEOS_DIR, exist_ok=True) |
| |
|
| | |
| | def verify_password(plain_password, hashed_password): |
| | return pwd_context.verify(plain_password, hashed_password) |
| |
|
| | def get_password_hash(password): |
| | return pwd_context.hash(password) |
| |
|
| | def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): |
| | to_encode = data.copy() |
| | if expires_delta: |
| | expire = datetime.utcnow() + expires_delta |
| | else: |
| | expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
| | to_encode.update({"exp": expire}) |
| | encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) |
| | return encoded_jwt |
| |
|
| | def get_user(username: str): |
| | user_data = mongodb.users.find_one({"username": username}) |
| | if user_data: |
| | return UserInDB(**user_data) |
| | return None |
| |
|
| | def authenticate_user(username: str, password: str): |
| | user = get_user(username) |
| | if not user: |
| | return False |
| | if not verify_password(password, user.hashed_password): |
| | return False |
| | return user |
| |
|
| | async def get_current_user(token: str = Depends(oauth2_scheme)): |
| | credentials_exception = HTTPException( |
| | status_code=401, |
| | detail="Could not validate credentials", |
| | headers={"WWW-Authenticate": "Bearer"}, |
| | ) |
| | try: |
| | payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
| | username: str = payload.get("sub") |
| | if username is None: |
| | raise credentials_exception |
| | token_data = TokenData(username=username) |
| | except jwt.PyJWTError: |
| | raise credentials_exception |
| | user = get_user(username=token_data.username) |
| | if user is None: |
| | raise credentials_exception |
| | return user |
| |
|
| | |
| | def init_google_client(): |
| | api_key = os.getenv("GOOGLE_API_KEY", "") |
| | if not api_key: |
| | raise ValueError("GOOGLE_API_KEY environment variable not set") |
| | return genai.Client(api_key=api_key) |
| |
|
| | |
| | def get_llm(): |
| | """ |
| | Returns the language model instance (LLM) using ChatGroq API. |
| | The LLM used is Llama 3.3 with a versatile 70 billion parameters model. |
| | """ |
| | api_key = os.getenv("CHATGROQ_API_KEY", "") |
| | if not api_key: |
| | raise ValueError("CHATGROQ_API_KEY environment variable not set") |
| | |
| | llm = ChatGroq( |
| | model="llama-3.3-70b-versatile", |
| | temperature=0, |
| | max_tokens=1024, |
| | api_key=api_key |
| | ) |
| | return llm |
| |
|
| | |
| | def get_embeddings(): |
| | model_name = "BAAI/bge-small-en" |
| | model_kwargs = {"device": "cpu"} |
| | encode_kwargs = {"normalize_embeddings": True} |
| | embeddings = HuggingFaceEmbeddings( |
| | model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs |
| | ) |
| | return embeddings |
| |
|
| | |
| | quiz_solving_prompt = ''' |
| | You are an assistant specialized in solving quizzes. Your goal is to provide accurate, concise, and contextually relevant answers. |
| | Use the following retrieved context to answer the user's question. |
| | If the context lacks sufficient information, respond with "I don't know." Do not make up answers or provide unverified information. |
| | |
| | Guidelines: |
| | 1. Extract key information from the context to form a coherent response. |
| | 2. Maintain a clear and professional tone. |
| | 3. If the question requires clarification, specify it politely. |
| | |
| | Retrieved context: |
| | {context} |
| | |
| | User's question: |
| | {question} |
| | |
| | Your response: |
| | ''' |
| |
|
| | |
| | user_prompt = ChatPromptTemplate.from_messages( |
| | [ |
| | ("system", quiz_solving_prompt), |
| | ("human", "{question}"), |
| | ] |
| | ) |
| |
|
| | |
| | def create_chain(retriever): |
| | llm = get_llm() |
| | chain = ConversationalRetrievalChain.from_llm( |
| | llm=llm, |
| | retriever=retriever, |
| | return_source_documents=True, |
| | chain_type='stuff', |
| | combine_docs_chain_kwargs={"prompt": user_prompt}, |
| | verbose=False, |
| | ) |
| | return chain |
| |
|
| | |
| | def process_transcription(transcription, user_id, title, source_type, source_url=None, file_size=None): |
| | |
| | text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=20) |
| | all_splits = text_splitter.split_text(transcription) |
| | |
| | |
| | embeddings = get_embeddings() |
| | vectorstore = FAISS.from_texts(all_splits, embeddings) |
| | retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) |
| | |
| | |
| | session_id = str(uuid.uuid4()) |
| | |
| | |
| | video_data = { |
| | "video_id": session_id, |
| | "user_id": user_id, |
| | "title": title, |
| | "source_type": source_type, |
| | "source_url": source_url, |
| | "created_at": datetime.utcnow(), |
| | "transcription": transcription, |
| | "size": file_size |
| | } |
| | |
| | mongodb.videos.insert_one(video_data) |
| | |
| | |
| | sessions[session_id] = { |
| | "retriever": retriever, |
| | "chat_history": chat_manager.initialize_chat_history(session_id) |
| | } |
| | |
| | return session_id |
| |
|
| | |
| | def save_video_file(video_id, file_path, contents): |
| | os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| | with open(file_path, "wb") as f: |
| | f.write(contents) |
| |
|
| | |
| | @app.post("/register", response_model=User) |
| | async def register_user(user: UserCreate): |
| | |
| | if mongodb.users.find_one({"username": user.username}): |
| | raise HTTPException(status_code=400, detail="Username already registered") |
| | |
| | |
| | if mongodb.users.find_one({"email": user.email}): |
| | raise HTTPException(status_code=400, detail="Email already registered") |
| | |
| | |
| | hashed_password = get_password_hash(user.password) |
| | user_dict = user.dict() |
| | del user_dict["password"] |
| | user_dict["hashed_password"] = hashed_password |
| | |
| | |
| | mongodb.users.insert_one(user_dict) |
| | |
| | return User(**user_dict) |
| |
|
| | @app.post("/token", response_model=Token) |
| | async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): |
| | user = authenticate_user(form_data.username, form_data.password) |
| | if not user: |
| | raise HTTPException( |
| | status_code=401, |
| | detail="Incorrect username or password", |
| | headers={"WWW-Authenticate": "Bearer"}, |
| | ) |
| | access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
| | access_token = create_access_token( |
| | data={"sub": user.username}, expires_delta=access_token_expires |
| | ) |
| | return {"access_token": access_token, "token_type": "bearer"} |
| |
|
| | |
| | @app.post("/transcribe", response_model=Dict[str, str]) |
| | async def transcribe_video( |
| | request: TranscriptionRequest, |
| | current_user: User = Depends(get_current_user) |
| | ): |
| | """ |
| | Transcribe a YouTube video and prepare the RAG system |
| | """ |
| | try: |
| | |
| | client = init_google_client() |
| | |
| | |
| | response = client.models.generate_content( |
| | model='models/gemini-2.0-flash', |
| | contents=types.Content( |
| | parts=[ |
| | types.Part(text='Transcribe the Video. Write all the things described in the video'), |
| | types.Part( |
| | file_data=types.FileData(file_uri=request.youtube_url) |
| | ) |
| | ] |
| | ) |
| | ) |
| | |
| | |
| | transcription = response.candidates[0].content.parts[0].text |
| | |
| | |
| | video_title = f"YouTube Video - {datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}" |
| | session_id = process_transcription( |
| | transcription, |
| | current_user.username, |
| | video_title, |
| | "youtube", |
| | request.youtube_url |
| | ) |
| | |
| | return {"session_id": session_id, "message": "YouTube video transcribed and RAG system prepared"} |
| | |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error transcribing video: {str(e)}") |
| |
|
| | @app.post("/upload", response_model=Dict[str, str]) |
| | async def upload_video( |
| | background_tasks: BackgroundTasks, |
| | title: str = Form(...), |
| | file: UploadFile = File(...), |
| | prompt: str = Form("Transcribe the Video. Write all the things described in the video"), |
| | current_user: User = Depends(get_current_user) |
| | ): |
| | """ |
| | Upload a video file (max 20MB), transcribe it and prepare the RAG system |
| | """ |
| | try: |
| | |
| | contents = await file.read() |
| | file_size = len(contents) |
| | if file_size > 20 * 1024 * 1024: |
| | raise HTTPException(status_code=400, detail="File size exceeds 20MB limit") |
| | |
| | |
| | if not file.content_type.startswith('video/'): |
| | raise HTTPException(status_code=400, detail="File must be a video") |
| | |
| | |
| | client = init_google_client() |
| | |
| | |
| | response = client.models.generate_content( |
| | model='models/gemini-2.0-flash', |
| | contents=types.Content( |
| | parts=[ |
| | types.Part(text=prompt), |
| | types.Part( |
| | inline_data=types.Blob(data=contents, mime_type=file.content_type) |
| | ) |
| | ] |
| | ) |
| | ) |
| | |
| | |
| | transcription = response.candidates[0].content.parts[0].text |
| | |
| | |
| | session_id = process_transcription( |
| | transcription, |
| | current_user.username, |
| | title, |
| | "upload", |
| | None, |
| | file_size |
| | ) |
| | |
| | |
| | file_extension = os.path.splitext(file.filename)[1] |
| | file_path = os.path.join(VIDEOS_DIR, f"{session_id}{file_extension}") |
| | background_tasks.add_task(save_video_file, session_id, file_path, contents) |
| | |
| | return {"session_id": session_id, "message": "Uploaded video transcribed and RAG system prepared"} |
| | |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Error processing uploaded video: {str(e)}") |
| | finally: |
| | |
| | await file.seek(0) |
| |
|
| | @app.get("/download/{video_id}") |
| | async def download_video( |
| | video_id: str, |
| | current_user: User = Depends(get_current_user) |
| | ): |
| | """ |
| | Download a previously uploaded video |
| | """ |
| | |
| | video_data = mongodb.videos.find_one({"video_id": video_id}) |
| | |
| | if not video_data: |
| | raise HTTPException(status_code=404, detail="Video not found") |
| | |
| | |
| | if video_data["user_id"] != current_user.username: |
| | raise HTTPException(status_code=403, detail="Not authorized to access this video") |
| | |
| | |
| | if video_data["source_type"] == "youtube": |
| | return {"message": "This is a YouTube video. Please use the original URL to access the video.", "url": video_data["source_url"]} |
| | |
| | |
| | |
| | video_files = [f for f in os.listdir(VIDEOS_DIR) if f.startswith(video_id)] |
| | |
| | if not video_files: |
| | raise HTTPException(status_code=404, detail="Video file not found") |
| | |
| | file_path = os.path.join(VIDEOS_DIR, video_files[0]) |
| | |
| | |
| | file_extension = os.path.splitext(video_files[0])[1] |
| | mime_type = f"video/{file_extension[1:]}" if file_extension else "video/mp4" |
| | |
| | |
| | def iterfile(): |
| | with open(file_path, "rb") as f: |
| | while chunk := f.read(8192): |
| | yield chunk |
| | |
| | return StreamingResponse( |
| | iterfile(), |
| | media_type=mime_type, |
| | headers={"Content-Disposition": f"attachment; filename={video_data['title']}{file_extension}"} |
| | ) |
| |
|
| | @app.post("/query", response_model=QueryResponse) |
| | async def query_system( |
| | request: QueryRequest, |
| | current_user: User = Depends(get_current_user) |
| | ): |
| | """ |
| | Query the RAG system with a question |
| | """ |
| | try: |
| | session_id = request.session_id |
| | |
| | |
| | if not session_id or session_id not in sessions: |
| | raise HTTPException(status_code=404, detail="Session not found. Please transcribe a video first.") |
| | |
| | |
| | video_data = mongodb.videos.find_one({"video_id": session_id}) |
| | if not video_data or video_data["user_id"] != current_user.username: |
| | raise HTTPException(status_code=403, detail="Not authorized to access this session") |
| | |
| | |
| | session = sessions[session_id] |
| | retriever = session["retriever"] |
| | |
| | |
| | chat_history = chat_manager.initialize_chat_history(session_id) |
| | |
| | |
| | chain = create_chain(retriever) |
| | |
| | |
| | messages = chat_history.messages |
| | |
| | |
| | langchain_chat_history = [] |
| | |
| | |
| | if messages: |
| | |
| | |
| | i = 0 |
| | while i < len(messages) - 1: |
| | user_message = messages[i].content |
| | ai_message = messages[i+1].content |
| | langchain_chat_history.append((user_message, ai_message)) |
| | i += 2 |
| | |
| | |
| | print(f"Chat history length: {len(langchain_chat_history)}") |
| | print(f"Query: {request.query}") |
| | |
| | try: |
| | |
| | result = chain.invoke({ |
| | "question": request.query, |
| | "chat_history": langchain_chat_history |
| | }) |
| | |
| | |
| | answer = result.get("answer", "I couldn't find an answer to your question.") |
| | |
| | |
| | chat_history.add_user_message(request.query) |
| | chat_history.add_ai_message(answer) |
| | |
| | |
| | source_docs = [] |
| | if "source_documents" in result and result["source_documents"]: |
| | for doc in result["source_documents"]: |
| | try: |
| | |
| | if hasattr(doc, 'page_content'): |
| | |
| | content = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content |
| | source_docs.append(content) |
| | elif isinstance(doc, dict) and 'page_content' in doc: |
| | |
| | content = doc['page_content'][:100] + "..." if len(doc['page_content']) > 100 else doc['page_content'] |
| | source_docs.append(content) |
| | elif isinstance(doc, str): |
| | |
| | content = doc[:100] + "..." if len(doc) > 100 else doc |
| | source_docs.append(content) |
| | except Exception as doc_error: |
| | print(f"Error processing source document: {str(doc_error)}") |
| | |
| | return { |
| | "answer": answer, |
| | "session_id": session_id, |
| | "source_documents": source_docs |
| | } |
| | |
| | except Exception as chain_error: |
| | print(f"Chain invocation error: {str(chain_error)}") |
| | |
| | fallback_answer = "I apologize, but I encountered an error while processing your question. Please try rephrasing your query or asking about a different topic." |
| | |
| | |
| | chat_history.add_user_message(request.query) |
| | chat_history.add_ai_message(fallback_answer) |
| | |
| | return { |
| | "answer": fallback_answer, |
| | "session_id": session_id, |
| | "source_documents": [] |
| | } |
| | |
| | except Exception as e: |
| | print(f"Query system error: {str(e)}") |
| | import traceback |
| | traceback.print_exc() |
| | raise HTTPException(status_code=500, detail=f"Error querying system: {str(e)}") |
| |
|
| | |
| | @app.get("/sessions", response_model=List[Dict[str, Any]]) |
| | async def get_user_sessions(current_user: User = Depends(get_current_user)): |
| | """ |
| | Get all video sessions for the current user |
| | """ |
| | user_videos = list(mongodb.videos.find({"user_id": current_user.username})) |
| | |
| | |
| | sessions_list = [] |
| | for video in user_videos: |
| | sessions_list.append({ |
| | "session_id": video["video_id"], |
| | "title": video["title"], |
| | "source_type": video["source_type"], |
| | "created_at": video["created_at"], |
| | "transcription_preview": video["transcription"][:200] + "..." if len(video["transcription"]) > 200 else video["transcription"] |
| | }) |
| | |
| | return sessions_list |
| |
|
| | @app.get("/sessions/{session_id}", response_model=Dict[str, Any]) |
| | async def get_session_info( |
| | session_id: str, |
| | current_user: User = Depends(get_current_user) |
| | ): |
| | """ |
| | Get information about a specific session |
| | """ |
| | |
| | video_data = mongodb.videos.find_one({"video_id": session_id}) |
| | |
| | if not video_data: |
| | raise HTTPException(status_code=404, detail="Session not found") |
| | |
| | |
| | if video_data["user_id"] != current_user.username: |
| | raise HTTPException(status_code=403, detail="Not authorized to access this session") |
| | |
| | |
| | chat_history_obj = chat_manager.get_chat_history(session_id) |
| | chat_messages = [] |
| | |
| | if chat_history_obj: |
| | messages = chat_history_obj.messages |
| | for i in range(0, len(messages), 2): |
| | if i+1 < len(messages): |
| | chat_messages.append({ |
| | "question": messages[i].content, |
| | "answer": messages[i+1].content |
| | }) |
| | |
| | return { |
| | "session_id": session_id, |
| | "title": video_data["title"], |
| | "source_type": video_data["source_type"], |
| | "source_url": video_data.get("source_url"), |
| | "created_at": video_data["created_at"], |
| | "transcription_preview": video_data["transcription"][:200] + "..." if len(video_data["transcription"]) > 200 else video_data["transcription"], |
| | "full_transcription": video_data["transcription"], |
| | "chat_history": chat_messages |
| | } |
| |
|
| | @app.delete("/sessions/{session_id}") |
| | async def delete_session( |
| | session_id: str, |
| | current_user: User = Depends(get_current_user) |
| | ): |
| | """ |
| | Delete a session |
| | """ |
| | |
| | video_data = mongodb.videos.find_one({"video_id": session_id}) |
| | |
| | if not video_data: |
| | raise HTTPException(status_code=404, detail="Session not found") |
| | |
| | |
| | if video_data["user_id"] != current_user.username: |
| | raise HTTPException(status_code=403, detail="Not authorized to access this session") |
| | |
| | |
| | mongodb.videos.delete_one({"video_id": session_id}) |
| | |
| | |
| | chat_history = chat_manager.get_chat_history(session_id) |
| | if chat_history: |
| | |
| | mongodb.db[CHAT_COLLECTION].delete_many({"session_id": session_id}) |
| | |
| | |
| | if session_id in sessions: |
| | del sessions[session_id] |
| | |
| | |
| | video_files = [f for f in os.listdir(VIDEOS_DIR) if f.startswith(session_id)] |
| | for file in video_files: |
| | try: |
| | os.remove(os.path.join(VIDEOS_DIR, file)) |
| | except: |
| | pass |
| | |
| | return {"message": f"Session {session_id} deleted successfully"} |
| |
|
| | @app.get("/") |
| | async def root(): |
| | """ |
| | API root endpoint |
| | """ |
| | return { |
| | "message": "Video Transcription and QA API", |
| | "endpoints": { |
| | "/register": "Register a new user", |
| | "/token": "Login and get access token", |
| | "/transcribe": "Transcribe YouTube videos", |
| | "/upload": "Upload and transcribe video files (max 20MB)", |
| | "/download/{video_id}": "Download an uploaded video", |
| | "/query": "Query the RAG system", |
| | "/sessions": "List all user sessions", |
| | "/sessions/{session_id}": "Get session information", |
| | } |
| | } |
| |
|
| | @app.on_event("shutdown") |
| | def shutdown_event(): |
| | mongodb.close() |
| | |
| | shutil.rmtree(VIDEOS_DIR, ignore_errors=True) |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |