stacklogix / app /graph.py
Deploy Bot
Deployment commit
6ca2339
"""LangGraph decision flow — the core intelligence engine with streaming support."""
import json
from typing import TypedDict, Optional, Generator
from langgraph.graph import StateGraph, END
from langchain_groq import ChatGroq
from langchain_core.messages import SystemMessage, HumanMessage
from app.config import GROQ_API_KEY, LLM_MODEL, LLM_MODEL_FAST, CONFIDENCE_THRESHOLD
from app.prompts import (
SYSTEM_PROMPT,
INTENT_ANALYSIS_PROMPT,
ANSWER_PROMPT,
REASON_PROMPT,
CLARIFY_PROMPT,
QUERY_NORMALIZE_PROMPT,
)
from app.retriever import get_retriever
from app.session import Session
# --- State Schema ---
class GraphState(TypedDict):
user_message: str
session: Session
intent: str
topic: str
needs_retrieval: bool
retrieved_docs: list[dict]
confidence: float
answer: str
follow_up_question: Optional[str]
action_taken: str
# --- LLM Instances ---
def get_llm():
"""Main LLM for answer generation (high quality)."""
return ChatGroq(
api_key=GROQ_API_KEY,
model_name=LLM_MODEL,
temperature=0.3,
max_tokens=1024,
)
def get_llm_fast():
"""Fast LLM for intent analysis (low latency)."""
return ChatGroq(
api_key=GROQ_API_KEY,
model_name=LLM_MODEL_FAST,
temperature=0.1,
max_tokens=256,
)
# --- Node Functions ---
def analyze_intent(state: GraphState) -> GraphState:
"""Analyze user intent using fast LLM and session context."""
llm = get_llm_fast() # Use fast model for speed
session = state["session"]
prompt = INTENT_ANALYSIS_PROMPT.format(
history=session.get_history_text(),
user_message=state["user_message"],
)
response = llm.invoke([
SystemMessage(content=SYSTEM_PROMPT),
HumanMessage(content=prompt),
])
text = response.content.strip()
# Parse LLM response
intent = "general"
topic = state["user_message"]
needs_retrieval = True
for line in text.split("\n"):
line = line.strip()
if line.startswith("INTENT:"):
intent = line.split(":", 1)[1].strip().lower()
elif line.startswith("TOPIC:"):
topic = line.split(":", 1)[1].strip()
elif line.startswith("NEEDS_RETRIEVAL:"):
val = line.split(":", 1)[1].strip().lower()
needs_retrieval = val in ("yes", "true")
# Update session
session.current_topic = topic
return {
**state,
"intent": intent,
"topic": topic,
"needs_retrieval": needs_retrieval,
}
def retrieve_docs(state: GraphState) -> GraphState:
"""Perform vector search against Qdrant."""
retriever = get_retriever()
# Use topic for more focused retrieval, fall back to raw message
query = state.get("topic") or state["user_message"]
results = retriever.search(query)
return {
**state,
"retrieved_docs": results,
}
def evaluate_confidence(state: GraphState) -> GraphState:
"""Evaluate retrieval confidence based on similarity scores."""
docs = state.get("retrieved_docs", [])
if not docs:
confidence = 0.0
else:
# Average of top scores
top_scores = [d["score"] for d in docs[:3]]
confidence = sum(top_scores) / len(top_scores)
state["session"].last_confidence = confidence
return {
**state,
"confidence": confidence,
}
def generate_answer(state: GraphState) -> GraphState:
"""Generate an answer grounded in retrieved documents (non-streaming)."""
llm = get_llm()
session = state["session"]
docs = state.get("retrieved_docs", [])
# Build context from retrieved docs
context_parts = []
for i, doc in enumerate(docs, 1):
source = doc["metadata"].get("source_file", "unknown")
context_parts.append(f"[Source: {source}]\n{doc['text']}")
context = "\n\n---\n\n".join(context_parts)
prompt = ANSWER_PROMPT.format(
context=context,
history=session.get_history_text(),
user_message=state["user_message"],
)
response = llm.invoke([
SystemMessage(content=SYSTEM_PROMPT),
HumanMessage(content=prompt),
])
session.last_action = "retrieve"
session.last_confidence = state.get("confidence", 0.0)
return {
**state,
"answer": response.content.strip(),
"follow_up_question": None,
"action_taken": "retrieve",
}
def reason_answer(state: GraphState) -> GraphState:
"""Generate a reasoned answer when retrieval is insufficient (non-streaming)."""
llm = get_llm()
session = state["session"]
docs = state.get("retrieved_docs", [])
# Include any partial context
context_parts = []
for doc in docs:
source = doc["metadata"].get("source_file", "unknown")
context_parts.append(f"[Source: {source}]\n{doc['text']}")
context = "\n\n---\n\n".join(context_parts) if context_parts else "No relevant documents found."
prompt = REASON_PROMPT.format(
context=context,
history=session.get_history_text(),
user_message=state["user_message"],
)
response = llm.invoke([
SystemMessage(content=SYSTEM_PROMPT),
HumanMessage(content=prompt),
])
session.last_action = "reason"
session.last_confidence = state.get("confidence", 0.0)
return {
**state,
"answer": response.content.strip(),
"follow_up_question": None,
"action_taken": "reason",
}
def clarify(state: GraphState) -> GraphState:
"""Generate a clarifying follow-up question."""
llm = get_llm_fast() # Use fast model for clarification too
session = state["session"]
prompt = CLARIFY_PROMPT.format(
history=session.get_history_text(),
user_message=state["user_message"],
)
response = llm.invoke([
SystemMessage(content=SYSTEM_PROMPT),
HumanMessage(content=prompt),
])
follow_up = response.content.strip()
session.last_action = "clarify"
session.pending_clarification = follow_up
return {
**state,
"answer": "I'd like to help you better. Let me ask a quick question:",
"follow_up_question": follow_up,
"action_taken": "clarify",
}
# --- Routing Functions ---
def route_after_intent(state: GraphState) -> str:
"""Route based on intent analysis: retrieve, reason, or clarify."""
intent = state.get("intent", "general")
if intent == "unclear":
return "clarify"
elif intent == "greeting":
return "reason_answer"
elif state.get("needs_retrieval", True):
return "retrieve_docs"
else:
return "reason_answer"
def route_after_confidence(state: GraphState) -> str:
"""Route based on retrieval confidence score."""
confidence = state.get("confidence", 0.0)
intent = state.get("intent", "general")
if confidence >= CONFIDENCE_THRESHOLD:
return "generate_answer"
elif confidence > 0.2:
return "reason_answer"
else:
# Very low confidence — might need clarification
if intent == "unclear":
return "clarify"
return "reason_answer"
# --- Build the Graph ---
def build_graph() -> StateGraph:
"""Build and compile the LangGraph decision flow."""
workflow = StateGraph(GraphState)
# Add nodes
workflow.add_node("analyze_intent", analyze_intent)
workflow.add_node("retrieve_docs", retrieve_docs)
workflow.add_node("evaluate_confidence", evaluate_confidence)
workflow.add_node("generate_answer", generate_answer)
workflow.add_node("reason_answer", reason_answer)
workflow.add_node("clarify", clarify)
# Set entry point
workflow.set_entry_point("analyze_intent")
# Conditional edge from intent analysis
workflow.add_conditional_edges(
"analyze_intent",
route_after_intent,
{
"retrieve_docs": "retrieve_docs",
"reason_answer": "reason_answer",
"clarify": "clarify",
},
)
# Retrieval → Confidence evaluation
workflow.add_edge("retrieve_docs", "evaluate_confidence")
# Conditional edge from confidence evaluation
workflow.add_conditional_edges(
"evaluate_confidence",
route_after_confidence,
{
"generate_answer": "generate_answer",
"reason_answer": "reason_answer",
"clarify": "clarify",
},
)
# Terminal edges
workflow.add_edge("generate_answer", END)
workflow.add_edge("reason_answer", END)
workflow.add_edge("clarify", END)
return workflow.compile()
# Global compiled graph
chatbot_graph = build_graph()
# --- Non-streaming entry point (kept for backward compat) ---
def run_chat(session: Session, user_message: str) -> dict:
"""
Run the full chatbot flow for a user message.
Returns: {"answer": str, "follow_up_question": str | None}
"""
session.add_user_message(user_message)
if session.pending_clarification:
session.pending_clarification = None
initial_state: GraphState = {
"user_message": user_message,
"session": session,
"intent": "",
"topic": "",
"needs_retrieval": True,
"retrieved_docs": [],
"confidence": 0.0,
"answer": "",
"follow_up_question": None,
"action_taken": "",
}
result = chatbot_graph.invoke(initial_state)
answer = result.get("answer", "I'm sorry, I couldn't process your request.")
follow_up = result.get("follow_up_question")
full_response = answer
if follow_up:
full_response += f"\n\n{follow_up}"
session.add_assistant_message(full_response)
return {
"answer": answer,
"follow_up_question": follow_up,
}
# =============================================================================
# STREAMING ENTRY POINT
# =============================================================================
def _extract_sources(docs: list[dict]) -> list[dict]:
"""Extract unique source metadata from retrieved docs."""
seen = set()
sources = []
for doc in docs:
meta = doc.get("metadata", {})
source_file = meta.get("source_file", "")
if source_file and source_file not in seen:
seen.add(source_file)
sources.append({
"file": source_file,
"folder": meta.get("folder", ""),
"department": meta.get("department", ""),
"score": round(doc.get("score", 0), 3),
})
return sources
def _classify_intent_fast(user_message: str, session: Session) -> tuple[str, bool]:
"""
Ultra-fast local intent classification (no LLM call).
Returns (intent, needs_retrieval).
"""
msg = user_message.lower().strip()
words = msg.split()
# Greetings
greetings = {"hi", "hello", "hey", "howdy", "greetings", "good morning",
"good afternoon", "good evening", "thanks", "thank you", "bye", "goodbye"}
if msg in greetings or (len(words) <= 2 and words[0] in greetings):
return "greeting", False
# Too vague (under 3 words, no real nouns)
if len(words) <= 2 and not any(w in msg for w in [
"stacklogix", "feature", "report", "dashboard", "ai", "ml",
"purchase", "jewellery", "jewelry", "gold", "diamond", "master",
"retail", "wholesale", "ecommerce", "manufacturer", "supply",
"price", "inventory", "model", "train", "monitoring"
]):
return "unclear", False
# Everything else → retrieve
return "factual", True
def _run_decision_phase(session: Session, user_message: str) -> dict:
"""
Fast decision phase: local intent classification + retrieval (no LLM call).
Returns the prepared state with all info needed to stream the final answer.
"""
session.add_user_message(user_message)
if session.pending_clarification:
session.pending_clarification = None
# --- Step 1: Fast local intent classification (instant) ---
intent, needs_retrieval = _classify_intent_fast(user_message, session)
topic = user_message
session.current_topic = topic
# --- Step 2: AI Query Normalization (fast 8b model) ---
normalized_query = user_message
if needs_retrieval:
try:
llm_fast = get_llm_fast()
norm_prompt = QUERY_NORMALIZE_PROMPT.format(
history=session.get_history_text(),
user_message=user_message,
)
norm_resp = llm_fast.invoke([HumanMessage(content=norm_prompt)])
normalized_query = norm_resp.content.strip().strip('"').strip("'")
if normalized_query:
print(f" Query normalized: '{user_message}' → '{normalized_query}'")
else:
normalized_query = user_message
except Exception as e:
print(f" Query normalization failed: {e}, using original")
normalized_query = user_message
# --- Step 3: Route decision ---
retrieved_docs = []
confidence = 0.0
if intent == "unclear":
action = "clarify"
elif intent == "greeting":
action = "reason"
elif needs_retrieval:
# Retrieve docs using normalized query
retriever = get_retriever()
retrieved_docs = retriever.search(normalized_query)
# Evaluate confidence
if retrieved_docs:
top_scores = [d["score"] for d in retrieved_docs[:3]]
confidence = sum(top_scores) / len(top_scores)
session.last_confidence = confidence
if confidence >= CONFIDENCE_THRESHOLD:
action = "retrieve"
elif confidence > 0.2:
action = "reason"
else:
action = "reason"
else:
action = "reason"
return {
"user_message": user_message,
"intent": intent,
"topic": topic,
"retrieved_docs": retrieved_docs,
"confidence": confidence,
"action": action,
"session": session,
}
def run_chat_streaming(session: Session, user_message: str) -> Generator[str, None, None]:
"""
Streaming chatbot flow. Yields SSE-formatted events:
data: {"type": "token", "content": "..."}
data: {"type": "sources", "sources": [...]}
data: {"type": "follow_up", "content": "..."}
data: {"type": "done"}
"""
# Phase 1: Decision (fast, non-streaming)
decision = _run_decision_phase(session, user_message)
action = decision["action"]
docs = decision["retrieved_docs"]
session_obj = decision["session"]
# Phase 2: Stream the response
if action == "clarify":
# Clarification — use fast model, no streaming needed (short response)
llm_fast = get_llm_fast()
prompt = CLARIFY_PROMPT.format(
history=session_obj.get_history_text(),
user_message=user_message,
)
response = llm_fast.invoke([
SystemMessage(content=SYSTEM_PROMPT),
HumanMessage(content=prompt),
])
follow_up = response.content.strip()
session_obj.last_action = "clarify"
session_obj.pending_clarification = follow_up
clarify_msg = "I'd like to help you better. Let me ask a quick question:"
yield f"data: {json.dumps({'type': 'token', 'content': clarify_msg})}\n\n"
yield f"data: {json.dumps({'type': 'follow_up', 'content': follow_up})}\n\n"
session_obj.add_assistant_message(f"{clarify_msg}\n\n{follow_up}")
yield f"data: {json.dumps({'type': 'done'})}\n\n"
return
# Build context for answer/reason
context_parts = []
for doc in docs:
source = doc["metadata"].get("source_file", "unknown")
context_parts.append(f"[Source: {source}]\n{doc['text']}")
context = "\n\n---\n\n".join(context_parts) if context_parts else "No relevant documents found."
if action == "retrieve":
prompt = ANSWER_PROMPT.format(
context=context,
history=session_obj.get_history_text(),
user_message=user_message,
)
session_obj.last_action = "retrieve"
else:
prompt = REASON_PROMPT.format(
context=context,
history=session_obj.get_history_text(),
user_message=user_message,
)
session_obj.last_action = "reason"
session_obj.last_confidence = decision["confidence"]
# Stream LLM response token-by-token
llm = get_llm()
full_answer = ""
for chunk in llm.stream([
SystemMessage(content=SYSTEM_PROMPT),
HumanMessage(content=prompt),
]):
token = chunk.content
if token:
full_answer += token
yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
# Send sources
if docs:
sources = _extract_sources(docs)
yield f"data: {json.dumps({'type': 'sources', 'sources': sources})}\n\n"
# Save to session history
session_obj.add_assistant_message(full_answer.strip())
yield f"data: {json.dumps({'type': 'done'})}\n\n"