Spaces:
Running
Running
| """LangGraph decision flow — the core intelligence engine with streaming support.""" | |
| import json | |
| from typing import TypedDict, Optional, Generator | |
| from langgraph.graph import StateGraph, END | |
| from langchain_groq import ChatGroq | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from app.config import GROQ_API_KEY, LLM_MODEL, LLM_MODEL_FAST, CONFIDENCE_THRESHOLD | |
| from app.prompts import ( | |
| SYSTEM_PROMPT, | |
| INTENT_ANALYSIS_PROMPT, | |
| ANSWER_PROMPT, | |
| REASON_PROMPT, | |
| CLARIFY_PROMPT, | |
| QUERY_NORMALIZE_PROMPT, | |
| ) | |
| from app.retriever import get_retriever | |
| from app.session import Session | |
| # --- State Schema --- | |
| class GraphState(TypedDict): | |
| user_message: str | |
| session: Session | |
| intent: str | |
| topic: str | |
| needs_retrieval: bool | |
| retrieved_docs: list[dict] | |
| confidence: float | |
| answer: str | |
| follow_up_question: Optional[str] | |
| action_taken: str | |
| # --- LLM Instances --- | |
| def get_llm(): | |
| """Main LLM for answer generation (high quality).""" | |
| return ChatGroq( | |
| api_key=GROQ_API_KEY, | |
| model_name=LLM_MODEL, | |
| temperature=0.3, | |
| max_tokens=1024, | |
| ) | |
| def get_llm_fast(): | |
| """Fast LLM for intent analysis (low latency).""" | |
| return ChatGroq( | |
| api_key=GROQ_API_KEY, | |
| model_name=LLM_MODEL_FAST, | |
| temperature=0.1, | |
| max_tokens=256, | |
| ) | |
| # --- Node Functions --- | |
| def analyze_intent(state: GraphState) -> GraphState: | |
| """Analyze user intent using fast LLM and session context.""" | |
| llm = get_llm_fast() # Use fast model for speed | |
| session = state["session"] | |
| prompt = INTENT_ANALYSIS_PROMPT.format( | |
| history=session.get_history_text(), | |
| user_message=state["user_message"], | |
| ) | |
| response = llm.invoke([ | |
| SystemMessage(content=SYSTEM_PROMPT), | |
| HumanMessage(content=prompt), | |
| ]) | |
| text = response.content.strip() | |
| # Parse LLM response | |
| intent = "general" | |
| topic = state["user_message"] | |
| needs_retrieval = True | |
| for line in text.split("\n"): | |
| line = line.strip() | |
| if line.startswith("INTENT:"): | |
| intent = line.split(":", 1)[1].strip().lower() | |
| elif line.startswith("TOPIC:"): | |
| topic = line.split(":", 1)[1].strip() | |
| elif line.startswith("NEEDS_RETRIEVAL:"): | |
| val = line.split(":", 1)[1].strip().lower() | |
| needs_retrieval = val in ("yes", "true") | |
| # Update session | |
| session.current_topic = topic | |
| return { | |
| **state, | |
| "intent": intent, | |
| "topic": topic, | |
| "needs_retrieval": needs_retrieval, | |
| } | |
| def retrieve_docs(state: GraphState) -> GraphState: | |
| """Perform vector search against Qdrant.""" | |
| retriever = get_retriever() | |
| # Use topic for more focused retrieval, fall back to raw message | |
| query = state.get("topic") or state["user_message"] | |
| results = retriever.search(query) | |
| return { | |
| **state, | |
| "retrieved_docs": results, | |
| } | |
| def evaluate_confidence(state: GraphState) -> GraphState: | |
| """Evaluate retrieval confidence based on similarity scores.""" | |
| docs = state.get("retrieved_docs", []) | |
| if not docs: | |
| confidence = 0.0 | |
| else: | |
| # Average of top scores | |
| top_scores = [d["score"] for d in docs[:3]] | |
| confidence = sum(top_scores) / len(top_scores) | |
| state["session"].last_confidence = confidence | |
| return { | |
| **state, | |
| "confidence": confidence, | |
| } | |
| def generate_answer(state: GraphState) -> GraphState: | |
| """Generate an answer grounded in retrieved documents (non-streaming).""" | |
| llm = get_llm() | |
| session = state["session"] | |
| docs = state.get("retrieved_docs", []) | |
| # Build context from retrieved docs | |
| context_parts = [] | |
| for i, doc in enumerate(docs, 1): | |
| source = doc["metadata"].get("source_file", "unknown") | |
| context_parts.append(f"[Source: {source}]\n{doc['text']}") | |
| context = "\n\n---\n\n".join(context_parts) | |
| prompt = ANSWER_PROMPT.format( | |
| context=context, | |
| history=session.get_history_text(), | |
| user_message=state["user_message"], | |
| ) | |
| response = llm.invoke([ | |
| SystemMessage(content=SYSTEM_PROMPT), | |
| HumanMessage(content=prompt), | |
| ]) | |
| session.last_action = "retrieve" | |
| session.last_confidence = state.get("confidence", 0.0) | |
| return { | |
| **state, | |
| "answer": response.content.strip(), | |
| "follow_up_question": None, | |
| "action_taken": "retrieve", | |
| } | |
| def reason_answer(state: GraphState) -> GraphState: | |
| """Generate a reasoned answer when retrieval is insufficient (non-streaming).""" | |
| llm = get_llm() | |
| session = state["session"] | |
| docs = state.get("retrieved_docs", []) | |
| # Include any partial context | |
| context_parts = [] | |
| for doc in docs: | |
| source = doc["metadata"].get("source_file", "unknown") | |
| context_parts.append(f"[Source: {source}]\n{doc['text']}") | |
| context = "\n\n---\n\n".join(context_parts) if context_parts else "No relevant documents found." | |
| prompt = REASON_PROMPT.format( | |
| context=context, | |
| history=session.get_history_text(), | |
| user_message=state["user_message"], | |
| ) | |
| response = llm.invoke([ | |
| SystemMessage(content=SYSTEM_PROMPT), | |
| HumanMessage(content=prompt), | |
| ]) | |
| session.last_action = "reason" | |
| session.last_confidence = state.get("confidence", 0.0) | |
| return { | |
| **state, | |
| "answer": response.content.strip(), | |
| "follow_up_question": None, | |
| "action_taken": "reason", | |
| } | |
| def clarify(state: GraphState) -> GraphState: | |
| """Generate a clarifying follow-up question.""" | |
| llm = get_llm_fast() # Use fast model for clarification too | |
| session = state["session"] | |
| prompt = CLARIFY_PROMPT.format( | |
| history=session.get_history_text(), | |
| user_message=state["user_message"], | |
| ) | |
| response = llm.invoke([ | |
| SystemMessage(content=SYSTEM_PROMPT), | |
| HumanMessage(content=prompt), | |
| ]) | |
| follow_up = response.content.strip() | |
| session.last_action = "clarify" | |
| session.pending_clarification = follow_up | |
| return { | |
| **state, | |
| "answer": "I'd like to help you better. Let me ask a quick question:", | |
| "follow_up_question": follow_up, | |
| "action_taken": "clarify", | |
| } | |
| # --- Routing Functions --- | |
| def route_after_intent(state: GraphState) -> str: | |
| """Route based on intent analysis: retrieve, reason, or clarify.""" | |
| intent = state.get("intent", "general") | |
| if intent == "unclear": | |
| return "clarify" | |
| elif intent == "greeting": | |
| return "reason_answer" | |
| elif state.get("needs_retrieval", True): | |
| return "retrieve_docs" | |
| else: | |
| return "reason_answer" | |
| def route_after_confidence(state: GraphState) -> str: | |
| """Route based on retrieval confidence score.""" | |
| confidence = state.get("confidence", 0.0) | |
| intent = state.get("intent", "general") | |
| if confidence >= CONFIDENCE_THRESHOLD: | |
| return "generate_answer" | |
| elif confidence > 0.2: | |
| return "reason_answer" | |
| else: | |
| # Very low confidence — might need clarification | |
| if intent == "unclear": | |
| return "clarify" | |
| return "reason_answer" | |
| # --- Build the Graph --- | |
| def build_graph() -> StateGraph: | |
| """Build and compile the LangGraph decision flow.""" | |
| workflow = StateGraph(GraphState) | |
| # Add nodes | |
| workflow.add_node("analyze_intent", analyze_intent) | |
| workflow.add_node("retrieve_docs", retrieve_docs) | |
| workflow.add_node("evaluate_confidence", evaluate_confidence) | |
| workflow.add_node("generate_answer", generate_answer) | |
| workflow.add_node("reason_answer", reason_answer) | |
| workflow.add_node("clarify", clarify) | |
| # Set entry point | |
| workflow.set_entry_point("analyze_intent") | |
| # Conditional edge from intent analysis | |
| workflow.add_conditional_edges( | |
| "analyze_intent", | |
| route_after_intent, | |
| { | |
| "retrieve_docs": "retrieve_docs", | |
| "reason_answer": "reason_answer", | |
| "clarify": "clarify", | |
| }, | |
| ) | |
| # Retrieval → Confidence evaluation | |
| workflow.add_edge("retrieve_docs", "evaluate_confidence") | |
| # Conditional edge from confidence evaluation | |
| workflow.add_conditional_edges( | |
| "evaluate_confidence", | |
| route_after_confidence, | |
| { | |
| "generate_answer": "generate_answer", | |
| "reason_answer": "reason_answer", | |
| "clarify": "clarify", | |
| }, | |
| ) | |
| # Terminal edges | |
| workflow.add_edge("generate_answer", END) | |
| workflow.add_edge("reason_answer", END) | |
| workflow.add_edge("clarify", END) | |
| return workflow.compile() | |
| # Global compiled graph | |
| chatbot_graph = build_graph() | |
| # --- Non-streaming entry point (kept for backward compat) --- | |
| def run_chat(session: Session, user_message: str) -> dict: | |
| """ | |
| Run the full chatbot flow for a user message. | |
| Returns: {"answer": str, "follow_up_question": str | None} | |
| """ | |
| session.add_user_message(user_message) | |
| if session.pending_clarification: | |
| session.pending_clarification = None | |
| initial_state: GraphState = { | |
| "user_message": user_message, | |
| "session": session, | |
| "intent": "", | |
| "topic": "", | |
| "needs_retrieval": True, | |
| "retrieved_docs": [], | |
| "confidence": 0.0, | |
| "answer": "", | |
| "follow_up_question": None, | |
| "action_taken": "", | |
| } | |
| result = chatbot_graph.invoke(initial_state) | |
| answer = result.get("answer", "I'm sorry, I couldn't process your request.") | |
| follow_up = result.get("follow_up_question") | |
| full_response = answer | |
| if follow_up: | |
| full_response += f"\n\n{follow_up}" | |
| session.add_assistant_message(full_response) | |
| return { | |
| "answer": answer, | |
| "follow_up_question": follow_up, | |
| } | |
| # ============================================================================= | |
| # STREAMING ENTRY POINT | |
| # ============================================================================= | |
| def _extract_sources(docs: list[dict]) -> list[dict]: | |
| """Extract unique source metadata from retrieved docs.""" | |
| seen = set() | |
| sources = [] | |
| for doc in docs: | |
| meta = doc.get("metadata", {}) | |
| source_file = meta.get("source_file", "") | |
| if source_file and source_file not in seen: | |
| seen.add(source_file) | |
| sources.append({ | |
| "file": source_file, | |
| "folder": meta.get("folder", ""), | |
| "department": meta.get("department", ""), | |
| "score": round(doc.get("score", 0), 3), | |
| }) | |
| return sources | |
| def _classify_intent_fast(user_message: str, session: Session) -> tuple[str, bool]: | |
| """ | |
| Ultra-fast local intent classification (no LLM call). | |
| Returns (intent, needs_retrieval). | |
| """ | |
| msg = user_message.lower().strip() | |
| words = msg.split() | |
| # Greetings | |
| greetings = {"hi", "hello", "hey", "howdy", "greetings", "good morning", | |
| "good afternoon", "good evening", "thanks", "thank you", "bye", "goodbye"} | |
| if msg in greetings or (len(words) <= 2 and words[0] in greetings): | |
| return "greeting", False | |
| # Too vague (under 3 words, no real nouns) | |
| if len(words) <= 2 and not any(w in msg for w in [ | |
| "stacklogix", "feature", "report", "dashboard", "ai", "ml", | |
| "purchase", "jewellery", "jewelry", "gold", "diamond", "master", | |
| "retail", "wholesale", "ecommerce", "manufacturer", "supply", | |
| "price", "inventory", "model", "train", "monitoring" | |
| ]): | |
| return "unclear", False | |
| # Everything else → retrieve | |
| return "factual", True | |
| def _run_decision_phase(session: Session, user_message: str) -> dict: | |
| """ | |
| Fast decision phase: local intent classification + retrieval (no LLM call). | |
| Returns the prepared state with all info needed to stream the final answer. | |
| """ | |
| session.add_user_message(user_message) | |
| if session.pending_clarification: | |
| session.pending_clarification = None | |
| # --- Step 1: Fast local intent classification (instant) --- | |
| intent, needs_retrieval = _classify_intent_fast(user_message, session) | |
| topic = user_message | |
| session.current_topic = topic | |
| # --- Step 2: AI Query Normalization (fast 8b model) --- | |
| normalized_query = user_message | |
| if needs_retrieval: | |
| try: | |
| llm_fast = get_llm_fast() | |
| norm_prompt = QUERY_NORMALIZE_PROMPT.format( | |
| history=session.get_history_text(), | |
| user_message=user_message, | |
| ) | |
| norm_resp = llm_fast.invoke([HumanMessage(content=norm_prompt)]) | |
| normalized_query = norm_resp.content.strip().strip('"').strip("'") | |
| if normalized_query: | |
| print(f" Query normalized: '{user_message}' → '{normalized_query}'") | |
| else: | |
| normalized_query = user_message | |
| except Exception as e: | |
| print(f" Query normalization failed: {e}, using original") | |
| normalized_query = user_message | |
| # --- Step 3: Route decision --- | |
| retrieved_docs = [] | |
| confidence = 0.0 | |
| if intent == "unclear": | |
| action = "clarify" | |
| elif intent == "greeting": | |
| action = "reason" | |
| elif needs_retrieval: | |
| # Retrieve docs using normalized query | |
| retriever = get_retriever() | |
| retrieved_docs = retriever.search(normalized_query) | |
| # Evaluate confidence | |
| if retrieved_docs: | |
| top_scores = [d["score"] for d in retrieved_docs[:3]] | |
| confidence = sum(top_scores) / len(top_scores) | |
| session.last_confidence = confidence | |
| if confidence >= CONFIDENCE_THRESHOLD: | |
| action = "retrieve" | |
| elif confidence > 0.2: | |
| action = "reason" | |
| else: | |
| action = "reason" | |
| else: | |
| action = "reason" | |
| return { | |
| "user_message": user_message, | |
| "intent": intent, | |
| "topic": topic, | |
| "retrieved_docs": retrieved_docs, | |
| "confidence": confidence, | |
| "action": action, | |
| "session": session, | |
| } | |
| def run_chat_streaming(session: Session, user_message: str) -> Generator[str, None, None]: | |
| """ | |
| Streaming chatbot flow. Yields SSE-formatted events: | |
| data: {"type": "token", "content": "..."} | |
| data: {"type": "sources", "sources": [...]} | |
| data: {"type": "follow_up", "content": "..."} | |
| data: {"type": "done"} | |
| """ | |
| # Phase 1: Decision (fast, non-streaming) | |
| decision = _run_decision_phase(session, user_message) | |
| action = decision["action"] | |
| docs = decision["retrieved_docs"] | |
| session_obj = decision["session"] | |
| # Phase 2: Stream the response | |
| if action == "clarify": | |
| # Clarification — use fast model, no streaming needed (short response) | |
| llm_fast = get_llm_fast() | |
| prompt = CLARIFY_PROMPT.format( | |
| history=session_obj.get_history_text(), | |
| user_message=user_message, | |
| ) | |
| response = llm_fast.invoke([ | |
| SystemMessage(content=SYSTEM_PROMPT), | |
| HumanMessage(content=prompt), | |
| ]) | |
| follow_up = response.content.strip() | |
| session_obj.last_action = "clarify" | |
| session_obj.pending_clarification = follow_up | |
| clarify_msg = "I'd like to help you better. Let me ask a quick question:" | |
| yield f"data: {json.dumps({'type': 'token', 'content': clarify_msg})}\n\n" | |
| yield f"data: {json.dumps({'type': 'follow_up', 'content': follow_up})}\n\n" | |
| session_obj.add_assistant_message(f"{clarify_msg}\n\n{follow_up}") | |
| yield f"data: {json.dumps({'type': 'done'})}\n\n" | |
| return | |
| # Build context for answer/reason | |
| context_parts = [] | |
| for doc in docs: | |
| source = doc["metadata"].get("source_file", "unknown") | |
| context_parts.append(f"[Source: {source}]\n{doc['text']}") | |
| context = "\n\n---\n\n".join(context_parts) if context_parts else "No relevant documents found." | |
| if action == "retrieve": | |
| prompt = ANSWER_PROMPT.format( | |
| context=context, | |
| history=session_obj.get_history_text(), | |
| user_message=user_message, | |
| ) | |
| session_obj.last_action = "retrieve" | |
| else: | |
| prompt = REASON_PROMPT.format( | |
| context=context, | |
| history=session_obj.get_history_text(), | |
| user_message=user_message, | |
| ) | |
| session_obj.last_action = "reason" | |
| session_obj.last_confidence = decision["confidence"] | |
| # Stream LLM response token-by-token | |
| llm = get_llm() | |
| full_answer = "" | |
| for chunk in llm.stream([ | |
| SystemMessage(content=SYSTEM_PROMPT), | |
| HumanMessage(content=prompt), | |
| ]): | |
| token = chunk.content | |
| if token: | |
| full_answer += token | |
| yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n" | |
| # Send sources | |
| if docs: | |
| sources = _extract_sources(docs) | |
| yield f"data: {json.dumps({'type': 'sources', 'sources': sources})}\n\n" | |
| # Save to session history | |
| session_obj.add_assistant_message(full_answer.strip()) | |
| yield f"data: {json.dumps({'type': 'done'})}\n\n" | |