Spaces:
Sleeping
Sleeping
| import os | |
| import ast | |
| import re | |
| import json | |
| import operator as op | |
| from pathlib import Path | |
| from typing import List, TypedDict, Annotated, Optional | |
| import requests | |
| from urllib.parse import urlparse | |
| import shutil | |
| import io | |
| from typing import Dict, Any | |
| from langchain.tools import tool, StructuredTool | |
| from langchain_community.document_loaders import ( | |
| CSVLoader, PyPDFLoader, YoutubeLoader | |
| ) | |
| from langchain_community.document_loaders import AssemblyAIAudioTranscriptLoader | |
| from langchain.chat_models import init_chat_model | |
| from langchain.agents import initialize_agent, AgentType | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.schema import BaseMessage, SystemMessage, HumanMessage, AIMessage | |
| from langgraph.graph.message import add_messages | |
| from langgraph.graph import START, END, StateGraph | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langchain_core.documents import Document | |
| from youtube_transcript_api import YouTubeTranscriptApi | |
| from PIL import Image | |
| import pytesseract | |
| import fitz # PyMuPDF | |
| from dotenv import load_dotenv | |
| from contextlib import redirect_stdout | |
| from langchain_community.tools import TavilySearchResults | |
| from tavily import TavilyClient | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| # Load environment variables from .env file | |
| # in HF Spaces, the .env file is saved in Variables and secrets in settings | |
| load_dotenv() | |
| # Initialize Tavily client (after loading environment variables) | |
| tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API_KEY")) | |
| print(tavily_client) | |
| # === System Prompt === | |
| SYSTEM_PROMPT = """ | |
| You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: | |
| FINAL ANSWER: [YOUR FINAL ANSWER]. | |
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number nor use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string. | |
| """.strip() | |
| def calculate(expr: str) -> str: | |
| """Evaluate a math expression. Supports operations, numpy and math functions.""" | |
| try: | |
| import math | |
| import numpy as np | |
| # Comprehensive math namespace | |
| safe_dict = { | |
| **{k: v for k, v in math.__dict__.items() if not k.startswith('_')}, | |
| 'np': np, | |
| 'array': np.array, | |
| 'mean': np.mean, | |
| 'median': np.median, | |
| 'std': np.std, | |
| 'sum': np.sum, | |
| 'abs': abs, | |
| 'round': round, | |
| 'max': max, | |
| 'min': min | |
| } | |
| result = eval(expr, {"__builtins__": {}}, safe_dict) | |
| # Format result appropriately | |
| if isinstance(result, (np.ndarray, list)): | |
| return str(result) | |
| if isinstance(result, (int, float)): | |
| return str(float(result)) | |
| return str(result) | |
| except Exception as e: | |
| return f"Error calculating expression: {e}" | |
| def web_search(query: str) -> str: | |
| """Search the web using DuckDuckGo and SerpAPI for comprehensive results.""" | |
| try: | |
| from langchain.utilities import DuckDuckGoSearchRun | |
| from langchain_community.utilities import SerpAPIWrapper | |
| # Try multiple search engines | |
| results = [] | |
| # DuckDuckGo search | |
| ddg = DuckDuckGoSearchRun() | |
| ddg_results = ddg.run(query) | |
| results.append(ddg_results) | |
| # SerpAPI search if API key available | |
| if os.getenv("SERPAPI_API_KEY"): | |
| serpapi = SerpAPIWrapper() | |
| serp_results = serpapi.run(query) | |
| results.append(serp_results) | |
| # Combine and summarize results | |
| combined_results = "\n\n".join(results) | |
| return combined_results[:1000] # Limit length for better handling | |
| except Exception as e: | |
| return f"Error performing web search: {e}" | |
| def wikipedia_search(query: str) -> str: | |
| """Search Wikipedia for a general-topic query.""" | |
| try: | |
| from langchain.utilities import WikipediaAPIWrapper | |
| return WikipediaAPIWrapper().run(query) | |
| except Exception as e: | |
| return f"Error searching Wikipedia: {e}" | |
| def tavily_search(query: str) -> str: | |
| """Search the web using Tavily for comprehensive results.""" | |
| try: | |
| results = tavily_client.search(query) | |
| # You can format the results as needed; here we just return the summary or first result | |
| if isinstance(results, dict) and "results" in results: | |
| return results["results"][0].get("content", "No content found.") | |
| elif isinstance(results, list) and results: | |
| return results[0].get("content", "No content found.") | |
| return str(results) | |
| except Exception as e: | |
| return f"Error performing Tavily search: {e}" | |
| def advanced_search(query: str, max_results: int = 5) -> str: | |
| """Advanced web search with multiple strategies and better result parsing.""" | |
| try: | |
| # Try multiple search approaches | |
| search_results = [] | |
| # Primary search | |
| results = tavily_client.search( | |
| query, | |
| search_depth="advanced", | |
| max_results=max_results, | |
| include_answer=True, | |
| include_raw_content=True, | |
| include_domains=["arxiv.org", "usgs.gov", "nih.gov", "pubmed.ncbi.nlm.nih.gov"] | |
| ) | |
| if isinstance(results, dict): | |
| # Include direct answer if available | |
| if results.get("answer"): | |
| search_results.append(f"DIRECT ANSWER: {results['answer']}") | |
| # Process search results | |
| if results.get("results"): | |
| for i, result in enumerate(results["results"], 1): | |
| title = result.get("title", "") | |
| content = result.get("content", "") | |
| url = result.get("url", "") | |
| # Extract more content for academic sources | |
| if any(domain in url for domain in ["arxiv.org", "usgs.gov", "nih.gov"]): | |
| content = content[:1000] # More content for academic sources | |
| else: | |
| content = content[:500] | |
| search_results.append( | |
| f"RESULT {i}:\nTitle: {title}\nURL: {url}\nContent: {content}\n" | |
| ) | |
| return "\n".join(search_results) | |
| except Exception as e: | |
| return f"Search error: {e}" | |
| def arxiv_search(query: str, date_filter: str = "") -> str: | |
| """Specialized search for arXiv papers with date filtering.""" | |
| try: | |
| # Construct arXiv-specific search | |
| arxiv_query = f"site:arxiv.org {query}" | |
| if date_filter: | |
| arxiv_query += f" {date_filter}" | |
| results = tavily_client.search( | |
| arxiv_query, | |
| search_depth="advanced", | |
| max_results=8, | |
| include_raw_content=True | |
| ) | |
| if isinstance(results, dict) and results.get("results"): | |
| arxiv_results = [] | |
| for result in results["results"]: | |
| if "arxiv.org" in result.get("url", ""): | |
| title = result.get("title", "") | |
| content = result.get("content", "") | |
| url = result.get("url", "") | |
| arxiv_results.append(f"ArXiv Paper:\nTitle: {title}\nURL: {url}\nContent: {content[:800]}\n") | |
| return "\n".join(arxiv_results) if arxiv_results else "No arXiv papers found" | |
| return "No results found" | |
| except Exception as e: | |
| return f"ArXiv search error: {e}" | |
| def targeted_search(base_query: str, additional_terms: List[str]) -> str: | |
| """Perform multiple targeted searches with different term combinations.""" | |
| try: | |
| all_results = [] | |
| for terms in additional_terms: | |
| query = f"{base_query} {terms}" | |
| results = tavily_client.search(query, max_results=3) | |
| if isinstance(results, dict) and results.get("results"): | |
| all_results.append(f"=== Search: {query} ===") | |
| for result in results["results"]: | |
| all_results.append(f"Title: {result.get('title', '')}") | |
| all_results.append(f"URL: {result.get('url', '')}") | |
| all_results.append(f"Content: {result.get('content', '')[:400]}\n") | |
| return "\n".join(all_results) | |
| except Exception as e: | |
| return f"Targeted search error: {e}" | |
| def extract_zip_codes(text: str) -> str: | |
| """Extract 5-digit zip codes from text.""" | |
| try: | |
| # Look for 5-digit zip codes | |
| zip_pattern = r'\b\d{5}\b' | |
| zip_codes = re.findall(zip_pattern, text) | |
| # Remove duplicates and sort | |
| unique_zips = sorted(list(set(zip_codes))) | |
| if unique_zips: | |
| return f"Found zip codes: {', '.join(unique_zips)}" | |
| else: | |
| return "No 5-digit zip codes found in text" | |
| except Exception as e: | |
| return f"Zip code extraction error: {e}" | |
| def academic_citation_search(paper_info: str) -> str: | |
| """Search for academic papers that cite or are cited by the given paper.""" | |
| try: | |
| # Search for papers that reference the given paper | |
| citation_queries = [ | |
| f'"{paper_info}" citations references', | |
| f'{paper_info} "cited by"', | |
| f'{paper_info} bibliography references', | |
| f'site:scholar.google.com {paper_info}' | |
| ] | |
| results = [] | |
| for query in citation_queries: | |
| search_result = tavily_client.search(query, max_results=3) | |
| if isinstance(search_result, dict) and search_result.get("results"): | |
| results.extend(search_result["results"]) | |
| formatted_results = [] | |
| for result in results[:5]: # Top 5 citation results | |
| formatted_results.append( | |
| f"Citation Source: {result.get('title', '')}\n" | |
| f"URL: {result.get('url', '')}\n" | |
| f"Content: {result.get('content', '')[:500]}\n" | |
| ) | |
| return "\n".join(formatted_results) | |
| except Exception as e: | |
| return f"Citation search error: {e}" | |
| def image_recognition(image_path: str) -> str: | |
| """Analyze and extract text from an image using Tesseract OCR.""" | |
| try: | |
| img = Image.open(image_path) | |
| return pytesseract.image_to_string(img) | |
| except Exception as e: | |
| return f"Error processing image: {e}" | |
| def read_pdf(pdf_path: str) -> str: | |
| """Read and extract text from a PDF document.""" | |
| try: | |
| doc = fitz.open(pdf_path) | |
| return "".join(page.get_text() for page in doc) | |
| except Exception as e: | |
| return f"Error reading PDF: {e}" | |
| def read_csv(csv_path: str) -> str: | |
| """Read and extract text from a CSV file, row by row.""" | |
| try: | |
| loader = CSVLoader(csv_path, encoding='utf-8') | |
| docs = loader.load() | |
| return "\n".join(doc.page_content for doc in docs) | |
| except Exception as e: | |
| return f"Error reading CSV: {e}" | |
| def read_spreadsheet(spreadsheet_path: str) -> str: | |
| """Read a spreadsheet into a DataFrame and return CSV text.""" | |
| try: | |
| import pandas as pd | |
| df = pd.read_excel(spreadsheet_path) | |
| return df.to_csv(index=False) | |
| except Exception as e: | |
| return f"Error reading spreadsheet: {e}" | |
| def youtube_transcript_tool(video_url: str) -> str: | |
| """Download the transcript of a YouTube video using LangChain YoutubeLoader.""" | |
| try: | |
| loader = YoutubeLoader.from_youtube_url(video_url) | |
| docs = loader.load() | |
| return "\n".join(doc.page_content for doc in docs) | |
| except Exception as e: | |
| return f"Error fetching YouTube transcript: {e}" | |
| def youtube_transcript_api(video_url_or_id: str) -> str: | |
| """Download transcript from YouTube using youtube-transcript-api.""" | |
| try: | |
| match = re.search(r"(?:v=|youtu\.be/)([A-Za-z0-9_-]{11})", video_url_or_id) | |
| vid = match.group(1) if match else video_url_or_id | |
| entries = YouTubeTranscriptApi.get_transcript(vid) | |
| return " ".join(segment["text"] for segment in entries) | |
| except Exception as e: | |
| return f"Error fetching transcript via API: {e}" | |
| def transcribe_audio(audio_path: str) -> str: | |
| """Transcribe audio file (e.g., MP3) using AssemblyAI.""" | |
| try: | |
| loader = AssemblyAIAudioTranscriptLoader(file_path=audio_path) | |
| docs = loader.load() | |
| return "\n".join(doc.page_content for doc in docs) | |
| except Exception as e: | |
| return f"Error transcribing audio: {e}" | |
| def read_jsonl(jsonl_path: str) -> str: | |
| """Read and extract data from a JSONL (JSON Lines) file.""" | |
| try: | |
| data = [] | |
| with open(jsonl_path, 'r', encoding='utf-8') as file: | |
| for line in file: | |
| data.append(json.loads(line)) | |
| return json.dumps(data, indent=2) | |
| except Exception as e: | |
| return f"Error reading JSONL file: {e}" | |
| def python_interpreter(code: str) -> str: | |
| """Execute Python code and return the output. Supports data analysis, plotting, and general Python operations.""" | |
| try: | |
| # Set up a safe globals environment | |
| safe_globals = { | |
| 'pd': __import__('pandas'), | |
| 'np': __import__('numpy'), | |
| 'plt': __import__('matplotlib.pyplot'), | |
| 'json': __import__('json'), | |
| 're': __import__('re'), | |
| 'math': __import__('math'), | |
| } | |
| # Capture output | |
| buffer = io.StringIO() | |
| with redirect_stdout(buffer): | |
| # Execute the code in a safe environment | |
| exec(code, safe_globals) | |
| return buffer.getvalue() or "Code executed successfully (no output)" | |
| except Exception as e: | |
| return f"Error executing Python code: {e}" | |
| def download_file(url_or_path: str, save_dir: str = "./downloads") -> str: | |
| """Download a file from URL or copy from local path to the downloads directory.""" | |
| try: | |
| # Create downloads directory if it doesn't exist | |
| save_dir = Path(save_dir) | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| # Check if input is URL or local path | |
| if url_or_path.startswith(('http://', 'https://')): | |
| # Handle URL download | |
| response = requests.get(url_or_path, stream=True) | |
| response.raise_for_status() | |
| # Get filename from URL or Content-Disposition header | |
| filename = response.headers.get('Content-Disposition') | |
| if filename and 'filename=' in filename: | |
| filename = filename.split('filename=')[1].strip('"') | |
| else: | |
| filename = Path(urlparse(url_or_path).path).name | |
| save_path = save_dir / filename | |
| # Download file | |
| with open(save_path, 'wb') as f: | |
| shutil.copyfileobj(response.raw, f) | |
| else: | |
| # Handle local file copy | |
| src_path = Path(url_or_path) | |
| if not src_path.exists(): | |
| return f"Error: Source file {url_or_path} not found" | |
| save_path = save_dir / src_path.name | |
| shutil.copy2(src_path, save_path) | |
| return f"File successfully saved to {save_path}" | |
| except Exception as e: | |
| return f"Error downloading/copying file: {e}" | |
| def extract_table(file_path: str, query: str = "") -> str: | |
| """Extract relevant rows from a CSV or Excel file based on a query.""" | |
| import pandas as pd | |
| ext = Path(file_path).suffix.lower() | |
| if ext in [".csv"]: | |
| df = pd.read_csv(file_path) | |
| elif ext in [".xlsx", ".xls"]: | |
| df = pd.read_excel(file_path) | |
| text_content = df.to_string() | |
| loaded_docs = [Document(page_content=text_content)] | |
| else: | |
| return "Unsupported file type." | |
| # Simple filter: return all if no query, else filter columns containing query | |
| if query: | |
| mask = df.apply(lambda row: row.astype(str).str.contains(query, case=False).any(), axis=1) | |
| df = df[mask] | |
| return df.head(10).to_csv(index=False) | |
| def summarize(text: str, llm=None) -> str: | |
| """Summarize a long text chunk.""" | |
| if llm is None: | |
| return "No LLM provided for summarization." | |
| return llm.invoke([ | |
| SystemMessage(content="Summarize the following:"), | |
| HumanMessage(content=text) | |
| ]).content | |
| # Update tools list | |
| tools: List[StructuredTool] = [ | |
| calculate, tavily_search, advanced_search, arxiv_search, targeted_search, | |
| academic_citation_search, extract_zip_codes, wikipedia_search, image_recognition, | |
| read_pdf, read_csv, read_spreadsheet, transcribe_audio, | |
| youtube_transcript_tool, youtube_transcript_api, read_jsonl, | |
| python_interpreter, download_file, extract_table, | |
| # Wrap summarize to inject self.llm at runtime | |
| ] | |
| class AgentState(TypedDict): | |
| # The document provided | |
| input_file: Optional[List[str]] # Contains file path (PDF/PNG) | |
| messages: Annotated[List[BaseMessage], add_messages] | |
| # === Agent Class === | |
| class MyAgent: | |
| def __init__( | |
| self, | |
| model_name: str = "anthropic:claude-3-5-sonnet-latest", # <-- Use a valid model name | |
| temperature: float = 0.0 | |
| ): | |
| try: | |
| self.llm = init_chat_model( | |
| model_name, | |
| temperature=temperature | |
| ) | |
| # Base tools | |
| self.tools = tools + [ | |
| StructuredTool.from_function(lambda text: summarize(text, llm=self.llm), name="summarize", description="Summarize a long text chunk.") | |
| ] | |
| # RAG components | |
| self.docs: List[Any] = [] | |
| self.retriever: Optional[BM25Retriever] = None | |
| except Exception as e: | |
| print(f"Error initializing LLM: {e}") | |
| raise | |
| def add_files(self, file_paths: List[str]): | |
| """ | |
| Load and index documents for RAG based on file extensions or URLs. | |
| Supports: PDF, CSV, Excel, JSONL, images, audio (mp3/wav), and YouTube URLs. | |
| """ | |
| for path in file_paths: | |
| ext = Path(path).suffix.lower() | |
| loaded_docs = [] | |
| try: | |
| if ext == ".csv": | |
| loader = CSVLoader(path) | |
| loaded_docs = loader.load() | |
| elif ext == ".pdf": | |
| loader = PyPDFLoader(path) | |
| loaded_docs = loader.load() | |
| elif ext in [".xlsx", ".xls"]: | |
| import pandas as pd | |
| df = pd.read_excel(path) | |
| text_content = df.to_string() | |
| loaded_docs = [Document(page_content=text_content)] | |
| elif ext == ".jsonl": | |
| with open(path, 'r', encoding='utf-8') as file: | |
| content = [json.loads(line) for line in file] | |
| text_content = json.dumps(content, indent=2) | |
| loaded_docs = [Document(page_content=text_content)] | |
| elif ext in [".png", ".jpg", ".jpeg"]: | |
| text = pytesseract.image_to_string(Image.open(path)) | |
| if text.strip(): | |
| loaded_docs = [Document(page_content=text)] | |
| elif ext in [".mp3", ".wav"]: | |
| loader = AssemblyAIAudioTranscriptLoader(file_path=path) | |
| loaded_docs = loader.load() | |
| elif "youtube" in path: | |
| loader = YoutubeLoader.from_youtube_url(path) | |
| loaded_docs = loader.load() | |
| else: | |
| print(f"Unsupported file type: {ext}") | |
| continue | |
| except Exception as e: | |
| print(f"Error loading {path}: {e}") | |
| continue | |
| # Chunk every loaded doc | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100) | |
| for doc in loaded_docs: | |
| chunks = text_splitter.split_text(doc.page_content) | |
| for i, chunk in enumerate(chunks): | |
| self.docs.append(Document( | |
| page_content=chunk, | |
| metadata={**getattr(doc, 'metadata', {}), "chunk": i, "source": path} | |
| )) | |
| def build_retriever(self): | |
| """ | |
| Create BM25Retriever over the loaded documents and register rag_search tool. | |
| """ | |
| if not self.docs: | |
| return | |
| # Build retriever | |
| try: | |
| self.retriever = BM25Retriever.from_documents(self.docs) | |
| # Define tool with proper error handling | |
| def rag_search(query: str) -> str: | |
| """Search loaded documents for relevant information.""" | |
| if not self.retriever: | |
| return "No documents loaded." | |
| docs = self.retriever.get_relevant_documents(query) | |
| if not docs: | |
| return "No relevant information found." | |
| return "\n\n".join(f"{doc.metadata.get('source', '')}: {doc.page_content[:500]}" for doc in docs[:3]) | |
| # Remove existing rag_search if present to prevent duplicates | |
| self.tools = [t for t in self.tools if t.name != "rag_search"] | |
| self.tools.append(rag_search) | |
| except Exception as e: | |
| print(f"Error building retriever: {e}") | |
| def __call__( | |
| self, | |
| question: str, | |
| file_paths: Optional[List[str]] = None | |
| ) -> str: | |
| try: | |
| state: Dict[str, Any] = {"messages": [], "input_file": None, "rag_used": False} | |
| tool_desc = "\n".join(f"{t.name}: {t.description}" for t in self.tools) | |
| rag_prompt = """ | |
| If the question seems to be about any loaded documents, ALWAYS: | |
| 1. Use the rag_search tool first to find relevant information | |
| 2. Base your answer on the retrieved content | |
| 3. If no relevant content is found, say so | |
| """ | |
| sys_msg = SystemMessage(content=f"{SYSTEM_PROMPT}\n\n{rag_prompt if file_paths else ''}\n\nTools:\n{tool_desc}") | |
| state["messages"] = [sys_msg] | |
| if file_paths and all(isinstance(p, str) for p in file_paths): | |
| try: | |
| self.add_files(file_paths) | |
| self.build_retriever() | |
| except Exception as file_err: | |
| print(f"Warning: Error loading files: {file_err}") | |
| state["messages"].append(HumanMessage(content=question)) | |
| if file_paths: | |
| state["input_file"] = file_paths | |
| builder = StateGraph(dict) | |
| builder.add_node("assistant", self._assistant_node) | |
| # Add the tools node BEFORE adding edges | |
| def tool_node_with_rag_flag(state): | |
| state = ToolNode(self.tools).invoke(state) | |
| if state.get("input_file") and not state.get("rag_used", False): | |
| state["rag_used"] = True | |
| return state | |
| builder.add_node("tools", tool_node_with_rag_flag) | |
| builder.add_edge(START, "assistant") | |
| # Graph flow: force rag_search if files loaded and not yet used, then use tools_condition | |
| def route(state): | |
| last_msg = state["messages"][-1] if state.get("messages") else None | |
| # Check if this is a math question that doesn't need RAG | |
| is_math_question = re.search(r'(calculate|compute|what is|solve|find the value|evaluate)', | |
| state["messages"][-2].content.lower()) if len(state["messages"]) > 1 else False | |
| # Only force RAG if we have files AND it's not a pure math question AND RAG hasn't been used | |
| if (state.get("input_file") and not state.get("rag_used", False) and not is_math_question): | |
| return "tools" | |
| # Regular tool routing logic | |
| if last_msg and isinstance(last_msg, AIMessage): | |
| if getattr(last_msg, "tool_calls", None): | |
| return "tools" | |
| if getattr(last_msg, "additional_kwargs", {}).get("tool_calls"): | |
| return "tools" | |
| return END | |
| builder.add_conditional_edges("assistant", route, {"tools": "tools", END: END}) | |
| builder.add_edge("tools", "assistant") | |
| # Instead of builder.update_node, define a custom tool node with rag flag logic | |
| graph = builder.compile() | |
| out = graph.invoke(state, {"recursion_limit": 10}) | |
| last_message = out["messages"][-1].content if out.get("messages") else "" | |
| match = re.search(r"FINAL ANSWER[:\s]*([^\n]*)", last_message, re.IGNORECASE) | |
| if match: | |
| return match.group(1).strip() | |
| return last_message.strip() | |
| except Exception as e: | |
| return f"Error processing question: {e}" | |
| def run(self, question: str, file_paths: Optional[List[str]] = None) -> str: | |
| return self(question, file_paths) | |
| def _assistant_node(self, state: dict) -> dict: | |
| """Process messages with the LLM.""" | |
| try: | |
| # Check if messages exist and ensure proper format | |
| if not state.get("messages") or len(state["messages"]) == 0: | |
| # Add a system message if empty | |
| state["messages"] = [SystemMessage(content=SYSTEM_PROMPT)] | |
| # Ensure we have at least a system and user message | |
| has_system = any(isinstance(m, SystemMessage) for m in state["messages"]) | |
| has_human = any(isinstance(m, HumanMessage) for m in state["messages"]) | |
| if not has_system: | |
| state["messages"].insert(0, SystemMessage(content=SYSTEM_PROMPT)) | |
| if not has_human: | |
| state["messages"].append(HumanMessage(content="Hello")) | |
| # Invoke the chat model with our BaseMessage list | |
| resp = self.llm.invoke(state["messages"]) | |
| state["messages"].append(resp) | |
| return state | |
| except Exception as e: | |
| error_msg = f"Error calling LLM: {str(e)}" | |
| print(error_msg) | |
| print(f"Message count: {len(state.get('messages', []))}") | |
| if state.get("messages"): | |
| print(f"Message types: {[type(m).__name__ for m in state['messages']]}") | |
| return state | |