"""LangGraph decision flow — the core intelligence engine with streaming support.""" import json from typing import TypedDict, Optional, Generator from langgraph.graph import StateGraph, END from langchain_groq import ChatGroq from langchain_core.messages import SystemMessage, HumanMessage from app.config import GROQ_API_KEY, LLM_MODEL, LLM_MODEL_FAST, CONFIDENCE_THRESHOLD from app.prompts import ( SYSTEM_PROMPT, INTENT_ANALYSIS_PROMPT, ANSWER_PROMPT, REASON_PROMPT, CLARIFY_PROMPT, QUERY_NORMALIZE_PROMPT, ) from app.retriever import get_retriever from app.session import Session # --- State Schema --- class GraphState(TypedDict): user_message: str session: Session intent: str topic: str needs_retrieval: bool retrieved_docs: list[dict] confidence: float answer: str follow_up_question: Optional[str] action_taken: str # --- LLM Instances --- def get_llm(): """Main LLM for answer generation (high quality).""" return ChatGroq( api_key=GROQ_API_KEY, model_name=LLM_MODEL, temperature=0.3, max_tokens=1024, ) def get_llm_fast(): """Fast LLM for intent analysis (low latency).""" return ChatGroq( api_key=GROQ_API_KEY, model_name=LLM_MODEL_FAST, temperature=0.1, max_tokens=256, ) # --- Node Functions --- def analyze_intent(state: GraphState) -> GraphState: """Analyze user intent using fast LLM and session context.""" llm = get_llm_fast() # Use fast model for speed session = state["session"] prompt = INTENT_ANALYSIS_PROMPT.format( history=session.get_history_text(), user_message=state["user_message"], ) response = llm.invoke([ SystemMessage(content=SYSTEM_PROMPT), HumanMessage(content=prompt), ]) text = response.content.strip() # Parse LLM response intent = "general" topic = state["user_message"] needs_retrieval = True for line in text.split("\n"): line = line.strip() if line.startswith("INTENT:"): intent = line.split(":", 1)[1].strip().lower() elif line.startswith("TOPIC:"): topic = line.split(":", 1)[1].strip() elif line.startswith("NEEDS_RETRIEVAL:"): val = line.split(":", 1)[1].strip().lower() needs_retrieval = val in ("yes", "true") # Update session session.current_topic = topic return { **state, "intent": intent, "topic": topic, "needs_retrieval": needs_retrieval, } def retrieve_docs(state: GraphState) -> GraphState: """Perform vector search against Qdrant.""" retriever = get_retriever() # Use topic for more focused retrieval, fall back to raw message query = state.get("topic") or state["user_message"] results = retriever.search(query) return { **state, "retrieved_docs": results, } def evaluate_confidence(state: GraphState) -> GraphState: """Evaluate retrieval confidence based on similarity scores.""" docs = state.get("retrieved_docs", []) if not docs: confidence = 0.0 else: # Average of top scores top_scores = [d["score"] for d in docs[:3]] confidence = sum(top_scores) / len(top_scores) state["session"].last_confidence = confidence return { **state, "confidence": confidence, } def generate_answer(state: GraphState) -> GraphState: """Generate an answer grounded in retrieved documents (non-streaming).""" llm = get_llm() session = state["session"] docs = state.get("retrieved_docs", []) # Build context from retrieved docs context_parts = [] for i, doc in enumerate(docs, 1): source = doc["metadata"].get("source_file", "unknown") context_parts.append(f"[Source: {source}]\n{doc['text']}") context = "\n\n---\n\n".join(context_parts) prompt = ANSWER_PROMPT.format( context=context, history=session.get_history_text(), user_message=state["user_message"], ) response = llm.invoke([ SystemMessage(content=SYSTEM_PROMPT), HumanMessage(content=prompt), ]) session.last_action = "retrieve" session.last_confidence = state.get("confidence", 0.0) return { **state, "answer": response.content.strip(), "follow_up_question": None, "action_taken": "retrieve", } def reason_answer(state: GraphState) -> GraphState: """Generate a reasoned answer when retrieval is insufficient (non-streaming).""" llm = get_llm() session = state["session"] docs = state.get("retrieved_docs", []) # Include any partial context context_parts = [] for doc in docs: source = doc["metadata"].get("source_file", "unknown") context_parts.append(f"[Source: {source}]\n{doc['text']}") context = "\n\n---\n\n".join(context_parts) if context_parts else "No relevant documents found." prompt = REASON_PROMPT.format( context=context, history=session.get_history_text(), user_message=state["user_message"], ) response = llm.invoke([ SystemMessage(content=SYSTEM_PROMPT), HumanMessage(content=prompt), ]) session.last_action = "reason" session.last_confidence = state.get("confidence", 0.0) return { **state, "answer": response.content.strip(), "follow_up_question": None, "action_taken": "reason", } def clarify(state: GraphState) -> GraphState: """Generate a clarifying follow-up question.""" llm = get_llm_fast() # Use fast model for clarification too session = state["session"] prompt = CLARIFY_PROMPT.format( history=session.get_history_text(), user_message=state["user_message"], ) response = llm.invoke([ SystemMessage(content=SYSTEM_PROMPT), HumanMessage(content=prompt), ]) follow_up = response.content.strip() session.last_action = "clarify" session.pending_clarification = follow_up return { **state, "answer": "I'd like to help you better. Let me ask a quick question:", "follow_up_question": follow_up, "action_taken": "clarify", } # --- Routing Functions --- def route_after_intent(state: GraphState) -> str: """Route based on intent analysis: retrieve, reason, or clarify.""" intent = state.get("intent", "general") if intent == "unclear": return "clarify" elif intent == "greeting": return "reason_answer" elif state.get("needs_retrieval", True): return "retrieve_docs" else: return "reason_answer" def route_after_confidence(state: GraphState) -> str: """Route based on retrieval confidence score.""" confidence = state.get("confidence", 0.0) intent = state.get("intent", "general") if confidence >= CONFIDENCE_THRESHOLD: return "generate_answer" elif confidence > 0.2: return "reason_answer" else: # Very low confidence — might need clarification if intent == "unclear": return "clarify" return "reason_answer" # --- Build the Graph --- def build_graph() -> StateGraph: """Build and compile the LangGraph decision flow.""" workflow = StateGraph(GraphState) # Add nodes workflow.add_node("analyze_intent", analyze_intent) workflow.add_node("retrieve_docs", retrieve_docs) workflow.add_node("evaluate_confidence", evaluate_confidence) workflow.add_node("generate_answer", generate_answer) workflow.add_node("reason_answer", reason_answer) workflow.add_node("clarify", clarify) # Set entry point workflow.set_entry_point("analyze_intent") # Conditional edge from intent analysis workflow.add_conditional_edges( "analyze_intent", route_after_intent, { "retrieve_docs": "retrieve_docs", "reason_answer": "reason_answer", "clarify": "clarify", }, ) # Retrieval → Confidence evaluation workflow.add_edge("retrieve_docs", "evaluate_confidence") # Conditional edge from confidence evaluation workflow.add_conditional_edges( "evaluate_confidence", route_after_confidence, { "generate_answer": "generate_answer", "reason_answer": "reason_answer", "clarify": "clarify", }, ) # Terminal edges workflow.add_edge("generate_answer", END) workflow.add_edge("reason_answer", END) workflow.add_edge("clarify", END) return workflow.compile() # Global compiled graph chatbot_graph = build_graph() # --- Non-streaming entry point (kept for backward compat) --- def run_chat(session: Session, user_message: str) -> dict: """ Run the full chatbot flow for a user message. Returns: {"answer": str, "follow_up_question": str | None} """ session.add_user_message(user_message) if session.pending_clarification: session.pending_clarification = None initial_state: GraphState = { "user_message": user_message, "session": session, "intent": "", "topic": "", "needs_retrieval": True, "retrieved_docs": [], "confidence": 0.0, "answer": "", "follow_up_question": None, "action_taken": "", } result = chatbot_graph.invoke(initial_state) answer = result.get("answer", "I'm sorry, I couldn't process your request.") follow_up = result.get("follow_up_question") full_response = answer if follow_up: full_response += f"\n\n{follow_up}" session.add_assistant_message(full_response) return { "answer": answer, "follow_up_question": follow_up, } # ============================================================================= # STREAMING ENTRY POINT # ============================================================================= def _extract_sources(docs: list[dict]) -> list[dict]: """Extract unique source metadata from retrieved docs.""" seen = set() sources = [] for doc in docs: meta = doc.get("metadata", {}) source_file = meta.get("source_file", "") if source_file and source_file not in seen: seen.add(source_file) sources.append({ "file": source_file, "folder": meta.get("folder", ""), "department": meta.get("department", ""), "score": round(doc.get("score", 0), 3), }) return sources def _classify_intent_fast(user_message: str, session: Session) -> tuple[str, bool]: """ Ultra-fast local intent classification (no LLM call). Returns (intent, needs_retrieval). """ msg = user_message.lower().strip() words = msg.split() # Greetings greetings = {"hi", "hello", "hey", "howdy", "greetings", "good morning", "good afternoon", "good evening", "thanks", "thank you", "bye", "goodbye"} if msg in greetings or (len(words) <= 2 and words[0] in greetings): return "greeting", False # Too vague (under 3 words, no real nouns) if len(words) <= 2 and not any(w in msg for w in [ "stacklogix", "feature", "report", "dashboard", "ai", "ml", "purchase", "jewellery", "jewelry", "gold", "diamond", "master", "retail", "wholesale", "ecommerce", "manufacturer", "supply", "price", "inventory", "model", "train", "monitoring" ]): return "unclear", False # Everything else → retrieve return "factual", True def _run_decision_phase(session: Session, user_message: str) -> dict: """ Fast decision phase: local intent classification + retrieval (no LLM call). Returns the prepared state with all info needed to stream the final answer. """ session.add_user_message(user_message) if session.pending_clarification: session.pending_clarification = None # --- Step 1: Fast local intent classification (instant) --- intent, needs_retrieval = _classify_intent_fast(user_message, session) topic = user_message session.current_topic = topic # --- Step 2: AI Query Normalization (fast 8b model) --- normalized_query = user_message if needs_retrieval: try: llm_fast = get_llm_fast() norm_prompt = QUERY_NORMALIZE_PROMPT.format( history=session.get_history_text(), user_message=user_message, ) norm_resp = llm_fast.invoke([HumanMessage(content=norm_prompt)]) normalized_query = norm_resp.content.strip().strip('"').strip("'") if normalized_query: print(f" Query normalized: '{user_message}' → '{normalized_query}'") else: normalized_query = user_message except Exception as e: print(f" Query normalization failed: {e}, using original") normalized_query = user_message # --- Step 3: Route decision --- retrieved_docs = [] confidence = 0.0 if intent == "unclear": action = "clarify" elif intent == "greeting": action = "reason" elif needs_retrieval: # Retrieve docs using normalized query retriever = get_retriever() retrieved_docs = retriever.search(normalized_query) # Evaluate confidence if retrieved_docs: top_scores = [d["score"] for d in retrieved_docs[:3]] confidence = sum(top_scores) / len(top_scores) session.last_confidence = confidence if confidence >= CONFIDENCE_THRESHOLD: action = "retrieve" elif confidence > 0.2: action = "reason" else: action = "reason" else: action = "reason" return { "user_message": user_message, "intent": intent, "topic": topic, "retrieved_docs": retrieved_docs, "confidence": confidence, "action": action, "session": session, } def run_chat_streaming(session: Session, user_message: str) -> Generator[str, None, None]: """ Streaming chatbot flow. Yields SSE-formatted events: data: {"type": "token", "content": "..."} data: {"type": "sources", "sources": [...]} data: {"type": "follow_up", "content": "..."} data: {"type": "done"} """ # Phase 1: Decision (fast, non-streaming) decision = _run_decision_phase(session, user_message) action = decision["action"] docs = decision["retrieved_docs"] session_obj = decision["session"] # Phase 2: Stream the response if action == "clarify": # Clarification — use fast model, no streaming needed (short response) llm_fast = get_llm_fast() prompt = CLARIFY_PROMPT.format( history=session_obj.get_history_text(), user_message=user_message, ) response = llm_fast.invoke([ SystemMessage(content=SYSTEM_PROMPT), HumanMessage(content=prompt), ]) follow_up = response.content.strip() session_obj.last_action = "clarify" session_obj.pending_clarification = follow_up clarify_msg = "I'd like to help you better. Let me ask a quick question:" yield f"data: {json.dumps({'type': 'token', 'content': clarify_msg})}\n\n" yield f"data: {json.dumps({'type': 'follow_up', 'content': follow_up})}\n\n" session_obj.add_assistant_message(f"{clarify_msg}\n\n{follow_up}") yield f"data: {json.dumps({'type': 'done'})}\n\n" return # Build context for answer/reason context_parts = [] for doc in docs: source = doc["metadata"].get("source_file", "unknown") context_parts.append(f"[Source: {source}]\n{doc['text']}") context = "\n\n---\n\n".join(context_parts) if context_parts else "No relevant documents found." if action == "retrieve": prompt = ANSWER_PROMPT.format( context=context, history=session_obj.get_history_text(), user_message=user_message, ) session_obj.last_action = "retrieve" else: prompt = REASON_PROMPT.format( context=context, history=session_obj.get_history_text(), user_message=user_message, ) session_obj.last_action = "reason" session_obj.last_confidence = decision["confidence"] # Stream LLM response token-by-token llm = get_llm() full_answer = "" for chunk in llm.stream([ SystemMessage(content=SYSTEM_PROMPT), HumanMessage(content=prompt), ]): token = chunk.content if token: full_answer += token yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n" # Send sources if docs: sources = _extract_sources(docs) yield f"data: {json.dumps({'type': 'sources', 'sources': sources})}\n\n" # Save to session history session_obj.add_assistant_message(full_answer.strip()) yield f"data: {json.dumps({'type': 'done'})}\n\n"