mabelwang21 commited on
Commit
348c1c6
·
1 Parent(s): b1fdd32

improve rag_search

Browse files
Files changed (1) hide show
  1. agent.py +50 -19
agent.py CHANGED
@@ -21,7 +21,7 @@ from langchain.agents import initialize_agent, AgentType
21
  from langchain_community.retrievers import BM25Retriever
22
  from langchain.schema import BaseMessage, SystemMessage, HumanMessage
23
  from langgraph.graph.message import add_messages
24
- from langgraph.graph import START, StateGraph
25
  from langgraph.prebuilt import ToolNode, tools_condition
26
  from langchain_core.documents import Document
27
 
@@ -33,6 +33,7 @@ from dotenv import load_dotenv
33
  from contextlib import redirect_stdout
34
  from langchain_community.tools import TavilySearchResults
35
  from tavily import TavilyClient
 
36
 
37
  # Load environment variables from .env file
38
  # in HF Spaces, the .env file is saved in Variables and secrets in settings
@@ -274,12 +275,29 @@ def download_file(url_or_path: str, save_dir: str = "./downloads") -> str:
274
  return f"Error downloading/copying file: {e}"
275
 
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  # Update tools list
278
  tools: List[StructuredTool] = [
279
- calculate, web_search, wikipedia_search, image_recognition,
280
  read_pdf, read_csv, read_spreadsheet, transcribe_audio,
281
  youtube_transcript_tool, youtube_transcript_api, read_jsonl,
282
- python_interpreter, download_file, tavily_search # Add tavily_search here
283
  ]
284
 
285
  class AgentState(TypedDict):
@@ -352,6 +370,15 @@ class MyAgent:
352
  except Exception as e:
353
  print(f"Error loading {path}: {e}")
354
  continue
 
 
 
 
 
 
 
 
 
355
 
356
  def build_retriever(self):
357
  """
@@ -368,18 +395,13 @@ class MyAgent:
368
  @tool(name="rag_search")
369
  def rag_search(query: str) -> str:
370
  """Search loaded documents for relevant information."""
371
- try:
372
- if not self.retriever:
373
- return "No documents have been loaded for search."
374
-
375
- res = self.retriever.get_relevant_documents(query)
376
- if res:
377
- return "\n\n".join(f"Document {i+1}:\n{doc.page_content}"
378
- for i, doc in enumerate(res[:3]))
379
- return "No relevant information found in loaded documents."
380
- except Exception as e:
381
- return f"Error searching documents: {e}"
382
-
383
  # Remove existing rag_search if present to prevent duplicates
384
  self.tools = [t for t in self.tools if t.name != "rag_search"]
385
  self.tools.append(rag_search)
@@ -426,10 +448,17 @@ class MyAgent:
426
  builder.add_edge(START, "assistant")
427
 
428
  # Fix conditional edges with better check
 
 
 
 
 
 
 
429
  builder.add_conditional_edges(
430
  "assistant",
431
- tools_condition, # Use built-in tools_condition
432
- "tools"
433
  )
434
  builder.add_edge("tools", "assistant")
435
 
@@ -441,8 +470,10 @@ class MyAgent:
441
  last_message = out["messages"][-1].content
442
 
443
  # Extract only the FINAL ANSWER part
444
- if "FINAL ANSWER:" in last_message:
445
- return last_message.split("FINAL ANSWER:")[-1].strip()
 
 
446
  return last_message.strip()
447
  except Exception as e:
448
  return f"Error processing question: {e}"
 
21
  from langchain_community.retrievers import BM25Retriever
22
  from langchain.schema import BaseMessage, SystemMessage, HumanMessage
23
  from langgraph.graph.message import add_messages
24
+ from langgraph.graph import START, END, StateGraph
25
  from langgraph.prebuilt import ToolNode, tools_condition
26
  from langchain_core.documents import Document
27
 
 
33
  from contextlib import redirect_stdout
34
  from langchain_community.tools import TavilySearchResults
35
  from tavily import TavilyClient
36
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
37
 
38
  # Load environment variables from .env file
39
  # in HF Spaces, the .env file is saved in Variables and secrets in settings
 
275
  return f"Error downloading/copying file: {e}"
276
 
277
 
278
+ @tool
279
+ def extract_table(file_path: str, query: str = "") -> str:
280
+ """Extract relevant rows from a CSV or Excel file based on a query."""
281
+ import pandas as pd
282
+ ext = Path(file_path).suffix.lower()
283
+ if ext in [".csv"]:
284
+ df = pd.read_csv(file_path)
285
+ elif ext in [".xlsx", ".xls"]:
286
+ df = pd.read_excel(file_path)
287
+ else:
288
+ return "Unsupported file type."
289
+ # Simple filter: return all if no query, else filter columns containing query
290
+ if query:
291
+ mask = df.apply(lambda row: row.astype(str).str.contains(query, case=False).any(), axis=1)
292
+ df = df[mask]
293
+ return df.head(10).to_csv(index=False)
294
+
295
  # Update tools list
296
  tools: List[StructuredTool] = [
297
+ calculate, tavily_search, wikipedia_search, image_recognition,
298
  read_pdf, read_csv, read_spreadsheet, transcribe_audio,
299
  youtube_transcript_tool, youtube_transcript_api, read_jsonl,
300
+ python_interpreter, download_file, extract_table # Add tavily_search here
301
  ]
302
 
303
  class AgentState(TypedDict):
 
370
  except Exception as e:
371
  print(f"Error loading {path}: {e}")
372
  continue
373
+ # After loading each doc:
374
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
375
+ for doc in loaded_docs:
376
+ chunks = text_splitter.split_text(doc.page_content)
377
+ for i, chunk in enumerate(chunks):
378
+ self.docs.append(Document(
379
+ page_content=chunk,
380
+ metadata={**doc.metadata, "chunk": i, "source": path}
381
+ ))
382
 
383
  def build_retriever(self):
384
  """
 
395
  @tool(name="rag_search")
396
  def rag_search(query: str) -> str:
397
  """Search loaded documents for relevant information."""
398
+ if not self.retriever:
399
+ return "No documents loaded."
400
+ docs = self.retriever.get_relevant_documents(query)
401
+ if not docs:
402
+ return "No relevant information found."
403
+ return "\n\n".join(f"{doc.metadata.get('source', '')}: {doc.page_content[:500]}" for doc in docs[:3])
404
+
 
 
 
 
 
405
  # Remove existing rag_search if present to prevent duplicates
406
  self.tools = [t for t in self.tools if t.name != "rag_search"]
407
  self.tools.append(rag_search)
 
448
  builder.add_edge(START, "assistant")
449
 
450
  # Fix conditional edges with better check
451
+ def _should_use_tools(state):
452
+ # If there are loaded docs, always use rag_search first
453
+ if state.get("input_file"):
454
+ return "tools"
455
+ # Otherwise, let the assistant try to answer
456
+ return "assistant"
457
+
458
  builder.add_conditional_edges(
459
  "assistant",
460
+ _should_use_tools,
461
+ {"tools": "tools", "assistant": END}
462
  )
463
  builder.add_edge("tools", "assistant")
464
 
 
470
  last_message = out["messages"][-1].content
471
 
472
  # Extract only the FINAL ANSWER part
473
+ import re
474
+ match = re.search(r"FINAL ANSWER[:\s]*([^\n]*)", last_message, re.IGNORECASE)
475
+ if match:
476
+ return match.group(1).strip()
477
  return last_message.strip()
478
  except Exception as e:
479
  return f"Error processing question: {e}"